|
from argparse import ArgumentParser |
|
|
|
import torch |
|
|
|
parser = ArgumentParser() |
|
parser.add_argument('src', default='old.pth') |
|
parser.add_argument('dst', default='new.pth') |
|
parser.add_argument('--code_size', type=int, default='10') |
|
args = parser.parse_args() |
|
model = torch.load(args.src) |
|
code_size = args.code_size |
|
if model['meta'].get('detr3d_convert_tag') is not None: |
|
print('this model has already converted!') |
|
else: |
|
print('converting...') |
|
|
|
for key in model['state_dict']: |
|
tsr = model['state_dict'][key] |
|
if 'reg_branches' in key and tsr.shape[0] == code_size: |
|
print(key, ' with ', tsr.shape, 'has changed') |
|
tsr[[2, 3], ...] = tsr[[3, 2], ...] |
|
tsr[[6, 7], ...] = -tsr[[7, 6], ...] |
|
model['meta']['detr3d_convert_tag'] = True |
|
torch.save(model, args.dst) |
|
print('done...') |
|
|