Enabling Pretraining in DeepChem
Background:
Transfer learning is a technique in machine learning where a model trained on one task is used as the starting point for a model on a second, related task. This is commonly known as pretraining and fine-tuning, respectively. This can be useful because it allows the model to start with knowledge already learned from the first task, rather than starting from scratch, which can save time and resources. It is particularly useful when there is a shortage of labeled data for the second task.
Abstract:
This proposal outlines the addition of a Pretrainer and PretrainableTorchModel classes to DeepChem. The Pretrainer class allows users to easily pretrain a TorchModel on a given task and dataset and then fine-tune on a different task and dataset. This will be achieved by creating PretrainableTorchModels as a subclass of TorchModel, which specifies how a model builds its embedding layers and prediction head, and then switching out the head layers and loss to work with the desired pretraining task. This allows for training a model on different tasks with the same underlying pretrained embedding.
Motivation:
In many applications, we may want to fine-tune a pre-trained model to a new task while reusing the majority of the model’s knowledge and architecture. Pretrained models act as a foundation for new development into more specific tasks. Finetuning can save time and computational resources compared to training a new model from scratch each time, and enables transfer learning between datasets. By removing friction in generating new pretrained models, users can focus more effort on innovative research code and less effort on engineering pretraining pipelines.
Outline:
-
Implement the PretrainableTorchModel abstract class as a subclass of TorchModel. This class identifies which TorchModels are compatible with a Pretrainer by implementing 3 methods and 1 public attribute that will be necessary for pretraining.
- build_embedding(): This will generate the embedding layers of the model
- build_head(): This will generate the prediction head layers of the model
- build_model(): This will attach the embedding and head to produce the final predicting model
- An embedding attribute which returns the embedding layers.
-
Add an abstract Pretrainer class as a subclass of TorchModel
- Takes a PretrainableTorchModel as input to form the basis of the pretrainer. The head layers and loss function will be replaced to enable a pretraining task. Training will be run as a normal TorchModel with .fit().
- Implements stub methods for build_head() and define_pretrain_loss(), which will be specified in any Pretrainer subclass.
-
Modify TorchModel.load_from_pretrained() to be compatible with Pretrainer models
- load_from_pretrained() will check if the source_model is pretrainer compatible by checking if it has an embedding attribute. If it is, and the user has specified to include_top = False, only the embedding layers will be loaded. Otherwise, the method will function as previously loading the entire model’s weights.
- If using a pretrainer to load_from_pretrained, include_top must be False. Otherwise it will attempt to load a head of a different dimension and will fail.
-
Test pretrainer functions similarly to TorchModel, with the addition of testing freezing the embedding and finetuning
- Overfit test
- Fit restore test
- Load from pretrained test
- Freeze embedding and finetune test
No breaking changes will be made.
Abstract PretrainableTorchModel Implementation:
class PretrainableTorchModel(TorchModel):
@property
def embedding():
return NotImplementedError("Subclass must define the embedding")
def build_embedding(self):
return NotImplementedError("Subclass must define the embedding")
def build_head(self):
return NotImplementedError("Subclass must define the head")
def build_model(self):
return NotImplementedError("Subclass must define the model")
Abstract Pretrainer Implementation:
class Pretrainer(TorchModel):
def __init__(self, torchmodel: PretrainableTorchModel, **kwargs):
super().__init__(torchmodel.model, torchmodel.loss, **kwargs)
@property
def embedding(self):
return NotImplementedError("Subclass must define the embedding")
def build_pretrain_loss(self):
return NotImplementedError("Subclass must define the pretrain loss")
Example PretrainableTorchModel:
class ExampleTorchModel(PretrainableTorchModel):
"""Example TorchModel for testing pretraining."""
def __init__(self, input_dim, d_hidden, n_layers, d_output, **kwargs):
self.input_dim = input_dim
self.d_hidden = d_hidden
self.n_layers = n_layers
self.d_output = d_output
self.loss = dc.models.losses.SigmoidCrossEntropy()
self._head = self.build_head()
self._embedding = self.build_embedding()
self._model = self.build_model(self._embedding, self._head)
super().__init__(self._model, self.loss, **kwargs)
@property
def embedding(self):
return self._embedding
def build_embedding(self):
embedding = []
for i in range(self.n_layers):
if i == 0:
embedding.append(nn.Linear(self.input_dim, self.d_hidden))
embedding.append(nn.ReLU())
else:
embedding.append(nn.Linear(self.d_hidden, self.d_hidden))
embedding.append(nn.ReLU())
return nn.Sequential(*embedding)
def build_head(self):
linear = nn.Linear(self.d_hidden, self.d_output)
af = nn.Sigmoid()
return nn.Sequential(linear, af)
def build_model(self, embedding, head):
return nn.Sequential(embedding, head)
Example Pretrainer:
class ExamplePretrainer(Pretrainer):
"""Example Pretrainer for testing."""
def __init__(self, model: ExampleTorchModel, pt_tasks: int, **kwargs):
self._embedding = model.build_embedding()
self._head = self.build_head(model.d_hidden, pt_tasks)
self._model = model.build_model(self._embedding, self._head)
self.loss = self.build_pretrain_loss()
torchmodel = TorchModel(self._model, self.loss, **kwargs)
super().__init__(torchmodel, **kwargs)
@property
def embedding(self):
return self._embedding
def build_pretrain_loss(self):
return dc.models.losses.L2Loss()
def build_head(self, d_hidden, pt_tasks):
linear = nn.Linear(d_hidden, pt_tasks)
af = nn.Sigmoid()
return nn.Sequential(linear, af)
Usage:
n_samples = 10
input_size = 15
d_hidden = 2
n_layers = 1
n_tasks = 3
pt_tasks = 5
X = np.random.rand(n_samples, input_size)
y = np.random.randint(2, size=(n_samples, pt_tasks)).astype(np.float32)
pt_dataset = dc.data.NumpyDataset(X, y)
X = np.random.rand(n_samples, input_size)
y = np.random.randint(2, size=(n_samples, n_tasks)).astype(np.float32)
ft_dataset = dc.data.NumpyDataset(X, y)
X = np.random.rand(n_samples, input_size)
y = np.random.randint(2, size=(n_samples, n_tasks)).astype(np.float32)
test_dataset = dc.data.NumpyDataset(X)
toy = ExampleTorchModel(input_size, d_hidden, n_layers, n_tasks, model_dir='./folder1')
toy2 = ExampleTorchModel(input_size, d_hidden, n_layers, n_tasks)
pretrainer = ExamplePretrainer(toy, pt_tasks=5, model_dir='./folder2')
pretrainer.fit(pt_dataset, nb_epoch=100, checkpoint_interval=10)
toy2.load_from_pretrained(pretrainer, include_top=False, model_dir='./folder2')
# Freeze embedding for finetuning
for param in toy2.embedding.parameters():
param.requires_grad = False
# Finetune
toy2.fit(ft_dataset, nb_epoch=100, checkpoint_interval=10)
preds = toy2.predict(test_dataset)
print('preds: \n', preds)