Google Summer of Code 2021: Integration of Jax into Deepchem

Hey everyone!

I’m Vignesh Venkataraman, an incoming final year student at IIT Roorkee, India. In the past, I have worked on building generative models for molecular strings (SMILES) using NLP and graphical machine learning with molecules.

This summer, I will be working with deepchem as a gsoc student. Some of my best learning experiences have come from contributing to deepchem in the past and I’m hoping to learn more from the program.

I will be working on integrating Jax support into the deepchem infrastructure. JAX is comparatively a new library and has received a warm welcome from the Machine Learning researchers because of core Functional approach and state management. Looking forward to learn more about Jax …


Week-1 & 2 updates

I have been working on building the core functionalities of JaxModel which involves training Neural Networks using Jax framework. The code for calling the backpropagation algorithm is done in the fit method and I have also added few other extra support functions which interact with other infrastructure of deepchem like Dataset, Model, and Losses classes. With the advent of Jax, many optimized neural network libraries have emerged like Haiku (from Deepmind), Flax (from Google Brain), etc. But we want our JaxModel to be library agnostic and we do this by adding _create_gradient_fn (inspired from Keras Model). This feature will allow user to choose the correct gradient transformation depending on his library of use or even pure jax functions. I personally feel that this will play very well for training on multiple GPU’s, TPU, etc due to jax features like jmap & pmap. Apart from these, there are a few other functions like get_trainable_params which give more control in splitting & freezing models.

The changes can be found in this PR - deepchem/deepchem#2549


Week 3-5 Updates

One of Deepchem recent challenges was to set up a new environment that could build all its dependency smoothly for users and get the CI running. Deepchem has model utilities in a bunch of NN frameworks like Tensorflow, Pytorch, Jax and things became harder to manage due to Tensorflow’s compatibility issues with Numpy>=1.20. With the addition of Jax & related packages, the dependency matrix was getting harder to manage and hence we decided to move to an approach inspired from other OSS libraries likeOpenGym (Later even followed by HuggingFace as well) where we have made 4 separate packages within deepchem

  1. deepchem
  2. deepchem[tensorflow]
  3. deepchem[torch]
  4. deepchem[jax]

In order to smoothly make these changes, I had to make changes setuptools installation, CI workflows for different environments, and manageable Test suites for running Tests independently of other dependencies.

PR for Jax Environment - deepchem/deepchem#2560
PR for Torch Environment - deepchem/deepchem#2563
PR for Tensorflow & Common Environment - deepchem/deepchem#2573

Miscellanous PR for making independent environment - deepchem/deepchem#2567

1 Like

Week 6 Updates

Worked on building the evaluation stage of the JaxModel API and along with this included many test suites covering checks in metrics, uncertainty, predict_batch, etc. Along with this, these test suites act as a good example showcasing the usage of the JaxModel.

Faced few challenges with Jax’s style of passing static_argnums when compiling the code with JIT as it was causing a few design challenges.
The changes can be found in this PR - deepchem/deepchem#2604

1 Like

Week 7-8 Updates

Along with completing the evaluation stage of JaxModel API, I went ahead to implement a small PINNs network using the existing JaxModel API. Implemented Burgers Equation and Schrodinger Equation and brought it down to a decent accuracy.

Challenges Faced -
I’m currently using optax library to perform all NN optimization using Adam. The author’s implementation in Tensorflow uses LBFGS here which is relatively a more powerful optimizer.

Burgers Equation -

Schrodinger Equation

Exact & Predictions plots are not coinciding due to the relatively weaker performance of Adam Optimiser.
Dataset and Codes for plotting the graph have been taken from Authors repo

1 Like

I don’t think the optimizer would change the result. Adam will converge more slowly since it’s first order instead of second order, but it should still reach the same result. Maybe you need to add some learning rate decay?

Yes, i agree it will converge but will just be slower, so preferably i should set the correct learning rate decay. I was building this model for the sake of tutorial and the convergence was taking some time hence I decided to move ahead with it. The current results using Adam were pretty decent as there is very little deviation from exact & prediction plots.

I will continue to work on improving the accuracy for next week.

I was able to solve the problem using schedulers. Initially, I had tried exponential and cosine schedulers but wasn’t really helpful. Later tried the piecewise_constant_schedule which can be thought of like step functions where we can control the boundaries (had to manually tune the boundary values). I was able to reproduce accurate curves in very little time.

1 Like

Week 9-10 Updates

Having finished building multiple PINN models (burgers and Schrodinger’s eqn), we had identified a lot of similar boilerplate code in both models. The only things which were changing were -

  1. The number of arguments of the differential equation (For example - F(x,y,z,t) has 4 arguments and F(x,t) has 2 arguments)
  2. The function definition of the Differential equation which defined the final loss used as a regulariser in PINNs.

In order to accommodate for the above features, I had included a PINNModel class for solving differential equations using PINNs with a test case solving a simple differential equation (f'(x) = - sin(x)). Along with this, I had also worked on the addition of a tutorial notebook explaining the usage of PINNModel with an example of Burger’s equation.

PR for PINNModel:
Link for the Tutorial Notebook (Soon to start a PR):

1 Like