Narsil HF staff commited on
Commit
4a04f8f
·
1 Parent(s): 1788c41

Update convert.py

Browse files
Files changed (1) hide show
  1. convert.py +9 -2
convert.py CHANGED
@@ -113,13 +113,20 @@ def convert_file(
113
  loaded = torch.load(pt_filename, map_location="cpu")
114
  if "state_dict" in loaded:
115
  loaded = loaded["state_dict"]
116
- loaded = _remove_duplicate_names(loaded)
 
 
 
 
 
 
 
117
  # For tensors to be contiguous
118
  loaded = {k: v.contiguous() for k, v in loaded.items()}
119
 
120
  dirname = os.path.dirname(sf_filename)
121
  os.makedirs(dirname, exist_ok=True)
122
- save_file(loaded, sf_filename, metadata={"format": "pt"})
123
  reloaded = load_file(sf_filename)
124
  for k in loaded:
125
  pt_tensor = loaded[k]
 
113
  loaded = torch.load(pt_filename, map_location="cpu")
114
  if "state_dict" in loaded:
115
  loaded = loaded["state_dict"]
116
+ to_removes = _remove_duplicate_names(loaded)
117
+
118
+ metadata = {"format": "pt"}
119
+ for kept_name, to_remove_group in to_removes.items():
120
+ for to_remove in to_remove_group:
121
+ if to_remove not in metadata:
122
+ metadata[to_remove] = kept_name
123
+ del loaded[to_remove]
124
  # For tensors to be contiguous
125
  loaded = {k: v.contiguous() for k, v in loaded.items()}
126
 
127
  dirname = os.path.dirname(sf_filename)
128
  os.makedirs(dirname, exist_ok=True)
129
+ save_file(loaded, sf_filename, metadata=metadata)
130
  reloaded = load_file(sf_filename)
131
  for k in loaded:
132
  pt_tensor = loaded[k]