Spaces:
Runtime error
Runtime error
File size: 1,696 Bytes
128757a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 |
import numpy as np
import torch
import torch.nn as nn
from collections import OrderedDict
def _remove_bn_statics(state_dict):
layer_keys = sorted(state_dict.keys())
remove_list = []
for key in layer_keys:
if 'running_mean' in key or 'running_var' in key or 'num_batches_tracked' in key:
remove_list.append(key)
for key in remove_list:
del state_dict[key]
return state_dict
def _rename_conv_weights_for_deformable_conv_layers(state_dict, cfg):
import re
layer_keys = sorted(state_dict.keys())
for ix, stage_with_dcn in enumerate(cfg.MODEL.RESNETS.STAGE_WITH_DCN, 1):
if not stage_with_dcn:
continue
for old_key in layer_keys:
pattern = ".*layer{}.*conv2.*".format(ix)
r = re.match(pattern, old_key)
if r is None:
continue
for param in ["weight", "bias"]:
if old_key.find(param) is -1:
continue
if 'unit01' in old_key:
continue
new_key = old_key.replace(
"conv2.{}".format(param), "conv2.conv.{}".format(param)
)
print("pattern: {}, old_key: {}, new_key: {}".format(
pattern, old_key, new_key
))
state_dict[new_key] = state_dict[old_key]
del state_dict[old_key]
return state_dict
def load_pretrain_format(cfg, f):
model = torch.load(f)
model = _remove_bn_statics(model)
model = _rename_conv_weights_for_deformable_conv_layers(model, cfg)
return dict(model=model)
|