Spaces:
Running
Running
edit mhg checkpoint
Browse files- models/mhg_model/load.py +2 -2
models/mhg_model/load.py
CHANGED
@@ -76,9 +76,9 @@ class PretrainedModelWrapper:
|
|
76 |
def load(model_name: str = "models/mhg_model/pickles/mhggnn_pretrained_model_0724_2023.pickle") -> Optional[
|
77 |
PretrainedModelWrapper]:
|
78 |
repo_id = "ibm/materials.mhg-ged"
|
79 |
-
filename = "mhggnn_pretrained_model_0724_2023.pickle"
|
80 |
file_path = hf_hub_download(repo_id=repo_id, filename=filename)
|
81 |
with open(file_path, "rb") as f:
|
82 |
-
model_dict =
|
83 |
return PretrainedModelWrapper(model_dict)
|
84 |
return None
|
|
|
76 |
def load(model_name: str = "models/mhg_model/pickles/mhggnn_pretrained_model_0724_2023.pickle") -> Optional[
|
77 |
PretrainedModelWrapper]:
|
78 |
repo_id = "ibm/materials.mhg-ged"
|
79 |
+
filename = "pytorch_model.bin" #"mhggnn_pretrained_model_0724_2023.pickle"
|
80 |
file_path = hf_hub_download(repo_id=repo_id, filename=filename)
|
81 |
with open(file_path, "rb") as f:
|
82 |
+
model_dict = torch.load(f)
|
83 |
return PretrainedModelWrapper(model_dict)
|
84 |
return None
|