chargoddard commited on
Commit
7d774e5
1 Parent(s): e3c365a

Add script for weight conversion

Browse files
Files changed (1) hide show
  1. convert_weights.py +100 -0
convert_weights.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # 1/17/2024
3
+ # Charles O. Goddard
4
+ """Convert internlm2 weights to Llama format."""
5
+
6
+ import json
7
+ import os
8
+ import einops
9
+ import tqdm
10
+ from mergekit.io import LazyTensorLoader, TensorWriter
11
+ from mergekit.common import ModelReference
12
+ from transformers import LlamaTokenizer
13
+
14
+ MODEL_IN = "internlm/internlm2-20b"
15
+ OUT_PATH = "./internlm2-20b-llama"
16
+
17
+ model_ref = ModelReference.parse(MODEL_IN)
18
+ cfg = model_ref.config(trust_remote_code=True)
19
+ head_dim = cfg.hidden_size // cfg.num_attention_heads
20
+ num_key_value_groups = cfg.num_attention_heads // cfg.num_key_value_heads
21
+ loader = LazyTensorLoader(model_ref.tensor_index(), lazy_unpickle=True)
22
+ writer = TensorWriter(OUT_PATH)
23
+
24
+ SIMPLE_REPLACEMENTS = {
25
+ "feed_forward.w1": "mlp.gate_proj",
26
+ "feed_forward.w2": "mlp.down_proj",
27
+ "feed_forward.w3": "mlp.up_proj",
28
+ "attention.wo": "self_attn.o_proj",
29
+ "ffn_norm": "post_attention_layernorm",
30
+ "attention_norm": "input_layernorm",
31
+ "tok_embeddings": "embed_tokens",
32
+ "output.weight": "lm_head.weight",
33
+ }
34
+
35
+ for tensor_name in tqdm.tqdm(loader.index.tensor_paths):
36
+ tensor = loader.get_tensor(tensor_name)
37
+ if "attention.wqkv" in tensor_name:
38
+ # make me think about tensor shapes will you >:(
39
+
40
+ # ((cfg.num_attention_heads + 2 * cfg.num_key_value_heads) * head_dim, cfg.hidden_size) x (batch_sz, sq_len, cfg.hidden_size)
41
+ # -> (batch_sz, sq_len, (cfg.num_attention_heads + 2 * cfg.num_key_value_heads) * head_dim)
42
+ # qkv_states = rearrange(
43
+ # qkv_states,
44
+ # "b q (h gs d) -> b q h gs d",
45
+ # gs=2 + self.num_key_value_groups,
46
+ # d=self.head_dim,
47
+ # )
48
+ # ->(batch_sz, sq_len, h, 2 + self.num_key_value_groups, head_dim)
49
+ qkv_vecs = einops.rearrange(
50
+ tensor, "(h gs d) z -> h gs d z", gs=2 + num_key_value_groups, d=head_dim
51
+ )
52
+ q_proj = (
53
+ qkv_vecs[:, :num_key_value_groups, ...]
54
+ .reshape(-1, cfg.hidden_size)
55
+ .contiguous()
56
+ )
57
+ k_proj = qkv_vecs[:, -2, ...].reshape(-1, cfg.hidden_size).contiguous()
58
+ v_proj = qkv_vecs[:, -1, ...].reshape(-1, cfg.hidden_size).contiguous()
59
+ assert k_proj.shape == v_proj.shape
60
+
61
+ writer.save_tensor(
62
+ tensor_name.replace("attention.wqkv", "self_attn.q_proj"),
63
+ q_proj,
64
+ clone=True,
65
+ )
66
+ writer.save_tensor(
67
+ tensor_name.replace("attention.wqkv", "self_attn.k_proj"),
68
+ k_proj,
69
+ clone=True,
70
+ )
71
+ writer.save_tensor(
72
+ tensor_name.replace("attention.wqkv", "self_attn.v_proj"),
73
+ v_proj,
74
+ clone=True,
75
+ )
76
+ continue
77
+
78
+ out_name = tensor_name
79
+ for pattern, sub in SIMPLE_REPLACEMENTS.items():
80
+ if pattern in out_name:
81
+ out_name = out_name.replace(pattern, sub)
82
+ writer.save_tensor(out_name, tensor)
83
+ writer.finalize()
84
+
85
+ cfg_dict = json.loads(cfg.to_json_string())
86
+ del cfg_dict["auto_map"]
87
+ cfg_dict["architectures"] = "LlamaForCausalLM"
88
+ cfg_dict["model_type"] = "llama"
89
+ if "rope_scaling" in cfg_dict and cfg_dict["rope_scaling"]["factor"] == 1.0:
90
+ del cfg_dict["rope_scaling"]
91
+ with open(os.path.join(OUT_PATH, "config.json"), "w", encoding="utf-8") as fp:
92
+ json.dump(cfg_dict, fp, indent=2)
93
+
94
+ # InternLMTokenizer differences:
95
+ # 1. clean_up_tokenization() hardcoded to always be called
96
+ # 2. might prepend a space to some tokens that LlamaTokenizer doesn't if they're the first token
97
+ # 1 is easy to fix, 2... is not important
98
+ tok = LlamaTokenizer.from_pretrained(MODEL_IN, trust_remote_code=False, legacy=True)
99
+ tok.clean_up_tokenization_spaces = True
100
+ tok.save_pretrained(OUT_PATH)