File size: 3,244 Bytes
a85f909 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 |
# This script converts the standrd LLaMA PyTorch checkpoint released by Meta
# to the EasyLM checkpoint format. The converted checkpoint can then be loaded
# by EasyLM for fine-tuning or inference.
# This script is largely borrow from https://github.com/Sea-Snell/JAX_llama
from pathlib import Path
import json
import numpy as np
import torch
import flax
import mlxu
from EasyLM.checkpoint import StreamingCheckpointer
FLAGS, FLAGS_DEF = mlxu.define_flags_with_default(
checkpoint_dir='',
output_file='',
streaming=True,
)
def main(argv):
ckpt_paths = sorted(Path(FLAGS.checkpoint_dir).glob("*.pth"))
ckpts = {}
for i, ckpt_path in enumerate(ckpt_paths):
checkpoint = torch.load(ckpt_path, map_location="cpu")
ckpts[int(ckpt_path.name.split('.', maxsplit=2)[1])] = checkpoint
ckpts = [ckpts[i] for i in sorted(list(ckpts.keys()))]
with open(Path(FLAGS.checkpoint_dir) / "params.json", "r") as f:
params = json.loads(f.read())
jax_weights = {
'transformer': {
'wte': {'embedding': np.concatenate([ckpt['tok_embeddings.weight'].numpy() for ckpt in ckpts], axis=1)},
'ln_f': {'kernel': ckpts[0]['norm.weight'].numpy()},
'h': {
'%d' % (layer): {
'attention': {
'wq': {'kernel': np.concatenate([ckpt['layers.%d.attention.wq.weight' % (layer)].numpy() for ckpt in ckpts], axis=0).transpose()},
'wk': {'kernel': np.concatenate([ckpt['layers.%d.attention.wk.weight' % (layer)].numpy() for ckpt in ckpts], axis=0).transpose()},
'wv': {'kernel': np.concatenate([ckpt['layers.%d.attention.wv.weight' % (layer)].numpy() for ckpt in ckpts], axis=0).transpose()},
'wo': {'kernel': np.concatenate([ckpt['layers.%d.attention.wo.weight' % (layer)].numpy() for ckpt in ckpts], axis=1).transpose()},
},
'feed_forward': {
'w1': {'kernel': np.concatenate([ckpt['layers.%d.feed_forward.w1.weight' % (layer)].numpy() for ckpt in ckpts], axis=0).transpose()},
'w2': {'kernel': np.concatenate([ckpt['layers.%d.feed_forward.w2.weight' % (layer)].numpy() for ckpt in ckpts], axis=1).transpose()},
'w3': {'kernel': np.concatenate([ckpt['layers.%d.feed_forward.w3.weight' % (layer)].numpy() for ckpt in ckpts], axis=0).transpose()},
},
'attention_norm': {'kernel': ckpts[0]['layers.%d.attention_norm.weight' % (layer)].numpy()},
'ffn_norm': {'kernel': ckpts[0]['layers.%d.ffn_norm.weight' % (layer)].numpy()},
}
for layer in range(params['n_layers'])},
},
'lm_head': {'kernel': np.concatenate([ckpt['output.weight'].numpy() for ckpt in ckpts], axis=0).transpose()},
}
if FLAGS.streaming:
StreamingCheckpointer.save_train_state_to_file(
jax_weights, FLAGS.output_file
)
else:
with mlxu.open_file(FLAGS.output_file, 'wb') as fout:
fout.write(flax.serialization.msgpack_serialize(jax_weights, in_place=True))
if __name__ == '__main__':
mlxu.run(main) |