ArthurZ HF staff commited on
Commit
7efae60
1 Parent(s): 0edb56d

Update convert_nllb_moe_sharded_original_checkpoint_to_pytorch.py

Browse files
convert_nllb_moe_sharded_original_checkpoint_to_pytorch.py CHANGED
@@ -1,4 +1,4 @@
1
- # Copyright 2021 The Fairseq Authors and The HuggingFace Inc. team. All rights reserved.
2
  #
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
  # you may not use this file except in compliance with the License.
@@ -23,9 +23,6 @@ from transformers.modeling_utils import dtype_byte_size
23
  from transformers.utils import WEIGHTS_INDEX_NAME, WEIGHTS_NAME
24
 
25
 
26
- # 'encoder.layers.7.moe_layer.experts.0.fc2.bias', 'encoder.layers.11.moe_layer.experts.0.fc1.weight',
27
-
28
-
29
  def remove_ignore_keys_(state_dict):
30
  ignore_keys = [
31
  "encoder.version",
@@ -48,30 +45,30 @@ def make_linear_from_emb(emb):
48
  return lin_layer
49
 
50
 
51
- def rename_fairseq_keys(state_dict, expert_idx = None):
52
- # 'encoder.layers.7.moe_layer.experts.0.fc2.bias' ->'encoder.layers.7.ffn.mlp.experts.0.fc2.bias'
53
- # 'encoder.layers.7.fc2.bias' -> 'encoder.layers.7.ffn.mlp.fc2.bias'
54
- # encoder.layers.7.wg -> encoder.layers.7.ffn.mlp.router.classifier
55
  new_dict = {}
56
  for old_key in state_dict.keys():
57
  key = old_key
58
  if "experts" in key:
59
- key = key.replace("moe_layer.experts.0", f"ffn.mlp.experts.{expert_idx}")
60
- elif "fc2" :
61
- key = key.replace(".fc2.", ".ffn.mlp.fc2")
62
- elif "fc1" :
63
- key = key.replace(".fc1.", ".ffn.mlp.fc1")
64
  elif "gate" in key:
65
  key = key.replace(".moe_layer.gate.wg", ".ffn.mlp.router.classifier")
 
 
 
 
 
 
 
 
 
 
66
  new_dict[key] = state_dict[old_key]
67
  return new_dict
68
 
69
 
70
- def shard_on_the_fly(
71
- switch_checkpoint_path, dump_path, num_experts, dtype, weights_name: str = WEIGHTS_NAME
72
- ):
73
  sharded_state_dicts = []
74
- current_block = {}
75
  total_size = 0
76
  os.makedirs(dump_path, exist_ok=True)
77
 
@@ -105,7 +102,6 @@ def shard_on_the_fly(
105
 
106
  # Otherwise, let's build the index
107
  weight_map = {}
108
- shards = {}
109
  for idx, shard in enumerate(sharded_state_dicts):
110
  shard_file = weights_name.replace(".bin", f"-{idx+1:05d}-of-{len(sharded_state_dicts):05d}.bin")
111
  temp_filename = os.path.join(dump_path, weights_name.replace(".bin", f"-{idx+1:05d}-of-???.bin"))
@@ -143,23 +139,17 @@ if __name__ == "__main__":
143
  help="Path to the output pytorch model.",
144
  )
145
  args = parser.parse_args()
146
- # metadata, index = shard_on_the_fly(
147
- # args.nllb_moe_checkpoint_path,
148
- # args.pytorch_dump_folder_path,
149
- # 128,
150
- # args.dtype,
151
- # )
152
-
153
 
154
  config = NllbMoeConfig.from_pretrained(
155
- "facebook/nllb-200-3.3B",
156
- num_sparse_encoder_layers=4,
157
- num_sparse_decoder_layers=4,
158
  )
159
  config.save_pretrained(args.pytorch_dump_folder_path)
160
-
161
-
162
- model = NllbMoeModel(config)
163
  model.save_pretrained(args.pytorch_dump_folder_path)
164
- # model.push_to_hub("ArthurZ/nllb-moe-54b", use_auth_token="")
165
- # model.save_pretrained(args.pytorch_dump_folder_path)
 
1
+ # Copyright 2023 The HuggingFace Inc. team. All rights reserved.
2
  #
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
  # you may not use this file except in compliance with the License.
 
23
  from transformers.utils import WEIGHTS_INDEX_NAME, WEIGHTS_NAME
24
 
25
 
 
 
 
26
  def remove_ignore_keys_(state_dict):
27
  ignore_keys = [
28
  "encoder.version",
 
45
  return lin_layer
46
 
47
 
48
+ def rename_fairseq_keys(state_dict, expert_idx=None):
 
 
 
49
  new_dict = {}
50
  for old_key in state_dict.keys():
51
  key = old_key
52
  if "experts" in key:
53
+ key = key.replace("moe_layer.experts.0", f"ffn.mlp.experts.expert_{expert_idx}")
 
 
 
 
54
  elif "gate" in key:
55
  key = key.replace(".moe_layer.gate.wg", ".ffn.mlp.router.classifier")
56
+ if "fc2" and "experts" not in key:
57
+ key = key.replace(".fc2.", ".ffn.mlp.fc2.")
58
+ if "fc1" and "experts" not in key:
59
+ key = key.replace(".fc1.", ".ffn.mlp.fc1.")
60
+ if ".encoder_attn." in key:
61
+ key = key.replace(".encoder_attn.", ".cross_attention.")
62
+ if "encoder_attn_layer_norm" in key:
63
+ key = key.replace("encoder_attn_layer_norm", "cross_attention_layer_norm")
64
+ if "final_layer_norm" in key:
65
+ key = key.replace("final_layer_norm", "ffn.layer_norm")
66
  new_dict[key] = state_dict[old_key]
67
  return new_dict
68
 
69
 
70
+ def shard_on_the_fly(switch_checkpoint_path, dump_path, num_experts, dtype, weights_name: str = WEIGHTS_NAME):
 
 
71
  sharded_state_dicts = []
 
72
  total_size = 0
73
  os.makedirs(dump_path, exist_ok=True)
74
 
 
102
 
103
  # Otherwise, let's build the index
104
  weight_map = {}
 
105
  for idx, shard in enumerate(sharded_state_dicts):
106
  shard_file = weights_name.replace(".bin", f"-{idx+1:05d}-of-{len(sharded_state_dicts):05d}.bin")
107
  temp_filename = os.path.join(dump_path, weights_name.replace(".bin", f"-{idx+1:05d}-of-???.bin"))
 
139
  help="Path to the output pytorch model.",
140
  )
141
  args = parser.parse_args()
142
+ metadata, index = shard_on_the_fly(
143
+ args.nllb_moe_checkpoint_path,
144
+ args.pytorch_dump_folder_path,
145
+ 128,
146
+ args.dtype,
147
+ )
 
148
 
149
  config = NllbMoeConfig.from_pretrained(
150
+ "facebook/nllb-200-3.3B", encoder_sparse_step=4, decoder_sparse_step=4, num_experts=128
 
 
151
  )
152
  config.save_pretrained(args.pytorch_dump_folder_path)
153
+ model = NllbMoeModel.from_pretrained(args.pytorch_dump_folder_path)
154
+ print("Done")
 
155
  model.save_pretrained(args.pytorch_dump_folder_path)