AmelieSchreiber commited on
Commit
27d2c5d
1 Parent(s): 14b33ff

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +90 -0
README.md CHANGED
@@ -1,3 +1,93 @@
1
  ---
2
  license: mit
3
  ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
  license: mit
3
  ---
4
+
5
+ # ESM-2 for Generating Peptide Binders for Proteins
6
+ This is just a retraining of PepMLM using this [forked repo](https://github.com/Amelie-Schreiber/pepmlm/tree/main).
7
+ The original PepMLM is also already on HuggingFace [here](https://huggingface.co/TianlaiChen/PepMLM-650M).
8
+
9
+
10
+ ## Using the Model
11
+
12
+ To use the model, run the following:
13
+
14
+ ```python
15
+ from transformers import AutoTokenizer, EsmForMaskedLM
16
+ import torch
17
+ import pandas as pd
18
+ import numpy as np
19
+ from torch.distributions import Categorical
20
+
21
+ def compute_pseudo_perplexity(model, tokenizer, protein_seq, binder_seq):
22
+ sequence = protein_seq + binder_seq
23
+ tensor_input = tokenizer.encode(sequence, return_tensors='pt').to(model.device)
24
+
25
+ # Create a mask for the binder sequence
26
+ binder_mask = torch.zeros(tensor_input.shape).to(model.device)
27
+ binder_mask[0, -len(binder_seq)-1:-1] = 1
28
+
29
+ # Mask the binder sequence in the input and create labels
30
+ masked_input = tensor_input.clone().masked_fill_(binder_mask.bool(), tokenizer.mask_token_id)
31
+ labels = tensor_input.clone().masked_fill_(~binder_mask.bool(), -100)
32
+
33
+ with torch.no_grad():
34
+ loss = model(masked_input, labels=labels).loss
35
+ return np.exp(loss.item())
36
+
37
+
38
+ def generate_peptide_for_single_sequence(protein_seq, peptide_length = 15, top_k = 3, num_binders = 4):
39
+
40
+ peptide_length = int(peptide_length)
41
+ top_k = int(top_k)
42
+ num_binders = int(num_binders)
43
+
44
+ binders_with_ppl = []
45
+
46
+ for _ in range(num_binders):
47
+ # Generate binder
48
+ masked_peptide = '<mask>' * peptide_length
49
+ input_sequence = protein_seq + masked_peptide
50
+ inputs = tokenizer(input_sequence, return_tensors="pt").to(model.device)
51
+
52
+ with torch.no_grad():
53
+ logits = model(**inputs).logits
54
+ mask_token_indices = (inputs["input_ids"] == tokenizer.mask_token_id).nonzero(as_tuple=True)[1]
55
+ logits_at_masks = logits[0, mask_token_indices]
56
+
57
+ # Apply top-k sampling
58
+ top_k_logits, top_k_indices = logits_at_masks.topk(top_k, dim=-1)
59
+ probabilities = torch.nn.functional.softmax(top_k_logits, dim=-1)
60
+ predicted_indices = Categorical(probabilities).sample()
61
+ predicted_token_ids = top_k_indices.gather(-1, predicted_indices.unsqueeze(-1)).squeeze(-1)
62
+
63
+ generated_binder = tokenizer.decode(predicted_token_ids, skip_special_tokens=True).replace(' ', '')
64
+
65
+ # Compute PPL for the generated binder
66
+ ppl_value = compute_pseudo_perplexity(model, tokenizer, protein_seq, generated_binder)
67
+
68
+ # Add the generated binder and its PPL to the results list
69
+ binders_with_ppl.append([generated_binder, ppl_value])
70
+
71
+ return binders_with_ppl
72
+
73
+ def generate_peptide(input_seqs, peptide_length=15, top_k=3, num_binders=4):
74
+ if isinstance(input_seqs, str): # Single sequence
75
+ binders = generate_peptide_for_single_sequence(input_seqs, peptide_length, top_k, num_binders)
76
+ return pd.DataFrame(binders, columns=['Binder', 'Pseudo Perplexity'])
77
+
78
+ elif isinstance(input_seqs, list): # List of sequences
79
+ results = []
80
+ for seq in input_seqs:
81
+ binders = generate_peptide_for_single_sequence(seq, peptide_length, top_k, num_binders)
82
+ for binder, ppl in binders:
83
+ results.append([seq, binder, ppl])
84
+ return pd.DataFrame(results, columns=['Input Sequence', 'Binder', 'Pseudo Perplexity'])
85
+
86
+ model = EsmForMaskedLM.from_pretrained("output/checkpoint-255")
87
+ tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D")
88
+
89
+ protein_seq = "MAPLRKTYVLKLYVAGNTPNSVRALKTLNNILEKEFKGVYALKVIDVLKNPQLAEEDKILATPTLAKVLPPPVRRIIGDLSNREKVLIGLDLLYEEIGDQAEDDLGLE"
90
+
91
+ results_df = generate_peptide(protein_seq, peptide_length=15, top_k=3, num_binders=5)
92
+ print(results_df)
93
+ ```