Thank you for your interest!
I joined this GSoC and decided to stop developing JAXChem during GSoC. There are some reasons.
The first reason is that I felt the JAX and Haiku are not mature considering implementing high level API for JAX like KerasModel or TorchModel of DeepChem (as of July). This is because JAX sometimes brings breaking changes even when bumping a minor version like 0.1.69 -> 0.1.70 and it is hard to match the same version of JAX between Haiku and Google Colab. Many deepchem users like chemists, including me, don’t have any GPU environment, so it is really important to set up the Google Colab environment easily.
The second reason is that I faced performance issues when constructing GCN models like DGL and PyG. (The detail is here). The performance issue was hard to resolve, so I worked for PyTorch support which is higher priority.
But, after GSoC, I know deepmind provided some good libraries like Optax, Jraph. So, it may be possible that we restart supporting JAX modeling.
Especially, it is really worth implementing JAXModel which has the same API as PyTorchModel or KerasModel by using Optax or Haiku and we welcome to contribute. Currently, I just implemented a simple GCN model and jaxchem don’t provide high level API for training or evaluating models like KerasModel or TorchModel of DeepChem. (I hard-coded a training loop and model evaluation in the example).