Update app.py
Browse files
app.py
CHANGED
@@ -10,14 +10,10 @@ import gradio as gr
|
|
10 |
from esm_scripts.extract import run_demo
|
11 |
from esm import pretrained, FastaBatchedDataset
|
12 |
|
13 |
-
|
14 |
-
model_esm, alphabet = pretrained.load_model_and_alphabet('esm2_t36_3B_UR50D')
|
15 |
-
model_esm.to('cuda')
|
16 |
-
|
17 |
# Load the model
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
|
22 |
|
23 |
@spaces.GPU
|
@@ -46,9 +42,13 @@ def generate_caption(protein, prompt):
|
|
46 |
)
|
47 |
print(f"Read sequences")
|
48 |
return_contacts = "contacts" in include
|
|
|
|
|
|
|
|
|
49 |
assert all(-(model_esm.num_layers + 1) <= i <= model_esm.num_layers for i in repr_layers)
|
50 |
repr_layers = [(i + model_esm.num_layers + 1) % (model_esm.num_layers + 1) for i in repr_layers]
|
51 |
-
|
52 |
with torch.no_grad():
|
53 |
for batch_idx, (labels, strs, toks) in enumerate(data_loader):
|
54 |
print(
|
@@ -57,6 +57,7 @@ def generate_caption(protein, prompt):
|
|
57 |
if torch.cuda.is_available():
|
58 |
toks = toks.to(device="cuda", non_blocking=True)
|
59 |
out = model_esm(toks, repr_layers=repr_layers, return_contacts=return_contacts)
|
|
|
60 |
logits = out["logits"].to(device="cpu")
|
61 |
representations = {
|
62 |
layer: t.to(device="cpu") for layer, t in out["representations"].items()
|
@@ -94,10 +95,10 @@ def generate_caption(protein, prompt):
|
|
94 |
'text_input': ['none'],
|
95 |
'prompt': [prompt]}
|
96 |
# Generate the output
|
97 |
-
|
98 |
|
99 |
-
|
100 |
-
return "test"
|
101 |
|
102 |
# Define the FAPM interface
|
103 |
description = """Quick demonstration of the FAPM model for protein function prediction. Upload an protein sequence to generate a function description. Modify the Prompt to provide the taxonomy information.
|
|
|
10 |
from esm_scripts.extract import run_demo
|
11 |
from esm import pretrained, FastaBatchedDataset
|
12 |
|
|
|
|
|
|
|
|
|
13 |
# Load the model
|
14 |
+
model = Blip2ProteinMistral(config=FAPMConfig(), esm_size='3b')
|
15 |
+
model.load_checkpoint("model/checkpoint_mf2.pth")
|
16 |
+
model.to('cuda')
|
17 |
|
18 |
|
19 |
@spaces.GPU
|
|
|
42 |
)
|
43 |
print(f"Read sequences")
|
44 |
return_contacts = "contacts" in include
|
45 |
+
|
46 |
+
model_esm, alphabet = pretrained.load_model_and_alphabet('esm2_t36_3B_UR50D')
|
47 |
+
model_esm.to('cuda')
|
48 |
+
model_esm.eval()
|
49 |
assert all(-(model_esm.num_layers + 1) <= i <= model_esm.num_layers for i in repr_layers)
|
50 |
repr_layers = [(i + model_esm.num_layers + 1) % (model_esm.num_layers + 1) for i in repr_layers]
|
51 |
+
|
52 |
with torch.no_grad():
|
53 |
for batch_idx, (labels, strs, toks) in enumerate(data_loader):
|
54 |
print(
|
|
|
57 |
if torch.cuda.is_available():
|
58 |
toks = toks.to(device="cuda", non_blocking=True)
|
59 |
out = model_esm(toks, repr_layers=repr_layers, return_contacts=return_contacts)
|
60 |
+
del model_esm
|
61 |
logits = out["logits"].to(device="cpu")
|
62 |
representations = {
|
63 |
layer: t.to(device="cpu") for layer, t in out["representations"].items()
|
|
|
95 |
'text_input': ['none'],
|
96 |
'prompt': [prompt]}
|
97 |
# Generate the output
|
98 |
+
prediction = model.generate(samples, length_penalty=0., num_beams=15, num_captions=10, temperature=1., repetition_penalty=1.0)
|
99 |
|
100 |
+
return prediction
|
101 |
+
# return "test"
|
102 |
|
103 |
# Define the FAPM interface
|
104 |
description = """Quick demonstration of the FAPM model for protein function prediction. Upload an protein sequence to generate a function description. Modify the Prompt to provide the taxonomy information.
|