File size: 1,394 Bytes
28c256d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import argparse
from collections import OrderedDict

import torch


def moco_convert(src, dst):
    """Convert keys in pycls pretrained moco models to mmdet style."""
    # load caffe model
    moco_model = torch.load(src)
    blobs = moco_model['state_dict']
    # convert to pytorch style
    state_dict = OrderedDict()
    for k, v in blobs.items():
        if not k.startswith('module.encoder_q.'):
            continue
        old_k = k
        k = k.replace('module.encoder_q.', '')
        state_dict[k] = v
        print(old_k, '->', k)
    # save checkpoint
    checkpoint = dict()
    checkpoint['state_dict'] = state_dict
    torch.save(checkpoint, dst)


def main():
    parser = argparse.ArgumentParser(description='Convert model keys')
    parser.add_argument('src', help='src detectron model path')
    parser.add_argument('dst', help='save path')
    parser.add_argument(
        '--selfsup', type=str, choices=['moco', 'swav'], help='save path')
    args = parser.parse_args()
    if args.selfsup == 'moco':
        moco_convert(args.src, args.dst)
    elif args.selfsup == 'swav':
        print('SWAV does not need to convert the keys')


if __name__ == '__main__':
    main()