Hi,
I’ve been experimenting with nested cross-validation in DeepChem, where the model’s hyperparameters are optimized in an inner loop and performance is estimated on an outer loop.
I needed a bit more flexibility in my training loop, so I wrote my own cross-validation loops and decided to use Optuna instead of DeepChem HyperparamOpt
. This approach worked fine, but I realized that cross-validating a dataset with ~2K molecules using GCNModel (Pytorch implementation of graph convolutional NN) was taking more than 24h!
The problem was that I did not include a pruning strategy to skip unpromising trials. Therefore, I wrote a custom callback to integrate DeepChem with Optuna hyperparameter optimization. In the following callback, I use Fastai + Pytorch:
from fastai.callback.core import CancelFitException # a callback to stop training if a condition is met.
class DeepChemPruningCallback():
"""DeepChem callback to prune unpromising trials for DeepChem.
.. note::
This callback is for deepchem>=2.5.0.
Example:
Register a pruning callback to ``model.fit``.
.. code::
model.fit(train_dataset, nb_epoch=n_epochs,
callbacks=DeepChemPruningCallback(valid_dataset, trial, metrics=metrics, monitor='mean_squared_error'))
Args:
dataset:
The dataset to monitor performance
metrics:
Metrics used to evaluate performance
trial:
A :class:`~optuna.trial.Trial` corresponding to the current
evaluation of the objective function.
monitor:
An evaluation metric for pruning, e.g. ``valid_loss`` or ``accuracy``.
Please refer to https://deepchem.readthedocs.io/en/latest/api_reference/metrics.html.
"""
def __init__(self,dataset, trial: optuna.Trial, metrics,monitor: str = "mean_squared_error"):
self.trial = trial
self.dataset = dataset
self.metrics = metrics
self.monitor = monitor
def __call__(self, model, step):
scores = model.evaluate(self.dataset, self.metrics)[self.monitor]
self.trial.report(scores, step=step)
if self.trial.should_prune():
raise optuna.TrialPruned(f"Trial was pruned at step {step}.")
raise CancelFitException()
Please, feel free to test it and give some feedback