File size: 1,260 Bytes
184241a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
import os
import sys

import torch

def load_state_dicts(folder_path):
    state_dicts = {}
    for filename in os.listdir(folder_path):
        if filename.endswith(".pth"):
            print('Processing {}'.format(filename))
            file_path = os.path.join(folder_path, filename)
            state_dict = torch.load(file_path)
            new_state_dict = {"state_dict": {},
                              "optimizer": state_dict['optimizer'],
                              "meta": state_dict['meta'],
                              }
            for key in state_dict['state_dict'].keys():
                if 'spatial_pos_encoder' in key or 'skeleton_head.MLP' in key or 'skeleton_head.adj_output_mlp' in key:
                    continue
                new_key = key.replace("keypoint_head.", "keypoint_head_module.").replace('bias_function_prior_weight', 'markov_structural_mlp')
                new_state_dict['state_dict'][new_key] = state_dict['state_dict'][key]
            new_file_path = os.path.join(folder_path, f'{filename}')
            print(f'Saving to {new_file_path}')
            torch.save(new_state_dict, new_file_path)
    return state_dicts

if __name__ == "__main__":
    folder_path = sys.argv[1]
    load_state_dicts(folder_path)