EdgeCape / rename_ckpt.py
orhir's picture
Upload 114 files
184241a verified
import os
import sys
import torch
def load_state_dicts(folder_path):
state_dicts = {}
for filename in os.listdir(folder_path):
if filename.endswith(".pth"):
print('Processing {}'.format(filename))
file_path = os.path.join(folder_path, filename)
state_dict = torch.load(file_path)
new_state_dict = {"state_dict": {},
"optimizer": state_dict['optimizer'],
"meta": state_dict['meta'],
}
for key in state_dict['state_dict'].keys():
if 'spatial_pos_encoder' in key or 'skeleton_head.MLP' in key or 'skeleton_head.adj_output_mlp' in key:
continue
new_key = key.replace("keypoint_head.", "keypoint_head_module.").replace('bias_function_prior_weight', 'markov_structural_mlp')
new_state_dict['state_dict'][new_key] = state_dict['state_dict'][key]
new_file_path = os.path.join(folder_path, f'{filename}')
print(f'Saving to {new_file_path}')
torch.save(new_state_dict, new_file_path)
return state_dicts
if __name__ == "__main__":
folder_path = sys.argv[1]
load_state_dicts(folder_path)