Over the summer, I got an opportunity to contribute to Deepchem as a student developer in Google Summer of Code. During this phase, I had worked on the integration of Jax
into Deepchem
by building a JaxModel
API that interacts with all the existing deepchem infrastructure for developing neural networks.
Main Contributions-
- Addition of the
JaxModel
API for build Neural Networks - Revamping the Github Workflows CI in order to handle dependencies of different Backend (Tensorflow, Pytorch, Jax)
- Addition of
PINNModel
and a tutorial notebook explaining the usage
JAX is comparatively a new library and has received a warm welcome from Machine Learning researchers because of its core Functional approach and state management.
My weekly updates can be here
JaxModel API
The main objective of the JaxModel API was to allow users to develop models using Deepchem
's core infrastructure such as Molnet
, Featurisers
, Splitters
, Transformers
and other deep learning infrastructures, etc. Due to the functional nature of Jax
framework, it was difficult to build a usual NN Model similar to Pytorch or Tensorflow. In this API, A neural network is defined by two things -
- Weights - Which act as the parameters or matrices for neural network training.
- Forward_fn - Which tell us how to process the weight matrices
Here is a simple usage of the JaxModel
def forward_model(x):
net = hk.nets.MLP([512, 256, 128, 2])
return net(x)
def rms_loss(pred, tar, w):
return jnp.mean(optax.l2_loss(pred, tar))
# Model Initialization
params_init, forward_fn = hk.transform(forward_model)
rng = jax.random.PRNGKey(500)
inputs, _, _, _ = next(iter(dataset.iterbatches(batch_size=256)))
modified_inputs = jnp.array(
[x.astype(np.float32) if x.dtype == np.float64 else x for x in inputs])
params = params_init(rng, modified_inputs)
# Loss Function
criterion = rms_loss
# JaxModel Working
j_m = JaxModel(
forward_fn,
params,
criterion,
batch_size=256,
learning_rate=0.001,
log_frequency=2)
_ = j_m.fit(dataset, nb_epochs=25, deterministic=True)
scores = j_m.evaluate(dataset, [metric])
assert scores[metric.name] < 0.5
PR for the building the training stage of the model - deepchem/deepchem#2549
PR for the building the evaluation stage of the model - deepchem/deepchem#2604
CI/CD Updates and Dependency Management
One of Deepchem
recent challenges was to set up a new environment that could build all its dependency smoothly for users and get the CI running. Deepchem
has model utilities in a bunch of NN frameworks like Tensorflow
, Pytorch
, Jax
and things became harder to manage due to Tensorflow’s compatibility issues with Numpy>=1.20
. With the addition of Jax
& related packages, the dependency matrix was getting harder to manage and hence we decided to move to an approach inspired from other OSS libraries like OpenGym
(Later even followed by HuggingFace
as well) where we have made 4 separate packages within deepchem
deepchem
deepchem[tensorflow]
deepchem[torch]
deepchem[jax]
In order to smoothly make these changes, I had to make changes setuptools installation, CI workflows for different environments, and manageable Test suites for running Tests independently of other dependencies.
PR for Jax Environment - deepchem/deepchem#2560
PR for Torch Environment - deepchem/deepchem#2563
PR for Tensorflow & Common Environment - deepchem/deepchem#2573
Miscellanous PR for making independent environment - deepchem/deepchem#2567
PINNs Model for Solving Differential Equations
PINNs are used for solving supervised learning tasks and also follow an underlying differential equation derived from understanding the underlying Physics. In more simple terms, we try solving a differential equation with a neural network and using the differential equation as the regulariser in the loss function. Using the JaxModel API, I had included a PINNModel API for solving data-driven Differential Equations.
Here are some important differential equations we had solved using PINNModel
-
-
Burger’s Equation -
\begin{array}{l} u_t + u u_x - (0.01/\pi) u_{xx} = 0,\ \ \ x \in [-1,1],\ \ \ t \in [0,1],\\ u(0,x) = -\sin(\pi x),\\ u(t,-1) = u(t,1) = 0. \end{array}
Results Obtained using PINNModel-
- Schrodinger’s Equation
\begin{array}{l} i h_t + 0.5 h_{xx} + |h|^2 h = 0,\ \ \ x \in [-5, 5],\ \ \ t \in [0, \pi/2],\\ h(0,x) = 2\ \text{sech}(x),\\ h(t,-5) = h(t, 5),\\ h_x(t,-5) = h_x(t, 5), \end{array}
Results Obtained using PINNModel-
Dataset and Codes for plotting the graph have been taken from Authors repo
Future Work
- Would like to add more documentation about the usage of the
PINNModel
. - The need to include Custom based Deepchem’s
Losses
,Optimizers
andSchedulers
usingJax
- Connect
JaxModel
with Wandb for further improved visualisations.
Acknowledgement
Deepchem
was the first Open Source Organisation I have actively contributed to and Im really grateful to the community for continuous support and knowledge. I want to thank @bharath, @ncfrey, @peastman for their valuable mentorship and feedback throughout the program.