Hi, I have been trying to rewrite your tutorial example on tox21 prediction form Tensorgraph to KerasModel. Apparently, I have missed something. Using the code below, I can create a KerasModel object. However, when I feed in the same generator as in the tutorial, the fitting fails due to KeyError.
from tensorflow.keras.layers import Input, Reshape, Conv2D, Flatten, Dense, Softmax
atom_features = Input(shape=(75,))
degree_slice = Input(shape=(2,), dtype=tf.int32)
membership = Input(shape=(), dtype=tf.int32)
deg_adjs = []
for i in range(0, 10 + 1):
deg_adj = Input(shape=(i + 1,), dtype=tf.int32)
deg_adjs.append(deg_adj)
from deepchem.models.layers import GraphPool, GraphGather, GraphConv
from tensorflow.keras.layers import BatchNormalization
batch_size = 50
gc1 = GraphConv(64, activation_fn=tf.nn.relu)([atom_features, degree_slice, membership] + deg_adjs)
batch_norm1 = BatchNormalization()(gc1)
gp1 = GraphPool()([batch_norm1, degree_slice, membership] + deg_adjs)
gc2 = GraphConv(64, activation_fn=tf.nn.relu)([gp1, degree_slice, membership] + deg_adjs)
batch_norm2 = BatchNormalization()(gc2)
gp2 = GraphPool()([batch_norm2, degree_slice, membership] + deg_adjs)
dense = Dense(128, activation=tf.nn.relu)(gp2)
batch_norm3 = BatchNormalization()(dense)
readout = GraphGather(batch_size=batch_size,activation_fn=tf.nn.tanh)([batch_norm3, degree_slice, membership] + deg_adjs)
from deepchem.models.layers import Stack
from tensorflow.keras.layers import Softmax
from tensorflow.keras.losses import CategoricalCrossentropy
costs = []
outputs = []
for task in range(len(tox21_tasks)):
classification = Dense(2, activation=None)(readout)
softmax = SoftMax()(classification)
outputs.append(softmax)
keras_model = tf.keras.Model(inputs=[atom_features, degree_slice, membership, deg_adjs], outputs=outputs)
model = dc.models.KerasModel(keras_model, dc.models.losses.CategoricalCrossEntropy())
model.fit_generator(data_generator(train_dataset, epochs=1))
The last line results in the following error:
KeyError Traceback (most recent call last)
in ()
----> 1 model.fit_generator(data_generator(train_dataset, epochs=1))
1 frames
/usr/local/lib/python3.7/site-packages/deepchem/models/keras_model.py in _create_training_ops(self, example_batch)
221 if self._training_ops_built:
222 return
–> 223 self._create_inputs(example_batch[0])
224 self._training_ops_built = True
225 self._label_dtypes = [
KeyError: 0
Could you, please help me fix this?