Spaces:
Running
on
Zero
Running
on
Zero
import os | |
import sys | |
import torch | |
def load_state_dicts(folder_path): | |
state_dicts = {} | |
for filename in os.listdir(folder_path): | |
if filename.endswith(".pth"): | |
print('Processing {}'.format(filename)) | |
file_path = os.path.join(folder_path, filename) | |
state_dict = torch.load(file_path) | |
new_state_dict = {"state_dict": {}, | |
"optimizer": state_dict['optimizer'], | |
"meta": state_dict['meta'], | |
} | |
for key in state_dict['state_dict'].keys(): | |
if 'spatial_pos_encoder' in key or 'skeleton_head.MLP' in key or 'skeleton_head.adj_output_mlp' in key: | |
continue | |
new_key = key.replace("keypoint_head.", "keypoint_head_module.").replace('bias_function_prior_weight', 'markov_structural_mlp') | |
new_state_dict['state_dict'][new_key] = state_dict['state_dict'][key] | |
new_file_path = os.path.join(folder_path, f'{filename}') | |
print(f'Saving to {new_file_path}') | |
torch.save(new_state_dict, new_file_path) | |
return state_dicts | |
if __name__ == "__main__": | |
folder_path = sys.argv[1] | |
load_state_dicts(folder_path) |