Spaces:
Running
on
Zero
Running
on
Zero
File size: 1,394 Bytes
28c256d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 |
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import argparse
from collections import OrderedDict
import torch
def moco_convert(src, dst):
"""Convert keys in pycls pretrained moco models to mmdet style."""
# load caffe model
moco_model = torch.load(src)
blobs = moco_model['state_dict']
# convert to pytorch style
state_dict = OrderedDict()
for k, v in blobs.items():
if not k.startswith('module.encoder_q.'):
continue
old_k = k
k = k.replace('module.encoder_q.', '')
state_dict[k] = v
print(old_k, '->', k)
# save checkpoint
checkpoint = dict()
checkpoint['state_dict'] = state_dict
torch.save(checkpoint, dst)
def main():
parser = argparse.ArgumentParser(description='Convert model keys')
parser.add_argument('src', help='src detectron model path')
parser.add_argument('dst', help='save path')
parser.add_argument(
'--selfsup', type=str, choices=['moco', 'swav'], help='save path')
args = parser.parse_args()
if args.selfsup == 'moco':
moco_convert(args.src, args.dst)
elif args.selfsup == 'swav':
print('SWAV does not need to convert the keys')
if __name__ == '__main__':
main()
|