Update README.md
Browse files
README.md
CHANGED
@@ -82,12 +82,16 @@ if __name__ == "__main__":
|
|
82 |
# get datasets
|
83 |
raw_datasets = load_dataset("mila-intel/ProtST-SubcellularLocalization", cache_dir="~/.cache/huggingface/datasets", split='test') # cache_dir defaults to "~/.cache/huggingface/datasets"
|
84 |
|
85 |
-
|
|
|
86 |
|
87 |
protst_model = AutoModel.from_pretrained("mila-intel/ProtST-esm1b", trust_remote_code=True, torch_dtype=torch.bfloat16).to(device)
|
88 |
protein_model = protst_model.protein_model
|
89 |
text_model = protst_model.text_model
|
90 |
logit_scale = protst_model.logit_scale
|
|
|
|
|
|
|
91 |
logit_scale.requires_grad = False
|
92 |
logit_scale = logit_scale.to(device)
|
93 |
logit_scale = logit_scale.exp()
|
|
|
82 |
# get datasets
|
83 |
raw_datasets = load_dataset("mila-intel/ProtST-SubcellularLocalization", cache_dir="~/.cache/huggingface/datasets", split='test') # cache_dir defaults to "~/.cache/huggingface/datasets"
|
84 |
|
85 |
+
- device = torch.device("cpu")
|
86 |
+
+ device = torch.device("hpu")
|
87 |
|
88 |
protst_model = AutoModel.from_pretrained("mila-intel/ProtST-esm1b", trust_remote_code=True, torch_dtype=torch.bfloat16).to(device)
|
89 |
protein_model = protst_model.protein_model
|
90 |
text_model = protst_model.text_model
|
91 |
logit_scale = protst_model.logit_scale
|
92 |
+
+ from habana_frameworks.torch.hpu import wrap_in_hpu_graph
|
93 |
+
+ protein_model = wrap_in_hpu_graph(protein_model)
|
94 |
+
+ text_model = wrap_in_hpu_graph(text_model)
|
95 |
logit_scale.requires_grad = False
|
96 |
logit_scale = logit_scale.to(device)
|
97 |
logit_scale = logit_scale.exp()
|