mm3dtest / projects /DETR3D /old_detr3d_converter.py
giantmonkeyTC
2344
34d1f8b
from argparse import ArgumentParser
import torch
parser = ArgumentParser()
parser.add_argument('src', default='old.pth')
parser.add_argument('dst', default='new.pth') # ('training','validation')
parser.add_argument('--code_size', type=int, default='10')
args = parser.parse_args()
model = torch.load(args.src)
code_size = args.code_size
if model['meta'].get('detr3d_convert_tag') is not None:
print('this model has already converted!')
else:
print('converting...')
# (cx, cy, w, l, cz, h, sin(φ), cos(φ), vx, vy)
for key in model['state_dict']:
tsr = model['state_dict'][key]
if 'reg_branches' in key and tsr.shape[0] == code_size:
print(key, ' with ', tsr.shape, 'has changed')
tsr[[2, 3], ...] = tsr[[3, 2], ...]
tsr[[6, 7], ...] = -tsr[[7, 6], ...]
model['meta']['detr3d_convert_tag'] = True
torch.save(model, args.dst)
print('done...')