unit_test / tools /convert_checkpoint.py
herrius's picture
Upload 259 files
32b542e
#!/usr/bin/env python
# coding: utf-8
# In[1]:
import torch
# In[2]:
# ckpt_path = '/mnt/cache/zhujinguo/codes/UniPerceiver/work_dirs/deepspeed_moe/BERT_L12_H768_experiments/16task_90k_bertbase_lr1e-3_wd0.2_gc0.1_prenorm_warm10k_layerscale1e-3_uniformdp0.1_maeinit_fixedpos_torchfp16_unifieddataset_changeweight_stage2_224size/bertbase_womoe_pretrain2/89999/mp_rank_00_model_states.pt'
# save_path = '/mnt/lustre/zhujinguo/codes/Uni-Perceiver/work_dirs/pretrained_models/uni-perceiver-base-L12-H768-224size-pretrained.pth'
ckpt_path = '/mnt/cache/zhujinguo/codes/UniPerceiver/work_dirs/deepspeed_moe/BERT_L24_H1024_experiments/16task_90k_bertlarge_lr2e-5_wd0.05_gc0.1_prenorm_warm5k_layerscale1e-3_uniformdp0.2_maeinit_fixedpos_torchfp16_unifieddataset_pretrain_stage2_224size_bw128_all0.5_accum2_bwv2_k700_8frames_yfccfixcap_womixup/all0.5_rmmixup_from430/89999/mp_rank_00_model_states.pt'
save_path = '/mnt/lustre/zhujinguo/codes/Uni-Perceiver/work_dirs/pretrained_models/uni-perceiver-large-L24-H1024-224size-pretrained.pth'
origin_checkpoint_path = ckpt_path
# In[3]:
origin_checkpoint = torch.load(origin_checkpoint_path, 'cpu')
origin_checkpoint.keys()
# list(origin_checkpoint['module'].keys())
# In[4]:
len(list(origin_checkpoint['module'].keys()))
# In[8]:
# new_checkpoint_path = 'new_exp/model_Epoch_00160_Iter_0000159.pth'
# new_checkpoint = torch.load(new_checkpoint_path, 'cpu')
# new_checkpoint.keys()
# list(new_checkpoint['model'].keys())
# In[10]:
# len(list(new_checkpoint['model'].keys()))
# In[5]:
mapping_dict = {
'encoder.': 'fused_encoder.',
'attention.self.qkv_proj.weight': 'self_attn.in_proj_weight',
'attention.self.qkv_proj.bias': 'self_attn.in_proj_bias',
'attention.output.dense': 'self_attn.out_proj',
'attention_output.residual_scale': 'gamma_1',
'ffn.dense.': 'linear1.',
'ffn.dense2.': 'linear2.',
'ffn_output.residual_scale': 'gamma_2',
'LayerNormModules.0.': 'norm1.',
'LayerNormModules.1.': 'norm2.',
'predictor.': 'loss_prepare.',
}
# In[6]:
new_checkpoint = { }
module_checkpoint = origin_checkpoint['module']
for k, v in module_checkpoint.items():
if k.endswith('residual_scale'):
v.squeeze_(1).squeeze_(0)
if k.startswith('visual_embed'):
continue
for origin_str, target_str in mapping_dict.items():
if origin_str in k:
k = k.replace(origin_str, target_str)
new_checkpoint[k] = v.float()
# merge type embedding in video_embed
new_checkpoint['video_embed.embeddings.bias'] = new_checkpoint['video_embed.embeddings.bias'] + new_checkpoint['video_embed.embeddings_type.weight'][0]
# In[7]:
torch.save({ 'model': new_checkpoint}, save_path)