entropy commited on
Commit
c889db3
1 Parent(s): 0ba5847

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +64 -0
README.md CHANGED
@@ -1,3 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
  license: mit
3
  ---
 
1
+ ---
2
+ tags:
3
+ - chemistry
4
+ - molecule
5
+ - drug
6
+ ---
7
+
8
+ # Roberta Zinc Decoder
9
+
10
+ This model is a GPT2 decoder model designed to reconstruct SMILES strings from embeddings created by the
11
+ [roberta_zinc_480m](https://huggingface.co/entropy/roberta_zinc_480m) model.
12
+
13
+ The decoder model conditions generation on mean pooled embeddings from the encoder model. Mean pooled
14
+ embeddings are used to allow for integration with vector databases, which require fixed length embeddings.
15
+
16
+ Condition embeddings are passed to the decoder model using the `encoder_hidden_states` attribute.
17
+ The standard `GPT2LMHeadModel` does not support generation with encoder hidden states, so this repo
18
+ includes a custom `ConditionalGPT2LMHeadModel`. See example below for how to instantiate the model.
19
+
20
+ ```python
21
+ import torch
22
+ from transformers import AutoModelForCausalLM, RobertaTokenizerFast, RobertaForMaskedLM, DataCollatorWithPadding
23
+
24
+ tokenizer = RobertaTokenizerFast.from_pretrained("entropy/roberta_zinc_480m", max_len=256)
25
+ collator = DataCollatorWithPadding(tokenizer, padding=True, return_tensors='pt')
26
+
27
+ encoder_model = RobertaForMaskedLM.from_pretrained('entropy/roberta_zinc_480m')
28
+ encoder_model.eval();
29
+
30
+ commit_hash = '0ba58478f467056fe33003d7d91644ecede695a7'
31
+ decoder_model = AutoModelForCausalLM.from_pretrained("entropy/roberta_zinc_decoder",
32
+ trust_remote_code=True, revision=commit_hash)
33
+ decoder_model.eval();
34
+
35
+
36
+ smiles = ['Brc1cc2c(NCc3ccccc3)ncnc2s1',
37
+ 'Brc1cc2c(NCc3ccccn3)ncnc2s1',
38
+ 'Brc1cc2c(NCc3cccs3)ncnc2s1',
39
+ 'Brc1cc2c(NCc3ccncc3)ncnc2s1',
40
+ 'Brc1cc2c(Nc3ccccc3)ncnc2s1']
41
+
42
+ inputs = collator(tokenizer(smiles))
43
+ outputs = encoder_model(**inputs, output_hidden_states=True)
44
+ full_embeddings = outputs[1][-1]
45
+ mask = inputs['attention_mask']
46
+ mean_embeddings = ((full_embeddings * mask.unsqueeze(-1)).sum(1) / mask.sum(-1).unsqueeze(-1))
47
+
48
+ decoder_inputs = torch.tensor([[tokenizer.bos_token_id] for i in range(len(smiles))])
49
+
50
+ hidden_states = mean_embeddings[:,None] # hidden states shape (bs, 1, -1)
51
+
52
+ gen = decoder_model.generate(
53
+ decoder_inputs,
54
+ encoder_hidden_states=hidden_states,
55
+ do_sample=False, # greedy decoding is recommended
56
+ max_length=100,
57
+ temperature=1.,
58
+ early_stopping=True,
59
+ pad_token_id=tokenizer.pad_token_id,
60
+ )
61
+
62
+ reconstructed_smiles = tokenizer.batch_decode(gen, skip_special_tokens=True)
63
+ ```
64
+
65
  ---
66
  license: mit
67
  ---