Spico commited on
Commit
c1ac78d
1 Parent(s): 86a52dd

Upload folder using huggingface_hub

Browse files
README.md ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # LLaMA-MoE-v1-3.0B (2/16)
2
+
3
+ [[💻 Code]](https://github.com/pjlab-sys4nlp/llama-moe) | [[📜 Technical Report]]()
4
+
5
+ 👋 Very nice to meet you here~
6
+
7
+ ❤️ This repo contains the model `LLaMA-MoE-v1-3.0B (2/16)`, which activates 2 out of 16 experts (3.0B parameters).
8
+ This model is NOT fine-tuned by instruction pairs, so it may not be good enough to act like a chatbot.
9
+
10
+ 📢 LLaMA-MoE is a series of Mixture-of-Expert (MoE) models based on [LLaMA-2](https://huggingface.co/meta-llama/Llama-2-7b-hf).
11
+ You can find the code for training this model at [this repo](https://github.com/pjlab-sys4nlp/llama-moe).
12
+
13
+ 💎 This series of models are obtained by partitioning original LLaMA FFNs into experts and further continual pre-training.
14
+ The total model size is only 6.7B parameters, which is very convenient for deployment and research usage.
15
+ More details could be found at [our technical report](https://arxiv.org/).
16
+
17
+ ## 🚀 QuickStart
18
+
19
+ ```python
20
+ import torch
21
+ from transformers import AutoTokenizer, AutoModelForCausalLM
22
+
23
+ model_dir = "llama-moe/LLaMA-MoE-v1-3_0B-2_16"
24
+ tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True)
25
+ model = AutoModelForCausalLM.from_pretrained(model_dir, torch_dtype=torch.bfloat16, trust_remote_code=True)
26
+ model.eval()
27
+ model.to("cuda:0")
28
+
29
+ input_text = "Suzhou is famous of"
30
+ inputs = tokenizer(input_text, return_tensors="pt")
31
+ inputs = inputs.to("cuda:0")
32
+
33
+ pred = model.generate(**inputs, max_length=50, temperature=0.0)
34
+ print(tokenizer.decode(pred.cpu()[0], skip_special_tokens=True))
35
+ # Suzhou is famous of its beautiful gardens. The most famous one is the Humble Administrator's Garden. It is a classical Chinese garden with a history of more than 600 years. The garden is divided into three
36
+ ```
37
+
38
+ ## 📊 Performance
39
+
40
+ | Model | \#Activated Experts | \#Experts | \#Activated Params | Links |
41
+ | :------------------------ | :-----------------: | :-------: | :----------------: | :-----------------------------------------------------------------------: |
42
+ | **LLaMA-MoE-3.0B** | 2 | 16 | 3.0B | [[🤗 HF Weights]](https://huggingface.co/llama-moe/LLaMA-MoE-v1-3_0B-2_16) |
43
+ | **LLaMA-MoE-3.5B (4/16)** | 4 | 16 | 3.5B | [[🤗 HF Weights]](https://huggingface.co/llama-moe/LLaMA-MoE-v1-3_5B-4_16) |
44
+ | **LLaMA-MoE-3.5B (2/8)** | 2 | 8 | 3.5B | [[🤗 HF Weights]](https://huggingface.co/llama-moe/LLaMA-MoE-v1-3_5B-2_8) |
45
+
46
+ | Model | SciQ | PIQA | WinoGrande | ARC-e | ARC-c (25) | HellaSwag (10) | LogiQA | BoolQ (32) | LAMBADA | NQ (32) | MMNLU (5) | Average |
47
+ | :------------------------------------------------------------------------------------ | :------: | :------: | :--------: | :------: | :--------: | :------------: | :------: | :--------: | :------: | :------: | :-------: | :-----: |
48
+ | [OPT-2.7B](https://huggingface.co/facebook/opt-2.7b) | 78.9 | 74.8 | 60.8 | 54.4 | 34.0 | 61.4 | 25.8 | 63.3 | 63.6 | 10.7 | 25.8 | 50.3 |
49
+ | [Pythia-2.8B](https://huggingface.co/EleutherAI/pythia-2.8b) | 83.2 | 73.6 | 59.6 | 58.8 | 36.7 | 60.7 | 28.1 | 65.9 | 64.6 | 8.7 | 26.8 | 51.5 |
50
+ | [INCITE-BASE-3B](https://huggingface.co/togethercomputer/RedPajama-INCITE-Base-3B-v1) | 85.6 | 73.9 | 63.5 | 61.7 | 40.3 | 64.7 | 27.5 | 65.8 | 65.4 | 15.2 | 27.2 | 53.7 |
51
+ | [Open-LLaMA-3B-v2](https://huggingface.co/openlm-research/open_llama_3b_v2) | 88.0 | 77.9 | 63.1 | 63.3 | 40.1 | 71.4 | 28.1 | 69.2 | 67.4 | 16.0 | 26.8 | 55.6 |
52
+ | [Sheared-LLaMA-2.7B](https://huggingface.co/princeton-nlp/Sheared-LLaMA-2.7B) | 87.5 | 76.9 | 65.0 | 63.3 | 41.6 | 71.0 | 28.3 | 73.6 | 68.3 | 17.6 | **27.3** | 56.4 |
53
+ | **LLaMA-MoE-3.0B** | 84.2 | 77.5 | 63.6 | 60.2 | 40.9 | 70.8 | **30.6** | 71.9 | 66.6 | 17.0 | 26.8 | 55.5 |
54
+ | **LLaMA-MoE-3.5B (4/16)** | 87.6 | **77.9** | 65.5 | **65.6** | **44.2** | **73.3** | 29.7 | **75.0** | **69.5** | **20.3** | 26.8 | 57.7 |
55
+ | **LLaMA-MoE-3.5B (2/8)** | **88.4** | 77.6 | **66.7** | 65.3 | 43.1 | **73.3** | 29.6 | 73.9 | 69.4 | 19.8 | 27.0 | 57.6 |
56
+
57
+ ## 📖 Details
58
+
59
+ Training Data: 200B tokens from [SlimPajama](https://www.cerebras.net/blog/slimpajama-a-627b-token-cleaned-and-deduplicated-version-of-redpajama) with the same data sampling weights as [Sheared LLaMA](https://arxiv.org/abs/2310.06694).
60
+
61
+ ## 📃 Citation
62
+
63
+ ```bibtex
64
+ @article{llama-moe,
65
+ title={LLaMA-MoE: Building Mixture-of-Experts from LLaMA with Continual Pre-training},
66
+ author={LLaMA-MoE Team},
67
+ journal={arXiv},
68
+ year={2023},
69
+ volume={abs/},
70
+ url={https://arxiv.org}
71
+ }
72
+ ```
config.json ADDED
@@ -0,0 +1,627 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "llama-moe/LLaMA-MoE-v1-3_0B-2_16",
3
+ "model_type": "llama_moe",
4
+ "add_weight_norm": false,
5
+ "architectures": [
6
+ "LlamaMoEForCausalLM"
7
+ ],
8
+ "auto_map": {
9
+ "AutoConfig": "configuration_llama_moe.LlamaMoEConfig",
10
+ "AutoModel": "modeling_llama_moe_hf.LlamaMoEModel",
11
+ "AutoModelForCausalLM": "modeling_llama_moe_hf.LlamaMoEForCausalLM"
12
+ },
13
+ "bos_token_id": 1,
14
+ "calculator_type": "UniversalCalculator",
15
+ "capacity_factor": 1.25,
16
+ "drop_tokens": true,
17
+ "dropped_padding": "zero",
18
+ "eos_token_id": 2,
19
+ "gate_add_noise": true,
20
+ "gate_balance_loss_weight": 0.01,
21
+ "gate_network": "mlp",
22
+ "gate_noise_epsilon": 0.01,
23
+ "gate_type": "TopKBalancedNoisyGate",
24
+ "gate_use_balance": true,
25
+ "gate_use_softmax": true,
26
+ "gates": "mlp",
27
+ "hidden_act": "silu",
28
+ "hidden_size": 4096,
29
+ "initializer_range": 0.02,
30
+ "intermediate_size": 11008,
31
+ "max_position_embeddings": 4096,
32
+ "multiply_gate_scores": true,
33
+ "num_attention_heads": 32,
34
+ "num_experts": 16,
35
+ "num_hidden_layers": 32,
36
+ "num_key_value_heads": 32,
37
+ "num_selects": 2,
38
+ "pad_token_id": 0,
39
+ "pretraining_tp": 1,
40
+ "rms_norm_eps": 1e-05,
41
+ "rope_scaling": null,
42
+ "score_scale_factor": 8.0,
43
+ "attention_dropout": 0.0,
44
+ "size_experts": [
45
+ [
46
+ 688,
47
+ 688,
48
+ 688,
49
+ 688,
50
+ 688,
51
+ 688,
52
+ 688,
53
+ 688,
54
+ 688,
55
+ 688,
56
+ 688,
57
+ 688,
58
+ 688,
59
+ 688,
60
+ 688,
61
+ 688
62
+ ],
63
+ [
64
+ 688,
65
+ 688,
66
+ 688,
67
+ 688,
68
+ 688,
69
+ 688,
70
+ 688,
71
+ 688,
72
+ 688,
73
+ 688,
74
+ 688,
75
+ 688,
76
+ 688,
77
+ 688,
78
+ 688,
79
+ 688
80
+ ],
81
+ [
82
+ 688,
83
+ 688,
84
+ 688,
85
+ 688,
86
+ 688,
87
+ 688,
88
+ 688,
89
+ 688,
90
+ 688,
91
+ 688,
92
+ 688,
93
+ 688,
94
+ 688,
95
+ 688,
96
+ 688,
97
+ 688
98
+ ],
99
+ [
100
+ 688,
101
+ 688,
102
+ 688,
103
+ 688,
104
+ 688,
105
+ 688,
106
+ 688,
107
+ 688,
108
+ 688,
109
+ 688,
110
+ 688,
111
+ 688,
112
+ 688,
113
+ 688,
114
+ 688,
115
+ 688
116
+ ],
117
+ [
118
+ 688,
119
+ 688,
120
+ 688,
121
+ 688,
122
+ 688,
123
+ 688,
124
+ 688,
125
+ 688,
126
+ 688,
127
+ 688,
128
+ 688,
129
+ 688,
130
+ 688,
131
+ 688,
132
+ 688,
133
+ 688
134
+ ],
135
+ [
136
+ 688,
137
+ 688,
138
+ 688,
139
+ 688,
140
+ 688,
141
+ 688,
142
+ 688,
143
+ 688,
144
+ 688,
145
+ 688,
146
+ 688,
147
+ 688,
148
+ 688,
149
+ 688,
150
+ 688,
151
+ 688
152
+ ],
153
+ [
154
+ 688,
155
+ 688,
156
+ 688,
157
+ 688,
158
+ 688,
159
+ 688,
160
+ 688,
161
+ 688,
162
+ 688,
163
+ 688,
164
+ 688,
165
+ 688,
166
+ 688,
167
+ 688,
168
+ 688,
169
+ 688
170
+ ],
171
+ [
172
+ 688,
173
+ 688,
174
+ 688,
175
+ 688,
176
+ 688,
177
+ 688,
178
+ 688,
179
+ 688,
180
+ 688,
181
+ 688,
182
+ 688,
183
+ 688,
184
+ 688,
185
+ 688,
186
+ 688,
187
+ 688
188
+ ],
189
+ [
190
+ 688,
191
+ 688,
192
+ 688,
193
+ 688,
194
+ 688,
195
+ 688,
196
+ 688,
197
+ 688,
198
+ 688,
199
+ 688,
200
+ 688,
201
+ 688,
202
+ 688,
203
+ 688,
204
+ 688,
205
+ 688
206
+ ],
207
+ [
208
+ 688,
209
+ 688,
210
+ 688,
211
+ 688,
212
+ 688,
213
+ 688,
214
+ 688,
215
+ 688,
216
+ 688,
217
+ 688,
218
+ 688,
219
+ 688,
220
+ 688,
221
+ 688,
222
+ 688,
223
+ 688
224
+ ],
225
+ [
226
+ 688,
227
+ 688,
228
+ 688,
229
+ 688,
230
+ 688,
231
+ 688,
232
+ 688,
233
+ 688,
234
+ 688,
235
+ 688,
236
+ 688,
237
+ 688,
238
+ 688,
239
+ 688,
240
+ 688,
241
+ 688
242
+ ],
243
+ [
244
+ 688,
245
+ 688,
246
+ 688,
247
+ 688,
248
+ 688,
249
+ 688,
250
+ 688,
251
+ 688,
252
+ 688,
253
+ 688,
254
+ 688,
255
+ 688,
256
+ 688,
257
+ 688,
258
+ 688,
259
+ 688
260
+ ],
261
+ [
262
+ 688,
263
+ 688,
264
+ 688,
265
+ 688,
266
+ 688,
267
+ 688,
268
+ 688,
269
+ 688,
270
+ 688,
271
+ 688,
272
+ 688,
273
+ 688,
274
+ 688,
275
+ 688,
276
+ 688,
277
+ 688
278
+ ],
279
+ [
280
+ 688,
281
+ 688,
282
+ 688,
283
+ 688,
284
+ 688,
285
+ 688,
286
+ 688,
287
+ 688,
288
+ 688,
289
+ 688,
290
+ 688,
291
+ 688,
292
+ 688,
293
+ 688,
294
+ 688,
295
+ 688
296
+ ],
297
+ [
298
+ 688,
299
+ 688,
300
+ 688,
301
+ 688,
302
+ 688,
303
+ 688,
304
+ 688,
305
+ 688,
306
+ 688,
307
+ 688,
308
+ 688,
309
+ 688,
310
+ 688,
311
+ 688,
312
+ 688,
313
+ 688
314
+ ],
315
+ [
316
+ 688,
317
+ 688,
318
+ 688,
319
+ 688,
320
+ 688,
321
+ 688,
322
+ 688,
323
+ 688,
324
+ 688,
325
+ 688,
326
+ 688,
327
+ 688,
328
+ 688,
329
+ 688,
330
+ 688,
331
+ 688
332
+ ],
333
+ [
334
+ 688,
335
+ 688,
336
+ 688,
337
+ 688,
338
+ 688,
339
+ 688,
340
+ 688,
341
+ 688,
342
+ 688,
343
+ 688,
344
+ 688,
345
+ 688,
346
+ 688,
347
+ 688,
348
+ 688,
349
+ 688
350
+ ],
351
+ [
352
+ 688,
353
+ 688,
354
+ 688,
355
+ 688,
356
+ 688,
357
+ 688,
358
+ 688,
359
+ 688,
360
+ 688,
361
+ 688,
362
+ 688,
363
+ 688,
364
+ 688,
365
+ 688,
366
+ 688,
367
+ 688
368
+ ],
369
+ [
370
+ 688,
371
+ 688,
372
+ 688,
373
+ 688,
374
+ 688,
375
+ 688,
376
+ 688,
377
+ 688,
378
+ 688,
379
+ 688,
380
+ 688,
381
+ 688,
382
+ 688,
383
+ 688,
384
+ 688,
385
+ 688
386
+ ],
387
+ [
388
+ 688,
389
+ 688,
390
+ 688,
391
+ 688,
392
+ 688,
393
+ 688,
394
+ 688,
395
+ 688,
396
+ 688,
397
+ 688,
398
+ 688,
399
+ 688,
400
+ 688,
401
+ 688,
402
+ 688,
403
+ 688
404
+ ],
405
+ [
406
+ 688,
407
+ 688,
408
+ 688,
409
+ 688,
410
+ 688,
411
+ 688,
412
+ 688,
413
+ 688,
414
+ 688,
415
+ 688,
416
+ 688,
417
+ 688,
418
+ 688,
419
+ 688,
420
+ 688,
421
+ 688
422
+ ],
423
+ [
424
+ 688,
425
+ 688,
426
+ 688,
427
+ 688,
428
+ 688,
429
+ 688,
430
+ 688,
431
+ 688,
432
+ 688,
433
+ 688,
434
+ 688,
435
+ 688,
436
+ 688,
437
+ 688,
438
+ 688,
439
+ 688
440
+ ],
441
+ [
442
+ 688,
443
+ 688,
444
+ 688,
445
+ 688,
446
+ 688,
447
+ 688,
448
+ 688,
449
+ 688,
450
+ 688,
451
+ 688,
452
+ 688,
453
+ 688,
454
+ 688,
455
+ 688,
456
+ 688,
457
+ 688
458
+ ],
459
+ [
460
+ 688,
461
+ 688,
462
+ 688,
463
+ 688,
464
+ 688,
465
+ 688,
466
+ 688,
467
+ 688,
468
+ 688,
469
+ 688,
470
+ 688,
471
+ 688,
472
+ 688,
473
+ 688,
474
+ 688,
475
+ 688
476
+ ],
477
+ [
478
+ 688,
479
+ 688,
480
+ 688,
481
+ 688,
482
+ 688,
483
+ 688,
484
+ 688,
485
+ 688,
486
+ 688,
487
+ 688,
488
+ 688,
489
+ 688,
490
+ 688,
491
+ 688,
492
+ 688,
493
+ 688
494
+ ],
495
+ [
496
+ 688,
497
+ 688,
498
+ 688,
499
+ 688,
500
+ 688,
501
+ 688,
502
+ 688,
503
+ 688,
504
+ 688,
505
+ 688,
506
+ 688,
507
+ 688,
508
+ 688,
509
+ 688,
510
+ 688,
511
+ 688
512
+ ],
513
+ [
514
+ 688,
515
+ 688,
516
+ 688,
517
+ 688,
518
+ 688,
519
+ 688,
520
+ 688,
521
+ 688,
522
+ 688,
523
+ 688,
524
+ 688,
525
+ 688,
526
+ 688,
527
+ 688,
528
+ 688,
529
+ 688
530
+ ],
531
+ [
532
+ 688,
533
+ 688,
534
+ 688,
535
+ 688,
536
+ 688,
537
+ 688,
538
+ 688,
539
+ 688,
540
+ 688,
541
+ 688,
542
+ 688,
543
+ 688,
544
+ 688,
545
+ 688,
546
+ 688,
547
+ 688
548
+ ],
549
+ [
550
+ 688,
551
+ 688,
552
+ 688,
553
+ 688,
554
+ 688,
555
+ 688,
556
+ 688,
557
+ 688,
558
+ 688,
559
+ 688,
560
+ 688,
561
+ 688,
562
+ 688,
563
+ 688,
564
+ 688,
565
+ 688
566
+ ],
567
+ [
568
+ 688,
569
+ 688,
570
+ 688,
571
+ 688,
572
+ 688,
573
+ 688,
574
+ 688,
575
+ 688,
576
+ 688,
577
+ 688,
578
+ 688,
579
+ 688,
580
+ 688,
581
+ 688,
582
+ 688,
583
+ 688
584
+ ],
585
+ [
586
+ 688,
587
+ 688,
588
+ 688,
589
+ 688,
590
+ 688,
591
+ 688,
592
+ 688,
593
+ 688,
594
+ 688,
595
+ 688,
596
+ 688,
597
+ 688,
598
+ 688,
599
+ 688,
600
+ 688,
601
+ 688
602
+ ],
603
+ [
604
+ 688,
605
+ 688,
606
+ 688,
607
+ 688,
608
+ 688,
609
+ 688,
610
+ 688,
611
+ 688,
612
+ 688,
613
+ 688,
614
+ 688,
615
+ 688,
616
+ 688,
617
+ 688,
618
+ 688,
619
+ 688
620
+ ]
621
+ ],
622
+ "tie_word_embeddings": false,
623
+ "torch_dtype": "bfloat16",
624
+ "transformers_version": "4.31.0",
625
+ "use_cache": true,
626
+ "vocab_size": 32000
627
+ }
configuration_llama_moe.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers.configuration_utils import PretrainedConfig
2
+
3
+
4
+ class LlamaMoEConfig(PretrainedConfig):
5
+ model_type = "llama_moe"
6
+ keys_to_ignore_at_inference = ["past_key_values"]
7
+
8
+ def __init__(
9
+ self,
10
+ vocab_size=32000,
11
+ hidden_size=4096,
12
+ intermediate_size=11008,
13
+ num_hidden_layers=32,
14
+ num_attention_heads=32,
15
+ num_key_value_heads=None,
16
+ hidden_act="silu",
17
+ max_position_embeddings=2048,
18
+ initializer_range=0.02,
19
+ rms_norm_eps=1e-6,
20
+ use_cache=True,
21
+ pad_token_id=0,
22
+ bos_token_id=1,
23
+ eos_token_id=2,
24
+ pretraining_tp=1,
25
+ tie_word_embeddings=False,
26
+ rope_scaling=None,
27
+ # -------- moe expert configs --------
28
+ num_experts=16,
29
+ num_selects=4,
30
+ size_experts=None,
31
+ # -------- moe gate configs --------
32
+ gate_type="TopKBalancedNoisyGate",
33
+ gate_network="mlp",
34
+ gate_use_softmax=True,
35
+ gate_use_balance=True,
36
+ gate_balance_loss_weight=1e-2,
37
+ gate_add_noise=True,
38
+ # TopKBalancedNoisyGate
39
+ gate_noise_epsilon=1e-2,
40
+ # -------- moe calculator configs --------
41
+ calculator_type="UniversalCalculator",
42
+ multiply_gate_scores=True,
43
+ score_scale_factor=1.0,
44
+ add_weight_norm=False,
45
+ # SwitchDropTokenCalculator
46
+ drop_tokens=True,
47
+ dropped_padding="zero",
48
+ capacity_factor=1.25,
49
+ **kwargs,
50
+ ):
51
+ self.vocab_size = vocab_size
52
+ self.max_position_embeddings = max_position_embeddings
53
+ self.hidden_size = hidden_size
54
+ self.intermediate_size = intermediate_size
55
+ self.num_hidden_layers = num_hidden_layers
56
+ self.num_attention_heads = num_attention_heads
57
+ self.hidden_act = hidden_act
58
+ self.initializer_range = initializer_range
59
+ self.rms_norm_eps = rms_norm_eps
60
+ self.pretraining_tp = pretraining_tp
61
+ self.use_cache = use_cache
62
+ self.rope_scaling = rope_scaling
63
+ self._rope_scaling_validation()
64
+
65
+ self.num_experts = num_experts
66
+ self.num_selects = num_selects
67
+ self.size_experts = size_experts
68
+
69
+ self.gate_type = gate_type
70
+ self.gate_network = gate_network
71
+ self.gate_use_softmax = gate_use_softmax
72
+ self.gate_use_balance = gate_use_balance
73
+ self.gate_balance_loss_weight = gate_balance_loss_weight
74
+ self.gate_add_noise = gate_add_noise
75
+ self.gate_noise_epsilon = gate_noise_epsilon
76
+
77
+ self.calculator_type = calculator_type
78
+ self.multiply_gate_scores = multiply_gate_scores
79
+ self.score_scale_factor = score_scale_factor
80
+ self.add_weight_norm = add_weight_norm
81
+ self.drop_tokens = drop_tokens
82
+ self.dropped_padding = dropped_padding
83
+ self.capacity_factor = capacity_factor
84
+
85
+ # for backward compatibility
86
+ if num_key_value_heads is None:
87
+ num_key_value_heads = num_attention_heads
88
+
89
+ self.num_key_value_heads = num_key_value_heads
90
+
91
+ super().__init__(
92
+ pad_token_id=pad_token_id,
93
+ bos_token_id=bos_token_id,
94
+ eos_token_id=eos_token_id,
95
+ tie_word_embeddings=tie_word_embeddings,
96
+ **kwargs,
97
+ )
98
+
99
+ def _rope_scaling_validation(self):
100
+ """
101
+ Validate the `rope_scaling` configuration.
102
+ """
103
+ if self.rope_scaling is None:
104
+ return
105
+
106
+ if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2:
107
+ raise ValueError(
108
+ "`rope_scaling` must be a dictionary with with two fields, `name` and `factor`, "
109
+ f"got {self.rope_scaling}"
110
+ )
111
+ rope_scaling_type = self.rope_scaling.get("type", None)
112
+ rope_scaling_factor = self.rope_scaling.get("factor", None)
113
+ if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]:
114
+ raise ValueError(
115
+ f"`rope_scaling`'s name field must be one of ['linear', 'dynamic'], got {rope_scaling_type}"
116
+ )
117
+ if (
118
+ rope_scaling_factor is None
119
+ or not isinstance(rope_scaling_factor, float)
120
+ or rope_scaling_factor <= 1.0
121
+ ):
122
+ raise ValueError(
123
+ f"`rope_scaling`'s factor field must be an float > 1, got {rope_scaling_factor}"
124
+ )
generation_config.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "bos_token_id": 1,
4
+ "eos_token_id": 2,
5
+ "pad_token_id": 0,
6
+ "transformers_version": "4.31.0"
7
+ }
modeling_llama_moe_hf.py ADDED
@@ -0,0 +1,1664 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import warnings
3
+ from dataclasses import dataclass
4
+ from typing import Optional, Tuple
5
+
6
+ import torch
7
+ import torch.utils.checkpoint
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+ from torch.distributions.normal import Normal
11
+ from transformers.modeling_outputs import (
12
+ CausalLMOutputWithPast,
13
+ )
14
+ from transformers.modeling_utils import PreTrainedModel
15
+ from transformers.activations import ACT2FN
16
+ from transformers.utils import ModelOutput, logging
17
+
18
+ from .configuration_llama_moe import LlamaMoEConfig
19
+
20
+ logger = logging.get_logger(__name__)
21
+
22
+ _CONFIG_FOR_DOC = "LlamaMoEConfig"
23
+
24
+
25
+ @dataclass
26
+ class CalculatorOutput(ModelOutput):
27
+ hidden_states: Optional[torch.FloatTensor] = None
28
+ num_dropped_tokens: Optional[int] = None
29
+
30
+
31
+ @dataclass
32
+ class BaseMoEModelOutputWithPast(ModelOutput):
33
+ """
34
+ Args:
35
+ num_dropped_tokens: layer idx to the number of dropped tokens
36
+ """
37
+
38
+ last_hidden_state: torch.FloatTensor = None
39
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
40
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
41
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
42
+ balance_loss: Optional[float] = None
43
+ num_dropped_tokens: Optional[Tuple[torch.Tensor]] = None
44
+ gate_load: Optional[Tuple[list]] = None
45
+ gate_importance: Optional[Tuple[list]] = None
46
+
47
+
48
+ @dataclass
49
+ class MoECausalLMOutputWithPast(CausalLMOutputWithPast):
50
+ balance_loss: Optional[float] = None
51
+ num_dropped_tokens: Optional[Tuple[int]] = None
52
+ gate_load: Optional[Tuple[list[torch.Tensor]]] = None
53
+ gate_importance: Optional[Tuple[list[torch.Tensor]]] = None
54
+
55
+
56
+ @dataclass
57
+ class MoEMlpOutput(ModelOutput):
58
+ hidden_states: Optional[torch.FloatTensor] = None
59
+ balance_loss: Optional[torch.FloatTensor] = None
60
+ num_dropped_tokens: Optional[int] = None
61
+ gate_load: Optional[list] = None
62
+ gate_importance: Optional[list] = None
63
+
64
+
65
+ def _make_causal_mask(
66
+ input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
67
+ ):
68
+ """
69
+ Make causal mask used for bi-directional self-attention.
70
+ """
71
+ bsz, tgt_len = input_ids_shape
72
+ mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)
73
+ mask_cond = torch.arange(mask.size(-1), device=device)
74
+ mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
75
+ mask = mask.to(dtype)
76
+
77
+ if past_key_values_length > 0:
78
+ mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
79
+ return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
80
+
81
+
82
+ # Copied from transformers.models.bart.modeling_bart._expand_mask
83
+ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
84
+ """
85
+ Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
86
+ """
87
+ bsz, src_len = mask.size()
88
+ tgt_len = tgt_len if tgt_len is not None else src_len
89
+
90
+ expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
91
+
92
+ inverted_mask = 1.0 - expanded_mask
93
+
94
+ return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
95
+
96
+
97
+ class LlamaRMSNorm(nn.Module):
98
+ def __init__(self, hidden_size, eps=1e-6):
99
+ """
100
+ LlamaRMSNorm is equivalent to T5LayerNorm
101
+ """
102
+ super().__init__()
103
+ self.weight = nn.Parameter(torch.ones(hidden_size))
104
+ self.variance_epsilon = eps
105
+
106
+ def forward(self, hidden_states):
107
+ input_dtype = hidden_states.dtype
108
+ hidden_states = hidden_states.to(torch.float32)
109
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
110
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
111
+ return self.weight * hidden_states.to(input_dtype)
112
+
113
+
114
+ class LlamaRotaryEmbedding(torch.nn.Module):
115
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
116
+ super().__init__()
117
+
118
+ self.dim = dim
119
+ self.max_position_embeddings = max_position_embeddings
120
+ self.base = base
121
+ inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
122
+ self.register_buffer("inv_freq", inv_freq)
123
+
124
+ # Build here to make `torch.jit.trace` work.
125
+ self._set_cos_sin_cache(
126
+ seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
127
+ )
128
+
129
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
130
+ self.max_seq_len_cached = seq_len
131
+ t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
132
+
133
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq)
134
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
135
+ emb = torch.cat((freqs, freqs), dim=-1)
136
+ self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
137
+ self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)
138
+
139
+ def forward(self, x, seq_len=None):
140
+ # x: [bs, num_attention_heads, seq_len, head_size]
141
+ if seq_len > self.max_seq_len_cached:
142
+ self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
143
+
144
+ return (
145
+ self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
146
+ self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
147
+ )
148
+
149
+
150
+ class LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding):
151
+ """LlamaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
152
+
153
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
154
+ self.scaling_factor = scaling_factor
155
+ super().__init__(dim, max_position_embeddings, base, device)
156
+
157
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
158
+ self.max_seq_len_cached = seq_len
159
+ t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
160
+ t = t / self.scaling_factor
161
+
162
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq)
163
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
164
+ emb = torch.cat((freqs, freqs), dim=-1)
165
+ self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
166
+ self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)
167
+
168
+
169
+ class LlamaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding):
170
+ """LlamaRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
171
+
172
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
173
+ self.scaling_factor = scaling_factor
174
+ super().__init__(dim, max_position_embeddings, base, device)
175
+
176
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
177
+ self.max_seq_len_cached = seq_len
178
+
179
+ if seq_len > self.max_position_embeddings:
180
+ base = self.base * (
181
+ (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1)
182
+ ) ** (self.dim / (self.dim - 2))
183
+ inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
184
+ self.register_buffer("inv_freq", inv_freq)
185
+
186
+ t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
187
+
188
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq)
189
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
190
+ emb = torch.cat((freqs, freqs), dim=-1)
191
+ self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
192
+ self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)
193
+
194
+
195
+ def rotate_half(x):
196
+ """Rotates half the hidden dims of the input."""
197
+ x1 = x[..., : x.shape[-1] // 2]
198
+ x2 = x[..., x.shape[-1] // 2 :]
199
+ return torch.cat((-x2, x1), dim=-1)
200
+
201
+
202
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
203
+ # The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
204
+ cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
205
+ sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
206
+ cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
207
+ sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
208
+ q_embed = (q * cos) + (rotate_half(q) * sin)
209
+ k_embed = (k * cos) + (rotate_half(k) * sin)
210
+ return q_embed, k_embed
211
+
212
+
213
+ class LlamaMLP(nn.Module):
214
+ def __init__(self, config):
215
+ super().__init__()
216
+ self.pretraining_tp = config.pretraining_tp
217
+ self.hidden_size = config.hidden_size
218
+ self.intermediate_size = config.intermediate_size
219
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
220
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
221
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
222
+ self.act_fn = ACT2FN[config.hidden_act]
223
+
224
+ def forward(self, x):
225
+ if self.pretraining_tp > 1:
226
+ slice = self.intermediate_size // self.pretraining_tp
227
+ gate_proj_slices = self.gate_proj.weight.split(slice, dim=0)
228
+ up_proj_slices = self.up_proj.weight.split(slice, dim=0)
229
+ down_proj_slices = self.down_proj.weight.split(slice, dim=1)
230
+
231
+ gate_proj = torch.cat([F.linear(x, gate_proj_slices[i]) for i in range(self.pretraining_tp)], dim=-1)
232
+ up_proj = torch.cat([F.linear(x, up_proj_slices[i]) for i in range(self.pretraining_tp)], dim=-1)
233
+
234
+ intermediate_states = (self.act_fn(gate_proj) * up_proj).split(slice, dim=2)
235
+ down_proj = [F.linear(intermediate_states[i], down_proj_slices[i]) for i in range(self.pretraining_tp)]
236
+ down_proj = sum(down_proj)
237
+ else:
238
+ down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
239
+
240
+ return down_proj
241
+
242
+
243
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
244
+ """
245
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
246
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
247
+ """
248
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
249
+ if n_rep == 1:
250
+ return hidden_states
251
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
252
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
253
+
254
+
255
+ class LlamaAttention(nn.Module):
256
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
257
+
258
+ def __init__(self, config: LlamaMoEConfig):
259
+ super().__init__()
260
+ self.config = config
261
+ self.hidden_size = config.hidden_size
262
+ self.num_heads = config.num_attention_heads
263
+ self.head_dim = self.hidden_size // self.num_heads
264
+ self.num_key_value_heads = config.num_key_value_heads
265
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
266
+ self.pretraining_tp = config.pretraining_tp
267
+ self.max_position_embeddings = config.max_position_embeddings
268
+
269
+ if (self.head_dim * self.num_heads) != self.hidden_size:
270
+ raise ValueError(
271
+ f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
272
+ f" and `num_heads`: {self.num_heads})."
273
+ )
274
+ self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
275
+ self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
276
+ self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
277
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
278
+ self._init_rope()
279
+
280
+ def _init_rope(self):
281
+ if self.config.rope_scaling is None:
282
+ self.rotary_emb = LlamaRotaryEmbedding(self.head_dim, max_position_embeddings=self.max_position_embeddings)
283
+ else:
284
+ scaling_type = self.config.rope_scaling["type"]
285
+ scaling_factor = self.config.rope_scaling["factor"]
286
+ if scaling_type == "linear":
287
+ self.rotary_emb = LlamaLinearScalingRotaryEmbedding(
288
+ self.head_dim, max_position_embeddings=self.max_position_embeddings, scaling_factor=scaling_factor
289
+ )
290
+ elif scaling_type == "dynamic":
291
+ self.rotary_emb = LlamaDynamicNTKScalingRotaryEmbedding(
292
+ self.head_dim, max_position_embeddings=self.max_position_embeddings, scaling_factor=scaling_factor
293
+ )
294
+ else:
295
+ raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
296
+
297
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
298
+ return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
299
+
300
+ def forward(
301
+ self,
302
+ hidden_states: torch.Tensor,
303
+ attention_mask: Optional[torch.Tensor] = None,
304
+ position_ids: Optional[torch.LongTensor] = None,
305
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
306
+ output_attentions: bool = False,
307
+ use_cache: bool = False,
308
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
309
+ bsz, q_len, _ = hidden_states.size()
310
+
311
+ if self.pretraining_tp > 1:
312
+ key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.pretraining_tp
313
+ query_slices = self.q_proj.weight.split((self.num_heads * self.head_dim) // self.pretraining_tp, dim=0)
314
+ key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
315
+ value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
316
+
317
+ query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.pretraining_tp)]
318
+ query_states = torch.cat(query_states, dim=-1)
319
+
320
+ key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.pretraining_tp)]
321
+ key_states = torch.cat(key_states, dim=-1)
322
+
323
+ value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.pretraining_tp)]
324
+ value_states = torch.cat(value_states, dim=-1)
325
+
326
+ else:
327
+ query_states = self.q_proj(hidden_states)
328
+ key_states = self.k_proj(hidden_states)
329
+ value_states = self.v_proj(hidden_states)
330
+
331
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
332
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
333
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
334
+
335
+ kv_seq_len = key_states.shape[-2]
336
+ if past_key_value is not None:
337
+ kv_seq_len += past_key_value[0].shape[-2]
338
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
339
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
340
+
341
+ if past_key_value is not None:
342
+ # reuse k, v, self_attention
343
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
344
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
345
+
346
+ past_key_value = (key_states, value_states) if use_cache else None
347
+
348
+ # repeat k/v heads if n_kv_heads < n_heads
349
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
350
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
351
+
352
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
353
+
354
+ if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
355
+ raise ValueError(
356
+ f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
357
+ f" {attn_weights.size()}"
358
+ )
359
+
360
+ if attention_mask is not None:
361
+ if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
362
+ raise ValueError(
363
+ f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
364
+ )
365
+ attn_weights = attn_weights + attention_mask
366
+
367
+ # upcast attention to fp32
368
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
369
+ attn_output = torch.matmul(attn_weights, value_states)
370
+
371
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
372
+ raise ValueError(
373
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
374
+ f" {attn_output.size()}"
375
+ )
376
+
377
+ attn_output = attn_output.transpose(1, 2).contiguous()
378
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
379
+
380
+ if self.pretraining_tp > 1:
381
+ attn_output = attn_output.split(self.hidden_size // self.pretraining_tp, dim=2)
382
+ o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.pretraining_tp, dim=1)
383
+ attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.pretraining_tp)])
384
+ else:
385
+ attn_output = self.o_proj(attn_output)
386
+
387
+ if not output_attentions:
388
+ attn_weights = None
389
+
390
+ return attn_output, attn_weights, past_key_value
391
+
392
+
393
+ class TopKBalancedNoisyGate(nn.Module):
394
+ def __init__(
395
+ self,
396
+ input_size,
397
+ num_experts,
398
+ num_selects,
399
+ gate_network="mlp",
400
+ use_softmax=True,
401
+ use_balance=True,
402
+ balance_loss_weight=1e-2,
403
+ add_noise=True,
404
+ noise_epsilon=1e-2,
405
+ ):
406
+ super(TopKBalancedNoisyGate, self).__init__()
407
+ assert num_selects <= num_experts
408
+ self.input_size = input_size
409
+ self.num_experts = num_experts
410
+ self.num_selects = num_selects
411
+
412
+ self.gate_network_type = gate_network
413
+ self.gate_network = self.get_gate_network(gate_network, input_size, num_experts)
414
+
415
+ self.use_softmax = use_softmax
416
+ self.softmax = nn.Softmax(1)
417
+
418
+ self.use_balance = use_balance
419
+ self.balance_loss_weight = balance_loss_weight
420
+
421
+ # add_noise
422
+ self.add_noise = add_noise
423
+ self.noise_epsilon = noise_epsilon
424
+ self.warned = False
425
+ if self.add_noise:
426
+ self.weight_noise = nn.Linear(input_size, num_experts, bias=False)
427
+ self.weight_noise.weight.data = torch.zeros(
428
+ (num_experts, input_size),
429
+ requires_grad=True,
430
+ device=self.weight_noise.weight.data.device,
431
+ dtype=self.weight_noise.weight.data.dtype,
432
+ )
433
+ self.mean = 0.0
434
+ self.std = 1.0
435
+ self.normal = Normal(self.mean, self.std)
436
+ self.softplus = nn.Softplus()
437
+
438
+ self.reset_parameters()
439
+
440
+ def get_gate_network(self, gate_type, input_size, num_experts):
441
+ gate_type = gate_type.lower()
442
+
443
+ if gate_type == "linear":
444
+ gate_network = nn.Linear(input_size, num_experts, bias=False)
445
+ nn.init.zeros_(gate_network.weight)
446
+ elif gate_type == "mlp":
447
+ gate_network = torch.nn.Sequential(
448
+ torch.nn.Linear(input_size, num_experts, bias=False),
449
+ torch.nn.Tanh(),
450
+ torch.nn.Linear(num_experts, num_experts, bias=False),
451
+ )
452
+ else:
453
+ raise ValueError(f'Unexpected gate_type: {gate_type}.')
454
+
455
+ return gate_network
456
+
457
+ def reset_gate_network(self):
458
+ if "gate_network_type" not in vars(self):
459
+ raise KeyError(f"{type(self)} does not have a gate network.")
460
+ else:
461
+ self.gate_network = self.get_gate_network(
462
+ self.gate_network_type, self.input_size, self.num_experts
463
+ )
464
+
465
+ def reset_parameters(self):
466
+ if self.add_noise:
467
+ nn.init.zeros_(self.weight_noise.weight)
468
+ # nn.init.zeros_(self.weight_noise)
469
+
470
+ def cv_squared(self, x, eps=1e-10):
471
+ """The squared coefficient of variation of a sample.
472
+ Useful as a loss to encourage a positive distribution to be more uniform.
473
+ Epsilons added for numerical stability.
474
+ Returns 0 for an empty Tensor.
475
+ Args:
476
+ x: a `Tensor`.
477
+ Returns:
478
+ a `Scalar`.s
479
+ """
480
+ if x.shape[0] == 1:
481
+ return torch.tensor(0.0, device=x.device)
482
+ return x.float().var() / (x.float().mean() ** 2 + eps)
483
+
484
+ def forward(self, x):
485
+ logits_gate = self.gate_network(x)
486
+ if self.training and self.add_noise:
487
+ noise_mm = self.weight_noise(x)
488
+ noise_control = self.softplus(noise_mm) + self.noise_epsilon
489
+ logits_noise = torch.randn_like(logits_gate) * noise_control
490
+ logits = logits_gate + logits_noise
491
+ else:
492
+ logits = logits_gate
493
+
494
+ top_logits, top_indices = logits.topk(min(self.num_selects + 1, self.num_experts), dim=1) # 选择并排序前k+1个权重
495
+ top_k_logits = top_logits[:, :self.num_selects]
496
+ top_k_indices = top_indices[:, :self.num_selects]
497
+ top_k_scores = self.softmax(top_k_logits.to(torch.float32)) if self.use_softmax else top_k_logits
498
+ top_k_scores = top_k_scores.to(logits.dtype)
499
+
500
+ zeros = torch.zeros_like(logits, requires_grad=True, device=logits.device)
501
+ scores_filtered = zeros.scatter(dim=1, index=top_k_indices, src=top_k_scores) # shape(batch_size, num_experts)
502
+ importance = scores_filtered.sum(0) # shape(num_experts)
503
+
504
+ if self.training:
505
+ if self.add_noise and self.num_selects != self.num_experts:
506
+ batch_size = top_logits.size(0)
507
+ m = top_logits.size(1)
508
+ top_values_flat = top_logits.flatten()
509
+ threshold_positions_if_in = torch.arange(batch_size, device=x.device) * m + self.num_selects
510
+ threshold_if_in = torch.unsqueeze(torch.gather(top_values_flat, 0, threshold_positions_if_in), 1)
511
+ is_in = torch.gt(logits_noise, threshold_if_in)
512
+ threshold_positions_if_out = threshold_positions_if_in - 1
513
+ threshold_if_out = torch.unsqueeze(torch.gather(top_values_flat, 0, threshold_positions_if_out), 1)
514
+ # is each value currently in the top k.
515
+ prob_if_in = self.normal.cdf((logits_gate - threshold_if_in) / noise_control)
516
+ prob_if_out = self.normal.cdf((logits_gate - threshold_if_out) / noise_control)
517
+ prob = torch.where(is_in, prob_if_in, prob_if_out)
518
+ load = prob.sum(0)
519
+ else:
520
+ load = (scores_filtered > 0).sum(0)
521
+ if not self.add_noise and not self.warned:
522
+ warnings.warn('Gradient-trackable implementation for load calculation is only available when "add_noise=True". '
523
+ 'Training without noise will block the gradient from "load" path and lead to inconsistency in optimization objectives.')
524
+ self.warned = True
525
+ else:
526
+ load = (scores_filtered > 0).sum(0)
527
+
528
+ if self.use_balance:
529
+ balance_loss = self.cv_squared(importance) + self.cv_squared(load)
530
+ balance_loss *= self.balance_loss_weight
531
+ else:
532
+ balance_loss = torch.tensor(-100.0, device=x.device)
533
+
534
+ return {
535
+ "topK_indices": top_k_indices,
536
+ "topK_scores": top_k_scores,
537
+ "balance_loss": balance_loss,
538
+ "load": load,
539
+ "importance": importance,
540
+ }
541
+
542
+
543
+ class LinearGLUExperts(nn.Module):
544
+ """
545
+ Modified from transformers.models.llama.modeling_llama.LlamaMLP
546
+ """
547
+
548
+ __constants__ = [
549
+ "bias",
550
+ "in_features",
551
+ "hidden_features",
552
+ "out_features",
553
+ "hidden_act",
554
+ "num_experts",
555
+ "size_experts",
556
+ ]
557
+
558
+ def __init__(
559
+ self,
560
+ in_features,
561
+ hidden_features,
562
+ out_features,
563
+ hidden_act,
564
+ num_experts,
565
+ size_experts=None,
566
+ bias=True,
567
+ device=None,
568
+ dtype=None,
569
+ ):
570
+ factory_kwargs = {"device": device, "dtype": dtype}
571
+ super(LinearGLUExperts, self).__init__()
572
+ self.in_features = in_features
573
+ self.hidden_features = hidden_features
574
+ self.out_features = out_features
575
+ self.hidden_act = hidden_act
576
+ self.num_experts = num_experts
577
+
578
+ if size_experts is None:
579
+ # all experts share the same number of hidden neurons
580
+ assert hidden_features % num_experts == 0
581
+ size_per_expert = hidden_features // num_experts
582
+ size_experts = [size_per_expert for _ in range(num_experts)]
583
+ else:
584
+ # use specified expert sizes
585
+ assert (
586
+ len(size_experts) == num_experts
587
+ and sum(size_experts) == hidden_features
588
+ )
589
+ self.size_experts = size_experts
590
+
591
+ self.act_fn = ACT2FN[hidden_act]
592
+
593
+ self.weight_gate = nn.ParameterList()
594
+ self.weight_up = nn.ParameterList()
595
+ self.weight_down = nn.ParameterList()
596
+
597
+ for i in range(num_experts):
598
+ # this matrix will be transposed when performing linear forwarding
599
+ this_expert_weight_gate = nn.Parameter(
600
+ torch.empty((size_experts[i], in_features), **factory_kwargs)
601
+ )
602
+ # this matrix will be transposed when performing linear forwarding
603
+ this_expert_weight_up = nn.Parameter(
604
+ torch.empty((size_experts[i], in_features), **factory_kwargs)
605
+ )
606
+ # this matrix will be transposed when performing linear forwarding
607
+ this_expert_weight_down = nn.Parameter(
608
+ torch.empty((out_features, size_experts[i]), **factory_kwargs)
609
+ )
610
+ self.weight_gate.append(this_expert_weight_gate)
611
+ self.weight_up.append(this_expert_weight_up)
612
+ self.weight_down.append(this_expert_weight_down)
613
+
614
+ if bias:
615
+ self.bias_gate = nn.ParameterList()
616
+ self.bias_up = nn.ParameterList()
617
+ self.bias_down = nn.ParameterList()
618
+
619
+ for i in range(num_experts):
620
+ this_expert_bias_gate = nn.Parameter(
621
+ torch.empty((size_experts[i],), **factory_kwargs)
622
+ )
623
+ this_expert_bias_up = nn.Parameter(
624
+ torch.empty((size_experts[i],), **factory_kwargs)
625
+ )
626
+ this_expert_bias_down = nn.Parameter(
627
+ torch.empty((out_features,), **factory_kwargs)
628
+ )
629
+ self.bias_gate.append(this_expert_bias_gate)
630
+ self.bias_up.append(this_expert_bias_up)
631
+ self.bias_down.append(this_expert_bias_down)
632
+ else:
633
+ self.register_parameter("bias_gate", None)
634
+ self.register_parameter("bias_up", None)
635
+ self.register_parameter("bias_down", None)
636
+
637
+ self.reset_parameters()
638
+
639
+ def reset_parameters(self):
640
+ for i in range(self.num_experts):
641
+ nn.init.kaiming_uniform_(self.weight_gate[i], a=math.sqrt(5))
642
+ nn.init.kaiming_uniform_(self.weight_up[i], a=math.sqrt(5))
643
+ nn.init.kaiming_uniform_(self.weight_down[i], a=math.sqrt(5))
644
+ if self.bias_gate is not None:
645
+ fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight_gate[i])
646
+ bound = 1 / math.sqrt(fan_in)
647
+ nn.init.uniform_(self.bias_gate[i], -bound, bound)
648
+ if self.bias_up is not None:
649
+ fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight_up[i])
650
+ bound = 1 / math.sqrt(fan_in)
651
+ nn.init.uniform_(self.bias_up[i], -bound, bound)
652
+ if self.bias_down is not None:
653
+ fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight_down[i])
654
+ bound = 1 / math.sqrt(fan_in)
655
+ nn.init.uniform_(self.bias_down[i], -bound, bound)
656
+
657
+ def forward(self, input, i):
658
+ gate = self.act_fn(
659
+ F.linear(
660
+ input,
661
+ self.weight_gate[i],
662
+ self.bias_gate[i] if self.bias_gate is not None else None,
663
+ )
664
+ )
665
+ up = F.linear(
666
+ input,
667
+ self.weight_up[i],
668
+ self.bias_up[i] if self.bias_up is not None else None,
669
+ )
670
+ down = F.linear(
671
+ gate * up,
672
+ self.weight_down[i],
673
+ self.bias_down[i] if self.bias_down is not None else None,
674
+ )
675
+ return down
676
+
677
+ def extra_repr(self):
678
+ return (
679
+ "in_features={}, hidden_features={}, out_features={}, hidden_act={},"
680
+ " num_experts={}, size_experts={}, bias={}".format(
681
+ self.in_features,
682
+ self.hidden_features,
683
+ self.out_features,
684
+ self.hidden_act,
685
+ self.num_experts,
686
+ self.size_experts,
687
+ self.bias_gate is not None,
688
+ )
689
+ )
690
+
691
+
692
+ class UniversalCalculator(nn.Module):
693
+ def __init__(
694
+ self,
695
+ experts: LinearGLUExperts,
696
+ multiply_gate_scores=True,
697
+ score_scale_factor=1.0,
698
+ add_weight_norm: bool = False,
699
+ ):
700
+ super(UniversalCalculator, self).__init__()
701
+ self.experts = experts
702
+ # TODO (zhutong): use vmap to boost the training efficiency
703
+ # self.experts_vmap = torch.vmap(self.experts)
704
+ self.multiply_gate_scores = multiply_gate_scores
705
+ self.score_scale_factor = score_scale_factor
706
+ self.num_experts = experts.num_experts
707
+ self.mlp_norm = None
708
+ if multiply_gate_scores and add_weight_norm:
709
+ raise NotImplementedError
710
+
711
+ def reset_experts(self):
712
+ self.experts.reset_parameters()
713
+
714
+ def forward(
715
+ self, x, topK_indices, topK_scores, expert_batch_size=None, **kwargs
716
+ ) -> CalculatorOutput:
717
+ batch_size = topK_indices.size(0) # topK_indices: (bsz*seq_len, num_selects)
718
+ num_selects = topK_indices.size(1)
719
+ topK_indices = topK_indices.flatten() # shape(batch_size*num_selects)
720
+ topK_scores = topK_scores.flatten() # shape(batch_size*num_selects)
721
+ batch_indices = torch.arange(
722
+ batch_size, device=topK_scores.device
723
+ ).repeat_interleave(num_selects)
724
+
725
+ _, index_sorted_topK_indices = topK_indices.sort(0)
726
+
727
+ sorted_topK_scores = topK_scores.index_select(0, index_sorted_topK_indices)
728
+ sorted_batch_indices = batch_indices.index_select(0, index_sorted_topK_indices)
729
+
730
+ if expert_batch_size is None:
731
+ expert_batch_size = topK_indices.bincount(
732
+ minlength=self.num_experts
733
+ ).tolist()
734
+
735
+ sorted_x = x.index_select(0, sorted_batch_indices)
736
+ split_x = torch.split(sorted_x, expert_batch_size, dim=0)
737
+
738
+ expert_outputs = [
739
+ self.experts(split_x[i], i)
740
+ for i in range(self.num_experts)
741
+ if split_x[i].shape[0] > 0
742
+ ]
743
+
744
+ # (bsz*seq_len*num_selects, hidden_size)
745
+ cat_expert_outputs = torch.cat(expert_outputs, 0)
746
+ output_dim = cat_expert_outputs.size(1)
747
+ if self.multiply_gate_scores:
748
+ if self.mlp_norm is None:
749
+ cat_expert_outputs = torch.mul(
750
+ cat_expert_outputs,
751
+ sorted_topK_scores.reshape(-1, 1) * self.score_scale_factor,
752
+ )
753
+ # cat_expert_outputs = torch.mul(cat_expert_outputs, sorted_topK_scores.reshape(-1, 1) * 1.0)
754
+ else:
755
+ cat_expert_outputs = torch.mul(
756
+ cat_expert_outputs, sorted_topK_scores.reshape(-1, 1)
757
+ )
758
+ cat_expert_outputs = self.mlp_norm(cat_expert_outputs)
759
+
760
+ zeros = torch.zeros(
761
+ (batch_size, output_dim),
762
+ device=cat_expert_outputs.device,
763
+ dtype=cat_expert_outputs.dtype,
764
+ )
765
+ y = zeros.index_add(0, sorted_batch_indices, cat_expert_outputs)
766
+
767
+ return CalculatorOutput(hidden_states=y, num_dropped_tokens=torch.tensor(-1.0))
768
+
769
+
770
+ class BaseMoELayer(nn.Module):
771
+ def __init__(self):
772
+ super(BaseMoELayer, self).__init__()
773
+
774
+ self.gate: TopKBalancedNoisyGate
775
+ self.calculator: UniversalCalculator
776
+
777
+ def _create_gate(self, **kwargs):
778
+ self.gate_type = kwargs.get("gate_type", "TopKBalancedNoisyGate")
779
+
780
+ if self.gate_type == "TopKBalancedNoisyGate": # noisy gate
781
+ self.gate = TopKBalancedNoisyGate(
782
+ self.input_size,
783
+ self.num_experts,
784
+ self.num_selects,
785
+ gate_network=kwargs.get("gate_network", "mlp"),
786
+ use_softmax=kwargs.get("gate_use_softmax", True),
787
+ use_balance=kwargs.get("gate_use_balance", True),
788
+ balance_loss_weight=kwargs.get("gate_balance_loss_weight", 1e-2),
789
+ add_noise=kwargs.get("gate_add_noise", True),
790
+ noise_epsilon=kwargs.get("gate_noise_epsilon", 1e-2),
791
+ )
792
+ else:
793
+ raise NotImplementedError
794
+
795
+ def _create_calculator(self, experts, **kwargs):
796
+ self.calculator_type = kwargs.get("calculator_type", "UniversalCalculator")
797
+
798
+ if self.calculator_type == "UniversalCalculator": # top K calculator
799
+ self.calculator = UniversalCalculator(
800
+ experts,
801
+ multiply_gate_scores=kwargs.get("multiply_gate_scores", True),
802
+ score_scale_factor=kwargs.get("score_scale_factor", 1.0),
803
+ add_weight_norm=kwargs.get("add_weight_norm", False),
804
+ )
805
+ else:
806
+ raise NotImplementedError
807
+
808
+ def forward(self, x) -> MoEMlpOutput:
809
+ original_shape = x.shape[:-1]
810
+ x = x.reshape(-1, self.input_size)
811
+ gate_outputs: dict = self.gate(x)
812
+ calc_outs: CalculatorOutput = self.calculator(x, **gate_outputs)
813
+ y = calc_outs.hidden_states
814
+ y = y.reshape(original_shape + (self.output_size,))
815
+
816
+ return MoEMlpOutput(
817
+ hidden_states=y,
818
+ balance_loss=gate_outputs.get("balance_loss"),
819
+ num_dropped_tokens=calc_outs.num_dropped_tokens,
820
+ gate_load=gate_outputs.get("load", torch.tensor(-1)),
821
+ gate_importance=gate_outputs.get("importance", torch.tensor(-1)),
822
+ )
823
+
824
+ def set_num_selects(self, num_selects):
825
+ if "num_selects" not in vars(self.gate):
826
+ raise KeyError(f'{self.gate_type} does not have a key named "num_selects".')
827
+ elif num_selects > self.gate.num_experts:
828
+ raise ValueError(
829
+ 'The value of "num_selects" must satisfy "num_selects <= num_experts"!'
830
+ )
831
+ elif self.gate_type in ("SwitchBalancedGate",):
832
+ raise ValueError(
833
+ f"{self.gate_type} doesn't support manually setting num_selects."
834
+ )
835
+ else:
836
+ self.num_selects = num_selects
837
+ self.gate.num_selects = num_selects
838
+
839
+ def set_gate_use_softmax(self, use_softmax):
840
+ if "use_softmax" not in vars(self.gate):
841
+ raise KeyError(f'{self.gate_type} does not have a key named "use_softmax".')
842
+ else:
843
+ self.gate.use_softmax = use_softmax
844
+
845
+ def set_gate_use_balance(self, use_balance):
846
+ if "use_balance" not in vars(self.gate):
847
+ raise KeyError(f'{self.gate_type} does not have a key named "use_balance".')
848
+ else:
849
+ self.gate.use_balance = use_balance
850
+
851
+ def set_gate_balance_loss_weight(self, balance_loss_weight):
852
+ if "balance_loss_weight" not in vars(self.gate):
853
+ raise KeyError(
854
+ f'{self.gate_type} does not have a key named "balance_loss_weight".'
855
+ )
856
+ else:
857
+ self.gate.balance_loss_weight = balance_loss_weight
858
+
859
+ def set_gate_add_noise(self, add_noise):
860
+ if "add_noise" not in vars(self.gate):
861
+ raise KeyError(f'{self.gate_type} does not have a key named "add_noise".')
862
+ else:
863
+ self.gate.add_noise = add_noise
864
+
865
+ def set_gate_noise_epsilon(self, noise_epsilon):
866
+ if "noise_epsilon" not in vars(self.gate):
867
+ raise KeyError(
868
+ f'{self.gate_type} does not have a key named "noise_epsilon".'
869
+ )
870
+ else:
871
+ self.gate.noise_epsilon = noise_epsilon
872
+
873
+ def set_calculator_multiply_gate_scores(self, multiply_gate_scores):
874
+ if "multiply_gate_scores" not in vars(self.calculator):
875
+ raise KeyError(
876
+ f'{self.gate_type} does not have a key named "multiply_gate_scores".'
877
+ )
878
+ else:
879
+ self.calculator.multiply_gate_scores = multiply_gate_scores
880
+
881
+ def set_calculator_score_scale_factor(self, score_scale_factor):
882
+ if "score_scale_factor" not in vars(self.calculator):
883
+ raise KeyError(
884
+ f'{self.gate_type} does not have a key named "score_scale_factor".'
885
+ )
886
+ else:
887
+ self.calculator.score_scale_factor = score_scale_factor
888
+
889
+ def set_calculator_drop_tokens(self, drop_tokens):
890
+ if "drop_tokens" not in vars(self.calculator):
891
+ raise KeyError(f'{self.gate_type} does not have a key named "drop_tokens".')
892
+ elif (
893
+ drop_tokens
894
+ and self.calculator.dropped_padding != "zero"
895
+ and self.input_size != self.output_size
896
+ ):
897
+ warnings.warn(
898
+ 'Setting "drop_tokens=True" without zero dropped padding when "input_size != output_size" will cause error!'
899
+ )
900
+ else:
901
+ self.calculator.drop_tokens = drop_tokens
902
+
903
+ def set_calculator_dropped_padding(self, dropped_padding):
904
+ if "dropped_padding" not in vars(self.calculator):
905
+ raise KeyError(
906
+ f'{self.gate_type} does not have a key named "dropped_padding".'
907
+ )
908
+ elif dropped_padding not in self.calculator.available_dropped_padding_choices:
909
+ raise ValueError(
910
+ f"'dropped_padding' type not available! (available choices: {self.calculator.available_dropped_padding_choices})"
911
+ )
912
+ elif (
913
+ self.calculator.drop_tokens
914
+ and dropped_padding != "zero"
915
+ and self.input_size != self.output_size
916
+ ):
917
+ warnings.warn(
918
+ f'Setting "dropped_padding={dropped_padding}" with "drop_tokens=True" when "input_size != output_size" will cause error!'
919
+ )
920
+ else:
921
+ self.calculator.dropped_padding = dropped_padding
922
+
923
+ def set_calculator_capacity_factor(self, capacity_factor):
924
+ if "capacity_factor" not in vars(self.calculator):
925
+ raise KeyError(
926
+ f'{self.gate_type} does not have a key named "capacity_factor".'
927
+ )
928
+ else:
929
+ self.calculator.capacity_factor = capacity_factor
930
+
931
+ def reset_gate_network(self):
932
+ self.gate.reset_gate_network()
933
+
934
+ def reset_experts(self):
935
+ self.calculator.reset_experts()
936
+
937
+
938
+ class LinearGLUMoELayer(BaseMoELayer):
939
+ def __init__(
940
+ self,
941
+ input_size,
942
+ hidden_size,
943
+ output_size,
944
+ hidden_act,
945
+ num_experts,
946
+ num_selects,
947
+ size_experts=None,
948
+ bias=True,
949
+ **kwargs,
950
+ ):
951
+ super(LinearGLUMoELayer, self).__init__()
952
+ assert num_selects <= num_experts
953
+ self.input_size = input_size
954
+ self.hidden_size = hidden_size
955
+ self.output_size = output_size
956
+ self.hidden_act = hidden_act
957
+ self.num_experts = num_experts
958
+ self.num_selects = num_selects
959
+ self.size_experts = size_experts
960
+ self.bias = bias
961
+
962
+ experts = LinearGLUExperts(
963
+ input_size,
964
+ hidden_size,
965
+ output_size,
966
+ hidden_act,
967
+ num_experts,
968
+ size_experts=size_experts,
969
+ bias=bias,
970
+ )
971
+
972
+ self._create_gate(**kwargs)
973
+ self._create_calculator(experts, **kwargs)
974
+
975
+
976
+ class LlamaMoEDecoderLayer(nn.Module):
977
+ def __init__(self, config: LlamaMoEConfig, layer_index):
978
+ super().__init__()
979
+
980
+ self.hidden_size = config.hidden_size
981
+ self.self_attn = LlamaAttention(config=config)
982
+ self.mlp = LlamaMLP(config)
983
+ self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
984
+ self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
985
+
986
+ gating_config = {
987
+ # all gates
988
+ "gate_type": config.gate_type,
989
+ "gate_network": config.gate_network,
990
+ "gate_use_softmax": config.gate_use_softmax,
991
+ "gate_use_balance": config.gate_use_balance,
992
+ "gate_balance_loss_weight": config.gate_balance_loss_weight,
993
+ "gate_add_noise": config.gate_add_noise,
994
+ # TopKBalancedNoisyGate
995
+ "gate_noise_epsilon": config.gate_noise_epsilon,
996
+ }
997
+ calculator_config = {
998
+ # all calculators
999
+ "calculator_type": config.calculator_type,
1000
+ "multiply_gate_scores": config.multiply_gate_scores,
1001
+ "score_scale_factor": (
1002
+ config.score_scale_factor[layer_index]
1003
+ if isinstance(config.score_scale_factor, list)
1004
+ else config.score_scale_factor
1005
+ ),
1006
+ "add_weight_norm": config.add_weight_norm,
1007
+ # SwitchDropTokenCalculator
1008
+ "drop_tokens": config.drop_tokens,
1009
+ "dropped_padding": config.dropped_padding,
1010
+ "capacity_factor": config.capacity_factor,
1011
+ }
1012
+
1013
+ self.mlp = LinearGLUMoELayer(
1014
+ input_size=self.hidden_size,
1015
+ hidden_size=config.intermediate_size,
1016
+ output_size=self.hidden_size,
1017
+ hidden_act=config.hidden_act,
1018
+ num_experts=config.num_experts,
1019
+ num_selects=config.num_selects,
1020
+ size_experts=(
1021
+ config.size_experts[layer_index]
1022
+ if config.size_experts is not None
1023
+ else None
1024
+ ),
1025
+ bias=False,
1026
+ **gating_config,
1027
+ **calculator_config,
1028
+ )
1029
+
1030
+ def forward(
1031
+ self,
1032
+ hidden_states,
1033
+ attention_mask=None,
1034
+ position_ids=None,
1035
+ past_key_value=None,
1036
+ output_attentions=False,
1037
+ use_cache=False,
1038
+ ) -> tuple:
1039
+ residual = hidden_states
1040
+ hidden_states = self.input_layernorm(hidden_states)
1041
+
1042
+ # Self Attention
1043
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
1044
+ hidden_states=hidden_states,
1045
+ attention_mask=attention_mask,
1046
+ position_ids=position_ids,
1047
+ past_key_value=past_key_value,
1048
+ output_attentions=output_attentions,
1049
+ use_cache=use_cache,
1050
+ )
1051
+ hidden_states = residual + hidden_states
1052
+
1053
+ # Fully Connected
1054
+ residual = hidden_states
1055
+ hidden_states = self.post_attention_layernorm(hidden_states)
1056
+ mlp_outs: MoEMlpOutput = self.mlp(hidden_states)
1057
+ hidden_states = residual + mlp_outs.hidden_states
1058
+
1059
+ outputs = (
1060
+ hidden_states,
1061
+ mlp_outs.balance_loss,
1062
+ mlp_outs.num_dropped_tokens,
1063
+ mlp_outs.gate_load,
1064
+ mlp_outs.gate_importance,
1065
+ )
1066
+ if output_attentions:
1067
+ outputs += (self_attn_weights,)
1068
+ if use_cache:
1069
+ outputs += (present_key_value,)
1070
+
1071
+ return outputs
1072
+
1073
+ def set_moe_num_selects(self, num_selects):
1074
+ self.mlp.set_num_selects(num_selects)
1075
+
1076
+ def set_moe_gate_use_softmax(self, use_softmax):
1077
+ self.mlp.set_gate_use_softmax(use_softmax)
1078
+
1079
+ def set_moe_gate_use_balance(self, use_balance):
1080
+ self.mlp.set_gate_use_balance(use_balance)
1081
+
1082
+ def set_moe_gate_balance_loss_weight(self, balance_loss_weight):
1083
+ self.mlp.set_gate_balance_loss_weight(balance_loss_weight)
1084
+
1085
+ def set_moe_gate_add_noise(self, add_noise):
1086
+ self.mlp.set_gate_add_noise(add_noise)
1087
+
1088
+ def set_moe_gate_noise_epsilon(self, noise_epsilon):
1089
+ self.mlp.set_gate_noise_epsilon(noise_epsilon)
1090
+
1091
+ def set_moe_calculator_multiply_gate_scores(self, multiply_gate_scores):
1092
+ self.mlp.set_calculator_multiply_gate_scores(multiply_gate_scores)
1093
+
1094
+ def set_moe_calculator_score_scale_factor(self, score_scale_factor):
1095
+ self.mlp.set_calculator_score_scale_factor(score_scale_factor)
1096
+
1097
+ def set_moe_calculator_drop_tokens(self, drop_tokens):
1098
+ self.mlp.set_calculator_drop_tokens(drop_tokens)
1099
+
1100
+ def set_moe_calculator_dropped_padding(self, dropped_padding):
1101
+ self.mlp.set_calculator_dropped_padding(dropped_padding)
1102
+
1103
+ def set_moe_calculator_capacity_factor(self, capacity_factor):
1104
+ self.mlp.set_calculator_capacity_factor(capacity_factor)
1105
+
1106
+ def reset_gate_network(self):
1107
+ self.mlp.reset_gate_network()
1108
+
1109
+ def reset_experts(self):
1110
+ self.mlp.reset_experts()
1111
+
1112
+
1113
+ class LlamaMoEPreTrainedModel(PreTrainedModel):
1114
+ config_class = LlamaMoEConfig
1115
+ base_model_prefix = "model"
1116
+ supports_gradient_checkpointing = True
1117
+ _no_split_modules = ["LlamaMoEDecoderLayer"]
1118
+ _skip_keys_device_placement = "past_key_values"
1119
+
1120
+ def _init_weights(self, module):
1121
+ std = self.config.initializer_range
1122
+ if isinstance(module, nn.Linear):
1123
+ module.weight.data.normal_(mean=0.0, std=std)
1124
+ if module.bias is not None:
1125
+ module.bias.data.zero_()
1126
+ elif isinstance(module, nn.Embedding):
1127
+ module.weight.data.normal_(mean=0.0, std=std)
1128
+ if module.padding_idx is not None:
1129
+ module.weight.data[module.padding_idx].zero_()
1130
+
1131
+ def _set_gradient_checkpointing(self, module, value=False):
1132
+ if isinstance(module, LlamaMoEModel):
1133
+ module.gradient_checkpointing = value
1134
+
1135
+
1136
+ class LlamaMoEModel(LlamaMoEPreTrainedModel):
1137
+ def __init__(self, config: LlamaMoEConfig):
1138
+ super().__init__(config)
1139
+ self.padding_idx = config.pad_token_id
1140
+ self.vocab_size = config.vocab_size
1141
+
1142
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
1143
+ self.layers = nn.ModuleList(
1144
+ [LlamaMoEDecoderLayer(config, i) for i in range(config.num_hidden_layers)]
1145
+ )
1146
+ self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
1147
+ self.gradient_checkpointing = False
1148
+ self.post_init()
1149
+
1150
+ def get_input_embeddings(self):
1151
+ return self.embed_tokens
1152
+
1153
+ def set_input_embeddings(self, value):
1154
+ self.embed_tokens = value
1155
+
1156
+ # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask
1157
+ def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
1158
+ # create causal mask
1159
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
1160
+ combined_attention_mask = None
1161
+ if input_shape[-1] > 1:
1162
+ combined_attention_mask = _make_causal_mask(
1163
+ input_shape,
1164
+ inputs_embeds.dtype,
1165
+ device=inputs_embeds.device,
1166
+ past_key_values_length=past_key_values_length,
1167
+ )
1168
+
1169
+ if attention_mask is not None:
1170
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
1171
+ expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(
1172
+ inputs_embeds.device
1173
+ )
1174
+ combined_attention_mask = (
1175
+ expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
1176
+ )
1177
+
1178
+ return combined_attention_mask
1179
+
1180
+ def forward(
1181
+ self,
1182
+ input_ids=None,
1183
+ attention_mask=None,
1184
+ position_ids=None,
1185
+ past_key_values=None,
1186
+ inputs_embeds=None,
1187
+ use_cache=None,
1188
+ output_attentions=None,
1189
+ output_hidden_states=None,
1190
+ return_dict=None,
1191
+ ):
1192
+ output_attentions = (
1193
+ output_attentions
1194
+ if output_attentions is not None
1195
+ else self.config.output_attentions
1196
+ )
1197
+ output_hidden_states = (
1198
+ output_hidden_states
1199
+ if output_hidden_states is not None
1200
+ else self.config.output_hidden_states
1201
+ )
1202
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
1203
+
1204
+ return_dict = (
1205
+ return_dict if return_dict is not None else self.config.use_return_dict
1206
+ )
1207
+
1208
+ # retrieve input_ids and inputs_embeds
1209
+ if input_ids is not None and inputs_embeds is not None:
1210
+ raise ValueError(
1211
+ "You cannot specify both decoder_input_ids and decoder_inputs_embeds at"
1212
+ " the same time"
1213
+ )
1214
+ elif input_ids is not None:
1215
+ batch_size, seq_length = input_ids.shape
1216
+ elif inputs_embeds is not None:
1217
+ batch_size, seq_length, _ = inputs_embeds.shape
1218
+ else:
1219
+ raise ValueError(
1220
+ "You have to specify either decoder_input_ids or decoder_inputs_embeds"
1221
+ )
1222
+
1223
+ seq_length_with_past = seq_length
1224
+ past_key_values_length = 0
1225
+
1226
+ if past_key_values is not None:
1227
+ past_key_values_length = past_key_values[0][0].shape[2]
1228
+ seq_length_with_past = seq_length_with_past + past_key_values_length
1229
+
1230
+ if position_ids is None:
1231
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
1232
+ position_ids = torch.arange(
1233
+ past_key_values_length,
1234
+ seq_length + past_key_values_length,
1235
+ dtype=torch.long,
1236
+ device=device,
1237
+ )
1238
+ position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
1239
+ else:
1240
+ position_ids = position_ids.view(-1, seq_length).long()
1241
+
1242
+ if inputs_embeds is None:
1243
+ inputs_embeds = self.embed_tokens(input_ids)
1244
+ # embed positions
1245
+ if attention_mask is None:
1246
+ attention_mask = torch.ones(
1247
+ (batch_size, seq_length_with_past),
1248
+ dtype=torch.bool,
1249
+ device=inputs_embeds.device,
1250
+ )
1251
+ attention_mask = self._prepare_decoder_attention_mask(
1252
+ attention_mask,
1253
+ (batch_size, seq_length),
1254
+ inputs_embeds,
1255
+ past_key_values_length,
1256
+ )
1257
+
1258
+ hidden_states = inputs_embeds
1259
+ balance_loss = 0.0
1260
+
1261
+ if self.gradient_checkpointing and self.training:
1262
+ if use_cache:
1263
+ logger.warning_once(
1264
+ "`use_cache=True` is incompatible with gradient checkpointing."
1265
+ " Setting `use_cache=False`..."
1266
+ )
1267
+ use_cache = False
1268
+
1269
+ # decoder layers
1270
+ all_hidden_states = () if output_hidden_states else None
1271
+ all_self_attns = () if output_attentions else None
1272
+ next_decoder_cache = () if use_cache else None
1273
+
1274
+ num_dropped_tokens = ()
1275
+ gate_load = ()
1276
+ gate_importance = ()
1277
+ for idx, decoder_layer in enumerate(self.layers):
1278
+ if output_hidden_states:
1279
+ all_hidden_states += (hidden_states,)
1280
+
1281
+ past_key_value = (
1282
+ past_key_values[idx] if past_key_values is not None else None
1283
+ )
1284
+
1285
+ if self.gradient_checkpointing and self.training:
1286
+
1287
+ def create_custom_forward(module):
1288
+ def custom_forward(*inputs):
1289
+ # None for past_key_value
1290
+ return module(*inputs, output_attentions, None)
1291
+
1292
+ return custom_forward
1293
+
1294
+ layer_outputs: tuple = torch.utils.checkpoint.checkpoint(
1295
+ create_custom_forward(decoder_layer),
1296
+ hidden_states,
1297
+ attention_mask,
1298
+ position_ids,
1299
+ None,
1300
+ )
1301
+ else:
1302
+ layer_outputs: tuple = decoder_layer(
1303
+ hidden_states,
1304
+ attention_mask=attention_mask,
1305
+ position_ids=position_ids,
1306
+ past_key_value=past_key_value,
1307
+ output_attentions=output_attentions,
1308
+ use_cache=use_cache,
1309
+ )
1310
+
1311
+ hidden_states = layer_outputs[0]
1312
+ if layer_outputs[1] is not None:
1313
+ balance_loss += layer_outputs[1]
1314
+
1315
+ if use_cache:
1316
+ next_decoder_cache += (layer_outputs[6 if output_attentions else 5],)
1317
+
1318
+ if output_attentions:
1319
+ all_self_attns += (layer_outputs[5],)
1320
+
1321
+ num_dropped_tokens += (layer_outputs[2],)
1322
+ gate_load += (layer_outputs[3],)
1323
+ gate_importance += (layer_outputs[4],)
1324
+
1325
+ hidden_states = self.norm(hidden_states)
1326
+
1327
+ # add hidden states from the last decoder layer
1328
+ if output_hidden_states:
1329
+ all_hidden_states += (hidden_states,)
1330
+
1331
+ next_cache = next_decoder_cache if use_cache else None
1332
+ if not return_dict:
1333
+ return tuple(
1334
+ v
1335
+ for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]
1336
+ if v is not None
1337
+ )
1338
+ return BaseMoEModelOutputWithPast(
1339
+ last_hidden_state=hidden_states,
1340
+ balance_loss=balance_loss,
1341
+ past_key_values=next_cache,
1342
+ hidden_states=all_hidden_states,
1343
+ attentions=all_self_attns,
1344
+ num_dropped_tokens=num_dropped_tokens,
1345
+ gate_load=gate_load,
1346
+ gate_importance=gate_importance,
1347
+ )
1348
+
1349
+ def update_config(self):
1350
+ self.config.vocab_size = self.config.vocab_size
1351
+ self.config.max_position_embeddings = self.config.max_position_embeddings
1352
+ # ↓↓↓↓↓↓↓↓↓↓↓↓ changed here ↓↓↓↓↓↓↓↓↓↓↓↓ #
1353
+ self.config.hidden_size = self.layers[0].mlp.input_size
1354
+ self.config.intermediate_size = self.layers[0].mlp.hidden_size
1355
+ self.config.num_hidden_layers = len(self.layers)
1356
+ self.config.num_attention_heads = self.layers[0].self_attn.num_heads
1357
+ self.config.hidden_act = self.layers[0].mlp.hidden_act
1358
+ # ↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑ #
1359
+ self.config.initializer_range = self.config.initializer_range
1360
+ self.config.rms_norm_eps = self.config.rms_norm_eps
1361
+ self.config.pretraining_tp = self.config.pretraining_tp
1362
+ self.config.use_cache = self.config.use_cache
1363
+ self.config.rope_scaling = self.config.rope_scaling
1364
+ self.config._rope_scaling_validation()
1365
+
1366
+ self.config.num_experts = self.layers[0].mlp.num_experts
1367
+ self.config.num_selects = self.layers[0].mlp.num_selects
1368
+ self.config.size_experts = [
1369
+ self.layers[i].mlp.calculator.experts.size_experts
1370
+ for i in range(self.config.num_hidden_layers)
1371
+ ]
1372
+
1373
+ self.config.gate_type = vars(self.layers[0].mlp).get(
1374
+ "gate_type", "TopKBalancedNoisyGate"
1375
+ )
1376
+ self.config.gate_network = vars(self.layers[0].mlp.gate).get(
1377
+ "gate_network_type", "mlp"
1378
+ )
1379
+ self.config.gate_use_softmax = vars(self.layers[0].mlp.gate).get(
1380
+ "use_softmax", True
1381
+ )
1382
+ self.config.gate_use_balance = vars(self.layers[0].mlp.gate).get(
1383
+ "use_balance", True
1384
+ )
1385
+ self.config.gate_balance_loss_weight = vars(self.layers[0].mlp.gate).get(
1386
+ "balance_loss_weight", 1e-2
1387
+ )
1388
+ self.config.gate_add_noise = vars(self.layers[0].mlp.gate).get(
1389
+ "add_noise", True
1390
+ )
1391
+ self.config.gate_noise_epsilon = vars(self.layers[0].mlp.gate).get(
1392
+ "noise_epsilon", 1e-2
1393
+ )
1394
+
1395
+ self.config.calculator_type = vars(self.layers[0].mlp).get(
1396
+ "calculator_type", "UniversalCalculator"
1397
+ )
1398
+ self.config.multiply_gate_scores = vars(self.layers[0].mlp.calculator).get(
1399
+ "multiply_gate_scores", True
1400
+ )
1401
+ self.config.score_scale_factor = [
1402
+ vars(self.layers[i].mlp.calculator).get("score_scale_factor", 1.0)
1403
+ for i in range(self.config.num_hidden_layers)
1404
+ ]
1405
+ self.config.drop_tokens = vars(self.layers[0].mlp.calculator).get(
1406
+ "drop_tokens", True
1407
+ )
1408
+ self.config.dropped_padding = vars(self.layers[0].mlp.calculator).get(
1409
+ "dropped_padding", "zero"
1410
+ )
1411
+ self.config.capacity_factor = vars(self.layers[0].mlp.calculator).get(
1412
+ "capacity_factor", 1.25
1413
+ )
1414
+
1415
+ def set_moe_num_selects(self, num_selects):
1416
+ for idx, decoder_layer in enumerate(self.layers):
1417
+ decoder_layer.set_moe_num_selects(num_selects)
1418
+
1419
+ def set_moe_gate_use_softmax(self, use_softmax):
1420
+ for idx, decoder_layer in enumerate(self.layers):
1421
+ decoder_layer.set_moe_gate_use_softmax(use_softmax)
1422
+
1423
+ def set_moe_gate_use_balance(self, use_balance):
1424
+ for idx, decoder_layer in enumerate(self.layers):
1425
+ decoder_layer.set_moe_gate_use_balance(use_balance)
1426
+
1427
+ def set_moe_gate_balance_loss_weight(self, balance_loss_weight):
1428
+ for idx, decoder_layer in enumerate(self.layers):
1429
+ decoder_layer.set_moe_gate_balance_loss_weight(balance_loss_weight)
1430
+
1431
+ def set_moe_gate_add_noise(self, add_noise):
1432
+ for idx, decoder_layer in enumerate(self.layers):
1433
+ decoder_layer.set_moe_gate_add_noise(add_noise)
1434
+
1435
+ def set_moe_gate_noise_epsilon(self, noise_epsilon):
1436
+ for idx, decoder_layer in enumerate(self.layers):
1437
+ decoder_layer.set_moe_gate_noise_epsilon(noise_epsilon)
1438
+
1439
+ def set_moe_calculator_multiply_gate_scores(self, multiply_gate_scores):
1440
+ for idx, decoder_layer in enumerate(self.layers):
1441
+ decoder_layer.set_moe_calculator_multiply_gate_scores(multiply_gate_scores)
1442
+
1443
+ def set_moe_calculator_score_scale_factor(
1444
+ self, score_scale_factor, layer_index=None
1445
+ ):
1446
+ if layer_index is None:
1447
+ for idx, decoder_layer in enumerate(self.layers):
1448
+ decoder_layer.set_moe_calculator_score_scale_factor(score_scale_factor)
1449
+ else:
1450
+ self.layers[layer_index].set_moe_calculator_score_scale_factor(
1451
+ score_scale_factor
1452
+ )
1453
+
1454
+ def set_moe_calculator_drop_tokens(self, drop_tokens):
1455
+ for idx, decoder_layer in enumerate(self.layers):
1456
+ decoder_layer.set_moe_calculator_drop_tokens(drop_tokens)
1457
+
1458
+ def set_moe_calculator_dropped_padding(self, dropped_padding):
1459
+ for idx, decoder_layer in enumerate(self.layers):
1460
+ decoder_layer.set_moe_calculator_dropped_padding(dropped_padding)
1461
+
1462
+ def set_moe_calculator_capacity_factor(self, capacity_factor):
1463
+ for idx, decoder_layer in enumerate(self.layers):
1464
+ decoder_layer.set_moe_calculator_capacity_factor(capacity_factor)
1465
+
1466
+ def reset_gate_network(self):
1467
+ for idx, decoder_layer in enumerate(self.layers):
1468
+ decoder_layer.reset_gate_network()
1469
+
1470
+ def reset_experts(self):
1471
+ for idx, decoder_layer in enumerate(self.layers):
1472
+ decoder_layer.reset_experts()
1473
+
1474
+
1475
+ class LlamaMoEForCausalLM(LlamaMoEPreTrainedModel):
1476
+ _tied_weights_keys = ["lm_head.weight"]
1477
+
1478
+ def __init__(self, config):
1479
+ super().__init__(config)
1480
+ self.model = LlamaMoEModel(config)
1481
+ self.pretraining_tp = config.pretraining_tp
1482
+ self.vocab_size = config.vocab_size
1483
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
1484
+
1485
+ # Initialize weights and apply final processing
1486
+ self.post_init()
1487
+
1488
+ def get_input_embeddings(self):
1489
+ return self.model.embed_tokens
1490
+
1491
+ def set_input_embeddings(self, value):
1492
+ self.model.embed_tokens = value
1493
+
1494
+ def get_output_embeddings(self):
1495
+ return self.lm_head
1496
+
1497
+ def set_output_embeddings(self, new_embeddings):
1498
+ self.lm_head = new_embeddings
1499
+
1500
+ def set_decoder(self, decoder):
1501
+ self.model = decoder
1502
+
1503
+ def get_decoder(self):
1504
+ return self.model
1505
+
1506
+ def forward(
1507
+ self,
1508
+ input_ids=None,
1509
+ attention_mask=None,
1510
+ position_ids=None,
1511
+ past_key_values=None,
1512
+ inputs_embeds=None,
1513
+ labels=None,
1514
+ use_cache=None,
1515
+ output_attentions=None,
1516
+ output_hidden_states=None,
1517
+ return_dict=None,
1518
+ **kwargs,
1519
+ ):
1520
+ output_attentions = (
1521
+ output_attentions
1522
+ if output_attentions is not None
1523
+ else self.config.output_attentions
1524
+ )
1525
+ output_hidden_states = (
1526
+ output_hidden_states
1527
+ if output_hidden_states is not None
1528
+ else self.config.output_hidden_states
1529
+ )
1530
+ return_dict = (
1531
+ return_dict if return_dict is not None else self.config.use_return_dict
1532
+ )
1533
+
1534
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
1535
+ outputs: BaseMoEModelOutputWithPast = self.model(
1536
+ input_ids=input_ids,
1537
+ attention_mask=attention_mask,
1538
+ position_ids=position_ids,
1539
+ past_key_values=past_key_values,
1540
+ inputs_embeds=inputs_embeds,
1541
+ use_cache=use_cache,
1542
+ output_attentions=output_attentions,
1543
+ output_hidden_states=output_hidden_states,
1544
+ return_dict=return_dict,
1545
+ )
1546
+
1547
+ hidden_states = outputs.last_hidden_state
1548
+ logits = self.lm_head(hidden_states)
1549
+
1550
+ loss = None
1551
+ if labels is not None:
1552
+ # Shift so that tokens < n predict n
1553
+ shift_logits = logits[..., :-1, :].contiguous()
1554
+ shift_labels = labels[..., 1:].contiguous()
1555
+ # Flatten the tokens
1556
+ loss_fct = nn.CrossEntropyLoss()
1557
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
1558
+ shift_labels = shift_labels.view(-1)
1559
+ # Enable model parallelism
1560
+ shift_labels = shift_labels.to(shift_logits.device)
1561
+ loss = loss_fct(shift_logits, shift_labels)
1562
+ if outputs.balance_loss is not None and outputs.balance_loss > 0:
1563
+ loss += outputs.balance_loss
1564
+
1565
+ if not return_dict:
1566
+ output = (logits,) + outputs[1:]
1567
+ return (loss,) + output if loss is not None else output
1568
+
1569
+ return MoECausalLMOutputWithPast(
1570
+ loss=loss,
1571
+ logits=logits,
1572
+ past_key_values=outputs.past_key_values,
1573
+ hidden_states=outputs.hidden_states,
1574
+ attentions=outputs.attentions,
1575
+ num_dropped_tokens=outputs.num_dropped_tokens,
1576
+ balance_loss=outputs.balance_loss,
1577
+ gate_load=outputs.gate_load,
1578
+ gate_importance=outputs.gate_importance,
1579
+ )
1580
+
1581
+ def prepare_inputs_for_generation(
1582
+ self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
1583
+ ):
1584
+ if past_key_values:
1585
+ input_ids = input_ids[:, -1:]
1586
+
1587
+ position_ids = kwargs.get("position_ids", None)
1588
+ if attention_mask is not None and position_ids is None:
1589
+ # create position_ids on the fly for batch generation
1590
+ position_ids = attention_mask.long().cumsum(-1) - 1
1591
+ position_ids.masked_fill_(attention_mask == 0, 1)
1592
+ if past_key_values:
1593
+ position_ids = position_ids[:, -1].unsqueeze(-1)
1594
+
1595
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
1596
+ if inputs_embeds is not None and past_key_values is None:
1597
+ model_inputs = {"inputs_embeds": inputs_embeds}
1598
+ else:
1599
+ model_inputs = {"input_ids": input_ids}
1600
+
1601
+ model_inputs.update(
1602
+ {
1603
+ "position_ids": position_ids,
1604
+ "past_key_values": past_key_values,
1605
+ "use_cache": kwargs.get("use_cache"),
1606
+ "attention_mask": attention_mask,
1607
+ }
1608
+ )
1609
+ return model_inputs
1610
+
1611
+ @staticmethod
1612
+ def _reorder_cache(past_key_values, beam_idx):
1613
+ reordered_past = ()
1614
+ for layer_past in past_key_values:
1615
+ reordered_past += (
1616
+ tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
1617
+ )
1618
+ return reordered_past
1619
+
1620
+ def update_config(self):
1621
+ self.model.update_config()
1622
+
1623
+ def set_moe_num_selects(self, num_selects):
1624
+ self.model.set_moe_num_selects(num_selects)
1625
+
1626
+ def set_moe_gate_use_softmax(self, use_softmax):
1627
+ self.model.set_moe_gate_use_softmax(use_softmax)
1628
+
1629
+ def set_moe_gate_use_balance(self, use_balance):
1630
+ self.model.set_moe_gate_use_balance(use_balance)
1631
+
1632
+ def set_moe_gate_balance_loss_weight(self, balance_loss_weight):
1633
+ self.model.set_moe_gate_balance_loss_weight(balance_loss_weight)
1634
+
1635
+ def set_moe_gate_add_noise(self, add_noise):
1636
+ self.model.set_moe_gate_add_noise(add_noise)
1637
+
1638
+ def set_moe_gate_noise_epsilon(self, noise_epsilon):
1639
+ self.model.set_moe_gate_noise_epsilon(noise_epsilon)
1640
+
1641
+ def set_moe_calculator_multiply_gate_scores(self, multiply_gate_scores):
1642
+ self.model.set_moe_calculator_multiply_gate_scores(multiply_gate_scores)
1643
+
1644
+ def set_moe_calculator_score_scale_factor(
1645
+ self, score_scale_factor, layer_index=None
1646
+ ):
1647
+ self.model.set_moe_calculator_score_scale_factor(
1648
+ score_scale_factor, layer_index=layer_index
1649
+ )
1650
+
1651
+ def set_moe_calculator_drop_tokens(self, drop_tokens):
1652
+ self.model.set_moe_calculator_drop_tokens(drop_tokens)
1653
+
1654
+ def set_moe_calculator_dropped_padding(self, dropped_padding):
1655
+ self.model.set_moe_calculator_dropped_padding(dropped_padding)
1656
+
1657
+ def set_moe_calculator_capacity_factor(self, capacity_factor):
1658
+ self.model.set_moe_calculator_capacity_factor(capacity_factor)
1659
+
1660
+ def reset_gate_network(self):
1661
+ self.model.reset_gate_network()
1662
+
1663
+ def reset_experts(self):
1664
+ self.model.reset_experts()
pytorch_model-00001-of-00002.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:41c5350c512c27afeb43152edf5651335e50c98854ff9c8a183a18b78b6d62e1
3
+ size 9983352727
pytorch_model-00002-of-00002.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:50cf1ea4ff5715675495ab9c53a41c4ea31b7509c08e32215f30d3daad9e2af3
3
+ size 3502552065
pytorch_model.bin.index.json ADDED
The diff for this file is too large to render. See raw diff
 
special_tokens_map.json ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": {
3
+ "content": "<s>",
4
+ "lstrip": false,
5
+ "normalized": false,
6
+ "rstrip": false,
7
+ "single_word": false
8
+ },
9
+ "eos_token": {
10
+ "content": "</s>",
11
+ "lstrip": false,
12
+ "normalized": false,
13
+ "rstrip": false,
14
+ "single_word": false
15
+ },
16
+ "unk_token": {
17
+ "content": "<unk>",
18
+ "lstrip": false,
19
+ "normalized": false,
20
+ "rstrip": false,
21
+ "single_word": false
22
+ }
23
+ }
tokenizer.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9e556afd44213b6bd1be2b850ebbbd98f5481437a8021afaf58ee7fb1818d347
3
+ size 499723
tokenizer_config.json ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_bos_token": true,
3
+ "add_eos_token": false,
4
+ "bos_token": {
5
+ "__type": "AddedToken",
6
+ "content": "<s>",
7
+ "lstrip": false,
8
+ "normalized": false,
9
+ "rstrip": false,
10
+ "single_word": false
11
+ },
12
+ "clean_up_tokenization_spaces": false,
13
+ "eos_token": {
14
+ "__type": "AddedToken",
15
+ "content": "</s>",
16
+ "lstrip": false,
17
+ "normalized": false,
18
+ "rstrip": false,
19
+ "single_word": false
20
+ },
21
+ "legacy": false,
22
+ "model_max_length": 1000000000000000019884624838656,
23
+ "pad_token": null,
24
+ "padding_side": "right",
25
+ "sp_model_kwargs": {},
26
+ "tokenizer_class": "LlamaTokenizer",
27
+ "unk_token": {
28
+ "__type": "AddedToken",
29
+ "content": "<unk>",
30
+ "lstrip": false,
31
+ "normalized": false,
32
+ "rstrip": false,
33
+ "single_word": false
34
+ },
35
+ "use_fast": true
36
+ }