rawalkhirodkar's picture
Add initial commit
history blame
2.02 kB
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import argparse
import tempfile
from collections import OrderedDict
import torch
from mmengine import Config
from mmengine.utils import digit_version
def parse_config(config_strings):
temp_file = tempfile.NamedTemporaryFile()
config_path = f'{temp_file.name}.py'
with open(config_path, 'w') as f:
config = Config.fromfile(config_path)
# check whether it is SSD
if config.model.bbox_head.type != 'SSDHead':
raise AssertionError('This is not a SSD model.')
def convert(in_file, out_file):
checkpoint = torch.load(in_file)
in_state_dict = checkpoint.pop('state_dict')
out_state_dict = OrderedDict()
meta_info = checkpoint['meta']
parse_config('#' + meta_info['config'])
for key, value in in_state_dict.items():
if 'extra' in key:
layer_idx = int(key.split('.')[2])
new_key = 'neck.extra_layers.{}.{}.conv.'.format(
layer_idx // 2, layer_idx % 2) + key.split('.')[-1]
elif 'l2_norm' in key:
new_key = 'neck.l2_norm.weight'
elif 'bbox_head' in key:
new_key = key[:21] + '.0' + key[21:]
new_key = key
out_state_dict[new_key] = value
checkpoint['state_dict'] = out_state_dict
if digit_version(torch.__version__) >= digit_version('1.6'):
torch.save(checkpoint, out_file, _use_new_zipfile_serialization=False)
torch.save(checkpoint, out_file)
def main():
parser = argparse.ArgumentParser(description='Upgrade SSD version')
parser.add_argument('in_file', help='input checkpoint file')
parser.add_argument('out_file', help='output checkpoint file')
args = parser.parse_args()
convert(args.in_file, args.out_file)
if __name__ == '__main__':