Update convert.py
Browse files- convert.py +4 -2
convert.py
CHANGED
@@ -95,7 +95,9 @@ def convert_file(
|
|
95 |
pt_filename: str,
|
96 |
sf_filename: str,
|
97 |
):
|
98 |
-
loaded = torch.load(pt_filename
|
|
|
|
|
99 |
shared = shared_pointers(loaded)
|
100 |
for shared_weights in shared:
|
101 |
for name in shared_weights[1:]:
|
@@ -238,7 +240,7 @@ def convert(api: "HfApi", model_id: str, force: bool = False) -> Optional["Commi
|
|
238 |
operations = convert_multi(model_id, folder)
|
239 |
else:
|
240 |
raise RuntimeError(f"Model {model_id} doesn't seem to be a valid pytorch model. Cannot convert")
|
241 |
-
|
242 |
else:
|
243 |
operations = convert_generic(model_id, folder, filenames)
|
244 |
|
|
|
95 |
pt_filename: str,
|
96 |
sf_filename: str,
|
97 |
):
|
98 |
+
loaded = torch.load(pt_filename)
|
99 |
+
if "state_dict" in loaded:
|
100 |
+
loaded = loaded["state_dict"]
|
101 |
shared = shared_pointers(loaded)
|
102 |
for shared_weights in shared:
|
103 |
for name in shared_weights[1:]:
|
|
|
240 |
operations = convert_multi(model_id, folder)
|
241 |
else:
|
242 |
raise RuntimeError(f"Model {model_id} doesn't seem to be a valid pytorch model. Cannot convert")
|
243 |
+
check_final_model(model_id, folder)
|
244 |
else:
|
245 |
operations = convert_generic(model_id, folder, filenames)
|
246 |
|