Custom callback - integration with Optuna

Hi,
I’ve been experimenting with nested cross-validation in DeepChem, where the model’s hyperparameters are optimized in an inner loop and performance is estimated on an outer loop.
I needed a bit more flexibility in my training loop, so I wrote my own cross-validation loops and decided to use Optuna instead of DeepChem HyperparamOpt. This approach worked fine, but I realized that cross-validating a dataset with ~2K molecules using GCNModel (Pytorch implementation of graph convolutional NN) was taking more than 24h!

The problem was that I did not include a pruning strategy to skip unpromising trials. Therefore, I wrote a custom callback to integrate DeepChem with Optuna hyperparameter optimization. In the following callback, I use Fastai + Pytorch:

from fastai.callback.core import CancelFitException # a callback to stop training if a condition is met. 

class DeepChemPruningCallback():
    """DeepChem callback to prune unpromising trials for DeepChem.

    .. note::
        This callback is for deepchem>=2.5.0.

    Example:

        Register a pruning callback to ``model.fit``.

        .. code::

        model.fit(train_dataset, nb_epoch=n_epochs, 
        callbacks=DeepChemPruningCallback(valid_dataset, trial, metrics=metrics, monitor='mean_squared_error'))



    Args:
        dataset:
            The dataset to monitor performance
            
        metrics: 
            Metrics used to evaluate performance
        trial:
            A :class:`~optuna.trial.Trial` corresponding to the current
            evaluation of the objective function.
        monitor:
            An evaluation metric for pruning, e.g. ``valid_loss`` or ``accuracy``.
            Please refer to https://deepchem.readthedocs.io/en/latest/api_reference/metrics.html.
    """

    def __init__(self,dataset, trial: optuna.Trial, metrics,monitor: str = "mean_squared_error"):

        self.trial = trial        
        self.dataset = dataset
        self.metrics = metrics
        self.monitor = monitor
    
    def __call__(self, model, step):
        scores = model.evaluate(self.dataset, self.metrics)[self.monitor]
        
        self.trial.report(scores, step=step)
        
        if self.trial.should_prune():
            raise optuna.TrialPruned(f"Trial was pruned at step {step}.")
            raise CancelFitException()

Please, feel free to test it and give some feedback :slight_smile:

2 Likes

Oh this looks really neat! How well does the Optuna callback work in your experience so far? If this provides a big advantage we could consider offering similar functionality in dc.hyper directly (our hyperparameter optimization toolkit is still pretty rudimentary as you’ve seen!)

1 Like

It seems to be working fine. I’m running nested cross-validation and the callback is pruning bad trials using the median of intermediate score values (e.g. mean squared error). It usually prunes trials very early, before training reaches 50 steps.

In terms of performance, for my datasets, Optuna usually converges within 10 trials.

1 Like