Update esm_scripts/extract.py
Browse files- esm_scripts/extract.py +1 -10
esm_scripts/extract.py
CHANGED
@@ -131,17 +131,8 @@ def run(args):
|
|
131 |
)
|
132 |
|
133 |
|
134 |
-
def run_demo(protein_name, protein_seq,
|
135 |
repr_layers=-1, truncation_seq_length=1022, toks_per_batch=4096):
|
136 |
-
model, alphabet = pretrained.load_model_and_alphabet(model_location)
|
137 |
-
model.eval()
|
138 |
-
if isinstance(model, MSATransformer):
|
139 |
-
raise ValueError(
|
140 |
-
"This script currently does not handle models with MSA input (MSA Transformer)."
|
141 |
-
)
|
142 |
-
if torch.cuda.is_available() and not nogpu:
|
143 |
-
model = model.cuda()
|
144 |
-
print("Transferred model to GPU")
|
145 |
|
146 |
dataset = FastaBatchedDataset([protein_name], [protein_seq])
|
147 |
batches = dataset.get_batch_indices(toks_per_batch, extra_toks_per_seq=1)
|
|
|
131 |
)
|
132 |
|
133 |
|
134 |
+
def run_demo(protein_name, protein_seq, model, alphabet, include,
|
135 |
repr_layers=-1, truncation_seq_length=1022, toks_per_batch=4096):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
136 |
|
137 |
dataset = FastaBatchedDataset([protein_name], [protein_seq])
|
138 |
batches = dataset.get_batch_indices(toks_per_batch, extra_toks_per_seq=1)
|