Update app.py
Browse files
app.py
CHANGED
@@ -15,9 +15,9 @@ model = Blip2ProteinMistral(config=FAPMConfig(), esm_size='3b')
|
|
15 |
model.load_checkpoint("model/checkpoint_mf2.pth")
|
16 |
model.to('cuda')
|
17 |
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
|
22 |
@spaces.GPU
|
23 |
def generate_caption(protein, prompt):
|
@@ -42,13 +42,13 @@ def generate_caption(protein, prompt):
|
|
42 |
print("batches prepared")
|
43 |
|
44 |
data_loader = torch.utils.data.DataLoader(
|
45 |
-
dataset, collate_fn=
|
46 |
)
|
47 |
print(f"Read sequences")
|
48 |
return_contacts = "contacts" in include
|
49 |
|
50 |
-
assert all(-(
|
51 |
-
repr_layers = [(i +
|
52 |
|
53 |
with torch.no_grad():
|
54 |
for batch_idx, (labels, strs, toks) in enumerate(data_loader):
|
@@ -57,7 +57,7 @@ def generate_caption(protein, prompt):
|
|
57 |
)
|
58 |
if torch.cuda.is_available():
|
59 |
toks = toks.to(device="cuda", non_blocking=True)
|
60 |
-
out =
|
61 |
logits = out["logits"].to(device="cpu")
|
62 |
representations = {
|
63 |
layer: t.to(device="cpu") for layer, t in out["representations"].items()
|
|
|
15 |
model.load_checkpoint("model/checkpoint_mf2.pth")
|
16 |
model.to('cuda')
|
17 |
|
18 |
+
model_esm, alphabet = pretrained.load_model_and_alphabet('esm2_t36_3B_UR50D')
|
19 |
+
model_esm.to('cuda')
|
20 |
+
model_esm.eval()
|
21 |
|
22 |
@spaces.GPU
|
23 |
def generate_caption(protein, prompt):
|
|
|
42 |
print("batches prepared")
|
43 |
|
44 |
data_loader = torch.utils.data.DataLoader(
|
45 |
+
dataset, collate_fn=alphabet.get_batch_converter(truncation_seq_length), batch_sampler=batches
|
46 |
)
|
47 |
print(f"Read sequences")
|
48 |
return_contacts = "contacts" in include
|
49 |
|
50 |
+
assert all(-(model_esm.num_layers + 1) <= i <= model_esm.num_layers for i in repr_layers)
|
51 |
+
repr_layers = [(i + model_esm.num_layers + 1) % (model_esm.num_layers + 1) for i in repr_layers]
|
52 |
|
53 |
with torch.no_grad():
|
54 |
for batch_idx, (labels, strs, toks) in enumerate(data_loader):
|
|
|
57 |
)
|
58 |
if torch.cuda.is_available():
|
59 |
toks = toks.to(device="cuda", non_blocking=True)
|
60 |
+
out = model_esm(toks, repr_layers=repr_layers, return_contacts=return_contacts)
|
61 |
logits = out["logits"].to(device="cpu")
|
62 |
representations = {
|
63 |
layer: t.to(device="cpu") for layer, t in out["representations"].items()
|