Spaces:
Runtime error
Runtime error
Update convert.py
Browse files- convert.py +9 -2
convert.py
CHANGED
@@ -113,13 +113,20 @@ def convert_file(
|
|
113 |
loaded = torch.load(pt_filename, map_location="cpu")
|
114 |
if "state_dict" in loaded:
|
115 |
loaded = loaded["state_dict"]
|
116 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
117 |
# For tensors to be contiguous
|
118 |
loaded = {k: v.contiguous() for k, v in loaded.items()}
|
119 |
|
120 |
dirname = os.path.dirname(sf_filename)
|
121 |
os.makedirs(dirname, exist_ok=True)
|
122 |
-
save_file(loaded, sf_filename, metadata=
|
123 |
reloaded = load_file(sf_filename)
|
124 |
for k in loaded:
|
125 |
pt_tensor = loaded[k]
|
|
|
113 |
loaded = torch.load(pt_filename, map_location="cpu")
|
114 |
if "state_dict" in loaded:
|
115 |
loaded = loaded["state_dict"]
|
116 |
+
to_removes = _remove_duplicate_names(loaded)
|
117 |
+
|
118 |
+
metadata = {"format": "pt"}
|
119 |
+
for kept_name, to_remove_group in to_removes.items():
|
120 |
+
for to_remove in to_remove_group:
|
121 |
+
if to_remove not in metadata:
|
122 |
+
metadata[to_remove] = kept_name
|
123 |
+
del loaded[to_remove]
|
124 |
# For tensors to be contiguous
|
125 |
loaded = {k: v.contiguous() for k, v in loaded.items()}
|
126 |
|
127 |
dirname = os.path.dirname(sf_filename)
|
128 |
os.makedirs(dirname, exist_ok=True)
|
129 |
+
save_file(loaded, sf_filename, metadata=metadata)
|
130 |
reloaded = load_file(sf_filename)
|
131 |
for k in loaded:
|
132 |
pt_tensor = loaded[k]
|