wenkai commited on
Commit
c8e59d5
1 Parent(s): 9b993cf

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -11
app.py CHANGED
@@ -10,14 +10,10 @@ import gradio as gr
10
  from esm_scripts.extract import run_demo
11
  from esm import pretrained, FastaBatchedDataset
12
 
13
-
14
- model_esm, alphabet = pretrained.load_model_and_alphabet('esm2_t36_3B_UR50D')
15
- model_esm.to('cuda')
16
-
17
  # Load the model
18
- # model = Blip2ProteinMistral(config=FAPMConfig(), esm_size='3b')
19
- # model.load_checkpoint("model/checkpoint_mf2.pth")
20
- # model.to('cuda')
21
 
22
 
23
  @spaces.GPU
@@ -46,9 +42,13 @@ def generate_caption(protein, prompt):
46
  )
47
  print(f"Read sequences")
48
  return_contacts = "contacts" in include
 
 
 
 
49
  assert all(-(model_esm.num_layers + 1) <= i <= model_esm.num_layers for i in repr_layers)
50
  repr_layers = [(i + model_esm.num_layers + 1) % (model_esm.num_layers + 1) for i in repr_layers]
51
- model_esm.eval()
52
  with torch.no_grad():
53
  for batch_idx, (labels, strs, toks) in enumerate(data_loader):
54
  print(
@@ -57,6 +57,7 @@ def generate_caption(protein, prompt):
57
  if torch.cuda.is_available():
58
  toks = toks.to(device="cuda", non_blocking=True)
59
  out = model_esm(toks, repr_layers=repr_layers, return_contacts=return_contacts)
 
60
  logits = out["logits"].to(device="cpu")
61
  representations = {
62
  layer: t.to(device="cpu") for layer, t in out["representations"].items()
@@ -94,10 +95,10 @@ def generate_caption(protein, prompt):
94
  'text_input': ['none'],
95
  'prompt': [prompt]}
96
  # Generate the output
97
- # prediction = model.generate(samples, length_penalty=0., num_beams=15, num_captions=10, temperature=1., repetition_penalty=1.0)
98
 
99
- # return prediction
100
- return "test"
101
 
102
  # Define the FAPM interface
103
  description = """Quick demonstration of the FAPM model for protein function prediction. Upload an protein sequence to generate a function description. Modify the Prompt to provide the taxonomy information.
 
10
  from esm_scripts.extract import run_demo
11
  from esm import pretrained, FastaBatchedDataset
12
 
 
 
 
 
13
  # Load the model
14
+ model = Blip2ProteinMistral(config=FAPMConfig(), esm_size='3b')
15
+ model.load_checkpoint("model/checkpoint_mf2.pth")
16
+ model.to('cuda')
17
 
18
 
19
  @spaces.GPU
 
42
  )
43
  print(f"Read sequences")
44
  return_contacts = "contacts" in include
45
+
46
+ model_esm, alphabet = pretrained.load_model_and_alphabet('esm2_t36_3B_UR50D')
47
+ model_esm.to('cuda')
48
+ model_esm.eval()
49
  assert all(-(model_esm.num_layers + 1) <= i <= model_esm.num_layers for i in repr_layers)
50
  repr_layers = [(i + model_esm.num_layers + 1) % (model_esm.num_layers + 1) for i in repr_layers]
51
+
52
  with torch.no_grad():
53
  for batch_idx, (labels, strs, toks) in enumerate(data_loader):
54
  print(
 
57
  if torch.cuda.is_available():
58
  toks = toks.to(device="cuda", non_blocking=True)
59
  out = model_esm(toks, repr_layers=repr_layers, return_contacts=return_contacts)
60
+ del model_esm
61
  logits = out["logits"].to(device="cpu")
62
  representations = {
63
  layer: t.to(device="cpu") for layer, t in out["representations"].items()
 
95
  'text_input': ['none'],
96
  'prompt': [prompt]}
97
  # Generate the output
98
+ prediction = model.generate(samples, length_penalty=0., num_beams=15, num_captions=10, temperature=1., repetition_penalty=1.0)
99
 
100
+ return prediction
101
+ # return "test"
102
 
103
  # Define the FAPM interface
104
  description = """Quick demonstration of the FAPM model for protein function prediction. Upload an protein sequence to generate a function description. Modify the Prompt to provide the taxonomy information.