|
--- |
|
license: mit |
|
--- |
|
|
|
<!-- ##### 🔴 <font color=red>Note: SaProt requires structural (SA token) input for optimal performance. AA-sequence-only mode works but must be finetuned - frozen embeddings work only for SA, not AA sequences! With structural input, SaProt surpasses ESM2 in most tasks.</font> --> |
|
|
|
We provide two ways to use SaProt, including through huggingface class and |
|
through the same way as in [esm github](https://github.com/facebookresearch/esm). Users can choose either one to use. |
|
|
|
### Huggingface model |
|
The following code shows how to load the model. |
|
``` |
|
from transformers import EsmTokenizer, EsmForMaskedLM |
|
|
|
model_path = "/your/path/to/SaProt_35M_AF2" |
|
tokenizer = EsmTokenizer.from_pretrained(model_path) |
|
model = EsmForMaskedLM.from_pretrained(model_path) |
|
|
|
#################### Example #################### |
|
device = "cuda" |
|
model.to(device) |
|
|
|
seq = "M#EvVpQpL#VyQdYaKv" # Here "#" represents lower plDDT regions (plddt < 70) |
|
tokens = tokenizer.tokenize(seq) |
|
print(tokens) |
|
|
|
inputs = tokenizer(seq, return_tensors="pt") |
|
inputs = {k: v.to(device) for k, v in inputs.items()} |
|
|
|
outputs = model(**inputs) |
|
print(outputs.logits.shape) |
|
|
|
""" |
|
['M#', 'Ev', 'Vp', 'Qp', 'L#', 'Vy', 'Qd', 'Ya', 'Kv'] |
|
torch.Size([1, 11, 446]) |
|
""" |
|
``` |
|
|
|
### esm model |
|
The esm version is also stored in the same folder, named `SaProt_35M_AF2.pt`. We provide a function to load the model. |
|
``` |
|
from utils.esm_loader import load_esm_saprot |
|
|
|
model_path = "/your/path/to/SaProt_35M_AF2.pt" |
|
model, alphabet = load_esm_saprot(model_path) |
|
``` |