Problem using predict_on_generator

I’m having some trouble understanding predict_on_generator. The code was mostly refashioned from the tutorials and the code below as well as the data preparation sections can be found in this colab notebook ready for execution.

(Please note that n the notebook I generate two sub-datasets (1) smiles and (2) porosity and merged them but to keep the example simple only the smiles data is being used.)

The dataset is first splitted:

splitter = dc.splits.RandomSplitter()
train_dataset, valid_dataset, test_dataset = splitter.train_valid_test_split(dataset, 
                                                                             #frac_train = 0.8, frac_valid = 0.5, frac_test = 0.15
                                                                             )

A data generator is introduced:

from deepchem.metrics import to_one_hot
from deepchem.feat.mol_graphs import ConvMol
import numpy as np

batch_size = 100
def data_generator(dataset, epochs=1):
  print(dataset)
  for ind, (X_b, y_b, w_b, ids_b) in enumerate(dataset.iterbatches(batch_size, epochs,
                                                                   deterministic=False, pad_batches=True)):
    multiConvMol = ConvMol.agglomerate_mols(X_b)
    inputs = [multiConvMol.get_atom_features(), multiConvMol.deg_slice, np.array(multiConvMol.membership)]
    for i in range(1, len(multiConvMol.get_deg_adjacency_lists())):
      inputs.append(multiConvMol.get_deg_adjacency_lists()[i])
    #labels = [to_one_hot(y_b.flatten(), 2).reshape(-1, n_tasks, 2)]
    labels = [y_b]
    weights = [w_b]
   
    yield (inputs, labels, weights)

And the model goes as:

#@title 
import warnings
warnings.filterwarnings('ignore')
#warnings.filterwarnings('default')

from deepchem.models.layers import GraphConv, GraphPool, GraphGather
import tensorflow as tf
import tensorflow.keras.layers as layers

batch_size = 100
n_tasks = 1
class MyGraphConvModel(tf.keras.Model):

  def __init__(self):
    super(MyGraphConvModel, self).__init__()
    self.gc1 = GraphConv(128, activation_fn=tf.nn.tanh)
    self.batch_norm1 = layers.BatchNormalization()
    self.gp1 = GraphPool()

    self.gc2 = GraphConv(128, activation_fn=tf.nn.tanh)
    self.batch_norm2 = layers.BatchNormalization()
    self.gp2 = GraphPool()

    self.dense1 = layers.Dense(256, activation=tf.nn.tanh)
    self.batch_norm3 = layers.BatchNormalization()
    self.readout = GraphGather(batch_size=batch_size, activation_fn=tf.nn.tanh)

    self.dense2 = layers.Dense(1)
    self.relu = layers.ReLU()


  def call(self, inputs):
    gc1_output = self.gc1(inputs)
    batch_norm1_output = self.batch_norm1(gc1_output)
    gp1_output = self.gp1([batch_norm1_output] + inputs[1:])

    gc2_output = self.gc2([gp1_output] + inputs[1:])
    batch_norm2_output = self.batch_norm1(gc2_output)
    gp2_output = self.gp2([batch_norm2_output] + inputs[1:])

    dense1_output = self.dense1(gp2_output)
    batch_norm3_output = self.batch_norm3(dense1_output)
    readout_output = self.readout([batch_norm3_output]+ inputs[1:])

    dense2_output = self.dense2(readout_output)
    relu_output = self.relu(dense2_output)


    # print(f"Shapes:\t\
    # Batch 1: {batch_norm1_output.shape}\
    # Dense 1: {dense1_output.shape}\
    # Batch_norm3: {batch_norm3_output.shape}\
    # Readout: {readout_output.shape}\
    # Dense 2: {dense2_output.shape}")


    return relu_output

The fitting and evaluation goes fine:

model2 = dc.models.KerasModel(MyGraphConvModel(), loss=dc.models.losses.L2Loss())
model2.fit_generator(data_generator(train_dataset, epochs=50))
metric = dc.metrics.Metric(dc.metrics.pearson_r2_score)
print(model2.evaluate_generator(data_generator(valid_dataset, epochs = 50), [metric], 
                     #transformers
                     ))
print(model2.evaluate_generator(data_generator(test_dataset, epochs = 50),[metric], 
                     #transformers
                     ))

But the prediction:

pred = model2.predict_on_generator(data_generator(train_dataset, epochs = 1))
print(train_dataset.y.shape, pred.shape)

Ends up giving a different shape from the train_dataset.
(10,) (100, 1)

Moreover the shape of pred is a function of epochs so that making it 2 gives a different shaped pred.

I am not sure why this is happening. Can someone direct me to a tutorial of some sort that might help?