Spaces:
Running
on
Zero
Running
on
Zero
File size: 1,260 Bytes
184241a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 |
import os
import sys
import torch
def load_state_dicts(folder_path):
state_dicts = {}
for filename in os.listdir(folder_path):
if filename.endswith(".pth"):
print('Processing {}'.format(filename))
file_path = os.path.join(folder_path, filename)
state_dict = torch.load(file_path)
new_state_dict = {"state_dict": {},
"optimizer": state_dict['optimizer'],
"meta": state_dict['meta'],
}
for key in state_dict['state_dict'].keys():
if 'spatial_pos_encoder' in key or 'skeleton_head.MLP' in key or 'skeleton_head.adj_output_mlp' in key:
continue
new_key = key.replace("keypoint_head.", "keypoint_head_module.").replace('bias_function_prior_weight', 'markov_structural_mlp')
new_state_dict['state_dict'][new_key] = state_dict['state_dict'][key]
new_file_path = os.path.join(folder_path, f'{filename}')
print(f'Saving to {new_file_path}')
torch.save(new_state_dict, new_file_path)
return state_dicts
if __name__ == "__main__":
folder_path = sys.argv[1]
load_state_dicts(folder_path) |