Update app.py
Browse files
app.py
CHANGED
@@ -14,13 +14,13 @@ from esm import pretrained, FastaBatchedDataset
|
|
14 |
|
15 |
|
16 |
# Load the model
|
17 |
-
|
18 |
-
|
19 |
# model.to('cuda')
|
20 |
|
21 |
-
|
22 |
# model_esm.to('cuda')
|
23 |
-
|
24 |
|
25 |
|
26 |
# tokenizer = EsmTokenizer.from_pretrained("facebook/esm2_t36_3B_UR50D")
|
@@ -51,9 +51,7 @@ def generate_caption(protein, prompt):
|
|
51 |
batches = dataset.get_batch_indices(toks_per_batch, extra_toks_per_seq=1)
|
52 |
print("batches prepared")
|
53 |
|
54 |
-
model_esm, alphabet = pretrained.load_model_and_alphabet('esm2_t36_3B_UR50D')
|
55 |
model_esm.to('cuda')
|
56 |
-
model_esm.eval()
|
57 |
|
58 |
data_loader = torch.utils.data.DataLoader(
|
59 |
dataset, collate_fn=alphabet.get_batch_converter(truncation_seq_length), batch_sampler=batches
|
@@ -115,9 +113,7 @@ def generate_caption(protein, prompt):
|
|
115 |
'prompt': [prompt]}
|
116 |
|
117 |
del model_esm
|
118 |
-
|
119 |
-
model = Blip2ProteinMistral(config=FAPMConfig(), esm_size='3b')
|
120 |
-
model.load_checkpoint("model/checkpoint_mf2.pth")
|
121 |
model.to('cuda')
|
122 |
# Generate the output
|
123 |
prediction = model.generate(samples, length_penalty=0., num_beams=15, num_captions=10, temperature=1.,
|
|
|
14 |
|
15 |
|
16 |
# Load the model
|
17 |
+
model = Blip2ProteinMistral(config=FAPMConfig(), esm_size='3b')
|
18 |
+
model.load_checkpoint("model/checkpoint_mf2.pth")
|
19 |
# model.to('cuda')
|
20 |
|
21 |
+
model_esm, alphabet = pretrained.load_model_and_alphabet('esm2_t36_3B_UR50D')
|
22 |
# model_esm.to('cuda')
|
23 |
+
model_esm.eval()
|
24 |
|
25 |
|
26 |
# tokenizer = EsmTokenizer.from_pretrained("facebook/esm2_t36_3B_UR50D")
|
|
|
51 |
batches = dataset.get_batch_indices(toks_per_batch, extra_toks_per_seq=1)
|
52 |
print("batches prepared")
|
53 |
|
|
|
54 |
model_esm.to('cuda')
|
|
|
55 |
|
56 |
data_loader = torch.utils.data.DataLoader(
|
57 |
dataset, collate_fn=alphabet.get_batch_converter(truncation_seq_length), batch_sampler=batches
|
|
|
113 |
'prompt': [prompt]}
|
114 |
|
115 |
del model_esm
|
116 |
+
|
|
|
|
|
117 |
model.to('cuda')
|
118 |
# Generate the output
|
119 |
prediction = model.generate(samples, length_penalty=0., num_beams=15, num_captions=10, temperature=1.,
|