File size: 3,779 Bytes
c889db3
 
 
 
 
 
 
 
 
 
b2975b7
 
c889db3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b2975b7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f4737d4
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
---
tags: 
- chemistry
- molecule
- drug
---

# Roberta Zinc Decoder

This model is a GPT2 decoder model designed to reconstruct SMILES strings from embeddings created by the 
[roberta_zinc_480m](https://huggingface.co/entropy/roberta_zinc_480m) model. The decoder model was 
trained on 30m compounds from the [ZINC Database](https://zinc.docking.org/).

The decoder model conditions generation on mean pooled embeddings from the encoder model. Mean pooled 
embeddings are used to allow for integration with vector databases, which require fixed length embeddings.

Condition embeddings are passed to the decoder model using the `encoder_hidden_states` attribute. 
The standard `GPT2LMHeadModel` does not support generation with encoder hidden states, so this repo 
includes a custom `ConditionalGPT2LMHeadModel`. See example below for how to instantiate the model.

```python
import torch
from transformers import AutoModelForCausalLM, RobertaTokenizerFast, RobertaForMaskedLM, DataCollatorWithPadding

tokenizer = RobertaTokenizerFast.from_pretrained("entropy/roberta_zinc_480m", max_len=256)
collator = DataCollatorWithPadding(tokenizer, padding=True, return_tensors='pt')

encoder_model = RobertaForMaskedLM.from_pretrained('entropy/roberta_zinc_480m')
encoder_model.eval();

commit_hash = '0ba58478f467056fe33003d7d91644ecede695a7'
decoder_model = AutoModelForCausalLM.from_pretrained("entropy/roberta_zinc_decoder",
                                                     trust_remote_code=True, revision=commit_hash)
decoder_model.eval();


smiles = ['Brc1cc2c(NCc3ccccc3)ncnc2s1',
 'Brc1cc2c(NCc3ccccn3)ncnc2s1',
 'Brc1cc2c(NCc3cccs3)ncnc2s1',
 'Brc1cc2c(NCc3ccncc3)ncnc2s1',
 'Brc1cc2c(Nc3ccccc3)ncnc2s1']

inputs = collator(tokenizer(smiles))
outputs = encoder_model(**inputs, output_hidden_states=True)
full_embeddings = outputs[1][-1]
mask = inputs['attention_mask']
mean_embeddings = ((full_embeddings * mask.unsqueeze(-1)).sum(1) / mask.sum(-1).unsqueeze(-1))

decoder_inputs = torch.tensor([[tokenizer.bos_token_id] for i in range(len(smiles))])

hidden_states = mean_embeddings[:,None] # hidden states shape (bs, 1, -1)

gen = decoder_model.generate(
              decoder_inputs,
              encoder_hidden_states=hidden_states,
              do_sample=False, # greedy decoding is recommended
              max_length=100, 
              temperature=1.,
              early_stopping=True,
              pad_token_id=tokenizer.pad_token_id,
                         )

reconstructed_smiles = tokenizer.batch_decode(gen, skip_special_tokens=True)
```

## Model Performance

The decoder model was evaluated on a test set of 1m compounds from ZINC. Compounds 
were encoded with the [roberta_zinc_480m](https://huggingface.co/entropy/roberta_zinc_480m) model 
and reconstructed with the decoder model.

The following metrics are computed:
* `exact_match` - percent of inputs exactly reconstructed
* `token_accuracy` - percent of output tokens exactly matching input tokens (excluding padding)
* `valid_structure` - percent of generated outputs that resolved to a valid SMILES string
* `tanimoto` - tanimoto similarity between inputs and generated outputs. Excludes invalid structures
* `cos_sim` - cosine similarity between input encoder embeddings and output encoder embeddings

`eval_type=full` reports metrics for the full 1m compound test set.

`eval_type=failed` subsets metrics for generated outputs that failed to exactly replicate the inputs.


|eval_type|exact_match|token_accuracy|valid_structure|tanimoto|cos_sim |
|---------|-----------|--------------|---------------|--------|--------|
|full     |0.948277   |0.990704      |0.994278       |0.987698|0.998224|
|failed   |0.000000   |0.820293      |0.889372       |0.734097|0.965668|


---
license: mit
---