DeepChem- Pytorch Lightning Proposal

Hi everyone, in this post I would like to propose pytorch-lightning integration for deepchem, a feature that I hope to work on for deepchem. Pytorch-Lightning is a pytorch framework which simplifies the process of experimenting with pytorch models easier. If we decide to pursue this integration project then there would be certain issues we will have to deal with related to the backward compatibility of deepchem. These issues specifically stem from the overlap between deepchem TorchModel and different modules of pytorch lightning.

An analogy which can be helpful to understand the relationship between pytorch <-> pytorch-lightning is the tensorflow <-> keras relationship. Keras provides higher level ML training and inference functionalities. Lightning also provides those functionalities for pytorch but it is more hands on, in the sense that it gives an API to be able to work with the lower level details of training the pytorch model.

I have created an initial tutorial for integrating the GCNModel from deepchem with pytorch lightning: https://github.com/deepchem/deepchem/pull/2826, hopefully this notebook gives an idea of what the integration can look like.

Motivation

Pytorch lightning provides a framework for writing pytorch code. It provides functionalities to reduce the amount of code to train and track models. It also reduces the burden of writing boilerplate pytorch code. Below I will talk about some functionalities of the lightning code. If integrated with deepchem, pytorch lightning can reduce the workload of implementing ML model functionalities for the deepchem library (e.g. a number of trainer functionalities of deepchem can be abstracted away inside pytorch lightning). There are additional functionalities (discussed below) which will extend the feature set of deepchem.

  1. Multi-gpu training functionalities: pytorch-lightning provides easy multi-gpu, multi-node training. It also simplifies the process of launching multi-gpu, multi-node jobs across different cluster infrastructure, e.g. AWS, slurm based clusters.
  • One specific advantage of using lightning here over pure pytorch is that there are a number of details of using multi-node training that lightning takes care of. e.g. logging and syncing metrics across multiple GPUs being used in the same run needs some care, lightning implements those functionalities out of the box. Similarly correctly distributing the train/val dataset inside the dataloader also requires care, this again is handled by lightning easily.
  1. Reducing boilerplate pytorch code: lightning takes care of details like, optimizer.zero_grad(), model.train(), model.eval(). Lightning also provides experiment logging functionality, for e.g. irrespective of training on CPU, GPU, multi-nodes the user can use the method self.log inside the trainer and it will appropriately log the metrics. There is definitely an overlap of this functionality with the TorchModel provided by deepchem, I will discuss these concerns in sections below.

  2. Some other basic training functionalities provided by lightning. Using these functionalities involves just modifying some keyword arguments and implementations of a few functions provided by the framework.

  3. Half-precision training.

  4. Gradient checkpoint.

  5. Code profiling.

  6. LR scheduler.

  7. Lightning ecosystem contains other libraries, e.g. lightning-bolts, torchmetrics which can be leveraged by users. We are not planning to integrate these libraries in the current project, but being compatible with the ecosystem can bring in more traction.

Deepchem and PytorchLightning

The TorchModel in deepchem has an overlap in the functionality that is provided by the trainer module in lightning. The deepchem TorchModel takes in the nn.module pytorch model and runs training on it with the TorchModel.fit method which takes in the dataset to train on. Similarly the pytorch-lightning trainer also takes in a nn.module and dataset as input and performs training on it with the trainer.fit method.

If we are to perform the integration then the API around the deepchem TorchModel will need some redesign, in my view there are 2 approaches for doing this:

  1. Write wrappers so that the current TorchModel can be called inside the LightningTrainer. This will allow us to expose the LightningTrainer to the user which is the intended place where the user can play with the lower level functionality of training models. One downside of this wrapper approach is that the exposition of internal modules is two step for a user who is working with pytorch lightning, the first step is working with the API inside the LightningTrainer and second is working with the API inside TorchModel.
  2. The second option is more transformative, in this option we can remove the TorchModel class completely. I think the LightningTrainer can provide the functionality of the TorchModel class and this will reduce the amount of code in deepchem as we can offload all those functionalities to pytorch lightning. This change will especially be disruptive for backwards compatibility as we will have to figure out how to support previously trained TorchModel objects. The one advantage of this option is that it will enable seamless integration with lightning and we can easily use its functionalities (highlighted in the previous section). With option 1 we will still be able to use those functionalities but that usage will require wrapper code to be written to handle the interactions.

Some other interesting interaction points that I came across while exploring both the libraries:

  • Both deepchem and pytorch-lightning leverage callbacks to plug in functions inside different objects; this design pattern will remain consistent after the integration.
  • For a future feature we can also explore the integration of hydra (https://github.com/facebookresearch/hydra) with the lightning + deepchem port. Hydra can be used to build a config system for the deepchem library which can streamline the process of experimentations with models.

Conclusion

In this proposal I have highlighted functionalities of pytorch-lightning that can be useful for deepchem. If we do decide to follow through on this project then there are some important design decisions for the integration with the deepchem TorchModel object. It would be very helpful if we can have further discussions around the proposal as that will also be useful to figure out the scope and timeline of the project. Thanks.

1 Like

Thanks for writing this up! At first thought, I think option #1 sounds like a better start since it has a fewer backwards compatibility issues (we have users using DeepChem in live settings and we want to minimize breakage).

I’d be curious to hear what other folks think as well and to do a deeper discussion on the developer calls

1 Like

One issue to keep in mind is that DeepChem has models implemented with several different frameworks: PyTorch, TensorFlow, JAX, and Scikit-learn. The Model class is designed to provide a unified interface for all of them. As much as possible, a user shouldn’t need to know or care what framework was used to implement a particular model. If PyTorch models used a completely different API from other models, it would break that design.

2 Likes

Thanks, the comments and suggestions were helpful. I agree that option #1 of building on top of the current TorchModel API for LightningTrainer is the better way to forward. I might have some follow up questions which I can discuss on the thread or pull request. I work on implementing option #1.