Spaces:
Runtime error
Runtime error
import os | |
import argparse | |
import torch | |
import numpy as np | |
def parse_args(): | |
parser = argparse.ArgumentParser("Reparameterize YOLO-World") | |
parser.add_argument('--model', help='model checkpoints to reparameterize') | |
parser.add_argument('--out-dir', help='output checkpoints') | |
parser.add_argument( | |
'--text-embed', | |
help='text embeddings to reparameterized into YOLO-World') | |
parser.add_argument('--conv-neck', | |
action='store_true', | |
help='whether using 1x1 conv in RepVL-PAN') | |
args = parser.parse_args() | |
return args | |
def convert_head(scale, bias, text_embed): | |
N, D = text_embed.shape | |
weight = (text_embed * scale.exp()).view(N, D, 1, 1) | |
bias = torch.ones(N) * bias | |
return weight, bias | |
def reparameterize_head(state_dict, embeds): | |
cls_layers = [ | |
'bbox_head.head_module.cls_contrasts.0', | |
'bbox_head.head_module.cls_contrasts.1', | |
'bbox_head.head_module.cls_contrasts.2' | |
] | |
for i in range(3): | |
scale = state_dict[cls_layers[i] + '.logit_scale'] | |
bias = state_dict[cls_layers[i] + '.bias'] | |
weight, bias = convert_head(scale, bias, embeds) | |
state_dict[cls_layers[i] + '.conv.weight'] = weight | |
state_dict[cls_layers[i] + '.conv.bias'] = bias | |
del state_dict[cls_layers[i] + '.bias'] | |
del state_dict[cls_layers[i] + '.logit_scale'] | |
return state_dict | |
def convert_neck_split_conv(input_state_dict, block_name, text_embeds, | |
num_heads): | |
if block_name + '.guide_fc.weight' not in input_state_dict: | |
return input_state_dict | |
guide_fc_weight = input_state_dict[block_name + '.guide_fc.weight'] | |
guide_fc_bias = input_state_dict[block_name + '.guide_fc.bias'] | |
guide = text_embeds @ guide_fc_weight.transpose(0, | |
1) + guide_fc_bias[None, :] | |
N, D = guide.shape | |
guide = list(guide.split(D // num_heads, dim=1)) | |
del input_state_dict[block_name + '.guide_fc.weight'] | |
del input_state_dict[block_name + '.guide_fc.bias'] | |
for i in range(num_heads): | |
input_state_dict[block_name + | |
f'.guide_convs.{i}.weight'] = guide[i][:, :, None, | |
None] | |
return input_state_dict | |
def convert_neck_weight(input_state_dict, block_name, embeds, num_heads): | |
guide_fc_weight = input_state_dict[block_name + '.guide_fc.weight'] | |
guide_fc_bias = input_state_dict[block_name + '.guide_fc.bias'] | |
guide = embeds @ guide_fc_weight.transpose(0, 1) + guide_fc_bias[None, :] | |
N, D = guide.shape | |
del input_state_dict[block_name + '.guide_fc.weight'] | |
del input_state_dict[block_name + '.guide_fc.bias'] | |
input_state_dict[block_name + '.guide_weight'] = guide.view( | |
N, D // num_heads, num_heads) | |
return input_state_dict | |
def reparameterize_neck(state_dict, embeds, type='conv'): | |
neck_blocks = [ | |
'neck.top_down_layers.0.attn_block', | |
'neck.top_down_layers.1.attn_block', | |
'neck.bottom_up_layers.0.attn_block', | |
'neck.bottom_up_layers.1.attn_block' | |
] | |
if "neck.top_down_layers.0.attn_block.bias" not in state_dict: | |
return state_dict | |
for block in neck_blocks: | |
num_heads = state_dict[block + '.bias'].shape[0] | |
if type == 'conv': | |
convert_neck_split_conv(state_dict, block, embeds, num_heads) | |
else: | |
convert_neck_weight(state_dict, block, embeds, num_heads) | |
return state_dict | |
def main(): | |
args = parse_args() | |
# load checkpoint | |
model = torch.load(args.model, map_location='cpu') | |
state_dict = model['state_dict'] | |
# load embeddings | |
embeddings = torch.from_numpy(np.load(args.text_embed)) | |
# remove text encoder | |
keys = list(state_dict.keys()) | |
keys = [x for x in keys if "text_model" not in x] | |
state_dict_wo_text = {x: state_dict[x] for x in keys} | |
print("removing text encoder") | |
state_dict_wo_text = reparameterize_head(state_dict_wo_text, embeddings) | |
print("reparameterizing head") | |
if args.conv_neck: | |
neck_type = "conv" | |
else: | |
neck_type = "linear" | |
state_dict_wo_text = reparameterize_neck(state_dict_wo_text, embeddings, | |
neck_type) | |
print("reparameterizing neck") | |
model['state_dict'] = state_dict_wo_text | |
model_name = os.path.basename(args.model) | |
model_name = model_name.replace('.pth', f'_rep_{neck_type}.pth') | |
torch.save(model, os.path.join(args.out_dir, model_name)) | |
if __name__ == "__main__": | |
main() | |