Jaxchem: Still Active?

Hello :wave: everyone, I just joined the community and have been looking for ways to contribute. I came across the jaxchem project and it seems to have become dormant. I know that @nd-02110114 worked on it for GSOC and I was wondering if there are plans to continue the work on the library?

I’d love to get involved :blush:.

1 Like

Thank you for your interest!
I joined this GSoC and decided to stop developing JAXChem during GSoC. There are some reasons.

The first reason is that I felt the JAX and Haiku are not mature considering implementing high level API for JAX like KerasModel or TorchModel of DeepChem (as of July). This is because JAX sometimes brings breaking changes even when bumping a minor version like 0.1.69 -> 0.1.70 and it is hard to match the same version of JAX between Haiku and Google Colab. Many deepchem users like chemists, including me, don’t have any GPU environment, so it is really important to set up the Google Colab environment easily.

The second reason is that I faced performance issues when constructing GCN models like DGL and PyG. (The detail is here). The performance issue was hard to resolve, so I worked for PyTorch support which is higher priority.

But, after GSoC, I know deepmind provided some good libraries like Optax, Jraph. So, it may be possible that we restart supporting JAX modeling.

Especially, it is really worth implementing JAXModel which has the same API as PyTorchModel or KerasModel by using Optax or Haiku and we welcome to contribute. Currently, I just implemented a simple GCN model and jaxchem don’t provide high level API for training or evaluating models like KerasModel or TorchModel of DeepChem. (I hard-coded a training loop and model evaluation in the example).

1 Like

When that repository was created, DeepChem was mostly based on TensorFlow and we were deciding how to support other frameworks. One idea we had was that there would be multiple parallel projects based around different ones, like TorchChem and JaxChem. We ended up deciding on a different approach though. We abstracted the DeepChem code to make it agnostic with respect to framework and tried to make it interoperate well with TensorFlow, PyTorch, and scikit-learn. See https://github.com/deepchem/deepchem/blob/master/examples/tutorials/05_Creating_Models_with_TensorFlow_and_PyTorch.ipynb for more information about this. If we decide to add Jax support, it would be done in the same way.

1 Like

Ahh okay so the plan would be to have something like this

dc.models.JAXModel(jax_model, <Loss Function>)

Where jax_model is something like the net_fn as in Haiku?

Sounds good if we can make something like a JAXModel Object.