Spaces:
Running
on
Zero
Running
on
Zero
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
# All rights reserved. | |
# | |
# This source code is licensed under the license found in the | |
# LICENSE file in the root directory of this source tree. | |
import argparse | |
from collections import OrderedDict | |
import torch | |
def moco_convert(src, dst): | |
"""Convert keys in pycls pretrained moco models to mmdet style.""" | |
# load caffe model | |
moco_model = torch.load(src) | |
blobs = moco_model['state_dict'] | |
# convert to pytorch style | |
state_dict = OrderedDict() | |
for k, v in blobs.items(): | |
if not k.startswith('module.encoder_q.'): | |
continue | |
old_k = k | |
k = k.replace('module.encoder_q.', '') | |
state_dict[k] = v | |
print(old_k, '->', k) | |
# save checkpoint | |
checkpoint = dict() | |
checkpoint['state_dict'] = state_dict | |
torch.save(checkpoint, dst) | |
def main(): | |
parser = argparse.ArgumentParser(description='Convert model keys') | |
parser.add_argument('src', help='src detectron model path') | |
parser.add_argument('dst', help='save path') | |
parser.add_argument( | |
'--selfsup', type=str, choices=['moco', 'swav'], help='save path') | |
args = parser.parse_args() | |
if args.selfsup == 'moco': | |
moco_convert(args.src, args.dst) | |
elif args.selfsup == 'swav': | |
print('SWAV does not need to convert the keys') | |
if __name__ == '__main__': | |
main() | |