wenkai commited on
Commit
d376f39
·
verified ·
1 Parent(s): e4c6c5e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -7
app.py CHANGED
@@ -15,9 +15,9 @@ model = Blip2ProteinMistral(config=FAPMConfig(), esm_size='3b')
15
  model.load_checkpoint("model/checkpoint_mf2.pth")
16
  model.to('cuda')
17
 
18
- # model_esm, alphabet = pretrained.load_model_and_alphabet('esm2_t36_3B_UR50D')
19
- # model_esm.to('cuda')
20
- # model_esm.eval()
21
 
22
  @spaces.GPU
23
  def generate_caption(protein, prompt):
@@ -42,13 +42,13 @@ def generate_caption(protein, prompt):
42
  print("batches prepared")
43
 
44
  data_loader = torch.utils.data.DataLoader(
45
- dataset, collate_fn=model.alphabet.get_batch_converter(truncation_seq_length), batch_sampler=batches
46
  )
47
  print(f"Read sequences")
48
  return_contacts = "contacts" in include
49
 
50
- assert all(-(model.model_esm.num_layers + 1) <= i <= model.model_esm.num_layers for i in repr_layers)
51
- repr_layers = [(i + model.model_esm.num_layers + 1) % (model.model_esm.num_layers + 1) for i in repr_layers]
52
 
53
  with torch.no_grad():
54
  for batch_idx, (labels, strs, toks) in enumerate(data_loader):
@@ -57,7 +57,7 @@ def generate_caption(protein, prompt):
57
  )
58
  if torch.cuda.is_available():
59
  toks = toks.to(device="cuda", non_blocking=True)
60
- out = model.model_esm(toks, repr_layers=repr_layers, return_contacts=return_contacts)
61
  logits = out["logits"].to(device="cpu")
62
  representations = {
63
  layer: t.to(device="cpu") for layer, t in out["representations"].items()
 
15
  model.load_checkpoint("model/checkpoint_mf2.pth")
16
  model.to('cuda')
17
 
18
+ model_esm, alphabet = pretrained.load_model_and_alphabet('esm2_t36_3B_UR50D')
19
+ model_esm.to('cuda')
20
+ model_esm.eval()
21
 
22
  @spaces.GPU
23
  def generate_caption(protein, prompt):
 
42
  print("batches prepared")
43
 
44
  data_loader = torch.utils.data.DataLoader(
45
+ dataset, collate_fn=alphabet.get_batch_converter(truncation_seq_length), batch_sampler=batches
46
  )
47
  print(f"Read sequences")
48
  return_contacts = "contacts" in include
49
 
50
+ assert all(-(model_esm.num_layers + 1) <= i <= model_esm.num_layers for i in repr_layers)
51
+ repr_layers = [(i + model_esm.num_layers + 1) % (model_esm.num_layers + 1) for i in repr_layers]
52
 
53
  with torch.no_grad():
54
  for batch_idx, (labels, strs, toks) in enumerate(data_loader):
 
57
  )
58
  if torch.cuda.is_available():
59
  toks = toks.to(device="cuda", non_blocking=True)
60
+ out = model_esm(toks, repr_layers=repr_layers, return_contacts=return_contacts)
61
  logits = out["logits"].to(device="cpu")
62
  representations = {
63
  layer: t.to(device="cpu") for layer, t in out["representations"].items()