Molecular Attention Transformers

The Molecular Attention Transformer is an intriguing recent work that’s come out that argues for the use of a modified transformer architecture for predicting molecular properties. There are a few key ingredients of the transformer. First, the authors augment the attention mechanism in the transformer with interatomic distances and the molecular adjacency matrix. Next, they use self-supervised pretraining to achieve improved performance. The authors have released a PyTorch implementation of their code at https://github.com/gmum/MAT.

Here’s a diagram of the Molecular Attention Transformer taken from their paper.

Unlike previous papers that have used transformers on molecules, here compounds are not represented as smiles strings and are represented as lists of atoms. I like this representation better than the SMILES representation since it feels like there’s less overhead for the model to learn (with SMILES, the model has to learn a working knowledge of SMILES before it can predict meaningful atomic properties).

As a brief review of the transformer architecture, a transformer consists of H heads. Head i takes in input hidden state \mathbf{H} and computes \mathbf{Q}_i = \mathbf{H}\mathbf{W}_i^Q, K_i = \mathbf{H}\mathbf{W}_i^H and \mathbf{V}_i = \mathbf{H}\mathbf{W}_i^V. This then leads to the attention operation

\mathcal{A}^{(i)} = \rho \left ( \frac{\mathbf{Q}_i\mathbf{K}_i^T}{\sqrt{d_k}} \right ) \mathbf{V}_i,

The idea behind the molecular attention transformer is to use the adjacency matrix of the molecule and the interatomic distances as additional inputs to the attention operation. This results in the equation

\mathcal{A}^{(i)} = \left ( \lambda_a \rho \left ( \frac{\mathbf{Q}_i\mathbf{K}_i^T}{\sqrt{d_k}} \right ) + \lambda_b g(\mathbf{D}) + \lambda_g \mathbf{A} \right )\mathbf{V}_i,

Here g is something like the softmax operation, \mathbf{D} the distance matrix, and \mathbf{A} the adjacency matrix. Note how this equation blends in information about the molecular structure into the attention. This version seems to require knowing the 3D pose of the molecule, but I suspect that the MAT could work reasonably well without the distance matrix (and indeed, the ablation experiments that they show later in the paper prove this out). There’s a nice elegance to this operation since it blends in molecular information naturally into the mathematical structure of the transformer.

As in graph convolutions, each atom in the molecule is encoded as a vector of chemical descriptors to start.

There are a couple of additional tricks this work uses. One of them is using a dummy node for each molecule, a type of “null atom” which roughly operates analogously to the separation token in BERT style models. The model also uses masked node pretraining. Since many datasets don’t have 3D information available experimentally, 3D conformers are computed using RDKit’s UFFOptimizeMolecule function.

Models are tested on a number of MoleculeNet and other recent datasets. A number of baselines including random forests, edge-attention graph convolutional networks, and weave networks are used. The baseline implementations are those from DeepChem.

Models were hyperparameter tuned using either 150 or 500 different random choices of hyperparameters for tuning.

The paper does a few ablation studies where they test how the presence of the dummy node, the distance matrix, and the use of additional edge features impact the model.

The dummy node seems to add some performance improvements, but the distance measure seems to be of middling value. This intuitively makes sense to me since the RDKit conformers are generated without any knowledge of the underlying physical system and we’re using only one conformer for modeling.

Looking at the attention layers of the molecular transformer seems to yield some insight into what the molecule is focusing on as well.

As an overall summary, the molecular attention transformer is a nice addition to the literature. It has strong performance that beats out a number of the standard DeepChem models on a collection of benchmarks. I think it would be worth adding an implementation to the DeepChem codebase.

6 Likes

Elegant approaches respecting the chemistry indeed work well! Can get onboard with this variant of the Transformer :slight_smile:

1 Like

On first read this was one of the most elegant and attractive approaches to molecule representation I’ve seen, much better than the bag-of-attributes style. After spending some time with it, I’m seeing many areas for improvement, but still love the idea of augmenting attention with domain priors.

One, I think their adjacency and distance matrices favor the identity edge (diagonal) too heavily, as it is set to the maximal value in both matrices, and then the two are summed. I’m testing with the diagonal set to zero, and am getting improved results, but have also made many other changes so not sure how much can be attributed to this.

Also, I’m not sure what to think about their toy distance modeling problem (appendix F). The model doing the best when the distance lambda was 100% makes me suspicious that it’s just learning to query the input for the answer (because the answer exists there). The model isn’t really learning to predict interatomic distances, and therefore it’s gated by the quality of RDKit’s estimation. I would like a way to model distance prediction as an explicit pretraining task. It could be pre-trained on masses of RDKit provided data, and fine-tuned on experimental measurements. This has nice conceptual parallels to AlphaFold’s approach.

If distogram prediction is not easy, at least try masking the adj/dist matrix rows/cols of substructures in question during training.

the RDKit conformers are generated without any knowledge of the underlying physical system and we’re using only one conformer for modeling

@bharath, what do you mean without knowledge of underlying system? I thought the conformers modeled bond rotation freedom and inter atomic forces?

2 Likes

There was a very interesting follow-up discussion on twitter about the toy distance modeling problem with the RDKit lead developer commenting. It seems likely that the 3D information mainly is acting as a noise regularizer and isn’t actually using meaningful physical information https://twitter.com/kudkudakpl/status/1236081257374773251

Good point, I wasn’t being clear when I said “without any knowledge of the underlying physical system.” What I meant was without any 3D experimental structures for the underlying system.

1 Like

If 3D conformation on the whole adds any signal at all (which I think it should, via the biases built into the model), the noisiness can be worked with via augmentation. Multiple conformations seems like a great candidate for Google’s UDA. If it is truly just pure noise on top of 2D structure… one wonders why algorithmic conformers exist at all (or what commercial virtual docking programs are doing), and molecule conformation ~= protein folding.

A really simple way to get something that scales reasonably well with distance is to just count the number of bonds between the features. The RDKit’s 2D Distance Matrix does this.

I guess molecules don’t really end up in shapes that fold over on itself, where cartesian distance << edge hop distance?

1 Like

Definitely agree multiple conformations could be used with UDA. In general, I’d love to see more self-learning techniques for molecules.

Molecular conformations aren’t just noise on top of the 2D structure; I think I wasn’t being clear above. Rather, the molecular conformation is determined by the binding environment. In the absence of information about the binding pocket we don’t really have much reason to favor one conformer over the other, which suggests using conformation information is improving performance by acting as a regularizer.

For more complex molecules like proteins, you definitely do get folks over the molecule. Long range contacts will have cartesian distance << edge hops. For small molecules this is less common although it can still happen.

Ah, binding environment meaning everything about the pocket that will distort the molecule and vice versa. Yeah accurately modeling that is the holy grail.

What do you think about the way adj and dist is just added to the attention weights? If it works it works, but I can’t help but think conceptually the adjacency and distance information is equivalent to relative position encoding, and should be formalized as such. This paper outlines a way of encoding edge labels within the attention weight softmax, not outside of it, and transformer-xl has an even more involved way. The important idea being that you learn more weights (N edge types * inner_dim) than just the three lambdas in MAT.

It’s probably worth trying both encodings to see which does better! This will need some careful benchmarking but should be very feasible

1 Like