Google Summer of Code 2021: Integration of Jax into Deepchem

Hey everyone!

I’m Vignesh Venkataraman, an incoming final year student at IIT Roorkee, India. In the past, I have worked on building generative models for molecular strings (SMILES) using NLP and graphical machine learning with molecules.

This summer, I will be working with deepchem as a gsoc student. Some of my best learning experiences have come from contributing to deepchem in the past and I’m hoping to learn more from the program.

I will be working on integrating Jax support into the deepchem infrastructure. JAX is comparatively a new library and has received a warm welcome from the Machine Learning researchers because of core Functional approach and state management. Looking forward to learn more about Jax …


Week-1 & 2 updates

I have been working on building the core functionalities of JaxModel which involves training Neural Networks using Jax framework. The code for calling the backpropagation algorithm is done in the fit method and I have also added few other extra support functions which interact with other infrastructure of deepchem like Dataset, Model, and Losses classes. With the advent of Jax, many optimized neural network libraries have emerged like Haiku (from Deepmind), Flax (from Google Brain), etc. But we want our JaxModel to be library agnostic and we do this by adding _create_gradient_fn (inspired from Keras Model). This feature will allow user to choose the correct gradient transformation depending on his library of use or even pure jax functions. I personally feel that this will play very well for training on multiple GPU’s, TPU, etc due to jax features like jmap & pmap. Apart from these, there are a few other functions like get_trainable_params which give more control in splitting & freezing models.

The changes can be found in this PR - deepchem/deepchem#2549


Week 3-5 Updates

One of Deepchem recent challenges was to set up a new environment that could build all its dependency smoothly for users and get the CI running. Deepchem has model utilities in a bunch of NN frameworks like Tensorflow, Pytorch, Jax and things became harder to manage due to Tensorflow’s compatibility issues with Numpy>=1.20. With the addition of Jax & related packages, the dependency matrix was getting harder to manage and hence we decided to move to an approach inspired from other OSS libraries likeOpenGym (Later even followed by HuggingFace as well) where we have made 4 separate packages within deepchem

  1. deepchem
  2. deepchem[tensorflow]
  3. deepchem[torch]
  4. deepchem[jax]

In order to smoothly make these changes, I had to make changes setuptools installation, CI workflows for different environments, and manageable Test suites for running Tests independently of other dependencies.

PR for Jax Environment - deepchem/deepchem#2560
PR for Torch Environment - deepchem/deepchem#2563
PR for Tensorflow & Common Environment - deepchem/deepchem#2573

Miscellanous PR for making independent environment - deepchem/deepchem#2567

1 Like

Week 6 Updates

Worked on building the evaluation stage of the JaxModel API and along with this included many test suites covering checks in metrics, uncertainty, predict_batch, etc. Along with this, these test suites act as a good example showcasing the usage of the JaxModel.

Faced few challenges with Jax’s style of passing static_argnums when compiling the code with JIT as it was causing a few design challenges.
The changes can be found in this PR - deepchem/deepchem#2604

1 Like