Update convert_nllb_moe_sharded_original_checkpoint_to_pytorch.py
Browse files
convert_nllb_moe_sharded_original_checkpoint_to_pytorch.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1 |
-
# Copyright
|
2 |
#
|
3 |
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
# you may not use this file except in compliance with the License.
|
@@ -23,9 +23,6 @@ from transformers.modeling_utils import dtype_byte_size
|
|
23 |
from transformers.utils import WEIGHTS_INDEX_NAME, WEIGHTS_NAME
|
24 |
|
25 |
|
26 |
-
# 'encoder.layers.7.moe_layer.experts.0.fc2.bias', 'encoder.layers.11.moe_layer.experts.0.fc1.weight',
|
27 |
-
|
28 |
-
|
29 |
def remove_ignore_keys_(state_dict):
|
30 |
ignore_keys = [
|
31 |
"encoder.version",
|
@@ -48,30 +45,30 @@ def make_linear_from_emb(emb):
|
|
48 |
return lin_layer
|
49 |
|
50 |
|
51 |
-
def rename_fairseq_keys(state_dict, expert_idx
|
52 |
-
# 'encoder.layers.7.moe_layer.experts.0.fc2.bias' ->'encoder.layers.7.ffn.mlp.experts.0.fc2.bias'
|
53 |
-
# 'encoder.layers.7.fc2.bias' -> 'encoder.layers.7.ffn.mlp.fc2.bias'
|
54 |
-
# encoder.layers.7.wg -> encoder.layers.7.ffn.mlp.router.classifier
|
55 |
new_dict = {}
|
56 |
for old_key in state_dict.keys():
|
57 |
key = old_key
|
58 |
if "experts" in key:
|
59 |
-
key = key.replace("moe_layer.experts.0", f"ffn.mlp.experts.{expert_idx}")
|
60 |
-
elif "fc2" :
|
61 |
-
key = key.replace(".fc2.", ".ffn.mlp.fc2")
|
62 |
-
elif "fc1" :
|
63 |
-
key = key.replace(".fc1.", ".ffn.mlp.fc1")
|
64 |
elif "gate" in key:
|
65 |
key = key.replace(".moe_layer.gate.wg", ".ffn.mlp.router.classifier")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
66 |
new_dict[key] = state_dict[old_key]
|
67 |
return new_dict
|
68 |
|
69 |
|
70 |
-
def shard_on_the_fly(
|
71 |
-
switch_checkpoint_path, dump_path, num_experts, dtype, weights_name: str = WEIGHTS_NAME
|
72 |
-
):
|
73 |
sharded_state_dicts = []
|
74 |
-
current_block = {}
|
75 |
total_size = 0
|
76 |
os.makedirs(dump_path, exist_ok=True)
|
77 |
|
@@ -105,7 +102,6 @@ def shard_on_the_fly(
|
|
105 |
|
106 |
# Otherwise, let's build the index
|
107 |
weight_map = {}
|
108 |
-
shards = {}
|
109 |
for idx, shard in enumerate(sharded_state_dicts):
|
110 |
shard_file = weights_name.replace(".bin", f"-{idx+1:05d}-of-{len(sharded_state_dicts):05d}.bin")
|
111 |
temp_filename = os.path.join(dump_path, weights_name.replace(".bin", f"-{idx+1:05d}-of-???.bin"))
|
@@ -143,23 +139,17 @@ if __name__ == "__main__":
|
|
143 |
help="Path to the output pytorch model.",
|
144 |
)
|
145 |
args = parser.parse_args()
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
|
154 |
config = NllbMoeConfig.from_pretrained(
|
155 |
-
"facebook/nllb-200-3.3B",
|
156 |
-
num_sparse_encoder_layers=4,
|
157 |
-
num_sparse_decoder_layers=4,
|
158 |
)
|
159 |
config.save_pretrained(args.pytorch_dump_folder_path)
|
160 |
-
|
161 |
-
|
162 |
-
model = NllbMoeModel(config)
|
163 |
model.save_pretrained(args.pytorch_dump_folder_path)
|
164 |
-
# model.push_to_hub("ArthurZ/nllb-moe-54b", use_auth_token="")
|
165 |
-
# model.save_pretrained(args.pytorch_dump_folder_path)
|
|
|
1 |
+
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
|
2 |
#
|
3 |
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
# you may not use this file except in compliance with the License.
|
|
|
23 |
from transformers.utils import WEIGHTS_INDEX_NAME, WEIGHTS_NAME
|
24 |
|
25 |
|
|
|
|
|
|
|
26 |
def remove_ignore_keys_(state_dict):
|
27 |
ignore_keys = [
|
28 |
"encoder.version",
|
|
|
45 |
return lin_layer
|
46 |
|
47 |
|
48 |
+
def rename_fairseq_keys(state_dict, expert_idx=None):
|
|
|
|
|
|
|
49 |
new_dict = {}
|
50 |
for old_key in state_dict.keys():
|
51 |
key = old_key
|
52 |
if "experts" in key:
|
53 |
+
key = key.replace("moe_layer.experts.0", f"ffn.mlp.experts.expert_{expert_idx}")
|
|
|
|
|
|
|
|
|
54 |
elif "gate" in key:
|
55 |
key = key.replace(".moe_layer.gate.wg", ".ffn.mlp.router.classifier")
|
56 |
+
if "fc2" and "experts" not in key:
|
57 |
+
key = key.replace(".fc2.", ".ffn.mlp.fc2.")
|
58 |
+
if "fc1" and "experts" not in key:
|
59 |
+
key = key.replace(".fc1.", ".ffn.mlp.fc1.")
|
60 |
+
if ".encoder_attn." in key:
|
61 |
+
key = key.replace(".encoder_attn.", ".cross_attention.")
|
62 |
+
if "encoder_attn_layer_norm" in key:
|
63 |
+
key = key.replace("encoder_attn_layer_norm", "cross_attention_layer_norm")
|
64 |
+
if "final_layer_norm" in key:
|
65 |
+
key = key.replace("final_layer_norm", "ffn.layer_norm")
|
66 |
new_dict[key] = state_dict[old_key]
|
67 |
return new_dict
|
68 |
|
69 |
|
70 |
+
def shard_on_the_fly(switch_checkpoint_path, dump_path, num_experts, dtype, weights_name: str = WEIGHTS_NAME):
|
|
|
|
|
71 |
sharded_state_dicts = []
|
|
|
72 |
total_size = 0
|
73 |
os.makedirs(dump_path, exist_ok=True)
|
74 |
|
|
|
102 |
|
103 |
# Otherwise, let's build the index
|
104 |
weight_map = {}
|
|
|
105 |
for idx, shard in enumerate(sharded_state_dicts):
|
106 |
shard_file = weights_name.replace(".bin", f"-{idx+1:05d}-of-{len(sharded_state_dicts):05d}.bin")
|
107 |
temp_filename = os.path.join(dump_path, weights_name.replace(".bin", f"-{idx+1:05d}-of-???.bin"))
|
|
|
139 |
help="Path to the output pytorch model.",
|
140 |
)
|
141 |
args = parser.parse_args()
|
142 |
+
metadata, index = shard_on_the_fly(
|
143 |
+
args.nllb_moe_checkpoint_path,
|
144 |
+
args.pytorch_dump_folder_path,
|
145 |
+
128,
|
146 |
+
args.dtype,
|
147 |
+
)
|
|
|
148 |
|
149 |
config = NllbMoeConfig.from_pretrained(
|
150 |
+
"facebook/nllb-200-3.3B", encoder_sparse_step=4, decoder_sparse_step=4, num_experts=128
|
|
|
|
|
151 |
)
|
152 |
config.save_pretrained(args.pytorch_dump_folder_path)
|
153 |
+
model = NllbMoeModel.from_pretrained(args.pytorch_dump_folder_path)
|
154 |
+
print("Done")
|
|
|
155 |
model.save_pretrained(args.pytorch_dump_folder_path)
|
|
|
|