from collections import OrderedDict | |
import torch | |
baichuan = torch.load("pytorch_model.bin") | |
llama = OrderedDict() | |
for key in baichuan: | |
if 'W_pack' in key: | |
llama[key.replace('W_pack', 'q_proj')] = baichuan[key][:4096] | |
llama[key.replace('W_pack', 'k_proj')] = baichuan[key][4096:4096 * 2] | |
llama[key.replace('W_pack', 'v_proj')] = baichuan[key][4096 * 2:] | |
else: | |
llama[key] = baichuan[key] | |
torch.save(baichuan, "pytorch_model.bin") | |