wenkai commited on
Commit
d3edc5a
1 Parent(s): 1a0324b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -9
app.py CHANGED
@@ -14,13 +14,13 @@ from esm import pretrained, FastaBatchedDataset
14
 
15
 
16
  # Load the model
17
- # model = Blip2ProteinMistral(config=FAPMConfig(), esm_size='3b')
18
- # model.load_checkpoint("model/checkpoint_mf2.pth")
19
  # model.to('cuda')
20
 
21
- # model_esm, alphabet = pretrained.load_model_and_alphabet('esm2_t36_3B_UR50D')
22
  # model_esm.to('cuda')
23
- # model_esm.eval()
24
 
25
 
26
  # tokenizer = EsmTokenizer.from_pretrained("facebook/esm2_t36_3B_UR50D")
@@ -51,9 +51,7 @@ def generate_caption(protein, prompt):
51
  batches = dataset.get_batch_indices(toks_per_batch, extra_toks_per_seq=1)
52
  print("batches prepared")
53
 
54
- model_esm, alphabet = pretrained.load_model_and_alphabet('esm2_t36_3B_UR50D')
55
  model_esm.to('cuda')
56
- model_esm.eval()
57
 
58
  data_loader = torch.utils.data.DataLoader(
59
  dataset, collate_fn=alphabet.get_batch_converter(truncation_seq_length), batch_sampler=batches
@@ -115,9 +113,7 @@ def generate_caption(protein, prompt):
115
  'prompt': [prompt]}
116
 
117
  del model_esm
118
-
119
- model = Blip2ProteinMistral(config=FAPMConfig(), esm_size='3b')
120
- model.load_checkpoint("model/checkpoint_mf2.pth")
121
  model.to('cuda')
122
  # Generate the output
123
  prediction = model.generate(samples, length_penalty=0., num_beams=15, num_captions=10, temperature=1.,
 
14
 
15
 
16
  # Load the model
17
+ model = Blip2ProteinMistral(config=FAPMConfig(), esm_size='3b')
18
+ model.load_checkpoint("model/checkpoint_mf2.pth")
19
  # model.to('cuda')
20
 
21
+ model_esm, alphabet = pretrained.load_model_and_alphabet('esm2_t36_3B_UR50D')
22
  # model_esm.to('cuda')
23
+ model_esm.eval()
24
 
25
 
26
  # tokenizer = EsmTokenizer.from_pretrained("facebook/esm2_t36_3B_UR50D")
 
51
  batches = dataset.get_batch_indices(toks_per_batch, extra_toks_per_seq=1)
52
  print("batches prepared")
53
 
 
54
  model_esm.to('cuda')
 
55
 
56
  data_loader = torch.utils.data.DataLoader(
57
  dataset, collate_fn=alphabet.get_batch_converter(truncation_seq_length), batch_sampler=batches
 
113
  'prompt': [prompt]}
114
 
115
  del model_esm
116
+
 
 
117
  model.to('cuda')
118
  # Generate the output
119
  prediction = model.generate(samples, length_penalty=0., num_beams=15, num_captions=10, temperature=1.,