wenkai commited on
Commit
3daa625
1 Parent(s): 4e11d5e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +68 -4
app.py CHANGED
@@ -10,20 +10,84 @@ import gradio as gr
10
  from esm_scripts.extract import run_demo
11
  from esm import pretrained, FastaBatchedDataset
12
 
13
- # from transformers import EsmTokenizer, EsmModel
14
-
15
 
16
  # Load the model
17
  model = Blip2ProteinMistral(config=FAPMConfig(), esm_size='3b')
18
  model.load_checkpoint("model/checkpoint_mf2.pth")
19
  model.to('cuda')
20
 
 
 
 
 
21
 
22
  @spaces.GPU
23
  def generate_caption(protein, prompt):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
- esm_emb = torch.load('data/emb_esm2_3b/P18281.pt')['representations'][36]
26
- torch.save(esm_emb, 'data/emb_esm2_3b/example.pt')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  '''
28
  inputs = tokenizer([protein], return_tensors="pt", padding=True, truncation=True).to('cuda')
29
  with torch.no_grad():
 
10
  from esm_scripts.extract import run_demo
11
  from esm import pretrained, FastaBatchedDataset
12
 
 
 
13
 
14
  # Load the model
15
  model = Blip2ProteinMistral(config=FAPMConfig(), esm_size='3b')
16
  model.load_checkpoint("model/checkpoint_mf2.pth")
17
  model.to('cuda')
18
 
19
+ model_esm, alphabet = pretrained.load_model_and_alphabet('esm2_t36_3B_UR50D')
20
+ model_esm.to('cuda')
21
+ model_esm.eval()
22
+
23
 
24
  @spaces.GPU
25
  def generate_caption(protein, prompt):
26
+ # Process the image and the prompt
27
+ # with open('/home/user/app/example.fasta', 'w') as f:
28
+ # f.write('>{}\n'.format("protein_name"))
29
+ # f.write('{}\n'.format(protein.strip()))
30
+ # os.system("python esm_scripts/extract.py esm2_t36_3B_UR50D /home/user/app/example.fasta /home/user/app --repr_layers 36 --truncation_seq_length 1024 --include per_tok")
31
+ # esm_emb = run_demo(protein_name='protein_name', protein_seq=protein,
32
+ # model=model_esm, alphabet=alphabet,
33
+ # include='per_tok', repr_layers=[36], truncation_seq_length=1024)
34
+
35
+ protein_name = 'protein_name'
36
+ protein_seq = protein
37
+ include = 'per_tok'
38
+ repr_layers = [36]
39
+ truncation_seq_length = 1024
40
+ toks_per_batch = 4096
41
+ print("start")
42
+ dataset = FastaBatchedDataset([protein_name], [protein_seq])
43
+ print("dataset prepared")
44
+ batches = dataset.get_batch_indices(toks_per_batch, extra_toks_per_seq=1)
45
+ print("batches prepared")
46
+
47
+ data_loader = torch.utils.data.DataLoader(
48
+ dataset, collate_fn=alphabet.get_batch_converter(truncation_seq_length), batch_sampler=batches
49
+ )
50
+ print(f"Read sequences")
51
+ return_contacts = "contacts" in include
52
+
53
+ assert all(-(model_esm.num_layers + 1) <= i <= model_esm.num_layers for i in repr_layers)
54
+ repr_layers = [(i + model_esm.num_layers + 1) % (model_esm.num_layers + 1) for i in repr_layers]
55
 
56
+ with torch.no_grad():
57
+ for batch_idx, (labels, strs, toks) in enumerate(data_loader):
58
+ print(
59
+ f"Processing {batch_idx + 1} of {len(batches)} batches ({toks.size(0)} sequences)"
60
+ )
61
+ if torch.cuda.is_available():
62
+ toks = toks.to(device="cuda", non_blocking=True)
63
+ out = model_esm(toks, repr_layers=repr_layers, return_contacts=return_contacts)
64
+ representations = {
65
+ layer: t.to(device="cpu") for layer, t in out["representations"].items()
66
+ }
67
+ if return_contacts:
68
+ contacts = out["contacts"].to(device="cpu")
69
+ for i, label in enumerate(labels):
70
+ result = {"label": label}
71
+ truncate_len = min(truncation_seq_length, len(strs[i]))
72
+ # Call clone on tensors to ensure tensors are not views into a larger representation
73
+ # See https://github.com/pytorch/pytorch/issues/1995
74
+ if "per_tok" in include:
75
+ result["representations"] = {
76
+ layer: t[i, 1: truncate_len + 1].clone()
77
+ for layer, t in representations.items()
78
+ }
79
+ if "mean" in include:
80
+ result["mean_representations"] = {
81
+ layer: t[i, 1: truncate_len + 1].mean(0).clone()
82
+ for layer, t in representations.items()
83
+ }
84
+ if "bos" in include:
85
+ result["bos_representations"] = {
86
+ layer: t[i, 0].clone() for layer, t in representations.items()
87
+ }
88
+ if return_contacts:
89
+ result["contacts"] = contacts[i, : truncate_len, : truncate_len].clone()
90
+ esm_emb = result['representations'][36]
91
  '''
92
  inputs = tokenizer([protein], return_tensors="pt", padding=True, truncation=True).to('cuda')
93
  with torch.no_grad():