GSOC'25 Project: Model-Parallel DeepChem Model Training

Hey everyone!

I’m Bhuvan, a 3rd-year Computer Science student from Bangalore, India. Super excited to share that I’ll be working on multi-GPU support for DeepChem models for my GSOC 2025 project! This is something I’ve been tinkering with for a while, and I can’t wait to dive deeper into the challenges of distributed training. I’ll be posting weekly updates here, be sure to check them out!

Looking forward to learning and building this. Catch you in the updates! :slight_smile:

Week 1: June 2 - June 8

Work Done:

  • Implemented custom __get__ and __len__ functions for DiskDataset to ensure compatibility with torch.data.DataLoader.
  • Developed a custom collate function for models inheriting the TorchModel base class, specifically for the Lightning Data wrapper.
  • Created a Lightning data wrapper to utilize DiskDataset data with PyTorch Lightning.

Issues Faced:

  • The iterable method implemented by DeepChem, currently causes GPU deadlocks in a multi-GPU environment.
  • Designing a custom general collate function for all models inheriting the TorchModel base class directly.

Open Pull Requests:

Week 2: June 9 - June 13

Work Done

  • Written a custom test case for checkpoint saving/loading and model fit/prediction, using two Torch models as examples:
    • MultitaskClassifier: A simple TorchModel base class inheriting Model.
    • GCN: A TorchModel inheriting Model that also utilizes DGL.
  • Created a Lightning model wrapper for models inheriting the TorchModel base class, using PyTorch Lightning.
  • Developed a simple trainer, integrating previously written Lightning data and model wrappers.

Issues Faced

  • Ensuring the wrapper worked with models utilizing DGL. This was later resolved by properly installing DGL from here.
  • Making sure train/predict functionality worked as expected in the TorchModel base class initially took some time.

PRs Open

Week 3: June 14 - June 20

Work Done

  • DiskDataset Refactor:
    • Separated DiskDataset enhancements from the main PR (#4454) into a dedicated PR (#4458).
    • Added detailed docstrings for better clarity and maintainability.
  • Testing Improvements:
    • Created a pytest fixture dummy_disk_dataset for testing datasets with uneven shard sizes.
    • Added unit tests for:
      • __getitem__: Including edge cases and out-of-bounds access.
      • _cumulative_sum: Verifying correctness across a variety of inputs.
  • PyTorch Lightning Integration:
    • Introduced DeepChemLightningDataModule and DeepChemLightningModule in PR #4461.
    • Integration with PyTorch Lightning tests with GCN model includes:
      • Training and prediction workflows
      • Checkpointing capabilities
      • GPU support for accelerated computation
    • Comprehensive tests added to ensure Lightning modules behave as expected. The test results can be seen as below image
  • Package Exposure:
    • Updated __init__.py to expose the new Lightning module classes.

Issues Faced

  • Minor integration issues and strict type-checking challenges.

Open Pull Requests

Slide Link

Week 4: June 21 - June 27

Work Done

  • DeepChemLightningTrainer class implementation:

    • Adds a wrapper class to train, predict, and manage checkpoints for DeepChem models using PyTorch Lightning. Includes methods like fit, predict, save_checkpoint, and load_checkpoint.
    • Includes clean docstring with examples.
  • Unit tests for Lightning integration:

    • Multitask classifier test: Implements a test for training, checkpointing, and reloading a multitask classifier using the new DeepChemLightningTrainer class. Ensures weight consistency after reloading.
    • GCN model test: Adds a test for training, checkpointing, and reloading a graph convolutional network model using the DeepChemLightningTrainer. Also verifies weight consistency after reloading.
    • The test results can be seen as below
      image
  • Indexable dataset wrapper

    • Made a huge change to the PR#4458, by moving the indexable logic to a separate class IndexDiskDatasetWrapper. Modified the test cases likewise.

Issues Faced

  • Integration issues and strict type-checking challenges, similar to previous week.

Open Pull Requests

Slide Link

Week 5: June 28 – July 04

Work Done

  • PR #4458 Final Review:

    • Refactored the collate function by removing the nested class structure and adding them to utils.rst
    • Optimized performance using list comprehensions instead of traditional loops.
  • Custom Trainer Class Enhancements (PR #4469):

    • Made load_checkpoints a @staticmethod to support training–prediction strategy switching.
    • Added an evaluate method via PR #4473:
      • Accepts model, metric, data, and transformers for performance scoring.
      • Verified functionality with test cases designed for multi-GPU environments.
    • Included an overfit test case utilizing evaluate method.

Issues Faced

  • Challenges in integrating DeepChem-based scikit functions in evaluation method.
  • Encountered and addressed circular import errors.

Open Pull Requests

Slide Link

Week 6: July 05 – July 11

Work done

  • Made structural changes to all Lightning modules, Lightning utils, custom Torch-based IndexDiskDataset, and test case files.
  • Created a single big draft PR that includes all modules, trainer, utils, and test cases for FSDP training. (PR#4481)
  • Created a detailed slide deck explaining my entire work-process to enable FSDP, highlighting areas where the current implementation needs improvement. View slide here

Issues faced

  • Challenges in deciding the structure and proper modularization for modules, trainer, and utils files.
  • Fixing CI issues across PRs.

PRs open

Slide Link

Week 7: July 12 – July 18

Work Done

  • Lightning Utils PR:

    • Merged Lightning utilities into DeepChem (PR #4483).
  • Lightning Modules:

    • Enhancing module robustness with additional test cases covering:
      • Checkpointing
      • Model reloading
      • Correctness test
  • Prediction Support:

    • Added predict method and corresponding dataloader to Lightning model and data wrappers.
    • Designed to support multi-GPU prediction workflows.

Issues Faced

  • Fix (CI) issues across multiple PRs.

Open Pull Requests

Slide Link

Week 8: July 19 – July 25

Work done

  • Reworked the modules PR to ensure the older lightning implementation is properly deprecated.
  • Updated the trainer PR to reflect changes from the modules PR, and verified that all test cases pass correctly.
  • Explored DeepChem’s HuggingFaceModel base class to make the imported models from HuggingFace compatible with the PyTorch Lightning’s infrastructure in deepchem.

Issues faced

  • Ensuring older features are softly deprecated while maintaining compatibility with the existing Lightning framework.

PRs open

Slide Link