wenkai commited on
Commit
5a69d92
1 Parent(s): d6d5757

Update esm_scripts/extract.py

Browse files
Files changed (1) hide show
  1. esm_scripts/extract.py +2 -10
esm_scripts/extract.py CHANGED
@@ -132,38 +132,31 @@ def run(args):
132
 
133
 
134
  def run_demo(protein_name, protein_seq, model, alphabet, include,
135
- repr_layers=-1, truncation_seq_length=1022, toks_per_batch=4096):
136
-
137
  dataset = FastaBatchedDataset([protein_name], [protein_seq])
138
  batches = dataset.get_batch_indices(toks_per_batch, extra_toks_per_seq=1)
139
  data_loader = torch.utils.data.DataLoader(
140
  dataset, collate_fn=alphabet.get_batch_converter(truncation_seq_length), batch_sampler=batches
141
  )
142
  print(f"Read sequences")
143
-
144
  # output_dir.mkdir(parents=True, exist_ok=True)
145
  return_contacts = "contacts" in include
146
-
147
  assert all(-(model.num_layers + 1) <= i <= model.num_layers for i in repr_layers)
148
  repr_layers = [(i + model.num_layers + 1) % (model.num_layers + 1) for i in repr_layers]
149
-
150
  with torch.no_grad():
151
  for batch_idx, (labels, strs, toks) in enumerate(data_loader):
152
  print(
153
  f"Processing {batch_idx + 1} of {len(batches)} batches ({toks.size(0)} sequences)"
154
  )
155
- if torch.cuda.is_available() and not nogpu:
156
  toks = toks.to(device="cuda", non_blocking=True)
157
-
158
  out = model(toks, repr_layers=repr_layers, return_contacts=return_contacts)
159
-
160
  logits = out["logits"].to(device="cpu")
161
  representations = {
162
  layer: t.to(device="cpu") for layer, t in out["representations"].items()
163
  }
164
  if return_contacts:
165
  contacts = out["contacts"].to(device="cpu")
166
-
167
  for i, label in enumerate(labels):
168
  result = {"label": label}
169
  truncate_len = min(truncation_seq_length, len(strs[i]))
@@ -185,7 +178,6 @@ def run_demo(protein_name, protein_seq, model, alphabet, include,
185
  }
186
  if return_contacts:
187
  result["contacts"] = contacts[i, : truncate_len, : truncate_len].clone()
188
-
189
  return result['representations'][36]
190
 
191
 
 
132
 
133
 
134
  def run_demo(protein_name, protein_seq, model, alphabet, include,
135
+ repr_layers=[-1], truncation_seq_length=1022, toks_per_batch=4096):
 
136
  dataset = FastaBatchedDataset([protein_name], [protein_seq])
137
  batches = dataset.get_batch_indices(toks_per_batch, extra_toks_per_seq=1)
138
  data_loader = torch.utils.data.DataLoader(
139
  dataset, collate_fn=alphabet.get_batch_converter(truncation_seq_length), batch_sampler=batches
140
  )
141
  print(f"Read sequences")
 
142
  # output_dir.mkdir(parents=True, exist_ok=True)
143
  return_contacts = "contacts" in include
 
144
  assert all(-(model.num_layers + 1) <= i <= model.num_layers for i in repr_layers)
145
  repr_layers = [(i + model.num_layers + 1) % (model.num_layers + 1) for i in repr_layers]
 
146
  with torch.no_grad():
147
  for batch_idx, (labels, strs, toks) in enumerate(data_loader):
148
  print(
149
  f"Processing {batch_idx + 1} of {len(batches)} batches ({toks.size(0)} sequences)"
150
  )
151
+ if torch.cuda.is_available():
152
  toks = toks.to(device="cuda", non_blocking=True)
 
153
  out = model(toks, repr_layers=repr_layers, return_contacts=return_contacts)
 
154
  logits = out["logits"].to(device="cpu")
155
  representations = {
156
  layer: t.to(device="cpu") for layer, t in out["representations"].items()
157
  }
158
  if return_contacts:
159
  contacts = out["contacts"].to(device="cpu")
 
160
  for i, label in enumerate(labels):
161
  result = {"label": label}
162
  truncate_len = min(truncation_seq_length, len(strs[i]))
 
178
  }
179
  if return_contacts:
180
  result["contacts"] = contacts[i, : truncate_len, : truncate_len].clone()
 
181
  return result['representations'][36]
182
 
183