Jiqing commited on
Commit
921086f
1 Parent(s): b72e236

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +5 -1
README.md CHANGED
@@ -82,12 +82,16 @@ if __name__ == "__main__":
82
  # get datasets
83
  raw_datasets = load_dataset("mila-intel/ProtST-SubcellularLocalization", cache_dir="~/.cache/huggingface/datasets", split='test') # cache_dir defaults to "~/.cache/huggingface/datasets"
84
 
85
- device = torch.device("cpu")
 
86
 
87
  protst_model = AutoModel.from_pretrained("mila-intel/ProtST-esm1b", trust_remote_code=True, torch_dtype=torch.bfloat16).to(device)
88
  protein_model = protst_model.protein_model
89
  text_model = protst_model.text_model
90
  logit_scale = protst_model.logit_scale
 
 
 
91
  logit_scale.requires_grad = False
92
  logit_scale = logit_scale.to(device)
93
  logit_scale = logit_scale.exp()
 
82
  # get datasets
83
  raw_datasets = load_dataset("mila-intel/ProtST-SubcellularLocalization", cache_dir="~/.cache/huggingface/datasets", split='test') # cache_dir defaults to "~/.cache/huggingface/datasets"
84
 
85
+ - device = torch.device("cpu")
86
+ + device = torch.device("hpu")
87
 
88
  protst_model = AutoModel.from_pretrained("mila-intel/ProtST-esm1b", trust_remote_code=True, torch_dtype=torch.bfloat16).to(device)
89
  protein_model = protst_model.protein_model
90
  text_model = protst_model.text_model
91
  logit_scale = protst_model.logit_scale
92
+ + from habana_frameworks.torch.hpu import wrap_in_hpu_graph
93
+ + protein_model = wrap_in_hpu_graph(protein_model)
94
+ + text_model = wrap_in_hpu_graph(text_model)
95
  logit_scale.requires_grad = False
96
  logit_scale = logit_scale.to(device)
97
  logit_scale = logit_scale.exp()