|
--- |
|
library_name: transformers |
|
license: apache-2.0 |
|
datasets: |
|
- vector-institute/s2ef-15m |
|
- vector-institute/atom3d-smp |
|
metrics: |
|
- mae |
|
--- |
|
|
|
# AtomFormer base model Finetuned on Small Molecule Prediction task (SMP) |
|
|
|
This model is a transformer-based model that leverages gaussian pair-wise positional embeddings to train on atomistic graph data. It |
|
is part of a suite of datasets/models/utilities in the AtomGen project that supports other methods for pre-training and fine-tuning |
|
models on atomistic graphs. This particular model is pre-trained on the `s2ef-15m` dataset and finetuned on the `atom3d-smp` dataset. |
|
|
|
|
|
## Model description |
|
|
|
AtomFormer is a transformer model with modifcations to train on atomstic graphs. It builds primarily on the work |
|
from uni-mol+ to add the pair-wise pos. embeds. to the attention mask to leverage 3-D positional information. |
|
This model was pre-trained on a diverse set of aggregated atomistic datasets where the target task is the per-atom |
|
force prediction and the per-system energy prediction. |
|
|
|
The model also includes metadata regarding the atomic species that are being modeled, this includes the atomic radius, |
|
electronegativity, valency, etc. The metadata is normalized and projected to be added to the atom embeddings in the model. |
|
|
|
|
|
## Intended uses & limitations |
|
|
|
You can use the model to predict properties for small molecules. |
|
|
|
|
|
### How to use |
|
|
|
Here is how to use the model to extract features from the pre-trained backbone: |
|
|
|
```python |
|
import torch |
|
from transformers import AutoModelForSequenceClassification |
|
model = AutoModelForSequenceClassification.from_pretrained("vector-institute/atomformer-base-smp", |
|
trust_remote_code=True) |
|
|
|
input_ids = torch.randint(0, 50, (1, 10)) |
|
coords = torch.randn(1, 10, 3) |
|
attn_mask = torch.ones(1, 10) |
|
|
|
output = model(input_ids, coords=coords, attention_mask=attn_mask) |
|
output[1].shape # (torch.Size([1, 20]) |
|
``` |
|
|
|
|
|
## Training data |
|
|
|
AtomFormer is trained on an aggregated S2EF dataset from multiple sources such as OC20, OC22, ODAC23, MPtrj, and SPICE |
|
with structures and energies/forces for pre-training. The pre-training data includes total energies and formation |
|
energies but trains using formation energy (which isn't included for OC22, indicated by "has_formation_energy" column). |
|
|
|
This model variant is finetuned on the small molecule prediction task where it outputs 20 different properties for each sample. |
|
|
|
|
|
### Preprocessing |
|
|
|
The model expects input in the form of tokenized atomic symbols represented as `input_ids` and 3D coordinates represented |
|
as `coords`. For the pre-training task it also expects labels for the `forces` and `formation_energy`. |
|
|
|
The `DataCollatorForAtomModeling` utility in the AtomGen library has the capacity to perform dynamic padding to batch the |
|
data together. It also offers the option to flatten the data and provide a `batch` column for gnn-style training. |
|
|
|
|
|
## Evaluation results |
|
|
|
The model is trained for 300 epochs with a batch size of 512, learning rate of 1e-3, cosine decay to zero, max_grad_norm of 5.0, and weight decay of 1e-2. |
|
|
|
Comparison between leveraging the pre-trained base model and training from scratch on the SMP task: |
|
|
|
| | base | scratch | |
|
|:----:|:----:|:----:| |
|
| val | 0.1766 | 0.2304 | |
|
| test | 1.077 | 1.13 | |