Summary of GSOC-21: Jax Integeration into Deepchem

Over the summer, I got an opportunity to contribute to Deepchem as a student developer in Google Summer of Code. During this phase, I had worked on the integration of Jax into Deepchem by building a JaxModel API that interacts with all the existing deepchem infrastructure for developing neural networks.

Main Contributions-

  1. Addition of the JaxModel API for build Neural Networks
  2. Revamping the Github Workflows CI in order to handle dependencies of different Backend (Tensorflow, Pytorch, Jax)
  3. Addition of PINNModel and a tutorial notebook explaining the usage

JAX is comparatively a new library and has received a warm welcome from Machine Learning researchers because of its core Functional approach and state management.

My weekly updates can be here

JaxModel API

The main objective of the JaxModel API was to allow users to develop models using Deepchem's core infrastructure such as Molnet, Featurisers, Splitters, Transformers and other deep learning infrastructures, etc. Due to the functional nature of Jax framework, it was difficult to build a usual NN Model similar to Pytorch or Tensorflow. In this API, A neural network is defined by two things -

  1. Weights - Which act as the parameters or matrices for neural network training.
  2. Forward_fn - Which tell us how to process the weight matrices

Here is a simple usage of the JaxModel

  def forward_model(x):
    net = hk.nets.MLP([512, 256, 128, 2])
    return net(x)

  def rms_loss(pred, tar, w):
    return jnp.mean(optax.l2_loss(pred, tar))

  # Model Initialization
  params_init, forward_fn = hk.transform(forward_model)
  rng = jax.random.PRNGKey(500)
  inputs, _, _, _ = next(iter(dataset.iterbatches(batch_size=256)))
  modified_inputs = jnp.array(
      [x.astype(np.float32) if x.dtype == np.float64 else x for x in inputs])
  params = params_init(rng, modified_inputs)

  # Loss Function
  criterion = rms_loss

  # JaxModel Working
  j_m = JaxModel(
      forward_fn,
      params,
      criterion,
      batch_size=256,
      learning_rate=0.001,
      log_frequency=2)
  _ = j_m.fit(dataset, nb_epochs=25, deterministic=True)
  scores = j_m.evaluate(dataset, [metric])
  assert scores[metric.name] < 0.5

PR for the building the training stage of the model - deepchem/deepchem#2549
PR for the building the evaluation stage of the model - deepchem/deepchem#2604

CI/CD Updates and Dependency Management

One of Deepchem recent challenges was to set up a new environment that could build all its dependency smoothly for users and get the CI running. Deepchem has model utilities in a bunch of NN frameworks like Tensorflow, Pytorch, Jax and things became harder to manage due to Tensorflow’s compatibility issues with Numpy>=1.20. With the addition of Jax & related packages, the dependency matrix was getting harder to manage and hence we decided to move to an approach inspired from other OSS libraries like OpenGym (Later even followed by HuggingFace as well) where we have made 4 separate packages within deepchem

  1. deepchem
  2. deepchem[tensorflow]
  3. deepchem[torch]
  4. deepchem[jax]

In order to smoothly make these changes, I had to make changes setuptools installation, CI workflows for different environments, and manageable Test suites for running Tests independently of other dependencies.

PR for Jax Environment - deepchem/deepchem#2560
PR for Torch Environment - deepchem/deepchem#2563
PR for Tensorflow & Common Environment - deepchem/deepchem#2573

Miscellanous PR for making independent environment - deepchem/deepchem#2567

PINNs Model for Solving Differential Equations

PINNs are used for solving supervised learning tasks and also follow an underlying differential equation derived from understanding the underlying Physics. In more simple terms, we try solving a differential equation with a neural network and using the differential equation as the regulariser in the loss function. Using the JaxModel API, I had included a PINNModel API for solving data-driven Differential Equations.

Here are some important differential equations we had solved using PINNModel -

  1. Burger’s Equation -
    \begin{array}{l} u_t + u u_x - (0.01/\pi) u_{xx} = 0,\ \ \ x \in [-1,1],\ \ \ t \in [0,1],\\ u(0,x) = -\sin(\pi x),\\ u(t,-1) = u(t,1) = 0. \end{array}

Results Obtained using PINNModel-
image

  1. Schrodinger’s Equation

\begin{array}{l} i h_t + 0.5 h_{xx} + |h|^2 h = 0,\ \ \ x \in [-5, 5],\ \ \ t \in [0, \pi/2],\\ h(0,x) = 2\ \text{sech}(x),\\ h(t,-5) = h(t, 5),\\ h_x(t,-5) = h_x(t, 5), \end{array}

Results Obtained using PINNModel-
image

Dataset and Codes for plotting the graph have been taken from Authors repo

Future Work

  1. Would like to add more documentation about the usage of the PINNModel.
  2. The need to include Custom based Deepchem’s Losses, Optimizers and Schedulers using Jax
  3. Connect JaxModel with Wandb for further improved visualisations.

Acknowledgement

Deepchem was the first Open Source Organisation I have actively contributed to and Im really grateful to the community for continuous support and knowledge. I want to thank @bharath, @ncfrey, @peastman for their valuable mentorship and feedback throughout the program.

2 Likes