A feature which will make DeepChem more powerful for scientific applications is the support for distributed training. The support for multi-gpu distributed training will enable DeepChem users to train models at a larger scale, enabling them to experiment with scientific foundation models and other such applications of large models.
End User API
import deepchem as dc
from deepchem.models.torch_models import DistributedTrainer
model = SNAPModel(**model_parameters)
trainer = DistributedTrainer(model)
trainer.fit(dataset)
Implementation details
The implementation takes ideas from the work of Abhishek Kadian and Princy Chahal on integration of multi-gpu support to DeepChem. The API is designed similar to PyTorchLightning’s API and it deviates from the earlier work by proposing to use torch.distributed
for distributed training instead of pytorch-lightning. By using torch.distributed
, we believe that DeepChem as a framework can be more powerful instead of depending on a third-party library.
To this end, the two main sub-tasks in implementing it are:
- Convert DeepChem DiskDataset to handle distributed loading of data
- Build a distributed trainer which uses PyTorch model for the training
Updating DiskDataset to handle distributed training
Note: This approach was earlier shared by Peter Eastman in here.
class TorchDiskDataset(torch.utils.data.IterableDataset):
def __init__(self,
disk_dataset: dc.data.DiskDataset,
epochs: int,
deterministic: bool = True,
batch_size: Optional[int] = None):
self.disk_dataset = disk_dataset
self.epochs = epochs
self.deterministic = deterministic
self.batch_size = batch_size
def __len__(self):
return len(self.disk_dataset)
def __iter__(self):
# Each time an iterator is created i.e when we call enumerate(dataloader),
# num_worker number of worker processes get created.
worker_info = torch.utils.data.get_worker_info()
n_shards = self.disk_dataset.get_number_shards()
if worker_info is None:
process_id = 0
num_processes = 1
else:
process_id = worker_info.id
num_processes = worker_info.num_workers
if dist.is_initialized():
process_id += dist.get_rank() * num_processes
num_processes *= dist.get_world_size()
first_shard = process_id * n_shards // num_processes
last_shard = (process_id + 1) * n_shards // num_processes
if first_shard == last_shard:
return
# Last shard exclusive
shard_indices = list(range(first_shard, last_shard))
for X, y, w, ids in self.disk_dataset._iterbatches_from_shards(
shard_indices,
batch_size=self.batch_size,
epochs=self.epochs,
deterministic=self.deterministic):
if self.batch_size is None:
for i in range(X.shape[0]):
yield (X[i], y[i], w[i], ids[i])
else:
yield (X, y, w, ids)
Distributed Trainer
Pseudocode (yet to be rigorously tested) for distributed training is shown here.
class DistributedTrainer():
def __init__(model: dc_model, distributed_strategy: 'ddp',
n_gpus: int, batch_size: int, n_epochs: int, collate_fn: Callable):
self.dc_model = model
self.total_epochs = total_epochs
self.n_gpus = n_gpus # or world_size
self.optimizer = None
def fit(self, data_dir: str = None):
# Processes shoud be spawned here before fitting data
mp.spawn(Trainer.run, args=(self, data_dir, ), self.n_gpus)
def load_training_data(self, data_dir, collate_fn):
dataset = dc.data.DiskDataset(data_dir)
dataloader = torch.utils.data.DataLoader(
dataset,
batch_size=None,
collate_fn=collate_fn,
shuffle=False,
)
return dataloader
def ddp_setup(self, rank, world_size):
os.environ['MASTER_ADDR'] = '127.0.0.1'
os.environ['MASTER_PORT'] = '29500'
dist.init_process_group(backend, rank=rank, world_size=size)
torch.cuda.set_device(rank)
def configure_optimizer(self, model):
# TODO More options for choosing optimizers to be included
self.optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
self.configure_lr_scheduler()
def configure_lr_scheduler(self):
# Initialize learning rate scheduler here
pass
def step_schedulers(self, metrics=None):
try:
self.lr_scheduler.step(metrics=metrics)
except:
self.lr_scheduler.step()
def run_batch(self, inputs, targets):
inputs = self.dc_model._prepare_batch(inputs)
out = self.model.forward(inputs)
loss = loss_fn(inputs, out)
loss.backward()
self.optim.step()
self.optim.zero_grad()
self.optim_steps += 1
self.step_schedulers()
def run_epoch(self, data):
for i, batch in enumerate(data.iterbatches()):
inputs, targets = batch
self.run_batch(inputs, targets)
# Additional per-batch utilitied like saving checkpoints here
def train(self):
data = self.load_training_data()
for i in range(self.total_epochs):
self.run_epoch()
# Additional logging or functions like evaluating validation loss for
# early stopping here.
@staticmethod
def run(local_rank: int, trainer: "pl.Trainer"):
trainer.ddp_setup(rank, trainer.n_gpus)
model = trainer.dc_model.model.to(local_rank)
model = DDP(model, device_ids=[local_rank])
trainer.configure_optimizer(model)
trainer.train()
@classmethod
def restore(path: str):
# Resumes training from a checkpoint, allowing for fault tolerance
pass