Towards Differentiable DeepChem

In the last few years, there has been a steady push towards making differentiable programming a more fundamental language primitive. Differentiable Swift (implementation progress) has started work to add differentiable programming as a fundamental part of the Swift programming language. Julia has also worked towards adding automatic differentiation at the language level (see zygote).

Python as a language and community is still far away from being a fully differentiable programming language, but progress has been made. Jax automatic differentiation nests through standard python data structures like lists, tuples, and dicts (link). Jax has some support for constructing derivatives for custom data classes (link).

Thinking of a longer term developer roadmap for DeepChem (building on Making DeepChem a Better Framework for AI-Driven Science), I would like to see DeepChem become a fully differentiable framework. That is, all DeepChem featurizers/transformers/models should have gradient functions defined. This is a longer term goal, so I’m starting up this thread to brainstorm ideas

1 Like

One challenge with a differentiable DeepChem is that there are multiple array types used in DeepChem:

  1. Numpy Arrays
  2. Tensorflow Tensors
  3. PyTorch tensors
  4. Jax tensors

Featurizers and transformers operate primarily on numpy arrays. Models operate on all types of tensors depending on backend implementation. For each DeepChem class/transformation, it may become necessary to define manual gradients in different backends.

A core part of the transition towards differentiable DeepChem will be making DeepChem layers first class parts of DeepChem. Our current layers lack documentation/tutorials. Adding proper documentation/tutorials will help new users confidently use DeepChem layers in their own code. We can over time build up a collection of differentiable layers for scientific AI applications.

1 Like

Numpy 1.22 added preliminary support for the new array API standard. Support for it in PyTorch is in progress. If we’re willing to wait a little while, it will provide a clean way to support all the different backends with the same code.

1 Like

This is a great suggestion! We could plan to move towards a standard that all DeepChem layers obey the new array API so that the implementations are multi-backend out of box.

Here are a couple suggestions for steps:

  1. Move towards a long-term target of making all DeepChem array manipulations follow the python array API standard (featurizers/transformers/layers/models/metrics). That way, in principle all moving parts of DeepChem will be become differentiable if suitable pytorch/jax tensors are passed through
  2. Introduce a new top-level dc.layers module for differentiable layers. Keep existing tensorflow/pytorch/jax layers for backwards compatibility but focus development on new layers. Introduce new Layer superclass that follows the rough structure here:
class Layer(object):

  def __init__(*args, **kwargs):
    # Store layer configuration parameters

  def __call__(*args, **kwargs):
    # Some implementation that obeys the python array API standard. This transformation should be functional since the layer isn't stateful.

These simple layers should be callable within PyTorch/Keras code as simple functions.

1 Like is a really neat reference. It provides some suggestions for code styles to follow when coding with the generic array API