File size: 467 Bytes
3cdc5c4 7502c3a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 |
from collections import OrderedDict
import torch
baichuan = torch.load("pytorch_model.bin")
llama = OrderedDict()
for key in baichuan:
if 'W_pack' in key:
llama[key.replace('W_pack', 'q_proj')] = baichuan[key][:4096]
llama[key.replace('W_pack', 'k_proj')] = baichuan[key][4096:4096 * 2]
llama[key.replace('W_pack', 'v_proj')] = baichuan[key][4096 * 2:]
else:
llama[key] = baichuan[key]
torch.save(llama, "pytorch_model.bin")
|