Multigpu Distributed Training in DeepChem

A feature which will make DeepChem more powerful for scientific applications is the support for distributed training. The support for multi-gpu distributed training will enable DeepChem users to train models at a larger scale, enabling them to experiment with scientific foundation models and other such applications of large models.

End User API

import deepchem as dc
from deepchem.models.torch_models import DistributedTrainer

model = SNAPModel(**model_parameters)
trainer = DistributedTrainer(model)
trainer.fit(dataset)

Implementation details

The implementation takes ideas from the work of Abhishek Kadian and Princy Chahal on integration of multi-gpu support to DeepChem. The API is designed similar to PyTorchLightning’s API and it deviates from the earlier work by proposing to use torch.distributed for distributed training instead of pytorch-lightning. By using torch.distributed, we believe that DeepChem as a framework can be more powerful instead of depending on a third-party library.

To this end, the two main sub-tasks in implementing it are:

  1. Convert DeepChem DiskDataset to handle distributed loading of data
  2. Build a distributed trainer which uses PyTorch model for the training

Updating DiskDataset to handle distributed training

Note: This approach was earlier shared by Peter Eastman in here.


class TorchDiskDataset(torch.utils.data.IterableDataset):

	def __init__(self,
             	disk_dataset: dc.data.DiskDataset,
             	epochs: int,
             	deterministic: bool = True,
             	batch_size: Optional[int] = None):
    	self.disk_dataset = disk_dataset
    	self.epochs = epochs
    	self.deterministic = deterministic
    	self.batch_size = batch_size

	def __len__(self):
    	   return len(self.disk_dataset)

	def __iter__(self):
    	# Each time an iterator is created i.e when we call enumerate(dataloader),
    	# num_worker number of worker processes get created.
    	worker_info = torch.utils.data.get_worker_info()
    	n_shards = self.disk_dataset.get_number_shards()
    	if worker_info is None:
        	process_id = 0
        	num_processes = 1
    	else:
        	process_id = worker_info.id
        	num_processes = worker_info.num_workers

    	if dist.is_initialized():
        	process_id += dist.get_rank() * num_processes
        	num_processes *= dist.get_world_size()


    	first_shard = process_id * n_shards // num_processes
    	last_shard = (process_id + 1) * n_shards // num_processes

    	if first_shard == last_shard:
            return

    	# Last shard exclusive
        shard_indices = list(range(first_shard, last_shard))
    	for X, y, w, ids in self.disk_dataset._iterbatches_from_shards(
            	shard_indices,
            	batch_size=self.batch_size,
            	epochs=self.epochs,
            	deterministic=self.deterministic):
        	if self.batch_size is None:
            	for i in range(X.shape[0]):
             	    yield (X[i], y[i], w[i], ids[i])
        	else:
            	    yield (X, y, w, ids)

Distributed Trainer

Pseudocode (yet to be rigorously tested) for distributed training is shown here.

class DistributedTrainer():

    def __init__(model: dc_model, distributed_strategy: 'ddp',
           n_gpus: int, batch_size: int, n_epochs: int, collate_fn: Callable):
    	self.dc_model = model
    	self.total_epochs = total_epochs
    	self.n_gpus = n_gpus  # or world_size
    	self.optimizer = None

   def fit(self, data_dir: str = None):
    	# Processes shoud be spawned here before fitting data
    	mp.spawn(Trainer.run, args=(self, data_dir, ), self.n_gpus)
    
   def load_training_data(self, data_dir, collate_fn):
    	dataset = dc.data.DiskDataset(data_dir)
    	dataloader = torch.utils.data.DataLoader(
    	     	dataset,
        		batch_size=None,
        		collate_fn=collate_fn,
        		shuffle=False,
    		)
    	return dataloader
   
   def ddp_setup(self, rank, world_size):
    	os.environ['MASTER_ADDR'] = '127.0.0.1'
    	os.environ['MASTER_PORT'] = '29500'
    	dist.init_process_group(backend, rank=rank, world_size=size)
    	torch.cuda.set_device(rank)
    
   def configure_optimizer(self, model):
    	# TODO More options for choosing optimizers to be included
    	self.optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
    	self.configure_lr_scheduler()

    def configure_lr_scheduler(self):
    	# Initialize learning rate scheduler here
    	pass

    def step_schedulers(self, metrics=None):
    	try:
             self.lr_scheduler.step(metrics=metrics)
    	except:
             self.lr_scheduler.step()

   def run_batch(self, inputs, targets):
    	inputs = self.dc_model._prepare_batch(inputs)
    	out = self.model.forward(inputs)
    	loss = loss_fn(inputs, out)
    	loss.backward()
    	self.optim.step()
    	self.optim.zero_grad()
    	self.optim_steps += 1
    	self.step_schedulers()
    
    def run_epoch(self, data):
    	for i, batch in enumerate(data.iterbatches()):
     	    inputs, targets = batch
            self.run_batch(inputs, targets)
            # Additional per-batch utilitied like saving checkpoints here
    
    def train(self):
    	data = self.load_training_data()
    	for i in range(self.total_epochs):
    	      self.run_epoch()
              # Additional logging or functions like evaluating validation loss for
              # early stopping here.
   	 
    @staticmethod
    def run(local_rank: int, trainer: "pl.Trainer"):
    	trainer.ddp_setup(rank, trainer.n_gpus)
    	model = trainer.dc_model.model.to(local_rank)
    	model = DDP(model, device_ids=[local_rank])
    	trainer.configure_optimizer(model)
    	trainer.train()

    @classmethod
    def restore(path: str):
	    # Resumes training from a checkpoint, allowing for fault tolerance
	    pass
1 Like

Updates:

After some initial prototypes with torch.distributed, we decided to go with PyTorch-Lightning for distributed training. The reasons for choosing Lightning are:

  • Lower maintenance efforts with Lightning when compared to torch.distributed
  • Modular approach by Lightning to train models
  • Easier integration with existing DeepChem ecosystem (DeepChem already has some preliminary support for PyTorch-Lightning

Here is an updated DistributedTrainer psuedo-code:

import deepchem as dc
from deepchem.models.lightning.dc_lightning_module import DCLightningModule
from deepchem.models.lightning.dc_lightning_dataset_module import DCLightningDatasetModule, collate_dataset_wrapper


class DistributedTrainer():
    r"""DistributedTrainer provides an interface for scaling the training of DeepChem
    model beyongs multiple GPUs and nodes. 

    Example
    -------
    .. code-block:: python

        import deepchem as dc
        from deepchem.models.trainer import DistributedTrainer

        dataset = dc.data.DiskDataset('zinc100k')

        atom_vocab = GroverAtomVocabularyBuilder.load('zinc100k_atom_vocab.json')
        bond_vocab = GroverBondVocabularyBuilder.load('zinc100k_bond_vocab.json')

        model = GroverModel(task='pretraining',
                            mode='pretraining',
                            node_fdim=151,
                            edge_fdim=165,
                            features_dim=2048,
                            functional_group_size=85,
                            hidden_size=128,
                            learning_rate=0.0001,
                            batch_size=1,
                            dropout=0.1,
                            ffn_num_layers=2,
                            ffn_hidden_size=5,
                            num_attn_heads=8,
                            attn_hidden_size=128,
                            atom_vocab=atom_vocab,
                            bond_vocab=bond_vocab)

        trainer = DistributedTrainer(max_epochs=1,
                                     batch_size=64,
                                     num_workers=0,
                                     accelerator='gpu',
                                     distributed_strategy='ddp')
        loss = trainer.fit(model, dataset)
    """

    def __init__(self,
                 max_epochs,
                 batch_size,
                 collate_fn=collate_dataset_wrapper,
                 devices: str = 'auto',
                 accelerator: str = 'auto',
                 distributed_strategy: str = 'auto'):
        self.max_epochs = max_epochs
        self.batch_size = batch_size
        self.devices = devices
        self.accelerator = accelerator
        self.distributed_strategy = distributed_strategy
        self.collate_fn = collate_fn

    def fit(self, model: dc.models.Model, dataset: dc.data.DiskDataset):
        import lightning as L
        lit_model = DCLightningModule(model)
        dataset = DCLightningDatasetModule(dataset,
                                           batch_size=self.batch_size,
                                           collate_fn=self.collate_fn,
                                           num_workers=0)
        trainer = L.Trainer(max_epochs=self.max_epochs,
                            devices=self.devices,
                            accelerator=self.accelerator)
        loss = trainer.fit(lit_model, dataset)
        return loss