File size: 5,322 Bytes
32b542e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 |
{
"cells": [
{
"cell_type": "code",
"execution_count": 3,
"id": "5fa051c8-8f5e-4809-b90e-bf129a701352",
"metadata": {},
"outputs": [],
"source": [
"import torch"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d549cf85-c638-4dec-a436-254da7060ee3",
"metadata": {},
"outputs": [],
"source": [
"checkpoints = torch.load('../output/model_Epoch_00030_Iter_0000029.pth', 'cpu')['model']\n",
"new_keys = list(list(checkpoints.keys()))\n",
"new_keys[:10]"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "81fd1bd8-d8ed-457f-a9ec-d2b0bf91c8c7",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'/mnt/lustre/zhujinguo/jinguo_data/codes/Uni-Perceiver/work_dirs'"
]
},
"execution_count": 1,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import os\n",
"os.getcwd()"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "d9a698e1-de54-4257-a45f-65cd2d7cf095",
"metadata": {},
"outputs": [],
"source": [
"ds_checkpoints = torch.load('deepspeed_moe/BERT_L12_H192_experiments/7task_150k_berttiny_lr1e-3_wd0.05_gc0.1_prenorm_warm10k_layerscale1e-3_uniformdp0.1_maeinit_fixedpos_torchfp16_unifieddataset_224inputsize_tagmoe_alllayer/tagmoe_alllayer_exp4/149999/mp_rank_00_model_states.pt', 'cpu')['module']\n",
"# ds_checkpoints = torch.load('/nfs/zhujinguo/codes/xmodaler/work_dirs/deepspeed_moe/BERT_L12_H768_experiments/basetagmoe_pretrainstage2/89999/mp_rank_00_model_states.pt', 'cpu')['module']\n",
"# ds_checkpoints = torch.load('/nfs/zhujinguo/codes/xmodaler/work_dirs/deepspeed_moe/BERT_L12_H768_experiments/bertbase_womoe/89999/mp_rank_00_model_states.pt', 'cpu')['module']\n",
"\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "d5f02e7e-95eb-4386-ac68-f8be0f7ac3c1",
"metadata": {},
"outputs": [],
"source": [
"oldkeys = list(ds_checkpoints.keys())\n",
"# oldkeys[:20]"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "6a9c4928-e808-4e5d-a8cb-c4d133cc9c6c",
"metadata": {},
"outputs": [],
"source": [
"mapping_dict = {\n",
"\n",
" 'encoder.': 'fused_encoder.',\n",
" # 'attention.self.qkv_proj.weight': 'self_attn.in_proj_weight',\n",
" # 'attention.self.qkv_proj.bias': 'self_attn.in_proj_bias',\n",
" 'attention.self.qkv_proj': 'self_attn.qkv_proj',\n",
" 'deepspeed_moe.gate': 'gate',\n",
" 'deepspeed_moe.experts': 'experts',\n",
" 'attention.output.dense': 'self_attn.dense',\n",
" 'attention_output.residual_scale': 'gamma_1',\n",
" 'ffn.dense.': 'linear1.',\n",
" 'ffn.dense2.': 'linear2.',\n",
" 'ffn_output.residual_scale': 'gamma_2',\n",
" 'LayerNormModules.0.': 'norm1.',\n",
" 'LayerNormModules.1.': 'norm2.',\n",
" 'predictor.': 'loss_prepare.',\n",
" \n",
"}"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "9a84319a-d13c-411a-bd21-c0a9c1adb872",
"metadata": {},
"outputs": [],
"source": [
"new_checkpoint = {}\n",
"for k, v in ds_checkpoints.items():\n",
" if k.endswith('residual_scale'):\n",
" v = v.squeeze(1).squeeze(0)\n",
" # print(v.shape )\n",
" \n",
" if k.startswith('visual_embed'):\n",
" continue\n",
" \n",
" \n",
" for origin_str, target_str in mapping_dict.items():\n",
" if origin_str in k:\n",
" k = k.replace(origin_str, target_str)\n",
" # merge type embedding in video_embed \n",
" # if k=='video_embed.embeddings.bias':\n",
" # v = v + ds_checkpoints['video_embed.embeddings_type.weight'][0]\n",
"\n",
" new_checkpoint[k] = v.float()\n",
" # if 'wg' in k:\n",
" # print(f'{k}, {v}')\n",
"# new_checkpoint['video_embed.embeddings.bias'] = new_checkpoint['video_embed.embeddings.bias'] + new_checkpoint['video_embed.embeddings_type.weight'][0]\n",
" "
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "b5c999e3-e4b1-4949-b89e-4ee2259db8fa",
"metadata": {},
"outputs": [],
"source": [
"torch.save({ 'model': new_checkpoint}, 'pretrained_models/uni-perceiver-moe-tiny-L12-H192-224size-pretrained-withvtype.pth')\n",
"\n",
"\n",
"# torch.save({ 'model': new_checkpoint}, 'pretrained_models/uni-perceiver-moe-tiny-L12-H192-224size-pretrained.pth')\n",
"# torch.save({ 'model': new_checkpoint}, 'pretrained_models/uni-perceiver-moe-base-L12-H768-224size-pretrained.pth')\n",
"# torch.save({ 'model': new_checkpoint}, 'pretrained_models/uni-perceiver-base-L12-H768-224size-pretrained-custom-attn-module.pth')\n",
"\n",
"\n"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.11"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
|