wenkai commited on
Commit
3705c34
1 Parent(s): 3ef1459

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -4
app.py CHANGED
@@ -9,15 +9,21 @@ import spaces
9
  import gradio as gr
10
  from esm_scripts.extract import run_demo
11
  from esm import pretrained, FastaBatchedDataset
 
 
12
 
13
  # Load the model
14
  model = Blip2ProteinMistral(config=FAPMConfig(), esm_size='3b')
15
  model.load_checkpoint("model/checkpoint_mf2.pth")
16
  model.to('cuda')
17
 
18
- model_esm, alphabet = pretrained.load_model_and_alphabet('esm2_t36_3B_UR50D')
19
- model_esm.to('cuda')
20
- model_esm.eval()
 
 
 
 
21
 
22
  @spaces.GPU
23
  def generate_caption(protein, prompt):
@@ -29,7 +35,8 @@ def generate_caption(protein, prompt):
29
  # esm_emb = run_demo(protein_name='protein_name', protein_seq=protein,
30
  # model=model_esm, alphabet=alphabet,
31
  # include='per_tok', repr_layers=[36], truncation_seq_length=1024)
32
- protein_name='protein_name'
 
33
  protein_seq=protein
34
  include='per_tok'
35
  repr_layers=[36]
@@ -86,6 +93,11 @@ def generate_caption(protein, prompt):
86
  if return_contacts:
87
  result["contacts"] = contacts[i, : truncate_len, : truncate_len].clone()
88
  esm_emb = result['representations'][36]
 
 
 
 
 
89
 
90
  print("esm embedding generated")
91
  esm_emb = F.pad(esm_emb.t(), (0, 1024 - len(esm_emb))).t().to('cuda')
 
9
  import gradio as gr
10
  from esm_scripts.extract import run_demo
11
  from esm import pretrained, FastaBatchedDataset
12
+ from transformers import EsmTokenizer, EsmModel
13
+
14
 
15
  # Load the model
16
  model = Blip2ProteinMistral(config=FAPMConfig(), esm_size='3b')
17
  model.load_checkpoint("model/checkpoint_mf2.pth")
18
  model.to('cuda')
19
 
20
+ # model_esm, alphabet = pretrained.load_model_and_alphabet('esm2_t36_3B_UR50D')
21
+ # model_esm.to('cuda')
22
+ # model_esm.eval()
23
+ tokenizer = EsmTokenizer.from_pretrained("facebook/esm2_t36_3B_UR50D")
24
+ model = EsmModel.from_pretrained("facebook/esm2_t36_3B_UR50D")
25
+ model.to('cuda')
26
+ model.eval()
27
 
28
  @spaces.GPU
29
  def generate_caption(protein, prompt):
 
35
  # esm_emb = run_demo(protein_name='protein_name', protein_seq=protein,
36
  # model=model_esm, alphabet=alphabet,
37
  # include='per_tok', repr_layers=[36], truncation_seq_length=1024)
38
+ '''
39
+ protein_name='protein_name'
40
  protein_seq=protein
41
  include='per_tok'
42
  repr_layers=[36]
 
93
  if return_contacts:
94
  result["contacts"] = contacts[i, : truncate_len, : truncate_len].clone()
95
  esm_emb = result['representations'][36]
96
+ '''
97
+ inputs = tokenizer([protein], return_tensors="pt", padding=True, truncation=True)
98
+ with torch.no_grad():
99
+ outputs = model(**inputs)
100
+ esm_emb = outputs.last_hidden_state.detach()[0]
101
 
102
  print("esm embedding generated")
103
  esm_emb = F.pad(esm_emb.t(), (0, 1024 - len(esm_emb))).t().to('cuda')