Towards Differentiable DeepChem

In the last few years, there has been a steady push towards making differentiable programming a more fundamental language primitive. Differentiable Swift (implementation progress) has started work to add differentiable programming as a fundamental part of the Swift programming language. Julia has also worked towards adding automatic differentiation at the language level (see zygote).

Python as a language and community is still far away from being a fully differentiable programming language, but progress has been made. Jax automatic differentiation nests through standard python data structures like lists, tuples, and dicts (link). Jax has some support for constructing derivatives for custom data classes (link).

Thinking of a longer term developer roadmap for DeepChem (building on Making DeepChem a Better Framework for AI-Driven Science), I would like to see DeepChem become a fully differentiable framework. That is, all DeepChem featurizers/transformers/models should have gradient functions defined. This is a longer term goal, so I’m starting up this thread to brainstorm ideas

One challenge with a differentiable DeepChem is that there are multiple array types used in DeepChem:

  1. Numpy Arrays
  2. Tensorflow Tensors
  3. PyTorch tensors
  4. Jax tensors

Featurizers and transformers operate primarily on numpy arrays. Models operate on all types of tensors depending on backend implementation. For each DeepChem class/transformation, it may become necessary to define manual gradients in different backends.