Hi everyone, in this post I would like to propose pytorch-lightning integration for deepchem, a feature that I hope to work on for deepchem. Pytorch-Lightning is a pytorch framework which simplifies the process of experimenting with pytorch models easier. If we decide to pursue this integration project then there would be certain issues we will have to deal with related to the backward compatibility of deepchem. These issues specifically stem from the overlap between deepchem TorchModel
and different modules of pytorch lightning.
An analogy which can be helpful to understand the relationship between pytorch <-> pytorch-lightning is the tensorflow <-> keras relationship. Keras provides higher level ML training and inference functionalities. Lightning also provides those functionalities for pytorch but it is more hands on, in the sense that it gives an API to be able to work with the lower level details of training the pytorch model.
I have created an initial tutorial for integrating the GCNModel from deepchem with pytorch lightning: https://github.com/deepchem/deepchem/pull/2826, hopefully this notebook gives an idea of what the integration can look like.
Motivation
Pytorch lightning provides a framework for writing pytorch code. It provides functionalities to reduce the amount of code to train and track models. It also reduces the burden of writing boilerplate pytorch code. Below I will talk about some functionalities of the lightning code. If integrated with deepchem, pytorch lightning can reduce the workload of implementing ML model functionalities for the deepchem library (e.g. a number of trainer functionalities of deepchem can be abstracted away inside pytorch lightning). There are additional functionalities (discussed below) which will extend the feature set of deepchem.
- Multi-gpu training functionalities: pytorch-lightning provides easy multi-gpu, multi-node training. It also simplifies the process of launching multi-gpu, multi-node jobs across different cluster infrastructure, e.g. AWS, slurm based clusters.
- One specific advantage of using lightning here over pure pytorch is that there are a number of details of using multi-node training that lightning takes care of. e.g. logging and syncing metrics across multiple GPUs being used in the same run needs some care, lightning implements those functionalities out of the box. Similarly correctly distributing the train/val dataset inside the dataloader also requires care, this again is handled by lightning easily.
-
Reducing boilerplate pytorch code: lightning takes care of details like,
optimizer.zero_grad()
,model.train(), model.eval()
. Lightning also provides experiment logging functionality, for e.g. irrespective of training on CPU, GPU, multi-nodes the user can use the methodself.log
inside the trainer and it will appropriately log the metrics. There is definitely an overlap of this functionality with theTorchModel
provided by deepchem, I will discuss these concerns in sections below. -
Some other basic training functionalities provided by lightning. Using these functionalities involves just modifying some keyword arguments and implementations of a few functions provided by the framework.
-
Half-precision training.
-
Gradient checkpoint.
-
Code profiling.
-
LR scheduler.
-
Lightning ecosystem contains other libraries, e.g. lightning-bolts, torchmetrics which can be leveraged by users. We are not planning to integrate these libraries in the current project, but being compatible with the ecosystem can bring in more traction.
Deepchem and PytorchLightning
The TorchModel in deepchem has an overlap in the functionality that is provided by the trainer module in lightning. The deepchem TorchModel takes in the nn.module
pytorch model and runs training on it with the TorchModel.fit
method which takes in the dataset
to train on. Similarly the pytorch-lightning trainer also takes in a nn.module
and dataset
as input and performs training on it with the trainer.fit
method.
If we are to perform the integration then the API around the deepchem TorchModel
will need some redesign, in my view there are 2 approaches for doing this:
- Write wrappers so that the current
TorchModel
can be called inside theLightningTrainer
. This will allow us to expose theLightningTrainer
to the user which is the intended place where the user can play with the lower level functionality of training models. One downside of this wrapper approach is that the exposition of internal modules is two step for a user who is working with pytorch lightning, the first step is working with the API inside theLightningTrainer
and second is working with the API insideTorchModel
. - The second option is more transformative, in this option we can remove the
TorchModel
class completely. I think theLightningTrainer
can provide the functionality of theTorchModel
class and this will reduce the amount of code indeepchem
as we can offload all those functionalities to pytorch lightning. This change will especially be disruptive for backwards compatibility as we will have to figure out how to support previously trainedTorchModel
objects. The one advantage of this option is that it will enable seamless integration with lightning and we can easily use its functionalities (highlighted in the previous section). With option 1 we will still be able to use those functionalities but that usage will require wrapper code to be written to handle the interactions.
Some other interesting interaction points that I came across while exploring both the libraries:
- Both deepchem and pytorch-lightning leverage callbacks to plug in functions inside different objects; this design pattern will remain consistent after the integration.
- For a future feature we can also explore the integration of hydra (https://github.com/facebookresearch/hydra) with the lightning + deepchem port. Hydra can be used to build a config system for the deepchem library which can streamline the process of experimentations with models.
Conclusion
In this proposal I have highlighted functionalities of pytorch-lightning that can be useful for deepchem. If we do decide to follow through on this project then there are some important design decisions for the integration with the deepchem TorchModel
object. It would be very helpful if we can have further discussions around the proposal as that will also be useful to figure out the scope and timeline of the project. Thanks.