pere commited on
Commit
864d4c5
1 Parent(s): 97e3033

Saving best state, step 500, val wer 96.036

Browse files
Files changed (47) hide show
  1. .gitignore +1 -0
  2. added_tokens.json +1611 -0
  3. checkpoint-500-epoch-0-val-wer-96.036/added_tokens.json +1611 -0
  4. checkpoint-500-epoch-0-val-wer-96.036/config.json +285 -0
  5. checkpoint-500-epoch-0-val-wer-96.036/generation_config.json +256 -0
  6. checkpoint-500-epoch-0-val-wer-96.036/merges.txt +0 -0
  7. checkpoint-500-epoch-0-val-wer-96.036/model.safetensors +3 -0
  8. checkpoint-500-epoch-0-val-wer-96.036/model_1.safetensors +3 -0
  9. checkpoint-500-epoch-0-val-wer-96.036/optimizer.bin +3 -0
  10. checkpoint-500-epoch-0-val-wer-96.036/preprocessor_config.json +14 -0
  11. checkpoint-500-epoch-0-val-wer-96.036/random_states_0.pkl +3 -0
  12. checkpoint-500-epoch-0-val-wer-96.036/scheduler.bin +3 -0
  13. checkpoint-500-epoch-0-val-wer-96.036/special_tokens_map.json +139 -0
  14. checkpoint-500-epoch-0-val-wer-96.036/tokenizer.json +0 -0
  15. checkpoint-500-epoch-0-val-wer-96.036/tokenizer_config.json +0 -0
  16. checkpoint-500-epoch-0-val-wer-96.036/vocab.json +0 -0
  17. config.json +285 -0
  18. create_student_model.py +231 -0
  19. distil-whisper/events.out.tfevents.1730988960.a100-80-west4a.48904.0 +3 -0
  20. distil-whisper/events.out.tfevents.1730989066.a100-80-west4a.49408.0 +3 -0
  21. distil-whisper/events.out.tfevents.1730989452.a100-80-west4a.68077.0 +3 -0
  22. distil-whisper/events.out.tfevents.1730990001.a100-80-west4a.87125.0 +3 -0
  23. distil_whisper/__init__.py +21 -0
  24. distil_whisper/layers.py +1338 -0
  25. distil_whisper/modeling_flax_whisper.py +2135 -0
  26. distil_whisper/partitioner.py +965 -0
  27. distil_whisper/pipeline.py +527 -0
  28. distil_whisper/train_state.py +118 -0
  29. generation_config.json +256 -0
  30. merges.txt +0 -0
  31. model.safetensors +3 -0
  32. nb-distil-large-init/added_tokens.json +1611 -0
  33. nb-distil-large-init/config.json +285 -0
  34. nb-distil-large-init/generation_config.json +256 -0
  35. nb-distil-large-init/merges.txt +0 -0
  36. nb-distil-large-init/model.safetensors +3 -0
  37. nb-distil-large-init/preprocessor_config.json +14 -0
  38. nb-distil-large-init/special_tokens_map.json +139 -0
  39. nb-distil-large-init/tokenizer_config.json +0 -0
  40. nb-distil-large-init/vocab.json +0 -0
  41. preprocessor_config.json +14 -0
  42. run_distillation.py +1827 -0
  43. run_large_training.sh +41 -0
  44. special_tokens_map.json +139 -0
  45. tokenizer.json +0 -0
  46. tokenizer_config.json +0 -0
  47. vocab.json +0 -0
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ wandb
added_tokens.json ADDED
@@ -0,0 +1,1611 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "<|0.00|>": 50365,
3
+ "<|0.02|>": 50366,
4
+ "<|0.04|>": 50367,
5
+ "<|0.06|>": 50368,
6
+ "<|0.08|>": 50369,
7
+ "<|0.10|>": 50370,
8
+ "<|0.12|>": 50371,
9
+ "<|0.14|>": 50372,
10
+ "<|0.16|>": 50373,
11
+ "<|0.18|>": 50374,
12
+ "<|0.20|>": 50375,
13
+ "<|0.22|>": 50376,
14
+ "<|0.24|>": 50377,
15
+ "<|0.26|>": 50378,
16
+ "<|0.28|>": 50379,
17
+ "<|0.30|>": 50380,
18
+ "<|0.32|>": 50381,
19
+ "<|0.34|>": 50382,
20
+ "<|0.36|>": 50383,
21
+ "<|0.38|>": 50384,
22
+ "<|0.40|>": 50385,
23
+ "<|0.42|>": 50386,
24
+ "<|0.44|>": 50387,
25
+ "<|0.46|>": 50388,
26
+ "<|0.48|>": 50389,
27
+ "<|0.50|>": 50390,
28
+ "<|0.52|>": 50391,
29
+ "<|0.54|>": 50392,
30
+ "<|0.56|>": 50393,
31
+ "<|0.58|>": 50394,
32
+ "<|0.60|>": 50395,
33
+ "<|0.62|>": 50396,
34
+ "<|0.64|>": 50397,
35
+ "<|0.66|>": 50398,
36
+ "<|0.68|>": 50399,
37
+ "<|0.70|>": 50400,
38
+ "<|0.72|>": 50401,
39
+ "<|0.74|>": 50402,
40
+ "<|0.76|>": 50403,
41
+ "<|0.78|>": 50404,
42
+ "<|0.80|>": 50405,
43
+ "<|0.82|>": 50406,
44
+ "<|0.84|>": 50407,
45
+ "<|0.86|>": 50408,
46
+ "<|0.88|>": 50409,
47
+ "<|0.90|>": 50410,
48
+ "<|0.92|>": 50411,
49
+ "<|0.94|>": 50412,
50
+ "<|0.96|>": 50413,
51
+ "<|0.98|>": 50414,
52
+ "<|1.00|>": 50415,
53
+ "<|1.02|>": 50416,
54
+ "<|1.04|>": 50417,
55
+ "<|1.06|>": 50418,
56
+ "<|1.08|>": 50419,
57
+ "<|1.10|>": 50420,
58
+ "<|1.12|>": 50421,
59
+ "<|1.14|>": 50422,
60
+ "<|1.16|>": 50423,
61
+ "<|1.18|>": 50424,
62
+ "<|1.20|>": 50425,
63
+ "<|1.22|>": 50426,
64
+ "<|1.24|>": 50427,
65
+ "<|1.26|>": 50428,
66
+ "<|1.28|>": 50429,
67
+ "<|1.30|>": 50430,
68
+ "<|1.32|>": 50431,
69
+ "<|1.34|>": 50432,
70
+ "<|1.36|>": 50433,
71
+ "<|1.38|>": 50434,
72
+ "<|1.40|>": 50435,
73
+ "<|1.42|>": 50436,
74
+ "<|1.44|>": 50437,
75
+ "<|1.46|>": 50438,
76
+ "<|1.48|>": 50439,
77
+ "<|1.50|>": 50440,
78
+ "<|1.52|>": 50441,
79
+ "<|1.54|>": 50442,
80
+ "<|1.56|>": 50443,
81
+ "<|1.58|>": 50444,
82
+ "<|1.60|>": 50445,
83
+ "<|1.62|>": 50446,
84
+ "<|1.64|>": 50447,
85
+ "<|1.66|>": 50448,
86
+ "<|1.68|>": 50449,
87
+ "<|1.70|>": 50450,
88
+ "<|1.72|>": 50451,
89
+ "<|1.74|>": 50452,
90
+ "<|1.76|>": 50453,
91
+ "<|1.78|>": 50454,
92
+ "<|1.80|>": 50455,
93
+ "<|1.82|>": 50456,
94
+ "<|1.84|>": 50457,
95
+ "<|1.86|>": 50458,
96
+ "<|1.88|>": 50459,
97
+ "<|1.90|>": 50460,
98
+ "<|1.92|>": 50461,
99
+ "<|1.94|>": 50462,
100
+ "<|1.96|>": 50463,
101
+ "<|1.98|>": 50464,
102
+ "<|10.00|>": 50865,
103
+ "<|10.02|>": 50866,
104
+ "<|10.04|>": 50867,
105
+ "<|10.06|>": 50868,
106
+ "<|10.08|>": 50869,
107
+ "<|10.10|>": 50870,
108
+ "<|10.12|>": 50871,
109
+ "<|10.14|>": 50872,
110
+ "<|10.16|>": 50873,
111
+ "<|10.18|>": 50874,
112
+ "<|10.20|>": 50875,
113
+ "<|10.22|>": 50876,
114
+ "<|10.24|>": 50877,
115
+ "<|10.26|>": 50878,
116
+ "<|10.28|>": 50879,
117
+ "<|10.30|>": 50880,
118
+ "<|10.32|>": 50881,
119
+ "<|10.34|>": 50882,
120
+ "<|10.36|>": 50883,
121
+ "<|10.38|>": 50884,
122
+ "<|10.40|>": 50885,
123
+ "<|10.42|>": 50886,
124
+ "<|10.44|>": 50887,
125
+ "<|10.46|>": 50888,
126
+ "<|10.48|>": 50889,
127
+ "<|10.50|>": 50890,
128
+ "<|10.52|>": 50891,
129
+ "<|10.54|>": 50892,
130
+ "<|10.56|>": 50893,
131
+ "<|10.58|>": 50894,
132
+ "<|10.60|>": 50895,
133
+ "<|10.62|>": 50896,
134
+ "<|10.64|>": 50897,
135
+ "<|10.66|>": 50898,
136
+ "<|10.68|>": 50899,
137
+ "<|10.70|>": 50900,
138
+ "<|10.72|>": 50901,
139
+ "<|10.74|>": 50902,
140
+ "<|10.76|>": 50903,
141
+ "<|10.78|>": 50904,
142
+ "<|10.80|>": 50905,
143
+ "<|10.82|>": 50906,
144
+ "<|10.84|>": 50907,
145
+ "<|10.86|>": 50908,
146
+ "<|10.88|>": 50909,
147
+ "<|10.90|>": 50910,
148
+ "<|10.92|>": 50911,
149
+ "<|10.94|>": 50912,
150
+ "<|10.96|>": 50913,
151
+ "<|10.98|>": 50914,
152
+ "<|11.00|>": 50915,
153
+ "<|11.02|>": 50916,
154
+ "<|11.04|>": 50917,
155
+ "<|11.06|>": 50918,
156
+ "<|11.08|>": 50919,
157
+ "<|11.10|>": 50920,
158
+ "<|11.12|>": 50921,
159
+ "<|11.14|>": 50922,
160
+ "<|11.16|>": 50923,
161
+ "<|11.18|>": 50924,
162
+ "<|11.20|>": 50925,
163
+ "<|11.22|>": 50926,
164
+ "<|11.24|>": 50927,
165
+ "<|11.26|>": 50928,
166
+ "<|11.28|>": 50929,
167
+ "<|11.30|>": 50930,
168
+ "<|11.32|>": 50931,
169
+ "<|11.34|>": 50932,
170
+ "<|11.36|>": 50933,
171
+ "<|11.38|>": 50934,
172
+ "<|11.40|>": 50935,
173
+ "<|11.42|>": 50936,
174
+ "<|11.44|>": 50937,
175
+ "<|11.46|>": 50938,
176
+ "<|11.48|>": 50939,
177
+ "<|11.50|>": 50940,
178
+ "<|11.52|>": 50941,
179
+ "<|11.54|>": 50942,
180
+ "<|11.56|>": 50943,
181
+ "<|11.58|>": 50944,
182
+ "<|11.60|>": 50945,
183
+ "<|11.62|>": 50946,
184
+ "<|11.64|>": 50947,
185
+ "<|11.66|>": 50948,
186
+ "<|11.68|>": 50949,
187
+ "<|11.70|>": 50950,
188
+ "<|11.72|>": 50951,
189
+ "<|11.74|>": 50952,
190
+ "<|11.76|>": 50953,
191
+ "<|11.78|>": 50954,
192
+ "<|11.80|>": 50955,
193
+ "<|11.82|>": 50956,
194
+ "<|11.84|>": 50957,
195
+ "<|11.86|>": 50958,
196
+ "<|11.88|>": 50959,
197
+ "<|11.90|>": 50960,
198
+ "<|11.92|>": 50961,
199
+ "<|11.94|>": 50962,
200
+ "<|11.96|>": 50963,
201
+ "<|11.98|>": 50964,
202
+ "<|12.00|>": 50965,
203
+ "<|12.02|>": 50966,
204
+ "<|12.04|>": 50967,
205
+ "<|12.06|>": 50968,
206
+ "<|12.08|>": 50969,
207
+ "<|12.10|>": 50970,
208
+ "<|12.12|>": 50971,
209
+ "<|12.14|>": 50972,
210
+ "<|12.16|>": 50973,
211
+ "<|12.18|>": 50974,
212
+ "<|12.20|>": 50975,
213
+ "<|12.22|>": 50976,
214
+ "<|12.24|>": 50977,
215
+ "<|12.26|>": 50978,
216
+ "<|12.28|>": 50979,
217
+ "<|12.30|>": 50980,
218
+ "<|12.32|>": 50981,
219
+ "<|12.34|>": 50982,
220
+ "<|12.36|>": 50983,
221
+ "<|12.38|>": 50984,
222
+ "<|12.40|>": 50985,
223
+ "<|12.42|>": 50986,
224
+ "<|12.44|>": 50987,
225
+ "<|12.46|>": 50988,
226
+ "<|12.48|>": 50989,
227
+ "<|12.50|>": 50990,
228
+ "<|12.52|>": 50991,
229
+ "<|12.54|>": 50992,
230
+ "<|12.56|>": 50993,
231
+ "<|12.58|>": 50994,
232
+ "<|12.60|>": 50995,
233
+ "<|12.62|>": 50996,
234
+ "<|12.64|>": 50997,
235
+ "<|12.66|>": 50998,
236
+ "<|12.68|>": 50999,
237
+ "<|12.70|>": 51000,
238
+ "<|12.72|>": 51001,
239
+ "<|12.74|>": 51002,
240
+ "<|12.76|>": 51003,
241
+ "<|12.78|>": 51004,
242
+ "<|12.80|>": 51005,
243
+ "<|12.82|>": 51006,
244
+ "<|12.84|>": 51007,
245
+ "<|12.86|>": 51008,
246
+ "<|12.88|>": 51009,
247
+ "<|12.90|>": 51010,
248
+ "<|12.92|>": 51011,
249
+ "<|12.94|>": 51012,
250
+ "<|12.96|>": 51013,
251
+ "<|12.98|>": 51014,
252
+ "<|13.00|>": 51015,
253
+ "<|13.02|>": 51016,
254
+ "<|13.04|>": 51017,
255
+ "<|13.06|>": 51018,
256
+ "<|13.08|>": 51019,
257
+ "<|13.10|>": 51020,
258
+ "<|13.12|>": 51021,
259
+ "<|13.14|>": 51022,
260
+ "<|13.16|>": 51023,
261
+ "<|13.18|>": 51024,
262
+ "<|13.20|>": 51025,
263
+ "<|13.22|>": 51026,
264
+ "<|13.24|>": 51027,
265
+ "<|13.26|>": 51028,
266
+ "<|13.28|>": 51029,
267
+ "<|13.30|>": 51030,
268
+ "<|13.32|>": 51031,
269
+ "<|13.34|>": 51032,
270
+ "<|13.36|>": 51033,
271
+ "<|13.38|>": 51034,
272
+ "<|13.40|>": 51035,
273
+ "<|13.42|>": 51036,
274
+ "<|13.44|>": 51037,
275
+ "<|13.46|>": 51038,
276
+ "<|13.48|>": 51039,
277
+ "<|13.50|>": 51040,
278
+ "<|13.52|>": 51041,
279
+ "<|13.54|>": 51042,
280
+ "<|13.56|>": 51043,
281
+ "<|13.58|>": 51044,
282
+ "<|13.60|>": 51045,
283
+ "<|13.62|>": 51046,
284
+ "<|13.64|>": 51047,
285
+ "<|13.66|>": 51048,
286
+ "<|13.68|>": 51049,
287
+ "<|13.70|>": 51050,
288
+ "<|13.72|>": 51051,
289
+ "<|13.74|>": 51052,
290
+ "<|13.76|>": 51053,
291
+ "<|13.78|>": 51054,
292
+ "<|13.80|>": 51055,
293
+ "<|13.82|>": 51056,
294
+ "<|13.84|>": 51057,
295
+ "<|13.86|>": 51058,
296
+ "<|13.88|>": 51059,
297
+ "<|13.90|>": 51060,
298
+ "<|13.92|>": 51061,
299
+ "<|13.94|>": 51062,
300
+ "<|13.96|>": 51063,
301
+ "<|13.98|>": 51064,
302
+ "<|14.00|>": 51065,
303
+ "<|14.02|>": 51066,
304
+ "<|14.04|>": 51067,
305
+ "<|14.06|>": 51068,
306
+ "<|14.08|>": 51069,
307
+ "<|14.10|>": 51070,
308
+ "<|14.12|>": 51071,
309
+ "<|14.14|>": 51072,
310
+ "<|14.16|>": 51073,
311
+ "<|14.18|>": 51074,
312
+ "<|14.20|>": 51075,
313
+ "<|14.22|>": 51076,
314
+ "<|14.24|>": 51077,
315
+ "<|14.26|>": 51078,
316
+ "<|14.28|>": 51079,
317
+ "<|14.30|>": 51080,
318
+ "<|14.32|>": 51081,
319
+ "<|14.34|>": 51082,
320
+ "<|14.36|>": 51083,
321
+ "<|14.38|>": 51084,
322
+ "<|14.40|>": 51085,
323
+ "<|14.42|>": 51086,
324
+ "<|14.44|>": 51087,
325
+ "<|14.46|>": 51088,
326
+ "<|14.48|>": 51089,
327
+ "<|14.50|>": 51090,
328
+ "<|14.52|>": 51091,
329
+ "<|14.54|>": 51092,
330
+ "<|14.56|>": 51093,
331
+ "<|14.58|>": 51094,
332
+ "<|14.60|>": 51095,
333
+ "<|14.62|>": 51096,
334
+ "<|14.64|>": 51097,
335
+ "<|14.66|>": 51098,
336
+ "<|14.68|>": 51099,
337
+ "<|14.70|>": 51100,
338
+ "<|14.72|>": 51101,
339
+ "<|14.74|>": 51102,
340
+ "<|14.76|>": 51103,
341
+ "<|14.78|>": 51104,
342
+ "<|14.80|>": 51105,
343
+ "<|14.82|>": 51106,
344
+ "<|14.84|>": 51107,
345
+ "<|14.86|>": 51108,
346
+ "<|14.88|>": 51109,
347
+ "<|14.90|>": 51110,
348
+ "<|14.92|>": 51111,
349
+ "<|14.94|>": 51112,
350
+ "<|14.96|>": 51113,
351
+ "<|14.98|>": 51114,
352
+ "<|15.00|>": 51115,
353
+ "<|15.02|>": 51116,
354
+ "<|15.04|>": 51117,
355
+ "<|15.06|>": 51118,
356
+ "<|15.08|>": 51119,
357
+ "<|15.10|>": 51120,
358
+ "<|15.12|>": 51121,
359
+ "<|15.14|>": 51122,
360
+ "<|15.16|>": 51123,
361
+ "<|15.18|>": 51124,
362
+ "<|15.20|>": 51125,
363
+ "<|15.22|>": 51126,
364
+ "<|15.24|>": 51127,
365
+ "<|15.26|>": 51128,
366
+ "<|15.28|>": 51129,
367
+ "<|15.30|>": 51130,
368
+ "<|15.32|>": 51131,
369
+ "<|15.34|>": 51132,
370
+ "<|15.36|>": 51133,
371
+ "<|15.38|>": 51134,
372
+ "<|15.40|>": 51135,
373
+ "<|15.42|>": 51136,
374
+ "<|15.44|>": 51137,
375
+ "<|15.46|>": 51138,
376
+ "<|15.48|>": 51139,
377
+ "<|15.50|>": 51140,
378
+ "<|15.52|>": 51141,
379
+ "<|15.54|>": 51142,
380
+ "<|15.56|>": 51143,
381
+ "<|15.58|>": 51144,
382
+ "<|15.60|>": 51145,
383
+ "<|15.62|>": 51146,
384
+ "<|15.64|>": 51147,
385
+ "<|15.66|>": 51148,
386
+ "<|15.68|>": 51149,
387
+ "<|15.70|>": 51150,
388
+ "<|15.72|>": 51151,
389
+ "<|15.74|>": 51152,
390
+ "<|15.76|>": 51153,
391
+ "<|15.78|>": 51154,
392
+ "<|15.80|>": 51155,
393
+ "<|15.82|>": 51156,
394
+ "<|15.84|>": 51157,
395
+ "<|15.86|>": 51158,
396
+ "<|15.88|>": 51159,
397
+ "<|15.90|>": 51160,
398
+ "<|15.92|>": 51161,
399
+ "<|15.94|>": 51162,
400
+ "<|15.96|>": 51163,
401
+ "<|15.98|>": 51164,
402
+ "<|16.00|>": 51165,
403
+ "<|16.02|>": 51166,
404
+ "<|16.04|>": 51167,
405
+ "<|16.06|>": 51168,
406
+ "<|16.08|>": 51169,
407
+ "<|16.10|>": 51170,
408
+ "<|16.12|>": 51171,
409
+ "<|16.14|>": 51172,
410
+ "<|16.16|>": 51173,
411
+ "<|16.18|>": 51174,
412
+ "<|16.20|>": 51175,
413
+ "<|16.22|>": 51176,
414
+ "<|16.24|>": 51177,
415
+ "<|16.26|>": 51178,
416
+ "<|16.28|>": 51179,
417
+ "<|16.30|>": 51180,
418
+ "<|16.32|>": 51181,
419
+ "<|16.34|>": 51182,
420
+ "<|16.36|>": 51183,
421
+ "<|16.38|>": 51184,
422
+ "<|16.40|>": 51185,
423
+ "<|16.42|>": 51186,
424
+ "<|16.44|>": 51187,
425
+ "<|16.46|>": 51188,
426
+ "<|16.48|>": 51189,
427
+ "<|16.50|>": 51190,
428
+ "<|16.52|>": 51191,
429
+ "<|16.54|>": 51192,
430
+ "<|16.56|>": 51193,
431
+ "<|16.58|>": 51194,
432
+ "<|16.60|>": 51195,
433
+ "<|16.62|>": 51196,
434
+ "<|16.64|>": 51197,
435
+ "<|16.66|>": 51198,
436
+ "<|16.68|>": 51199,
437
+ "<|16.70|>": 51200,
438
+ "<|16.72|>": 51201,
439
+ "<|16.74|>": 51202,
440
+ "<|16.76|>": 51203,
441
+ "<|16.78|>": 51204,
442
+ "<|16.80|>": 51205,
443
+ "<|16.82|>": 51206,
444
+ "<|16.84|>": 51207,
445
+ "<|16.86|>": 51208,
446
+ "<|16.88|>": 51209,
447
+ "<|16.90|>": 51210,
448
+ "<|16.92|>": 51211,
449
+ "<|16.94|>": 51212,
450
+ "<|16.96|>": 51213,
451
+ "<|16.98|>": 51214,
452
+ "<|17.00|>": 51215,
453
+ "<|17.02|>": 51216,
454
+ "<|17.04|>": 51217,
455
+ "<|17.06|>": 51218,
456
+ "<|17.08|>": 51219,
457
+ "<|17.10|>": 51220,
458
+ "<|17.12|>": 51221,
459
+ "<|17.14|>": 51222,
460
+ "<|17.16|>": 51223,
461
+ "<|17.18|>": 51224,
462
+ "<|17.20|>": 51225,
463
+ "<|17.22|>": 51226,
464
+ "<|17.24|>": 51227,
465
+ "<|17.26|>": 51228,
466
+ "<|17.28|>": 51229,
467
+ "<|17.30|>": 51230,
468
+ "<|17.32|>": 51231,
469
+ "<|17.34|>": 51232,
470
+ "<|17.36|>": 51233,
471
+ "<|17.38|>": 51234,
472
+ "<|17.40|>": 51235,
473
+ "<|17.42|>": 51236,
474
+ "<|17.44|>": 51237,
475
+ "<|17.46|>": 51238,
476
+ "<|17.48|>": 51239,
477
+ "<|17.50|>": 51240,
478
+ "<|17.52|>": 51241,
479
+ "<|17.54|>": 51242,
480
+ "<|17.56|>": 51243,
481
+ "<|17.58|>": 51244,
482
+ "<|17.60|>": 51245,
483
+ "<|17.62|>": 51246,
484
+ "<|17.64|>": 51247,
485
+ "<|17.66|>": 51248,
486
+ "<|17.68|>": 51249,
487
+ "<|17.70|>": 51250,
488
+ "<|17.72|>": 51251,
489
+ "<|17.74|>": 51252,
490
+ "<|17.76|>": 51253,
491
+ "<|17.78|>": 51254,
492
+ "<|17.80|>": 51255,
493
+ "<|17.82|>": 51256,
494
+ "<|17.84|>": 51257,
495
+ "<|17.86|>": 51258,
496
+ "<|17.88|>": 51259,
497
+ "<|17.90|>": 51260,
498
+ "<|17.92|>": 51261,
499
+ "<|17.94|>": 51262,
500
+ "<|17.96|>": 51263,
501
+ "<|17.98|>": 51264,
502
+ "<|18.00|>": 51265,
503
+ "<|18.02|>": 51266,
504
+ "<|18.04|>": 51267,
505
+ "<|18.06|>": 51268,
506
+ "<|18.08|>": 51269,
507
+ "<|18.10|>": 51270,
508
+ "<|18.12|>": 51271,
509
+ "<|18.14|>": 51272,
510
+ "<|18.16|>": 51273,
511
+ "<|18.18|>": 51274,
512
+ "<|18.20|>": 51275,
513
+ "<|18.22|>": 51276,
514
+ "<|18.24|>": 51277,
515
+ "<|18.26|>": 51278,
516
+ "<|18.28|>": 51279,
517
+ "<|18.30|>": 51280,
518
+ "<|18.32|>": 51281,
519
+ "<|18.34|>": 51282,
520
+ "<|18.36|>": 51283,
521
+ "<|18.38|>": 51284,
522
+ "<|18.40|>": 51285,
523
+ "<|18.42|>": 51286,
524
+ "<|18.44|>": 51287,
525
+ "<|18.46|>": 51288,
526
+ "<|18.48|>": 51289,
527
+ "<|18.50|>": 51290,
528
+ "<|18.52|>": 51291,
529
+ "<|18.54|>": 51292,
530
+ "<|18.56|>": 51293,
531
+ "<|18.58|>": 51294,
532
+ "<|18.60|>": 51295,
533
+ "<|18.62|>": 51296,
534
+ "<|18.64|>": 51297,
535
+ "<|18.66|>": 51298,
536
+ "<|18.68|>": 51299,
537
+ "<|18.70|>": 51300,
538
+ "<|18.72|>": 51301,
539
+ "<|18.74|>": 51302,
540
+ "<|18.76|>": 51303,
541
+ "<|18.78|>": 51304,
542
+ "<|18.80|>": 51305,
543
+ "<|18.82|>": 51306,
544
+ "<|18.84|>": 51307,
545
+ "<|18.86|>": 51308,
546
+ "<|18.88|>": 51309,
547
+ "<|18.90|>": 51310,
548
+ "<|18.92|>": 51311,
549
+ "<|18.94|>": 51312,
550
+ "<|18.96|>": 51313,
551
+ "<|18.98|>": 51314,
552
+ "<|19.00|>": 51315,
553
+ "<|19.02|>": 51316,
554
+ "<|19.04|>": 51317,
555
+ "<|19.06|>": 51318,
556
+ "<|19.08|>": 51319,
557
+ "<|19.10|>": 51320,
558
+ "<|19.12|>": 51321,
559
+ "<|19.14|>": 51322,
560
+ "<|19.16|>": 51323,
561
+ "<|19.18|>": 51324,
562
+ "<|19.20|>": 51325,
563
+ "<|19.22|>": 51326,
564
+ "<|19.24|>": 51327,
565
+ "<|19.26|>": 51328,
566
+ "<|19.28|>": 51329,
567
+ "<|19.30|>": 51330,
568
+ "<|19.32|>": 51331,
569
+ "<|19.34|>": 51332,
570
+ "<|19.36|>": 51333,
571
+ "<|19.38|>": 51334,
572
+ "<|19.40|>": 51335,
573
+ "<|19.42|>": 51336,
574
+ "<|19.44|>": 51337,
575
+ "<|19.46|>": 51338,
576
+ "<|19.48|>": 51339,
577
+ "<|19.50|>": 51340,
578
+ "<|19.52|>": 51341,
579
+ "<|19.54|>": 51342,
580
+ "<|19.56|>": 51343,
581
+ "<|19.58|>": 51344,
582
+ "<|19.60|>": 51345,
583
+ "<|19.62|>": 51346,
584
+ "<|19.64|>": 51347,
585
+ "<|19.66|>": 51348,
586
+ "<|19.68|>": 51349,
587
+ "<|19.70|>": 51350,
588
+ "<|19.72|>": 51351,
589
+ "<|19.74|>": 51352,
590
+ "<|19.76|>": 51353,
591
+ "<|19.78|>": 51354,
592
+ "<|19.80|>": 51355,
593
+ "<|19.82|>": 51356,
594
+ "<|19.84|>": 51357,
595
+ "<|19.86|>": 51358,
596
+ "<|19.88|>": 51359,
597
+ "<|19.90|>": 51360,
598
+ "<|19.92|>": 51361,
599
+ "<|19.94|>": 51362,
600
+ "<|19.96|>": 51363,
601
+ "<|19.98|>": 51364,
602
+ "<|2.00|>": 50465,
603
+ "<|2.02|>": 50466,
604
+ "<|2.04|>": 50467,
605
+ "<|2.06|>": 50468,
606
+ "<|2.08|>": 50469,
607
+ "<|2.10|>": 50470,
608
+ "<|2.12|>": 50471,
609
+ "<|2.14|>": 50472,
610
+ "<|2.16|>": 50473,
611
+ "<|2.18|>": 50474,
612
+ "<|2.20|>": 50475,
613
+ "<|2.22|>": 50476,
614
+ "<|2.24|>": 50477,
615
+ "<|2.26|>": 50478,
616
+ "<|2.28|>": 50479,
617
+ "<|2.30|>": 50480,
618
+ "<|2.32|>": 50481,
619
+ "<|2.34|>": 50482,
620
+ "<|2.36|>": 50483,
621
+ "<|2.38|>": 50484,
622
+ "<|2.40|>": 50485,
623
+ "<|2.42|>": 50486,
624
+ "<|2.44|>": 50487,
625
+ "<|2.46|>": 50488,
626
+ "<|2.48|>": 50489,
627
+ "<|2.50|>": 50490,
628
+ "<|2.52|>": 50491,
629
+ "<|2.54|>": 50492,
630
+ "<|2.56|>": 50493,
631
+ "<|2.58|>": 50494,
632
+ "<|2.60|>": 50495,
633
+ "<|2.62|>": 50496,
634
+ "<|2.64|>": 50497,
635
+ "<|2.66|>": 50498,
636
+ "<|2.68|>": 50499,
637
+ "<|2.70|>": 50500,
638
+ "<|2.72|>": 50501,
639
+ "<|2.74|>": 50502,
640
+ "<|2.76|>": 50503,
641
+ "<|2.78|>": 50504,
642
+ "<|2.80|>": 50505,
643
+ "<|2.82|>": 50506,
644
+ "<|2.84|>": 50507,
645
+ "<|2.86|>": 50508,
646
+ "<|2.88|>": 50509,
647
+ "<|2.90|>": 50510,
648
+ "<|2.92|>": 50511,
649
+ "<|2.94|>": 50512,
650
+ "<|2.96|>": 50513,
651
+ "<|2.98|>": 50514,
652
+ "<|20.00|>": 51365,
653
+ "<|20.02|>": 51366,
654
+ "<|20.04|>": 51367,
655
+ "<|20.06|>": 51368,
656
+ "<|20.08|>": 51369,
657
+ "<|20.10|>": 51370,
658
+ "<|20.12|>": 51371,
659
+ "<|20.14|>": 51372,
660
+ "<|20.16|>": 51373,
661
+ "<|20.18|>": 51374,
662
+ "<|20.20|>": 51375,
663
+ "<|20.22|>": 51376,
664
+ "<|20.24|>": 51377,
665
+ "<|20.26|>": 51378,
666
+ "<|20.28|>": 51379,
667
+ "<|20.30|>": 51380,
668
+ "<|20.32|>": 51381,
669
+ "<|20.34|>": 51382,
670
+ "<|20.36|>": 51383,
671
+ "<|20.38|>": 51384,
672
+ "<|20.40|>": 51385,
673
+ "<|20.42|>": 51386,
674
+ "<|20.44|>": 51387,
675
+ "<|20.46|>": 51388,
676
+ "<|20.48|>": 51389,
677
+ "<|20.50|>": 51390,
678
+ "<|20.52|>": 51391,
679
+ "<|20.54|>": 51392,
680
+ "<|20.56|>": 51393,
681
+ "<|20.58|>": 51394,
682
+ "<|20.60|>": 51395,
683
+ "<|20.62|>": 51396,
684
+ "<|20.64|>": 51397,
685
+ "<|20.66|>": 51398,
686
+ "<|20.68|>": 51399,
687
+ "<|20.70|>": 51400,
688
+ "<|20.72|>": 51401,
689
+ "<|20.74|>": 51402,
690
+ "<|20.76|>": 51403,
691
+ "<|20.78|>": 51404,
692
+ "<|20.80|>": 51405,
693
+ "<|20.82|>": 51406,
694
+ "<|20.84|>": 51407,
695
+ "<|20.86|>": 51408,
696
+ "<|20.88|>": 51409,
697
+ "<|20.90|>": 51410,
698
+ "<|20.92|>": 51411,
699
+ "<|20.94|>": 51412,
700
+ "<|20.96|>": 51413,
701
+ "<|20.98|>": 51414,
702
+ "<|21.00|>": 51415,
703
+ "<|21.02|>": 51416,
704
+ "<|21.04|>": 51417,
705
+ "<|21.06|>": 51418,
706
+ "<|21.08|>": 51419,
707
+ "<|21.10|>": 51420,
708
+ "<|21.12|>": 51421,
709
+ "<|21.14|>": 51422,
710
+ "<|21.16|>": 51423,
711
+ "<|21.18|>": 51424,
712
+ "<|21.20|>": 51425,
713
+ "<|21.22|>": 51426,
714
+ "<|21.24|>": 51427,
715
+ "<|21.26|>": 51428,
716
+ "<|21.28|>": 51429,
717
+ "<|21.30|>": 51430,
718
+ "<|21.32|>": 51431,
719
+ "<|21.34|>": 51432,
720
+ "<|21.36|>": 51433,
721
+ "<|21.38|>": 51434,
722
+ "<|21.40|>": 51435,
723
+ "<|21.42|>": 51436,
724
+ "<|21.44|>": 51437,
725
+ "<|21.46|>": 51438,
726
+ "<|21.48|>": 51439,
727
+ "<|21.50|>": 51440,
728
+ "<|21.52|>": 51441,
729
+ "<|21.54|>": 51442,
730
+ "<|21.56|>": 51443,
731
+ "<|21.58|>": 51444,
732
+ "<|21.60|>": 51445,
733
+ "<|21.62|>": 51446,
734
+ "<|21.64|>": 51447,
735
+ "<|21.66|>": 51448,
736
+ "<|21.68|>": 51449,
737
+ "<|21.70|>": 51450,
738
+ "<|21.72|>": 51451,
739
+ "<|21.74|>": 51452,
740
+ "<|21.76|>": 51453,
741
+ "<|21.78|>": 51454,
742
+ "<|21.80|>": 51455,
743
+ "<|21.82|>": 51456,
744
+ "<|21.84|>": 51457,
745
+ "<|21.86|>": 51458,
746
+ "<|21.88|>": 51459,
747
+ "<|21.90|>": 51460,
748
+ "<|21.92|>": 51461,
749
+ "<|21.94|>": 51462,
750
+ "<|21.96|>": 51463,
751
+ "<|21.98|>": 51464,
752
+ "<|22.00|>": 51465,
753
+ "<|22.02|>": 51466,
754
+ "<|22.04|>": 51467,
755
+ "<|22.06|>": 51468,
756
+ "<|22.08|>": 51469,
757
+ "<|22.10|>": 51470,
758
+ "<|22.12|>": 51471,
759
+ "<|22.14|>": 51472,
760
+ "<|22.16|>": 51473,
761
+ "<|22.18|>": 51474,
762
+ "<|22.20|>": 51475,
763
+ "<|22.22|>": 51476,
764
+ "<|22.24|>": 51477,
765
+ "<|22.26|>": 51478,
766
+ "<|22.28|>": 51479,
767
+ "<|22.30|>": 51480,
768
+ "<|22.32|>": 51481,
769
+ "<|22.34|>": 51482,
770
+ "<|22.36|>": 51483,
771
+ "<|22.38|>": 51484,
772
+ "<|22.40|>": 51485,
773
+ "<|22.42|>": 51486,
774
+ "<|22.44|>": 51487,
775
+ "<|22.46|>": 51488,
776
+ "<|22.48|>": 51489,
777
+ "<|22.50|>": 51490,
778
+ "<|22.52|>": 51491,
779
+ "<|22.54|>": 51492,
780
+ "<|22.56|>": 51493,
781
+ "<|22.58|>": 51494,
782
+ "<|22.60|>": 51495,
783
+ "<|22.62|>": 51496,
784
+ "<|22.64|>": 51497,
785
+ "<|22.66|>": 51498,
786
+ "<|22.68|>": 51499,
787
+ "<|22.70|>": 51500,
788
+ "<|22.72|>": 51501,
789
+ "<|22.74|>": 51502,
790
+ "<|22.76|>": 51503,
791
+ "<|22.78|>": 51504,
792
+ "<|22.80|>": 51505,
793
+ "<|22.82|>": 51506,
794
+ "<|22.84|>": 51507,
795
+ "<|22.86|>": 51508,
796
+ "<|22.88|>": 51509,
797
+ "<|22.90|>": 51510,
798
+ "<|22.92|>": 51511,
799
+ "<|22.94|>": 51512,
800
+ "<|22.96|>": 51513,
801
+ "<|22.98|>": 51514,
802
+ "<|23.00|>": 51515,
803
+ "<|23.02|>": 51516,
804
+ "<|23.04|>": 51517,
805
+ "<|23.06|>": 51518,
806
+ "<|23.08|>": 51519,
807
+ "<|23.10|>": 51520,
808
+ "<|23.12|>": 51521,
809
+ "<|23.14|>": 51522,
810
+ "<|23.16|>": 51523,
811
+ "<|23.18|>": 51524,
812
+ "<|23.20|>": 51525,
813
+ "<|23.22|>": 51526,
814
+ "<|23.24|>": 51527,
815
+ "<|23.26|>": 51528,
816
+ "<|23.28|>": 51529,
817
+ "<|23.30|>": 51530,
818
+ "<|23.32|>": 51531,
819
+ "<|23.34|>": 51532,
820
+ "<|23.36|>": 51533,
821
+ "<|23.38|>": 51534,
822
+ "<|23.40|>": 51535,
823
+ "<|23.42|>": 51536,
824
+ "<|23.44|>": 51537,
825
+ "<|23.46|>": 51538,
826
+ "<|23.48|>": 51539,
827
+ "<|23.50|>": 51540,
828
+ "<|23.52|>": 51541,
829
+ "<|23.54|>": 51542,
830
+ "<|23.56|>": 51543,
831
+ "<|23.58|>": 51544,
832
+ "<|23.60|>": 51545,
833
+ "<|23.62|>": 51546,
834
+ "<|23.64|>": 51547,
835
+ "<|23.66|>": 51548,
836
+ "<|23.68|>": 51549,
837
+ "<|23.70|>": 51550,
838
+ "<|23.72|>": 51551,
839
+ "<|23.74|>": 51552,
840
+ "<|23.76|>": 51553,
841
+ "<|23.78|>": 51554,
842
+ "<|23.80|>": 51555,
843
+ "<|23.82|>": 51556,
844
+ "<|23.84|>": 51557,
845
+ "<|23.86|>": 51558,
846
+ "<|23.88|>": 51559,
847
+ "<|23.90|>": 51560,
848
+ "<|23.92|>": 51561,
849
+ "<|23.94|>": 51562,
850
+ "<|23.96|>": 51563,
851
+ "<|23.98|>": 51564,
852
+ "<|24.00|>": 51565,
853
+ "<|24.02|>": 51566,
854
+ "<|24.04|>": 51567,
855
+ "<|24.06|>": 51568,
856
+ "<|24.08|>": 51569,
857
+ "<|24.10|>": 51570,
858
+ "<|24.12|>": 51571,
859
+ "<|24.14|>": 51572,
860
+ "<|24.16|>": 51573,
861
+ "<|24.18|>": 51574,
862
+ "<|24.20|>": 51575,
863
+ "<|24.22|>": 51576,
864
+ "<|24.24|>": 51577,
865
+ "<|24.26|>": 51578,
866
+ "<|24.28|>": 51579,
867
+ "<|24.30|>": 51580,
868
+ "<|24.32|>": 51581,
869
+ "<|24.34|>": 51582,
870
+ "<|24.36|>": 51583,
871
+ "<|24.38|>": 51584,
872
+ "<|24.40|>": 51585,
873
+ "<|24.42|>": 51586,
874
+ "<|24.44|>": 51587,
875
+ "<|24.46|>": 51588,
876
+ "<|24.48|>": 51589,
877
+ "<|24.50|>": 51590,
878
+ "<|24.52|>": 51591,
879
+ "<|24.54|>": 51592,
880
+ "<|24.56|>": 51593,
881
+ "<|24.58|>": 51594,
882
+ "<|24.60|>": 51595,
883
+ "<|24.62|>": 51596,
884
+ "<|24.64|>": 51597,
885
+ "<|24.66|>": 51598,
886
+ "<|24.68|>": 51599,
887
+ "<|24.70|>": 51600,
888
+ "<|24.72|>": 51601,
889
+ "<|24.74|>": 51602,
890
+ "<|24.76|>": 51603,
891
+ "<|24.78|>": 51604,
892
+ "<|24.80|>": 51605,
893
+ "<|24.82|>": 51606,
894
+ "<|24.84|>": 51607,
895
+ "<|24.86|>": 51608,
896
+ "<|24.88|>": 51609,
897
+ "<|24.90|>": 51610,
898
+ "<|24.92|>": 51611,
899
+ "<|24.94|>": 51612,
900
+ "<|24.96|>": 51613,
901
+ "<|24.98|>": 51614,
902
+ "<|25.00|>": 51615,
903
+ "<|25.02|>": 51616,
904
+ "<|25.04|>": 51617,
905
+ "<|25.06|>": 51618,
906
+ "<|25.08|>": 51619,
907
+ "<|25.10|>": 51620,
908
+ "<|25.12|>": 51621,
909
+ "<|25.14|>": 51622,
910
+ "<|25.16|>": 51623,
911
+ "<|25.18|>": 51624,
912
+ "<|25.20|>": 51625,
913
+ "<|25.22|>": 51626,
914
+ "<|25.24|>": 51627,
915
+ "<|25.26|>": 51628,
916
+ "<|25.28|>": 51629,
917
+ "<|25.30|>": 51630,
918
+ "<|25.32|>": 51631,
919
+ "<|25.34|>": 51632,
920
+ "<|25.36|>": 51633,
921
+ "<|25.38|>": 51634,
922
+ "<|25.40|>": 51635,
923
+ "<|25.42|>": 51636,
924
+ "<|25.44|>": 51637,
925
+ "<|25.46|>": 51638,
926
+ "<|25.48|>": 51639,
927
+ "<|25.50|>": 51640,
928
+ "<|25.52|>": 51641,
929
+ "<|25.54|>": 51642,
930
+ "<|25.56|>": 51643,
931
+ "<|25.58|>": 51644,
932
+ "<|25.60|>": 51645,
933
+ "<|25.62|>": 51646,
934
+ "<|25.64|>": 51647,
935
+ "<|25.66|>": 51648,
936
+ "<|25.68|>": 51649,
937
+ "<|25.70|>": 51650,
938
+ "<|25.72|>": 51651,
939
+ "<|25.74|>": 51652,
940
+ "<|25.76|>": 51653,
941
+ "<|25.78|>": 51654,
942
+ "<|25.80|>": 51655,
943
+ "<|25.82|>": 51656,
944
+ "<|25.84|>": 51657,
945
+ "<|25.86|>": 51658,
946
+ "<|25.88|>": 51659,
947
+ "<|25.90|>": 51660,
948
+ "<|25.92|>": 51661,
949
+ "<|25.94|>": 51662,
950
+ "<|25.96|>": 51663,
951
+ "<|25.98|>": 51664,
952
+ "<|26.00|>": 51665,
953
+ "<|26.02|>": 51666,
954
+ "<|26.04|>": 51667,
955
+ "<|26.06|>": 51668,
956
+ "<|26.08|>": 51669,
957
+ "<|26.10|>": 51670,
958
+ "<|26.12|>": 51671,
959
+ "<|26.14|>": 51672,
960
+ "<|26.16|>": 51673,
961
+ "<|26.18|>": 51674,
962
+ "<|26.20|>": 51675,
963
+ "<|26.22|>": 51676,
964
+ "<|26.24|>": 51677,
965
+ "<|26.26|>": 51678,
966
+ "<|26.28|>": 51679,
967
+ "<|26.30|>": 51680,
968
+ "<|26.32|>": 51681,
969
+ "<|26.34|>": 51682,
970
+ "<|26.36|>": 51683,
971
+ "<|26.38|>": 51684,
972
+ "<|26.40|>": 51685,
973
+ "<|26.42|>": 51686,
974
+ "<|26.44|>": 51687,
975
+ "<|26.46|>": 51688,
976
+ "<|26.48|>": 51689,
977
+ "<|26.50|>": 51690,
978
+ "<|26.52|>": 51691,
979
+ "<|26.54|>": 51692,
980
+ "<|26.56|>": 51693,
981
+ "<|26.58|>": 51694,
982
+ "<|26.60|>": 51695,
983
+ "<|26.62|>": 51696,
984
+ "<|26.64|>": 51697,
985
+ "<|26.66|>": 51698,
986
+ "<|26.68|>": 51699,
987
+ "<|26.70|>": 51700,
988
+ "<|26.72|>": 51701,
989
+ "<|26.74|>": 51702,
990
+ "<|26.76|>": 51703,
991
+ "<|26.78|>": 51704,
992
+ "<|26.80|>": 51705,
993
+ "<|26.82|>": 51706,
994
+ "<|26.84|>": 51707,
995
+ "<|26.86|>": 51708,
996
+ "<|26.88|>": 51709,
997
+ "<|26.90|>": 51710,
998
+ "<|26.92|>": 51711,
999
+ "<|26.94|>": 51712,
1000
+ "<|26.96|>": 51713,
1001
+ "<|26.98|>": 51714,
1002
+ "<|27.00|>": 51715,
1003
+ "<|27.02|>": 51716,
1004
+ "<|27.04|>": 51717,
1005
+ "<|27.06|>": 51718,
1006
+ "<|27.08|>": 51719,
1007
+ "<|27.10|>": 51720,
1008
+ "<|27.12|>": 51721,
1009
+ "<|27.14|>": 51722,
1010
+ "<|27.16|>": 51723,
1011
+ "<|27.18|>": 51724,
1012
+ "<|27.20|>": 51725,
1013
+ "<|27.22|>": 51726,
1014
+ "<|27.24|>": 51727,
1015
+ "<|27.26|>": 51728,
1016
+ "<|27.28|>": 51729,
1017
+ "<|27.30|>": 51730,
1018
+ "<|27.32|>": 51731,
1019
+ "<|27.34|>": 51732,
1020
+ "<|27.36|>": 51733,
1021
+ "<|27.38|>": 51734,
1022
+ "<|27.40|>": 51735,
1023
+ "<|27.42|>": 51736,
1024
+ "<|27.44|>": 51737,
1025
+ "<|27.46|>": 51738,
1026
+ "<|27.48|>": 51739,
1027
+ "<|27.50|>": 51740,
1028
+ "<|27.52|>": 51741,
1029
+ "<|27.54|>": 51742,
1030
+ "<|27.56|>": 51743,
1031
+ "<|27.58|>": 51744,
1032
+ "<|27.60|>": 51745,
1033
+ "<|27.62|>": 51746,
1034
+ "<|27.64|>": 51747,
1035
+ "<|27.66|>": 51748,
1036
+ "<|27.68|>": 51749,
1037
+ "<|27.70|>": 51750,
1038
+ "<|27.72|>": 51751,
1039
+ "<|27.74|>": 51752,
1040
+ "<|27.76|>": 51753,
1041
+ "<|27.78|>": 51754,
1042
+ "<|27.80|>": 51755,
1043
+ "<|27.82|>": 51756,
1044
+ "<|27.84|>": 51757,
1045
+ "<|27.86|>": 51758,
1046
+ "<|27.88|>": 51759,
1047
+ "<|27.90|>": 51760,
1048
+ "<|27.92|>": 51761,
1049
+ "<|27.94|>": 51762,
1050
+ "<|27.96|>": 51763,
1051
+ "<|27.98|>": 51764,
1052
+ "<|28.00|>": 51765,
1053
+ "<|28.02|>": 51766,
1054
+ "<|28.04|>": 51767,
1055
+ "<|28.06|>": 51768,
1056
+ "<|28.08|>": 51769,
1057
+ "<|28.10|>": 51770,
1058
+ "<|28.12|>": 51771,
1059
+ "<|28.14|>": 51772,
1060
+ "<|28.16|>": 51773,
1061
+ "<|28.18|>": 51774,
1062
+ "<|28.20|>": 51775,
1063
+ "<|28.22|>": 51776,
1064
+ "<|28.24|>": 51777,
1065
+ "<|28.26|>": 51778,
1066
+ "<|28.28|>": 51779,
1067
+ "<|28.30|>": 51780,
1068
+ "<|28.32|>": 51781,
1069
+ "<|28.34|>": 51782,
1070
+ "<|28.36|>": 51783,
1071
+ "<|28.38|>": 51784,
1072
+ "<|28.40|>": 51785,
1073
+ "<|28.42|>": 51786,
1074
+ "<|28.44|>": 51787,
1075
+ "<|28.46|>": 51788,
1076
+ "<|28.48|>": 51789,
1077
+ "<|28.50|>": 51790,
1078
+ "<|28.52|>": 51791,
1079
+ "<|28.54|>": 51792,
1080
+ "<|28.56|>": 51793,
1081
+ "<|28.58|>": 51794,
1082
+ "<|28.60|>": 51795,
1083
+ "<|28.62|>": 51796,
1084
+ "<|28.64|>": 51797,
1085
+ "<|28.66|>": 51798,
1086
+ "<|28.68|>": 51799,
1087
+ "<|28.70|>": 51800,
1088
+ "<|28.72|>": 51801,
1089
+ "<|28.74|>": 51802,
1090
+ "<|28.76|>": 51803,
1091
+ "<|28.78|>": 51804,
1092
+ "<|28.80|>": 51805,
1093
+ "<|28.82|>": 51806,
1094
+ "<|28.84|>": 51807,
1095
+ "<|28.86|>": 51808,
1096
+ "<|28.88|>": 51809,
1097
+ "<|28.90|>": 51810,
1098
+ "<|28.92|>": 51811,
1099
+ "<|28.94|>": 51812,
1100
+ "<|28.96|>": 51813,
1101
+ "<|28.98|>": 51814,
1102
+ "<|29.00|>": 51815,
1103
+ "<|29.02|>": 51816,
1104
+ "<|29.04|>": 51817,
1105
+ "<|29.06|>": 51818,
1106
+ "<|29.08|>": 51819,
1107
+ "<|29.10|>": 51820,
1108
+ "<|29.12|>": 51821,
1109
+ "<|29.14|>": 51822,
1110
+ "<|29.16|>": 51823,
1111
+ "<|29.18|>": 51824,
1112
+ "<|29.20|>": 51825,
1113
+ "<|29.22|>": 51826,
1114
+ "<|29.24|>": 51827,
1115
+ "<|29.26|>": 51828,
1116
+ "<|29.28|>": 51829,
1117
+ "<|29.30|>": 51830,
1118
+ "<|29.32|>": 51831,
1119
+ "<|29.34|>": 51832,
1120
+ "<|29.36|>": 51833,
1121
+ "<|29.38|>": 51834,
1122
+ "<|29.40|>": 51835,
1123
+ "<|29.42|>": 51836,
1124
+ "<|29.44|>": 51837,
1125
+ "<|29.46|>": 51838,
1126
+ "<|29.48|>": 51839,
1127
+ "<|29.50|>": 51840,
1128
+ "<|29.52|>": 51841,
1129
+ "<|29.54|>": 51842,
1130
+ "<|29.56|>": 51843,
1131
+ "<|29.58|>": 51844,
1132
+ "<|29.60|>": 51845,
1133
+ "<|29.62|>": 51846,
1134
+ "<|29.64|>": 51847,
1135
+ "<|29.66|>": 51848,
1136
+ "<|29.68|>": 51849,
1137
+ "<|29.70|>": 51850,
1138
+ "<|29.72|>": 51851,
1139
+ "<|29.74|>": 51852,
1140
+ "<|29.76|>": 51853,
1141
+ "<|29.78|>": 51854,
1142
+ "<|29.80|>": 51855,
1143
+ "<|29.82|>": 51856,
1144
+ "<|29.84|>": 51857,
1145
+ "<|29.86|>": 51858,
1146
+ "<|29.88|>": 51859,
1147
+ "<|29.90|>": 51860,
1148
+ "<|29.92|>": 51861,
1149
+ "<|29.94|>": 51862,
1150
+ "<|29.96|>": 51863,
1151
+ "<|29.98|>": 51864,
1152
+ "<|3.00|>": 50515,
1153
+ "<|3.02|>": 50516,
1154
+ "<|3.04|>": 50517,
1155
+ "<|3.06|>": 50518,
1156
+ "<|3.08|>": 50519,
1157
+ "<|3.10|>": 50520,
1158
+ "<|3.12|>": 50521,
1159
+ "<|3.14|>": 50522,
1160
+ "<|3.16|>": 50523,
1161
+ "<|3.18|>": 50524,
1162
+ "<|3.20|>": 50525,
1163
+ "<|3.22|>": 50526,
1164
+ "<|3.24|>": 50527,
1165
+ "<|3.26|>": 50528,
1166
+ "<|3.28|>": 50529,
1167
+ "<|3.30|>": 50530,
1168
+ "<|3.32|>": 50531,
1169
+ "<|3.34|>": 50532,
1170
+ "<|3.36|>": 50533,
1171
+ "<|3.38|>": 50534,
1172
+ "<|3.40|>": 50535,
1173
+ "<|3.42|>": 50536,
1174
+ "<|3.44|>": 50537,
1175
+ "<|3.46|>": 50538,
1176
+ "<|3.48|>": 50539,
1177
+ "<|3.50|>": 50540,
1178
+ "<|3.52|>": 50541,
1179
+ "<|3.54|>": 50542,
1180
+ "<|3.56|>": 50543,
1181
+ "<|3.58|>": 50544,
1182
+ "<|3.60|>": 50545,
1183
+ "<|3.62|>": 50546,
1184
+ "<|3.64|>": 50547,
1185
+ "<|3.66|>": 50548,
1186
+ "<|3.68|>": 50549,
1187
+ "<|3.70|>": 50550,
1188
+ "<|3.72|>": 50551,
1189
+ "<|3.74|>": 50552,
1190
+ "<|3.76|>": 50553,
1191
+ "<|3.78|>": 50554,
1192
+ "<|3.80|>": 50555,
1193
+ "<|3.82|>": 50556,
1194
+ "<|3.84|>": 50557,
1195
+ "<|3.86|>": 50558,
1196
+ "<|3.88|>": 50559,
1197
+ "<|3.90|>": 50560,
1198
+ "<|3.92|>": 50561,
1199
+ "<|3.94|>": 50562,
1200
+ "<|3.96|>": 50563,
1201
+ "<|3.98|>": 50564,
1202
+ "<|30.00|>": 51865,
1203
+ "<|4.00|>": 50565,
1204
+ "<|4.02|>": 50566,
1205
+ "<|4.04|>": 50567,
1206
+ "<|4.06|>": 50568,
1207
+ "<|4.08|>": 50569,
1208
+ "<|4.10|>": 50570,
1209
+ "<|4.12|>": 50571,
1210
+ "<|4.14|>": 50572,
1211
+ "<|4.16|>": 50573,
1212
+ "<|4.18|>": 50574,
1213
+ "<|4.20|>": 50575,
1214
+ "<|4.22|>": 50576,
1215
+ "<|4.24|>": 50577,
1216
+ "<|4.26|>": 50578,
1217
+ "<|4.28|>": 50579,
1218
+ "<|4.30|>": 50580,
1219
+ "<|4.32|>": 50581,
1220
+ "<|4.34|>": 50582,
1221
+ "<|4.36|>": 50583,
1222
+ "<|4.38|>": 50584,
1223
+ "<|4.40|>": 50585,
1224
+ "<|4.42|>": 50586,
1225
+ "<|4.44|>": 50587,
1226
+ "<|4.46|>": 50588,
1227
+ "<|4.48|>": 50589,
1228
+ "<|4.50|>": 50590,
1229
+ "<|4.52|>": 50591,
1230
+ "<|4.54|>": 50592,
1231
+ "<|4.56|>": 50593,
1232
+ "<|4.58|>": 50594,
1233
+ "<|4.60|>": 50595,
1234
+ "<|4.62|>": 50596,
1235
+ "<|4.64|>": 50597,
1236
+ "<|4.66|>": 50598,
1237
+ "<|4.68|>": 50599,
1238
+ "<|4.70|>": 50600,
1239
+ "<|4.72|>": 50601,
1240
+ "<|4.74|>": 50602,
1241
+ "<|4.76|>": 50603,
1242
+ "<|4.78|>": 50604,
1243
+ "<|4.80|>": 50605,
1244
+ "<|4.82|>": 50606,
1245
+ "<|4.84|>": 50607,
1246
+ "<|4.86|>": 50608,
1247
+ "<|4.88|>": 50609,
1248
+ "<|4.90|>": 50610,
1249
+ "<|4.92|>": 50611,
1250
+ "<|4.94|>": 50612,
1251
+ "<|4.96|>": 50613,
1252
+ "<|4.98|>": 50614,
1253
+ "<|5.00|>": 50615,
1254
+ "<|5.02|>": 50616,
1255
+ "<|5.04|>": 50617,
1256
+ "<|5.06|>": 50618,
1257
+ "<|5.08|>": 50619,
1258
+ "<|5.10|>": 50620,
1259
+ "<|5.12|>": 50621,
1260
+ "<|5.14|>": 50622,
1261
+ "<|5.16|>": 50623,
1262
+ "<|5.18|>": 50624,
1263
+ "<|5.20|>": 50625,
1264
+ "<|5.22|>": 50626,
1265
+ "<|5.24|>": 50627,
1266
+ "<|5.26|>": 50628,
1267
+ "<|5.28|>": 50629,
1268
+ "<|5.30|>": 50630,
1269
+ "<|5.32|>": 50631,
1270
+ "<|5.34|>": 50632,
1271
+ "<|5.36|>": 50633,
1272
+ "<|5.38|>": 50634,
1273
+ "<|5.40|>": 50635,
1274
+ "<|5.42|>": 50636,
1275
+ "<|5.44|>": 50637,
1276
+ "<|5.46|>": 50638,
1277
+ "<|5.48|>": 50639,
1278
+ "<|5.50|>": 50640,
1279
+ "<|5.52|>": 50641,
1280
+ "<|5.54|>": 50642,
1281
+ "<|5.56|>": 50643,
1282
+ "<|5.58|>": 50644,
1283
+ "<|5.60|>": 50645,
1284
+ "<|5.62|>": 50646,
1285
+ "<|5.64|>": 50647,
1286
+ "<|5.66|>": 50648,
1287
+ "<|5.68|>": 50649,
1288
+ "<|5.70|>": 50650,
1289
+ "<|5.72|>": 50651,
1290
+ "<|5.74|>": 50652,
1291
+ "<|5.76|>": 50653,
1292
+ "<|5.78|>": 50654,
1293
+ "<|5.80|>": 50655,
1294
+ "<|5.82|>": 50656,
1295
+ "<|5.84|>": 50657,
1296
+ "<|5.86|>": 50658,
1297
+ "<|5.88|>": 50659,
1298
+ "<|5.90|>": 50660,
1299
+ "<|5.92|>": 50661,
1300
+ "<|5.94|>": 50662,
1301
+ "<|5.96|>": 50663,
1302
+ "<|5.98|>": 50664,
1303
+ "<|6.00|>": 50665,
1304
+ "<|6.02|>": 50666,
1305
+ "<|6.04|>": 50667,
1306
+ "<|6.06|>": 50668,
1307
+ "<|6.08|>": 50669,
1308
+ "<|6.10|>": 50670,
1309
+ "<|6.12|>": 50671,
1310
+ "<|6.14|>": 50672,
1311
+ "<|6.16|>": 50673,
1312
+ "<|6.18|>": 50674,
1313
+ "<|6.20|>": 50675,
1314
+ "<|6.22|>": 50676,
1315
+ "<|6.24|>": 50677,
1316
+ "<|6.26|>": 50678,
1317
+ "<|6.28|>": 50679,
1318
+ "<|6.30|>": 50680,
1319
+ "<|6.32|>": 50681,
1320
+ "<|6.34|>": 50682,
1321
+ "<|6.36|>": 50683,
1322
+ "<|6.38|>": 50684,
1323
+ "<|6.40|>": 50685,
1324
+ "<|6.42|>": 50686,
1325
+ "<|6.44|>": 50687,
1326
+ "<|6.46|>": 50688,
1327
+ "<|6.48|>": 50689,
1328
+ "<|6.50|>": 50690,
1329
+ "<|6.52|>": 50691,
1330
+ "<|6.54|>": 50692,
1331
+ "<|6.56|>": 50693,
1332
+ "<|6.58|>": 50694,
1333
+ "<|6.60|>": 50695,
1334
+ "<|6.62|>": 50696,
1335
+ "<|6.64|>": 50697,
1336
+ "<|6.66|>": 50698,
1337
+ "<|6.68|>": 50699,
1338
+ "<|6.70|>": 50700,
1339
+ "<|6.72|>": 50701,
1340
+ "<|6.74|>": 50702,
1341
+ "<|6.76|>": 50703,
1342
+ "<|6.78|>": 50704,
1343
+ "<|6.80|>": 50705,
1344
+ "<|6.82|>": 50706,
1345
+ "<|6.84|>": 50707,
1346
+ "<|6.86|>": 50708,
1347
+ "<|6.88|>": 50709,
1348
+ "<|6.90|>": 50710,
1349
+ "<|6.92|>": 50711,
1350
+ "<|6.94|>": 50712,
1351
+ "<|6.96|>": 50713,
1352
+ "<|6.98|>": 50714,
1353
+ "<|7.00|>": 50715,
1354
+ "<|7.02|>": 50716,
1355
+ "<|7.04|>": 50717,
1356
+ "<|7.06|>": 50718,
1357
+ "<|7.08|>": 50719,
1358
+ "<|7.10|>": 50720,
1359
+ "<|7.12|>": 50721,
1360
+ "<|7.14|>": 50722,
1361
+ "<|7.16|>": 50723,
1362
+ "<|7.18|>": 50724,
1363
+ "<|7.20|>": 50725,
1364
+ "<|7.22|>": 50726,
1365
+ "<|7.24|>": 50727,
1366
+ "<|7.26|>": 50728,
1367
+ "<|7.28|>": 50729,
1368
+ "<|7.30|>": 50730,
1369
+ "<|7.32|>": 50731,
1370
+ "<|7.34|>": 50732,
1371
+ "<|7.36|>": 50733,
1372
+ "<|7.38|>": 50734,
1373
+ "<|7.40|>": 50735,
1374
+ "<|7.42|>": 50736,
1375
+ "<|7.44|>": 50737,
1376
+ "<|7.46|>": 50738,
1377
+ "<|7.48|>": 50739,
1378
+ "<|7.50|>": 50740,
1379
+ "<|7.52|>": 50741,
1380
+ "<|7.54|>": 50742,
1381
+ "<|7.56|>": 50743,
1382
+ "<|7.58|>": 50744,
1383
+ "<|7.60|>": 50745,
1384
+ "<|7.62|>": 50746,
1385
+ "<|7.64|>": 50747,
1386
+ "<|7.66|>": 50748,
1387
+ "<|7.68|>": 50749,
1388
+ "<|7.70|>": 50750,
1389
+ "<|7.72|>": 50751,
1390
+ "<|7.74|>": 50752,
1391
+ "<|7.76|>": 50753,
1392
+ "<|7.78|>": 50754,
1393
+ "<|7.80|>": 50755,
1394
+ "<|7.82|>": 50756,
1395
+ "<|7.84|>": 50757,
1396
+ "<|7.86|>": 50758,
1397
+ "<|7.88|>": 50759,
1398
+ "<|7.90|>": 50760,
1399
+ "<|7.92|>": 50761,
1400
+ "<|7.94|>": 50762,
1401
+ "<|7.96|>": 50763,
1402
+ "<|7.98|>": 50764,
1403
+ "<|8.00|>": 50765,
1404
+ "<|8.02|>": 50766,
1405
+ "<|8.04|>": 50767,
1406
+ "<|8.06|>": 50768,
1407
+ "<|8.08|>": 50769,
1408
+ "<|8.10|>": 50770,
1409
+ "<|8.12|>": 50771,
1410
+ "<|8.14|>": 50772,
1411
+ "<|8.16|>": 50773,
1412
+ "<|8.18|>": 50774,
1413
+ "<|8.20|>": 50775,
1414
+ "<|8.22|>": 50776,
1415
+ "<|8.24|>": 50777,
1416
+ "<|8.26|>": 50778,
1417
+ "<|8.28|>": 50779,
1418
+ "<|8.30|>": 50780,
1419
+ "<|8.32|>": 50781,
1420
+ "<|8.34|>": 50782,
1421
+ "<|8.36|>": 50783,
1422
+ "<|8.38|>": 50784,
1423
+ "<|8.40|>": 50785,
1424
+ "<|8.42|>": 50786,
1425
+ "<|8.44|>": 50787,
1426
+ "<|8.46|>": 50788,
1427
+ "<|8.48|>": 50789,
1428
+ "<|8.50|>": 50790,
1429
+ "<|8.52|>": 50791,
1430
+ "<|8.54|>": 50792,
1431
+ "<|8.56|>": 50793,
1432
+ "<|8.58|>": 50794,
1433
+ "<|8.60|>": 50795,
1434
+ "<|8.62|>": 50796,
1435
+ "<|8.64|>": 50797,
1436
+ "<|8.66|>": 50798,
1437
+ "<|8.68|>": 50799,
1438
+ "<|8.70|>": 50800,
1439
+ "<|8.72|>": 50801,
1440
+ "<|8.74|>": 50802,
1441
+ "<|8.76|>": 50803,
1442
+ "<|8.78|>": 50804,
1443
+ "<|8.80|>": 50805,
1444
+ "<|8.82|>": 50806,
1445
+ "<|8.84|>": 50807,
1446
+ "<|8.86|>": 50808,
1447
+ "<|8.88|>": 50809,
1448
+ "<|8.90|>": 50810,
1449
+ "<|8.92|>": 50811,
1450
+ "<|8.94|>": 50812,
1451
+ "<|8.96|>": 50813,
1452
+ "<|8.98|>": 50814,
1453
+ "<|9.00|>": 50815,
1454
+ "<|9.02|>": 50816,
1455
+ "<|9.04|>": 50817,
1456
+ "<|9.06|>": 50818,
1457
+ "<|9.08|>": 50819,
1458
+ "<|9.10|>": 50820,
1459
+ "<|9.12|>": 50821,
1460
+ "<|9.14|>": 50822,
1461
+ "<|9.16|>": 50823,
1462
+ "<|9.18|>": 50824,
1463
+ "<|9.20|>": 50825,
1464
+ "<|9.22|>": 50826,
1465
+ "<|9.24|>": 50827,
1466
+ "<|9.26|>": 50828,
1467
+ "<|9.28|>": 50829,
1468
+ "<|9.30|>": 50830,
1469
+ "<|9.32|>": 50831,
1470
+ "<|9.34|>": 50832,
1471
+ "<|9.36|>": 50833,
1472
+ "<|9.38|>": 50834,
1473
+ "<|9.40|>": 50835,
1474
+ "<|9.42|>": 50836,
1475
+ "<|9.44|>": 50837,
1476
+ "<|9.46|>": 50838,
1477
+ "<|9.48|>": 50839,
1478
+ "<|9.50|>": 50840,
1479
+ "<|9.52|>": 50841,
1480
+ "<|9.54|>": 50842,
1481
+ "<|9.56|>": 50843,
1482
+ "<|9.58|>": 50844,
1483
+ "<|9.60|>": 50845,
1484
+ "<|9.62|>": 50846,
1485
+ "<|9.64|>": 50847,
1486
+ "<|9.66|>": 50848,
1487
+ "<|9.68|>": 50849,
1488
+ "<|9.70|>": 50850,
1489
+ "<|9.72|>": 50851,
1490
+ "<|9.74|>": 50852,
1491
+ "<|9.76|>": 50853,
1492
+ "<|9.78|>": 50854,
1493
+ "<|9.80|>": 50855,
1494
+ "<|9.82|>": 50856,
1495
+ "<|9.84|>": 50857,
1496
+ "<|9.86|>": 50858,
1497
+ "<|9.88|>": 50859,
1498
+ "<|9.90|>": 50860,
1499
+ "<|9.92|>": 50861,
1500
+ "<|9.94|>": 50862,
1501
+ "<|9.96|>": 50863,
1502
+ "<|9.98|>": 50864,
1503
+ "<|af|>": 50327,
1504
+ "<|am|>": 50334,
1505
+ "<|ar|>": 50272,
1506
+ "<|as|>": 50350,
1507
+ "<|az|>": 50304,
1508
+ "<|ba|>": 50355,
1509
+ "<|be|>": 50330,
1510
+ "<|bg|>": 50292,
1511
+ "<|bn|>": 50302,
1512
+ "<|bo|>": 50347,
1513
+ "<|br|>": 50309,
1514
+ "<|bs|>": 50315,
1515
+ "<|ca|>": 50270,
1516
+ "<|cs|>": 50283,
1517
+ "<|cy|>": 50297,
1518
+ "<|da|>": 50285,
1519
+ "<|de|>": 50261,
1520
+ "<|el|>": 50281,
1521
+ "<|endoftext|>": 50257,
1522
+ "<|en|>": 50259,
1523
+ "<|es|>": 50262,
1524
+ "<|et|>": 50307,
1525
+ "<|eu|>": 50310,
1526
+ "<|fa|>": 50300,
1527
+ "<|fi|>": 50277,
1528
+ "<|fo|>": 50338,
1529
+ "<|fr|>": 50265,
1530
+ "<|gl|>": 50319,
1531
+ "<|gu|>": 50333,
1532
+ "<|haw|>": 50352,
1533
+ "<|ha|>": 50354,
1534
+ "<|he|>": 50279,
1535
+ "<|hi|>": 50276,
1536
+ "<|hr|>": 50291,
1537
+ "<|ht|>": 50339,
1538
+ "<|hu|>": 50286,
1539
+ "<|hy|>": 50312,
1540
+ "<|id|>": 50275,
1541
+ "<|is|>": 50311,
1542
+ "<|it|>": 50274,
1543
+ "<|ja|>": 50266,
1544
+ "<|jw|>": 50356,
1545
+ "<|ka|>": 50329,
1546
+ "<|kk|>": 50316,
1547
+ "<|km|>": 50323,
1548
+ "<|kn|>": 50306,
1549
+ "<|ko|>": 50264,
1550
+ "<|la|>": 50294,
1551
+ "<|lb|>": 50345,
1552
+ "<|ln|>": 50353,
1553
+ "<|lo|>": 50336,
1554
+ "<|lt|>": 50293,
1555
+ "<|lv|>": 50301,
1556
+ "<|mg|>": 50349,
1557
+ "<|mi|>": 50295,
1558
+ "<|mk|>": 50308,
1559
+ "<|ml|>": 50296,
1560
+ "<|mn|>": 50314,
1561
+ "<|mr|>": 50320,
1562
+ "<|ms|>": 50282,
1563
+ "<|mt|>": 50343,
1564
+ "<|my|>": 50346,
1565
+ "<|ne|>": 50313,
1566
+ "<|nl|>": 50271,
1567
+ "<|nn|>": 50342,
1568
+ "<|nospeech|>": 50363,
1569
+ "<|notimestamps|>": 50364,
1570
+ "<|no|>": 50288,
1571
+ "<|oc|>": 50328,
1572
+ "<|pa|>": 50321,
1573
+ "<|pl|>": 50269,
1574
+ "<|ps|>": 50340,
1575
+ "<|pt|>": 50267,
1576
+ "<|ro|>": 50284,
1577
+ "<|ru|>": 50263,
1578
+ "<|sa|>": 50344,
1579
+ "<|sd|>": 50332,
1580
+ "<|si|>": 50322,
1581
+ "<|sk|>": 50298,
1582
+ "<|sl|>": 50305,
1583
+ "<|sn|>": 50324,
1584
+ "<|so|>": 50326,
1585
+ "<|sq|>": 50317,
1586
+ "<|sr|>": 50303,
1587
+ "<|startoflm|>": 50361,
1588
+ "<|startofprev|>": 50362,
1589
+ "<|startoftranscript|>": 50258,
1590
+ "<|su|>": 50357,
1591
+ "<|sv|>": 50273,
1592
+ "<|sw|>": 50318,
1593
+ "<|ta|>": 50287,
1594
+ "<|te|>": 50299,
1595
+ "<|tg|>": 50331,
1596
+ "<|th|>": 50289,
1597
+ "<|tk|>": 50341,
1598
+ "<|tl|>": 50348,
1599
+ "<|transcribe|>": 50360,
1600
+ "<|translate|>": 50359,
1601
+ "<|tr|>": 50268,
1602
+ "<|tt|>": 50351,
1603
+ "<|uk|>": 50280,
1604
+ "<|ur|>": 50290,
1605
+ "<|uz|>": 50337,
1606
+ "<|vi|>": 50278,
1607
+ "<|yi|>": 50335,
1608
+ "<|yo|>": 50325,
1609
+ "<|yue|>": 50358,
1610
+ "<|zh|>": 50260
1611
+ }
checkpoint-500-epoch-0-val-wer-96.036/added_tokens.json ADDED
@@ -0,0 +1,1611 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "<|0.00|>": 50365,
3
+ "<|0.02|>": 50366,
4
+ "<|0.04|>": 50367,
5
+ "<|0.06|>": 50368,
6
+ "<|0.08|>": 50369,
7
+ "<|0.10|>": 50370,
8
+ "<|0.12|>": 50371,
9
+ "<|0.14|>": 50372,
10
+ "<|0.16|>": 50373,
11
+ "<|0.18|>": 50374,
12
+ "<|0.20|>": 50375,
13
+ "<|0.22|>": 50376,
14
+ "<|0.24|>": 50377,
15
+ "<|0.26|>": 50378,
16
+ "<|0.28|>": 50379,
17
+ "<|0.30|>": 50380,
18
+ "<|0.32|>": 50381,
19
+ "<|0.34|>": 50382,
20
+ "<|0.36|>": 50383,
21
+ "<|0.38|>": 50384,
22
+ "<|0.40|>": 50385,
23
+ "<|0.42|>": 50386,
24
+ "<|0.44|>": 50387,
25
+ "<|0.46|>": 50388,
26
+ "<|0.48|>": 50389,
27
+ "<|0.50|>": 50390,
28
+ "<|0.52|>": 50391,
29
+ "<|0.54|>": 50392,
30
+ "<|0.56|>": 50393,
31
+ "<|0.58|>": 50394,
32
+ "<|0.60|>": 50395,
33
+ "<|0.62|>": 50396,
34
+ "<|0.64|>": 50397,
35
+ "<|0.66|>": 50398,
36
+ "<|0.68|>": 50399,
37
+ "<|0.70|>": 50400,
38
+ "<|0.72|>": 50401,
39
+ "<|0.74|>": 50402,
40
+ "<|0.76|>": 50403,
41
+ "<|0.78|>": 50404,
42
+ "<|0.80|>": 50405,
43
+ "<|0.82|>": 50406,
44
+ "<|0.84|>": 50407,
45
+ "<|0.86|>": 50408,
46
+ "<|0.88|>": 50409,
47
+ "<|0.90|>": 50410,
48
+ "<|0.92|>": 50411,
49
+ "<|0.94|>": 50412,
50
+ "<|0.96|>": 50413,
51
+ "<|0.98|>": 50414,
52
+ "<|1.00|>": 50415,
53
+ "<|1.02|>": 50416,
54
+ "<|1.04|>": 50417,
55
+ "<|1.06|>": 50418,
56
+ "<|1.08|>": 50419,
57
+ "<|1.10|>": 50420,
58
+ "<|1.12|>": 50421,
59
+ "<|1.14|>": 50422,
60
+ "<|1.16|>": 50423,
61
+ "<|1.18|>": 50424,
62
+ "<|1.20|>": 50425,
63
+ "<|1.22|>": 50426,
64
+ "<|1.24|>": 50427,
65
+ "<|1.26|>": 50428,
66
+ "<|1.28|>": 50429,
67
+ "<|1.30|>": 50430,
68
+ "<|1.32|>": 50431,
69
+ "<|1.34|>": 50432,
70
+ "<|1.36|>": 50433,
71
+ "<|1.38|>": 50434,
72
+ "<|1.40|>": 50435,
73
+ "<|1.42|>": 50436,
74
+ "<|1.44|>": 50437,
75
+ "<|1.46|>": 50438,
76
+ "<|1.48|>": 50439,
77
+ "<|1.50|>": 50440,
78
+ "<|1.52|>": 50441,
79
+ "<|1.54|>": 50442,
80
+ "<|1.56|>": 50443,
81
+ "<|1.58|>": 50444,
82
+ "<|1.60|>": 50445,
83
+ "<|1.62|>": 50446,
84
+ "<|1.64|>": 50447,
85
+ "<|1.66|>": 50448,
86
+ "<|1.68|>": 50449,
87
+ "<|1.70|>": 50450,
88
+ "<|1.72|>": 50451,
89
+ "<|1.74|>": 50452,
90
+ "<|1.76|>": 50453,
91
+ "<|1.78|>": 50454,
92
+ "<|1.80|>": 50455,
93
+ "<|1.82|>": 50456,
94
+ "<|1.84|>": 50457,
95
+ "<|1.86|>": 50458,
96
+ "<|1.88|>": 50459,
97
+ "<|1.90|>": 50460,
98
+ "<|1.92|>": 50461,
99
+ "<|1.94|>": 50462,
100
+ "<|1.96|>": 50463,
101
+ "<|1.98|>": 50464,
102
+ "<|10.00|>": 50865,
103
+ "<|10.02|>": 50866,
104
+ "<|10.04|>": 50867,
105
+ "<|10.06|>": 50868,
106
+ "<|10.08|>": 50869,
107
+ "<|10.10|>": 50870,
108
+ "<|10.12|>": 50871,
109
+ "<|10.14|>": 50872,
110
+ "<|10.16|>": 50873,
111
+ "<|10.18|>": 50874,
112
+ "<|10.20|>": 50875,
113
+ "<|10.22|>": 50876,
114
+ "<|10.24|>": 50877,
115
+ "<|10.26|>": 50878,
116
+ "<|10.28|>": 50879,
117
+ "<|10.30|>": 50880,
118
+ "<|10.32|>": 50881,
119
+ "<|10.34|>": 50882,
120
+ "<|10.36|>": 50883,
121
+ "<|10.38|>": 50884,
122
+ "<|10.40|>": 50885,
123
+ "<|10.42|>": 50886,
124
+ "<|10.44|>": 50887,
125
+ "<|10.46|>": 50888,
126
+ "<|10.48|>": 50889,
127
+ "<|10.50|>": 50890,
128
+ "<|10.52|>": 50891,
129
+ "<|10.54|>": 50892,
130
+ "<|10.56|>": 50893,
131
+ "<|10.58|>": 50894,
132
+ "<|10.60|>": 50895,
133
+ "<|10.62|>": 50896,
134
+ "<|10.64|>": 50897,
135
+ "<|10.66|>": 50898,
136
+ "<|10.68|>": 50899,
137
+ "<|10.70|>": 50900,
138
+ "<|10.72|>": 50901,
139
+ "<|10.74|>": 50902,
140
+ "<|10.76|>": 50903,
141
+ "<|10.78|>": 50904,
142
+ "<|10.80|>": 50905,
143
+ "<|10.82|>": 50906,
144
+ "<|10.84|>": 50907,
145
+ "<|10.86|>": 50908,
146
+ "<|10.88|>": 50909,
147
+ "<|10.90|>": 50910,
148
+ "<|10.92|>": 50911,
149
+ "<|10.94|>": 50912,
150
+ "<|10.96|>": 50913,
151
+ "<|10.98|>": 50914,
152
+ "<|11.00|>": 50915,
153
+ "<|11.02|>": 50916,
154
+ "<|11.04|>": 50917,
155
+ "<|11.06|>": 50918,
156
+ "<|11.08|>": 50919,
157
+ "<|11.10|>": 50920,
158
+ "<|11.12|>": 50921,
159
+ "<|11.14|>": 50922,
160
+ "<|11.16|>": 50923,
161
+ "<|11.18|>": 50924,
162
+ "<|11.20|>": 50925,
163
+ "<|11.22|>": 50926,
164
+ "<|11.24|>": 50927,
165
+ "<|11.26|>": 50928,
166
+ "<|11.28|>": 50929,
167
+ "<|11.30|>": 50930,
168
+ "<|11.32|>": 50931,
169
+ "<|11.34|>": 50932,
170
+ "<|11.36|>": 50933,
171
+ "<|11.38|>": 50934,
172
+ "<|11.40|>": 50935,
173
+ "<|11.42|>": 50936,
174
+ "<|11.44|>": 50937,
175
+ "<|11.46|>": 50938,
176
+ "<|11.48|>": 50939,
177
+ "<|11.50|>": 50940,
178
+ "<|11.52|>": 50941,
179
+ "<|11.54|>": 50942,
180
+ "<|11.56|>": 50943,
181
+ "<|11.58|>": 50944,
182
+ "<|11.60|>": 50945,
183
+ "<|11.62|>": 50946,
184
+ "<|11.64|>": 50947,
185
+ "<|11.66|>": 50948,
186
+ "<|11.68|>": 50949,
187
+ "<|11.70|>": 50950,
188
+ "<|11.72|>": 50951,
189
+ "<|11.74|>": 50952,
190
+ "<|11.76|>": 50953,
191
+ "<|11.78|>": 50954,
192
+ "<|11.80|>": 50955,
193
+ "<|11.82|>": 50956,
194
+ "<|11.84|>": 50957,
195
+ "<|11.86|>": 50958,
196
+ "<|11.88|>": 50959,
197
+ "<|11.90|>": 50960,
198
+ "<|11.92|>": 50961,
199
+ "<|11.94|>": 50962,
200
+ "<|11.96|>": 50963,
201
+ "<|11.98|>": 50964,
202
+ "<|12.00|>": 50965,
203
+ "<|12.02|>": 50966,
204
+ "<|12.04|>": 50967,
205
+ "<|12.06|>": 50968,
206
+ "<|12.08|>": 50969,
207
+ "<|12.10|>": 50970,
208
+ "<|12.12|>": 50971,
209
+ "<|12.14|>": 50972,
210
+ "<|12.16|>": 50973,
211
+ "<|12.18|>": 50974,
212
+ "<|12.20|>": 50975,
213
+ "<|12.22|>": 50976,
214
+ "<|12.24|>": 50977,
215
+ "<|12.26|>": 50978,
216
+ "<|12.28|>": 50979,
217
+ "<|12.30|>": 50980,
218
+ "<|12.32|>": 50981,
219
+ "<|12.34|>": 50982,
220
+ "<|12.36|>": 50983,
221
+ "<|12.38|>": 50984,
222
+ "<|12.40|>": 50985,
223
+ "<|12.42|>": 50986,
224
+ "<|12.44|>": 50987,
225
+ "<|12.46|>": 50988,
226
+ "<|12.48|>": 50989,
227
+ "<|12.50|>": 50990,
228
+ "<|12.52|>": 50991,
229
+ "<|12.54|>": 50992,
230
+ "<|12.56|>": 50993,
231
+ "<|12.58|>": 50994,
232
+ "<|12.60|>": 50995,
233
+ "<|12.62|>": 50996,
234
+ "<|12.64|>": 50997,
235
+ "<|12.66|>": 50998,
236
+ "<|12.68|>": 50999,
237
+ "<|12.70|>": 51000,
238
+ "<|12.72|>": 51001,
239
+ "<|12.74|>": 51002,
240
+ "<|12.76|>": 51003,
241
+ "<|12.78|>": 51004,
242
+ "<|12.80|>": 51005,
243
+ "<|12.82|>": 51006,
244
+ "<|12.84|>": 51007,
245
+ "<|12.86|>": 51008,
246
+ "<|12.88|>": 51009,
247
+ "<|12.90|>": 51010,
248
+ "<|12.92|>": 51011,
249
+ "<|12.94|>": 51012,
250
+ "<|12.96|>": 51013,
251
+ "<|12.98|>": 51014,
252
+ "<|13.00|>": 51015,
253
+ "<|13.02|>": 51016,
254
+ "<|13.04|>": 51017,
255
+ "<|13.06|>": 51018,
256
+ "<|13.08|>": 51019,
257
+ "<|13.10|>": 51020,
258
+ "<|13.12|>": 51021,
259
+ "<|13.14|>": 51022,
260
+ "<|13.16|>": 51023,
261
+ "<|13.18|>": 51024,
262
+ "<|13.20|>": 51025,
263
+ "<|13.22|>": 51026,
264
+ "<|13.24|>": 51027,
265
+ "<|13.26|>": 51028,
266
+ "<|13.28|>": 51029,
267
+ "<|13.30|>": 51030,
268
+ "<|13.32|>": 51031,
269
+ "<|13.34|>": 51032,
270
+ "<|13.36|>": 51033,
271
+ "<|13.38|>": 51034,
272
+ "<|13.40|>": 51035,
273
+ "<|13.42|>": 51036,
274
+ "<|13.44|>": 51037,
275
+ "<|13.46|>": 51038,
276
+ "<|13.48|>": 51039,
277
+ "<|13.50|>": 51040,
278
+ "<|13.52|>": 51041,
279
+ "<|13.54|>": 51042,
280
+ "<|13.56|>": 51043,
281
+ "<|13.58|>": 51044,
282
+ "<|13.60|>": 51045,
283
+ "<|13.62|>": 51046,
284
+ "<|13.64|>": 51047,
285
+ "<|13.66|>": 51048,
286
+ "<|13.68|>": 51049,
287
+ "<|13.70|>": 51050,
288
+ "<|13.72|>": 51051,
289
+ "<|13.74|>": 51052,
290
+ "<|13.76|>": 51053,
291
+ "<|13.78|>": 51054,
292
+ "<|13.80|>": 51055,
293
+ "<|13.82|>": 51056,
294
+ "<|13.84|>": 51057,
295
+ "<|13.86|>": 51058,
296
+ "<|13.88|>": 51059,
297
+ "<|13.90|>": 51060,
298
+ "<|13.92|>": 51061,
299
+ "<|13.94|>": 51062,
300
+ "<|13.96|>": 51063,
301
+ "<|13.98|>": 51064,
302
+ "<|14.00|>": 51065,
303
+ "<|14.02|>": 51066,
304
+ "<|14.04|>": 51067,
305
+ "<|14.06|>": 51068,
306
+ "<|14.08|>": 51069,
307
+ "<|14.10|>": 51070,
308
+ "<|14.12|>": 51071,
309
+ "<|14.14|>": 51072,
310
+ "<|14.16|>": 51073,
311
+ "<|14.18|>": 51074,
312
+ "<|14.20|>": 51075,
313
+ "<|14.22|>": 51076,
314
+ "<|14.24|>": 51077,
315
+ "<|14.26|>": 51078,
316
+ "<|14.28|>": 51079,
317
+ "<|14.30|>": 51080,
318
+ "<|14.32|>": 51081,
319
+ "<|14.34|>": 51082,
320
+ "<|14.36|>": 51083,
321
+ "<|14.38|>": 51084,
322
+ "<|14.40|>": 51085,
323
+ "<|14.42|>": 51086,
324
+ "<|14.44|>": 51087,
325
+ "<|14.46|>": 51088,
326
+ "<|14.48|>": 51089,
327
+ "<|14.50|>": 51090,
328
+ "<|14.52|>": 51091,
329
+ "<|14.54|>": 51092,
330
+ "<|14.56|>": 51093,
331
+ "<|14.58|>": 51094,
332
+ "<|14.60|>": 51095,
333
+ "<|14.62|>": 51096,
334
+ "<|14.64|>": 51097,
335
+ "<|14.66|>": 51098,
336
+ "<|14.68|>": 51099,
337
+ "<|14.70|>": 51100,
338
+ "<|14.72|>": 51101,
339
+ "<|14.74|>": 51102,
340
+ "<|14.76|>": 51103,
341
+ "<|14.78|>": 51104,
342
+ "<|14.80|>": 51105,
343
+ "<|14.82|>": 51106,
344
+ "<|14.84|>": 51107,
345
+ "<|14.86|>": 51108,
346
+ "<|14.88|>": 51109,
347
+ "<|14.90|>": 51110,
348
+ "<|14.92|>": 51111,
349
+ "<|14.94|>": 51112,
350
+ "<|14.96|>": 51113,
351
+ "<|14.98|>": 51114,
352
+ "<|15.00|>": 51115,
353
+ "<|15.02|>": 51116,
354
+ "<|15.04|>": 51117,
355
+ "<|15.06|>": 51118,
356
+ "<|15.08|>": 51119,
357
+ "<|15.10|>": 51120,
358
+ "<|15.12|>": 51121,
359
+ "<|15.14|>": 51122,
360
+ "<|15.16|>": 51123,
361
+ "<|15.18|>": 51124,
362
+ "<|15.20|>": 51125,
363
+ "<|15.22|>": 51126,
364
+ "<|15.24|>": 51127,
365
+ "<|15.26|>": 51128,
366
+ "<|15.28|>": 51129,
367
+ "<|15.30|>": 51130,
368
+ "<|15.32|>": 51131,
369
+ "<|15.34|>": 51132,
370
+ "<|15.36|>": 51133,
371
+ "<|15.38|>": 51134,
372
+ "<|15.40|>": 51135,
373
+ "<|15.42|>": 51136,
374
+ "<|15.44|>": 51137,
375
+ "<|15.46|>": 51138,
376
+ "<|15.48|>": 51139,
377
+ "<|15.50|>": 51140,
378
+ "<|15.52|>": 51141,
379
+ "<|15.54|>": 51142,
380
+ "<|15.56|>": 51143,
381
+ "<|15.58|>": 51144,
382
+ "<|15.60|>": 51145,
383
+ "<|15.62|>": 51146,
384
+ "<|15.64|>": 51147,
385
+ "<|15.66|>": 51148,
386
+ "<|15.68|>": 51149,
387
+ "<|15.70|>": 51150,
388
+ "<|15.72|>": 51151,
389
+ "<|15.74|>": 51152,
390
+ "<|15.76|>": 51153,
391
+ "<|15.78|>": 51154,
392
+ "<|15.80|>": 51155,
393
+ "<|15.82|>": 51156,
394
+ "<|15.84|>": 51157,
395
+ "<|15.86|>": 51158,
396
+ "<|15.88|>": 51159,
397
+ "<|15.90|>": 51160,
398
+ "<|15.92|>": 51161,
399
+ "<|15.94|>": 51162,
400
+ "<|15.96|>": 51163,
401
+ "<|15.98|>": 51164,
402
+ "<|16.00|>": 51165,
403
+ "<|16.02|>": 51166,
404
+ "<|16.04|>": 51167,
405
+ "<|16.06|>": 51168,
406
+ "<|16.08|>": 51169,
407
+ "<|16.10|>": 51170,
408
+ "<|16.12|>": 51171,
409
+ "<|16.14|>": 51172,
410
+ "<|16.16|>": 51173,
411
+ "<|16.18|>": 51174,
412
+ "<|16.20|>": 51175,
413
+ "<|16.22|>": 51176,
414
+ "<|16.24|>": 51177,
415
+ "<|16.26|>": 51178,
416
+ "<|16.28|>": 51179,
417
+ "<|16.30|>": 51180,
418
+ "<|16.32|>": 51181,
419
+ "<|16.34|>": 51182,
420
+ "<|16.36|>": 51183,
421
+ "<|16.38|>": 51184,
422
+ "<|16.40|>": 51185,
423
+ "<|16.42|>": 51186,
424
+ "<|16.44|>": 51187,
425
+ "<|16.46|>": 51188,
426
+ "<|16.48|>": 51189,
427
+ "<|16.50|>": 51190,
428
+ "<|16.52|>": 51191,
429
+ "<|16.54|>": 51192,
430
+ "<|16.56|>": 51193,
431
+ "<|16.58|>": 51194,
432
+ "<|16.60|>": 51195,
433
+ "<|16.62|>": 51196,
434
+ "<|16.64|>": 51197,
435
+ "<|16.66|>": 51198,
436
+ "<|16.68|>": 51199,
437
+ "<|16.70|>": 51200,
438
+ "<|16.72|>": 51201,
439
+ "<|16.74|>": 51202,
440
+ "<|16.76|>": 51203,
441
+ "<|16.78|>": 51204,
442
+ "<|16.80|>": 51205,
443
+ "<|16.82|>": 51206,
444
+ "<|16.84|>": 51207,
445
+ "<|16.86|>": 51208,
446
+ "<|16.88|>": 51209,
447
+ "<|16.90|>": 51210,
448
+ "<|16.92|>": 51211,
449
+ "<|16.94|>": 51212,
450
+ "<|16.96|>": 51213,
451
+ "<|16.98|>": 51214,
452
+ "<|17.00|>": 51215,
453
+ "<|17.02|>": 51216,
454
+ "<|17.04|>": 51217,
455
+ "<|17.06|>": 51218,
456
+ "<|17.08|>": 51219,
457
+ "<|17.10|>": 51220,
458
+ "<|17.12|>": 51221,
459
+ "<|17.14|>": 51222,
460
+ "<|17.16|>": 51223,
461
+ "<|17.18|>": 51224,
462
+ "<|17.20|>": 51225,
463
+ "<|17.22|>": 51226,
464
+ "<|17.24|>": 51227,
465
+ "<|17.26|>": 51228,
466
+ "<|17.28|>": 51229,
467
+ "<|17.30|>": 51230,
468
+ "<|17.32|>": 51231,
469
+ "<|17.34|>": 51232,
470
+ "<|17.36|>": 51233,
471
+ "<|17.38|>": 51234,
472
+ "<|17.40|>": 51235,
473
+ "<|17.42|>": 51236,
474
+ "<|17.44|>": 51237,
475
+ "<|17.46|>": 51238,
476
+ "<|17.48|>": 51239,
477
+ "<|17.50|>": 51240,
478
+ "<|17.52|>": 51241,
479
+ "<|17.54|>": 51242,
480
+ "<|17.56|>": 51243,
481
+ "<|17.58|>": 51244,
482
+ "<|17.60|>": 51245,
483
+ "<|17.62|>": 51246,
484
+ "<|17.64|>": 51247,
485
+ "<|17.66|>": 51248,
486
+ "<|17.68|>": 51249,
487
+ "<|17.70|>": 51250,
488
+ "<|17.72|>": 51251,
489
+ "<|17.74|>": 51252,
490
+ "<|17.76|>": 51253,
491
+ "<|17.78|>": 51254,
492
+ "<|17.80|>": 51255,
493
+ "<|17.82|>": 51256,
494
+ "<|17.84|>": 51257,
495
+ "<|17.86|>": 51258,
496
+ "<|17.88|>": 51259,
497
+ "<|17.90|>": 51260,
498
+ "<|17.92|>": 51261,
499
+ "<|17.94|>": 51262,
500
+ "<|17.96|>": 51263,
501
+ "<|17.98|>": 51264,
502
+ "<|18.00|>": 51265,
503
+ "<|18.02|>": 51266,
504
+ "<|18.04|>": 51267,
505
+ "<|18.06|>": 51268,
506
+ "<|18.08|>": 51269,
507
+ "<|18.10|>": 51270,
508
+ "<|18.12|>": 51271,
509
+ "<|18.14|>": 51272,
510
+ "<|18.16|>": 51273,
511
+ "<|18.18|>": 51274,
512
+ "<|18.20|>": 51275,
513
+ "<|18.22|>": 51276,
514
+ "<|18.24|>": 51277,
515
+ "<|18.26|>": 51278,
516
+ "<|18.28|>": 51279,
517
+ "<|18.30|>": 51280,
518
+ "<|18.32|>": 51281,
519
+ "<|18.34|>": 51282,
520
+ "<|18.36|>": 51283,
521
+ "<|18.38|>": 51284,
522
+ "<|18.40|>": 51285,
523
+ "<|18.42|>": 51286,
524
+ "<|18.44|>": 51287,
525
+ "<|18.46|>": 51288,
526
+ "<|18.48|>": 51289,
527
+ "<|18.50|>": 51290,
528
+ "<|18.52|>": 51291,
529
+ "<|18.54|>": 51292,
530
+ "<|18.56|>": 51293,
531
+ "<|18.58|>": 51294,
532
+ "<|18.60|>": 51295,
533
+ "<|18.62|>": 51296,
534
+ "<|18.64|>": 51297,
535
+ "<|18.66|>": 51298,
536
+ "<|18.68|>": 51299,
537
+ "<|18.70|>": 51300,
538
+ "<|18.72|>": 51301,
539
+ "<|18.74|>": 51302,
540
+ "<|18.76|>": 51303,
541
+ "<|18.78|>": 51304,
542
+ "<|18.80|>": 51305,
543
+ "<|18.82|>": 51306,
544
+ "<|18.84|>": 51307,
545
+ "<|18.86|>": 51308,
546
+ "<|18.88|>": 51309,
547
+ "<|18.90|>": 51310,
548
+ "<|18.92|>": 51311,
549
+ "<|18.94|>": 51312,
550
+ "<|18.96|>": 51313,
551
+ "<|18.98|>": 51314,
552
+ "<|19.00|>": 51315,
553
+ "<|19.02|>": 51316,
554
+ "<|19.04|>": 51317,
555
+ "<|19.06|>": 51318,
556
+ "<|19.08|>": 51319,
557
+ "<|19.10|>": 51320,
558
+ "<|19.12|>": 51321,
559
+ "<|19.14|>": 51322,
560
+ "<|19.16|>": 51323,
561
+ "<|19.18|>": 51324,
562
+ "<|19.20|>": 51325,
563
+ "<|19.22|>": 51326,
564
+ "<|19.24|>": 51327,
565
+ "<|19.26|>": 51328,
566
+ "<|19.28|>": 51329,
567
+ "<|19.30|>": 51330,
568
+ "<|19.32|>": 51331,
569
+ "<|19.34|>": 51332,
570
+ "<|19.36|>": 51333,
571
+ "<|19.38|>": 51334,
572
+ "<|19.40|>": 51335,
573
+ "<|19.42|>": 51336,
574
+ "<|19.44|>": 51337,
575
+ "<|19.46|>": 51338,
576
+ "<|19.48|>": 51339,
577
+ "<|19.50|>": 51340,
578
+ "<|19.52|>": 51341,
579
+ "<|19.54|>": 51342,
580
+ "<|19.56|>": 51343,
581
+ "<|19.58|>": 51344,
582
+ "<|19.60|>": 51345,
583
+ "<|19.62|>": 51346,
584
+ "<|19.64|>": 51347,
585
+ "<|19.66|>": 51348,
586
+ "<|19.68|>": 51349,
587
+ "<|19.70|>": 51350,
588
+ "<|19.72|>": 51351,
589
+ "<|19.74|>": 51352,
590
+ "<|19.76|>": 51353,
591
+ "<|19.78|>": 51354,
592
+ "<|19.80|>": 51355,
593
+ "<|19.82|>": 51356,
594
+ "<|19.84|>": 51357,
595
+ "<|19.86|>": 51358,
596
+ "<|19.88|>": 51359,
597
+ "<|19.90|>": 51360,
598
+ "<|19.92|>": 51361,
599
+ "<|19.94|>": 51362,
600
+ "<|19.96|>": 51363,
601
+ "<|19.98|>": 51364,
602
+ "<|2.00|>": 50465,
603
+ "<|2.02|>": 50466,
604
+ "<|2.04|>": 50467,
605
+ "<|2.06|>": 50468,
606
+ "<|2.08|>": 50469,
607
+ "<|2.10|>": 50470,
608
+ "<|2.12|>": 50471,
609
+ "<|2.14|>": 50472,
610
+ "<|2.16|>": 50473,
611
+ "<|2.18|>": 50474,
612
+ "<|2.20|>": 50475,
613
+ "<|2.22|>": 50476,
614
+ "<|2.24|>": 50477,
615
+ "<|2.26|>": 50478,
616
+ "<|2.28|>": 50479,
617
+ "<|2.30|>": 50480,
618
+ "<|2.32|>": 50481,
619
+ "<|2.34|>": 50482,
620
+ "<|2.36|>": 50483,
621
+ "<|2.38|>": 50484,
622
+ "<|2.40|>": 50485,
623
+ "<|2.42|>": 50486,
624
+ "<|2.44|>": 50487,
625
+ "<|2.46|>": 50488,
626
+ "<|2.48|>": 50489,
627
+ "<|2.50|>": 50490,
628
+ "<|2.52|>": 50491,
629
+ "<|2.54|>": 50492,
630
+ "<|2.56|>": 50493,
631
+ "<|2.58|>": 50494,
632
+ "<|2.60|>": 50495,
633
+ "<|2.62|>": 50496,
634
+ "<|2.64|>": 50497,
635
+ "<|2.66|>": 50498,
636
+ "<|2.68|>": 50499,
637
+ "<|2.70|>": 50500,
638
+ "<|2.72|>": 50501,
639
+ "<|2.74|>": 50502,
640
+ "<|2.76|>": 50503,
641
+ "<|2.78|>": 50504,
642
+ "<|2.80|>": 50505,
643
+ "<|2.82|>": 50506,
644
+ "<|2.84|>": 50507,
645
+ "<|2.86|>": 50508,
646
+ "<|2.88|>": 50509,
647
+ "<|2.90|>": 50510,
648
+ "<|2.92|>": 50511,
649
+ "<|2.94|>": 50512,
650
+ "<|2.96|>": 50513,
651
+ "<|2.98|>": 50514,
652
+ "<|20.00|>": 51365,
653
+ "<|20.02|>": 51366,
654
+ "<|20.04|>": 51367,
655
+ "<|20.06|>": 51368,
656
+ "<|20.08|>": 51369,
657
+ "<|20.10|>": 51370,
658
+ "<|20.12|>": 51371,
659
+ "<|20.14|>": 51372,
660
+ "<|20.16|>": 51373,
661
+ "<|20.18|>": 51374,
662
+ "<|20.20|>": 51375,
663
+ "<|20.22|>": 51376,
664
+ "<|20.24|>": 51377,
665
+ "<|20.26|>": 51378,
666
+ "<|20.28|>": 51379,
667
+ "<|20.30|>": 51380,
668
+ "<|20.32|>": 51381,
669
+ "<|20.34|>": 51382,
670
+ "<|20.36|>": 51383,
671
+ "<|20.38|>": 51384,
672
+ "<|20.40|>": 51385,
673
+ "<|20.42|>": 51386,
674
+ "<|20.44|>": 51387,
675
+ "<|20.46|>": 51388,
676
+ "<|20.48|>": 51389,
677
+ "<|20.50|>": 51390,
678
+ "<|20.52|>": 51391,
679
+ "<|20.54|>": 51392,
680
+ "<|20.56|>": 51393,
681
+ "<|20.58|>": 51394,
682
+ "<|20.60|>": 51395,
683
+ "<|20.62|>": 51396,
684
+ "<|20.64|>": 51397,
685
+ "<|20.66|>": 51398,
686
+ "<|20.68|>": 51399,
687
+ "<|20.70|>": 51400,
688
+ "<|20.72|>": 51401,
689
+ "<|20.74|>": 51402,
690
+ "<|20.76|>": 51403,
691
+ "<|20.78|>": 51404,
692
+ "<|20.80|>": 51405,
693
+ "<|20.82|>": 51406,
694
+ "<|20.84|>": 51407,
695
+ "<|20.86|>": 51408,
696
+ "<|20.88|>": 51409,
697
+ "<|20.90|>": 51410,
698
+ "<|20.92|>": 51411,
699
+ "<|20.94|>": 51412,
700
+ "<|20.96|>": 51413,
701
+ "<|20.98|>": 51414,
702
+ "<|21.00|>": 51415,
703
+ "<|21.02|>": 51416,
704
+ "<|21.04|>": 51417,
705
+ "<|21.06|>": 51418,
706
+ "<|21.08|>": 51419,
707
+ "<|21.10|>": 51420,
708
+ "<|21.12|>": 51421,
709
+ "<|21.14|>": 51422,
710
+ "<|21.16|>": 51423,
711
+ "<|21.18|>": 51424,
712
+ "<|21.20|>": 51425,
713
+ "<|21.22|>": 51426,
714
+ "<|21.24|>": 51427,
715
+ "<|21.26|>": 51428,
716
+ "<|21.28|>": 51429,
717
+ "<|21.30|>": 51430,
718
+ "<|21.32|>": 51431,
719
+ "<|21.34|>": 51432,
720
+ "<|21.36|>": 51433,
721
+ "<|21.38|>": 51434,
722
+ "<|21.40|>": 51435,
723
+ "<|21.42|>": 51436,
724
+ "<|21.44|>": 51437,
725
+ "<|21.46|>": 51438,
726
+ "<|21.48|>": 51439,
727
+ "<|21.50|>": 51440,
728
+ "<|21.52|>": 51441,
729
+ "<|21.54|>": 51442,
730
+ "<|21.56|>": 51443,
731
+ "<|21.58|>": 51444,
732
+ "<|21.60|>": 51445,
733
+ "<|21.62|>": 51446,
734
+ "<|21.64|>": 51447,
735
+ "<|21.66|>": 51448,
736
+ "<|21.68|>": 51449,
737
+ "<|21.70|>": 51450,
738
+ "<|21.72|>": 51451,
739
+ "<|21.74|>": 51452,
740
+ "<|21.76|>": 51453,
741
+ "<|21.78|>": 51454,
742
+ "<|21.80|>": 51455,
743
+ "<|21.82|>": 51456,
744
+ "<|21.84|>": 51457,
745
+ "<|21.86|>": 51458,
746
+ "<|21.88|>": 51459,
747
+ "<|21.90|>": 51460,
748
+ "<|21.92|>": 51461,
749
+ "<|21.94|>": 51462,
750
+ "<|21.96|>": 51463,
751
+ "<|21.98|>": 51464,
752
+ "<|22.00|>": 51465,
753
+ "<|22.02|>": 51466,
754
+ "<|22.04|>": 51467,
755
+ "<|22.06|>": 51468,
756
+ "<|22.08|>": 51469,
757
+ "<|22.10|>": 51470,
758
+ "<|22.12|>": 51471,
759
+ "<|22.14|>": 51472,
760
+ "<|22.16|>": 51473,
761
+ "<|22.18|>": 51474,
762
+ "<|22.20|>": 51475,
763
+ "<|22.22|>": 51476,
764
+ "<|22.24|>": 51477,
765
+ "<|22.26|>": 51478,
766
+ "<|22.28|>": 51479,
767
+ "<|22.30|>": 51480,
768
+ "<|22.32|>": 51481,
769
+ "<|22.34|>": 51482,
770
+ "<|22.36|>": 51483,
771
+ "<|22.38|>": 51484,
772
+ "<|22.40|>": 51485,
773
+ "<|22.42|>": 51486,
774
+ "<|22.44|>": 51487,
775
+ "<|22.46|>": 51488,
776
+ "<|22.48|>": 51489,
777
+ "<|22.50|>": 51490,
778
+ "<|22.52|>": 51491,
779
+ "<|22.54|>": 51492,
780
+ "<|22.56|>": 51493,
781
+ "<|22.58|>": 51494,
782
+ "<|22.60|>": 51495,
783
+ "<|22.62|>": 51496,
784
+ "<|22.64|>": 51497,
785
+ "<|22.66|>": 51498,
786
+ "<|22.68|>": 51499,
787
+ "<|22.70|>": 51500,
788
+ "<|22.72|>": 51501,
789
+ "<|22.74|>": 51502,
790
+ "<|22.76|>": 51503,
791
+ "<|22.78|>": 51504,
792
+ "<|22.80|>": 51505,
793
+ "<|22.82|>": 51506,
794
+ "<|22.84|>": 51507,
795
+ "<|22.86|>": 51508,
796
+ "<|22.88|>": 51509,
797
+ "<|22.90|>": 51510,
798
+ "<|22.92|>": 51511,
799
+ "<|22.94|>": 51512,
800
+ "<|22.96|>": 51513,
801
+ "<|22.98|>": 51514,
802
+ "<|23.00|>": 51515,
803
+ "<|23.02|>": 51516,
804
+ "<|23.04|>": 51517,
805
+ "<|23.06|>": 51518,
806
+ "<|23.08|>": 51519,
807
+ "<|23.10|>": 51520,
808
+ "<|23.12|>": 51521,
809
+ "<|23.14|>": 51522,
810
+ "<|23.16|>": 51523,
811
+ "<|23.18|>": 51524,
812
+ "<|23.20|>": 51525,
813
+ "<|23.22|>": 51526,
814
+ "<|23.24|>": 51527,
815
+ "<|23.26|>": 51528,
816
+ "<|23.28|>": 51529,
817
+ "<|23.30|>": 51530,
818
+ "<|23.32|>": 51531,
819
+ "<|23.34|>": 51532,
820
+ "<|23.36|>": 51533,
821
+ "<|23.38|>": 51534,
822
+ "<|23.40|>": 51535,
823
+ "<|23.42|>": 51536,
824
+ "<|23.44|>": 51537,
825
+ "<|23.46|>": 51538,
826
+ "<|23.48|>": 51539,
827
+ "<|23.50|>": 51540,
828
+ "<|23.52|>": 51541,
829
+ "<|23.54|>": 51542,
830
+ "<|23.56|>": 51543,
831
+ "<|23.58|>": 51544,
832
+ "<|23.60|>": 51545,
833
+ "<|23.62|>": 51546,
834
+ "<|23.64|>": 51547,
835
+ "<|23.66|>": 51548,
836
+ "<|23.68|>": 51549,
837
+ "<|23.70|>": 51550,
838
+ "<|23.72|>": 51551,
839
+ "<|23.74|>": 51552,
840
+ "<|23.76|>": 51553,
841
+ "<|23.78|>": 51554,
842
+ "<|23.80|>": 51555,
843
+ "<|23.82|>": 51556,
844
+ "<|23.84|>": 51557,
845
+ "<|23.86|>": 51558,
846
+ "<|23.88|>": 51559,
847
+ "<|23.90|>": 51560,
848
+ "<|23.92|>": 51561,
849
+ "<|23.94|>": 51562,
850
+ "<|23.96|>": 51563,
851
+ "<|23.98|>": 51564,
852
+ "<|24.00|>": 51565,
853
+ "<|24.02|>": 51566,
854
+ "<|24.04|>": 51567,
855
+ "<|24.06|>": 51568,
856
+ "<|24.08|>": 51569,
857
+ "<|24.10|>": 51570,
858
+ "<|24.12|>": 51571,
859
+ "<|24.14|>": 51572,
860
+ "<|24.16|>": 51573,
861
+ "<|24.18|>": 51574,
862
+ "<|24.20|>": 51575,
863
+ "<|24.22|>": 51576,
864
+ "<|24.24|>": 51577,
865
+ "<|24.26|>": 51578,
866
+ "<|24.28|>": 51579,
867
+ "<|24.30|>": 51580,
868
+ "<|24.32|>": 51581,
869
+ "<|24.34|>": 51582,
870
+ "<|24.36|>": 51583,
871
+ "<|24.38|>": 51584,
872
+ "<|24.40|>": 51585,
873
+ "<|24.42|>": 51586,
874
+ "<|24.44|>": 51587,
875
+ "<|24.46|>": 51588,
876
+ "<|24.48|>": 51589,
877
+ "<|24.50|>": 51590,
878
+ "<|24.52|>": 51591,
879
+ "<|24.54|>": 51592,
880
+ "<|24.56|>": 51593,
881
+ "<|24.58|>": 51594,
882
+ "<|24.60|>": 51595,
883
+ "<|24.62|>": 51596,
884
+ "<|24.64|>": 51597,
885
+ "<|24.66|>": 51598,
886
+ "<|24.68|>": 51599,
887
+ "<|24.70|>": 51600,
888
+ "<|24.72|>": 51601,
889
+ "<|24.74|>": 51602,
890
+ "<|24.76|>": 51603,
891
+ "<|24.78|>": 51604,
892
+ "<|24.80|>": 51605,
893
+ "<|24.82|>": 51606,
894
+ "<|24.84|>": 51607,
895
+ "<|24.86|>": 51608,
896
+ "<|24.88|>": 51609,
897
+ "<|24.90|>": 51610,
898
+ "<|24.92|>": 51611,
899
+ "<|24.94|>": 51612,
900
+ "<|24.96|>": 51613,
901
+ "<|24.98|>": 51614,
902
+ "<|25.00|>": 51615,
903
+ "<|25.02|>": 51616,
904
+ "<|25.04|>": 51617,
905
+ "<|25.06|>": 51618,
906
+ "<|25.08|>": 51619,
907
+ "<|25.10|>": 51620,
908
+ "<|25.12|>": 51621,
909
+ "<|25.14|>": 51622,
910
+ "<|25.16|>": 51623,
911
+ "<|25.18|>": 51624,
912
+ "<|25.20|>": 51625,
913
+ "<|25.22|>": 51626,
914
+ "<|25.24|>": 51627,
915
+ "<|25.26|>": 51628,
916
+ "<|25.28|>": 51629,
917
+ "<|25.30|>": 51630,
918
+ "<|25.32|>": 51631,
919
+ "<|25.34|>": 51632,
920
+ "<|25.36|>": 51633,
921
+ "<|25.38|>": 51634,
922
+ "<|25.40|>": 51635,
923
+ "<|25.42|>": 51636,
924
+ "<|25.44|>": 51637,
925
+ "<|25.46|>": 51638,
926
+ "<|25.48|>": 51639,
927
+ "<|25.50|>": 51640,
928
+ "<|25.52|>": 51641,
929
+ "<|25.54|>": 51642,
930
+ "<|25.56|>": 51643,
931
+ "<|25.58|>": 51644,
932
+ "<|25.60|>": 51645,
933
+ "<|25.62|>": 51646,
934
+ "<|25.64|>": 51647,
935
+ "<|25.66|>": 51648,
936
+ "<|25.68|>": 51649,
937
+ "<|25.70|>": 51650,
938
+ "<|25.72|>": 51651,
939
+ "<|25.74|>": 51652,
940
+ "<|25.76|>": 51653,
941
+ "<|25.78|>": 51654,
942
+ "<|25.80|>": 51655,
943
+ "<|25.82|>": 51656,
944
+ "<|25.84|>": 51657,
945
+ "<|25.86|>": 51658,
946
+ "<|25.88|>": 51659,
947
+ "<|25.90|>": 51660,
948
+ "<|25.92|>": 51661,
949
+ "<|25.94|>": 51662,
950
+ "<|25.96|>": 51663,
951
+ "<|25.98|>": 51664,
952
+ "<|26.00|>": 51665,
953
+ "<|26.02|>": 51666,
954
+ "<|26.04|>": 51667,
955
+ "<|26.06|>": 51668,
956
+ "<|26.08|>": 51669,
957
+ "<|26.10|>": 51670,
958
+ "<|26.12|>": 51671,
959
+ "<|26.14|>": 51672,
960
+ "<|26.16|>": 51673,
961
+ "<|26.18|>": 51674,
962
+ "<|26.20|>": 51675,
963
+ "<|26.22|>": 51676,
964
+ "<|26.24|>": 51677,
965
+ "<|26.26|>": 51678,
966
+ "<|26.28|>": 51679,
967
+ "<|26.30|>": 51680,
968
+ "<|26.32|>": 51681,
969
+ "<|26.34|>": 51682,
970
+ "<|26.36|>": 51683,
971
+ "<|26.38|>": 51684,
972
+ "<|26.40|>": 51685,
973
+ "<|26.42|>": 51686,
974
+ "<|26.44|>": 51687,
975
+ "<|26.46|>": 51688,
976
+ "<|26.48|>": 51689,
977
+ "<|26.50|>": 51690,
978
+ "<|26.52|>": 51691,
979
+ "<|26.54|>": 51692,
980
+ "<|26.56|>": 51693,
981
+ "<|26.58|>": 51694,
982
+ "<|26.60|>": 51695,
983
+ "<|26.62|>": 51696,
984
+ "<|26.64|>": 51697,
985
+ "<|26.66|>": 51698,
986
+ "<|26.68|>": 51699,
987
+ "<|26.70|>": 51700,
988
+ "<|26.72|>": 51701,
989
+ "<|26.74|>": 51702,
990
+ "<|26.76|>": 51703,
991
+ "<|26.78|>": 51704,
992
+ "<|26.80|>": 51705,
993
+ "<|26.82|>": 51706,
994
+ "<|26.84|>": 51707,
995
+ "<|26.86|>": 51708,
996
+ "<|26.88|>": 51709,
997
+ "<|26.90|>": 51710,
998
+ "<|26.92|>": 51711,
999
+ "<|26.94|>": 51712,
1000
+ "<|26.96|>": 51713,
1001
+ "<|26.98|>": 51714,
1002
+ "<|27.00|>": 51715,
1003
+ "<|27.02|>": 51716,
1004
+ "<|27.04|>": 51717,
1005
+ "<|27.06|>": 51718,
1006
+ "<|27.08|>": 51719,
1007
+ "<|27.10|>": 51720,
1008
+ "<|27.12|>": 51721,
1009
+ "<|27.14|>": 51722,
1010
+ "<|27.16|>": 51723,
1011
+ "<|27.18|>": 51724,
1012
+ "<|27.20|>": 51725,
1013
+ "<|27.22|>": 51726,
1014
+ "<|27.24|>": 51727,
1015
+ "<|27.26|>": 51728,
1016
+ "<|27.28|>": 51729,
1017
+ "<|27.30|>": 51730,
1018
+ "<|27.32|>": 51731,
1019
+ "<|27.34|>": 51732,
1020
+ "<|27.36|>": 51733,
1021
+ "<|27.38|>": 51734,
1022
+ "<|27.40|>": 51735,
1023
+ "<|27.42|>": 51736,
1024
+ "<|27.44|>": 51737,
1025
+ "<|27.46|>": 51738,
1026
+ "<|27.48|>": 51739,
1027
+ "<|27.50|>": 51740,
1028
+ "<|27.52|>": 51741,
1029
+ "<|27.54|>": 51742,
1030
+ "<|27.56|>": 51743,
1031
+ "<|27.58|>": 51744,
1032
+ "<|27.60|>": 51745,
1033
+ "<|27.62|>": 51746,
1034
+ "<|27.64|>": 51747,
1035
+ "<|27.66|>": 51748,
1036
+ "<|27.68|>": 51749,
1037
+ "<|27.70|>": 51750,
1038
+ "<|27.72|>": 51751,
1039
+ "<|27.74|>": 51752,
1040
+ "<|27.76|>": 51753,
1041
+ "<|27.78|>": 51754,
1042
+ "<|27.80|>": 51755,
1043
+ "<|27.82|>": 51756,
1044
+ "<|27.84|>": 51757,
1045
+ "<|27.86|>": 51758,
1046
+ "<|27.88|>": 51759,
1047
+ "<|27.90|>": 51760,
1048
+ "<|27.92|>": 51761,
1049
+ "<|27.94|>": 51762,
1050
+ "<|27.96|>": 51763,
1051
+ "<|27.98|>": 51764,
1052
+ "<|28.00|>": 51765,
1053
+ "<|28.02|>": 51766,
1054
+ "<|28.04|>": 51767,
1055
+ "<|28.06|>": 51768,
1056
+ "<|28.08|>": 51769,
1057
+ "<|28.10|>": 51770,
1058
+ "<|28.12|>": 51771,
1059
+ "<|28.14|>": 51772,
1060
+ "<|28.16|>": 51773,
1061
+ "<|28.18|>": 51774,
1062
+ "<|28.20|>": 51775,
1063
+ "<|28.22|>": 51776,
1064
+ "<|28.24|>": 51777,
1065
+ "<|28.26|>": 51778,
1066
+ "<|28.28|>": 51779,
1067
+ "<|28.30|>": 51780,
1068
+ "<|28.32|>": 51781,
1069
+ "<|28.34|>": 51782,
1070
+ "<|28.36|>": 51783,
1071
+ "<|28.38|>": 51784,
1072
+ "<|28.40|>": 51785,
1073
+ "<|28.42|>": 51786,
1074
+ "<|28.44|>": 51787,
1075
+ "<|28.46|>": 51788,
1076
+ "<|28.48|>": 51789,
1077
+ "<|28.50|>": 51790,
1078
+ "<|28.52|>": 51791,
1079
+ "<|28.54|>": 51792,
1080
+ "<|28.56|>": 51793,
1081
+ "<|28.58|>": 51794,
1082
+ "<|28.60|>": 51795,
1083
+ "<|28.62|>": 51796,
1084
+ "<|28.64|>": 51797,
1085
+ "<|28.66|>": 51798,
1086
+ "<|28.68|>": 51799,
1087
+ "<|28.70|>": 51800,
1088
+ "<|28.72|>": 51801,
1089
+ "<|28.74|>": 51802,
1090
+ "<|28.76|>": 51803,
1091
+ "<|28.78|>": 51804,
1092
+ "<|28.80|>": 51805,
1093
+ "<|28.82|>": 51806,
1094
+ "<|28.84|>": 51807,
1095
+ "<|28.86|>": 51808,
1096
+ "<|28.88|>": 51809,
1097
+ "<|28.90|>": 51810,
1098
+ "<|28.92|>": 51811,
1099
+ "<|28.94|>": 51812,
1100
+ "<|28.96|>": 51813,
1101
+ "<|28.98|>": 51814,
1102
+ "<|29.00|>": 51815,
1103
+ "<|29.02|>": 51816,
1104
+ "<|29.04|>": 51817,
1105
+ "<|29.06|>": 51818,
1106
+ "<|29.08|>": 51819,
1107
+ "<|29.10|>": 51820,
1108
+ "<|29.12|>": 51821,
1109
+ "<|29.14|>": 51822,
1110
+ "<|29.16|>": 51823,
1111
+ "<|29.18|>": 51824,
1112
+ "<|29.20|>": 51825,
1113
+ "<|29.22|>": 51826,
1114
+ "<|29.24|>": 51827,
1115
+ "<|29.26|>": 51828,
1116
+ "<|29.28|>": 51829,
1117
+ "<|29.30|>": 51830,
1118
+ "<|29.32|>": 51831,
1119
+ "<|29.34|>": 51832,
1120
+ "<|29.36|>": 51833,
1121
+ "<|29.38|>": 51834,
1122
+ "<|29.40|>": 51835,
1123
+ "<|29.42|>": 51836,
1124
+ "<|29.44|>": 51837,
1125
+ "<|29.46|>": 51838,
1126
+ "<|29.48|>": 51839,
1127
+ "<|29.50|>": 51840,
1128
+ "<|29.52|>": 51841,
1129
+ "<|29.54|>": 51842,
1130
+ "<|29.56|>": 51843,
1131
+ "<|29.58|>": 51844,
1132
+ "<|29.60|>": 51845,
1133
+ "<|29.62|>": 51846,
1134
+ "<|29.64|>": 51847,
1135
+ "<|29.66|>": 51848,
1136
+ "<|29.68|>": 51849,
1137
+ "<|29.70|>": 51850,
1138
+ "<|29.72|>": 51851,
1139
+ "<|29.74|>": 51852,
1140
+ "<|29.76|>": 51853,
1141
+ "<|29.78|>": 51854,
1142
+ "<|29.80|>": 51855,
1143
+ "<|29.82|>": 51856,
1144
+ "<|29.84|>": 51857,
1145
+ "<|29.86|>": 51858,
1146
+ "<|29.88|>": 51859,
1147
+ "<|29.90|>": 51860,
1148
+ "<|29.92|>": 51861,
1149
+ "<|29.94|>": 51862,
1150
+ "<|29.96|>": 51863,
1151
+ "<|29.98|>": 51864,
1152
+ "<|3.00|>": 50515,
1153
+ "<|3.02|>": 50516,
1154
+ "<|3.04|>": 50517,
1155
+ "<|3.06|>": 50518,
1156
+ "<|3.08|>": 50519,
1157
+ "<|3.10|>": 50520,
1158
+ "<|3.12|>": 50521,
1159
+ "<|3.14|>": 50522,
1160
+ "<|3.16|>": 50523,
1161
+ "<|3.18|>": 50524,
1162
+ "<|3.20|>": 50525,
1163
+ "<|3.22|>": 50526,
1164
+ "<|3.24|>": 50527,
1165
+ "<|3.26|>": 50528,
1166
+ "<|3.28|>": 50529,
1167
+ "<|3.30|>": 50530,
1168
+ "<|3.32|>": 50531,
1169
+ "<|3.34|>": 50532,
1170
+ "<|3.36|>": 50533,
1171
+ "<|3.38|>": 50534,
1172
+ "<|3.40|>": 50535,
1173
+ "<|3.42|>": 50536,
1174
+ "<|3.44|>": 50537,
1175
+ "<|3.46|>": 50538,
1176
+ "<|3.48|>": 50539,
1177
+ "<|3.50|>": 50540,
1178
+ "<|3.52|>": 50541,
1179
+ "<|3.54|>": 50542,
1180
+ "<|3.56|>": 50543,
1181
+ "<|3.58|>": 50544,
1182
+ "<|3.60|>": 50545,
1183
+ "<|3.62|>": 50546,
1184
+ "<|3.64|>": 50547,
1185
+ "<|3.66|>": 50548,
1186
+ "<|3.68|>": 50549,
1187
+ "<|3.70|>": 50550,
1188
+ "<|3.72|>": 50551,
1189
+ "<|3.74|>": 50552,
1190
+ "<|3.76|>": 50553,
1191
+ "<|3.78|>": 50554,
1192
+ "<|3.80|>": 50555,
1193
+ "<|3.82|>": 50556,
1194
+ "<|3.84|>": 50557,
1195
+ "<|3.86|>": 50558,
1196
+ "<|3.88|>": 50559,
1197
+ "<|3.90|>": 50560,
1198
+ "<|3.92|>": 50561,
1199
+ "<|3.94|>": 50562,
1200
+ "<|3.96|>": 50563,
1201
+ "<|3.98|>": 50564,
1202
+ "<|30.00|>": 51865,
1203
+ "<|4.00|>": 50565,
1204
+ "<|4.02|>": 50566,
1205
+ "<|4.04|>": 50567,
1206
+ "<|4.06|>": 50568,
1207
+ "<|4.08|>": 50569,
1208
+ "<|4.10|>": 50570,
1209
+ "<|4.12|>": 50571,
1210
+ "<|4.14|>": 50572,
1211
+ "<|4.16|>": 50573,
1212
+ "<|4.18|>": 50574,
1213
+ "<|4.20|>": 50575,
1214
+ "<|4.22|>": 50576,
1215
+ "<|4.24|>": 50577,
1216
+ "<|4.26|>": 50578,
1217
+ "<|4.28|>": 50579,
1218
+ "<|4.30|>": 50580,
1219
+ "<|4.32|>": 50581,
1220
+ "<|4.34|>": 50582,
1221
+ "<|4.36|>": 50583,
1222
+ "<|4.38|>": 50584,
1223
+ "<|4.40|>": 50585,
1224
+ "<|4.42|>": 50586,
1225
+ "<|4.44|>": 50587,
1226
+ "<|4.46|>": 50588,
1227
+ "<|4.48|>": 50589,
1228
+ "<|4.50|>": 50590,
1229
+ "<|4.52|>": 50591,
1230
+ "<|4.54|>": 50592,
1231
+ "<|4.56|>": 50593,
1232
+ "<|4.58|>": 50594,
1233
+ "<|4.60|>": 50595,
1234
+ "<|4.62|>": 50596,
1235
+ "<|4.64|>": 50597,
1236
+ "<|4.66|>": 50598,
1237
+ "<|4.68|>": 50599,
1238
+ "<|4.70|>": 50600,
1239
+ "<|4.72|>": 50601,
1240
+ "<|4.74|>": 50602,
1241
+ "<|4.76|>": 50603,
1242
+ "<|4.78|>": 50604,
1243
+ "<|4.80|>": 50605,
1244
+ "<|4.82|>": 50606,
1245
+ "<|4.84|>": 50607,
1246
+ "<|4.86|>": 50608,
1247
+ "<|4.88|>": 50609,
1248
+ "<|4.90|>": 50610,
1249
+ "<|4.92|>": 50611,
1250
+ "<|4.94|>": 50612,
1251
+ "<|4.96|>": 50613,
1252
+ "<|4.98|>": 50614,
1253
+ "<|5.00|>": 50615,
1254
+ "<|5.02|>": 50616,
1255
+ "<|5.04|>": 50617,
1256
+ "<|5.06|>": 50618,
1257
+ "<|5.08|>": 50619,
1258
+ "<|5.10|>": 50620,
1259
+ "<|5.12|>": 50621,
1260
+ "<|5.14|>": 50622,
1261
+ "<|5.16|>": 50623,
1262
+ "<|5.18|>": 50624,
1263
+ "<|5.20|>": 50625,
1264
+ "<|5.22|>": 50626,
1265
+ "<|5.24|>": 50627,
1266
+ "<|5.26|>": 50628,
1267
+ "<|5.28|>": 50629,
1268
+ "<|5.30|>": 50630,
1269
+ "<|5.32|>": 50631,
1270
+ "<|5.34|>": 50632,
1271
+ "<|5.36|>": 50633,
1272
+ "<|5.38|>": 50634,
1273
+ "<|5.40|>": 50635,
1274
+ "<|5.42|>": 50636,
1275
+ "<|5.44|>": 50637,
1276
+ "<|5.46|>": 50638,
1277
+ "<|5.48|>": 50639,
1278
+ "<|5.50|>": 50640,
1279
+ "<|5.52|>": 50641,
1280
+ "<|5.54|>": 50642,
1281
+ "<|5.56|>": 50643,
1282
+ "<|5.58|>": 50644,
1283
+ "<|5.60|>": 50645,
1284
+ "<|5.62|>": 50646,
1285
+ "<|5.64|>": 50647,
1286
+ "<|5.66|>": 50648,
1287
+ "<|5.68|>": 50649,
1288
+ "<|5.70|>": 50650,
1289
+ "<|5.72|>": 50651,
1290
+ "<|5.74|>": 50652,
1291
+ "<|5.76|>": 50653,
1292
+ "<|5.78|>": 50654,
1293
+ "<|5.80|>": 50655,
1294
+ "<|5.82|>": 50656,
1295
+ "<|5.84|>": 50657,
1296
+ "<|5.86|>": 50658,
1297
+ "<|5.88|>": 50659,
1298
+ "<|5.90|>": 50660,
1299
+ "<|5.92|>": 50661,
1300
+ "<|5.94|>": 50662,
1301
+ "<|5.96|>": 50663,
1302
+ "<|5.98|>": 50664,
1303
+ "<|6.00|>": 50665,
1304
+ "<|6.02|>": 50666,
1305
+ "<|6.04|>": 50667,
1306
+ "<|6.06|>": 50668,
1307
+ "<|6.08|>": 50669,
1308
+ "<|6.10|>": 50670,
1309
+ "<|6.12|>": 50671,
1310
+ "<|6.14|>": 50672,
1311
+ "<|6.16|>": 50673,
1312
+ "<|6.18|>": 50674,
1313
+ "<|6.20|>": 50675,
1314
+ "<|6.22|>": 50676,
1315
+ "<|6.24|>": 50677,
1316
+ "<|6.26|>": 50678,
1317
+ "<|6.28|>": 50679,
1318
+ "<|6.30|>": 50680,
1319
+ "<|6.32|>": 50681,
1320
+ "<|6.34|>": 50682,
1321
+ "<|6.36|>": 50683,
1322
+ "<|6.38|>": 50684,
1323
+ "<|6.40|>": 50685,
1324
+ "<|6.42|>": 50686,
1325
+ "<|6.44|>": 50687,
1326
+ "<|6.46|>": 50688,
1327
+ "<|6.48|>": 50689,
1328
+ "<|6.50|>": 50690,
1329
+ "<|6.52|>": 50691,
1330
+ "<|6.54|>": 50692,
1331
+ "<|6.56|>": 50693,
1332
+ "<|6.58|>": 50694,
1333
+ "<|6.60|>": 50695,
1334
+ "<|6.62|>": 50696,
1335
+ "<|6.64|>": 50697,
1336
+ "<|6.66|>": 50698,
1337
+ "<|6.68|>": 50699,
1338
+ "<|6.70|>": 50700,
1339
+ "<|6.72|>": 50701,
1340
+ "<|6.74|>": 50702,
1341
+ "<|6.76|>": 50703,
1342
+ "<|6.78|>": 50704,
1343
+ "<|6.80|>": 50705,
1344
+ "<|6.82|>": 50706,
1345
+ "<|6.84|>": 50707,
1346
+ "<|6.86|>": 50708,
1347
+ "<|6.88|>": 50709,
1348
+ "<|6.90|>": 50710,
1349
+ "<|6.92|>": 50711,
1350
+ "<|6.94|>": 50712,
1351
+ "<|6.96|>": 50713,
1352
+ "<|6.98|>": 50714,
1353
+ "<|7.00|>": 50715,
1354
+ "<|7.02|>": 50716,
1355
+ "<|7.04|>": 50717,
1356
+ "<|7.06|>": 50718,
1357
+ "<|7.08|>": 50719,
1358
+ "<|7.10|>": 50720,
1359
+ "<|7.12|>": 50721,
1360
+ "<|7.14|>": 50722,
1361
+ "<|7.16|>": 50723,
1362
+ "<|7.18|>": 50724,
1363
+ "<|7.20|>": 50725,
1364
+ "<|7.22|>": 50726,
1365
+ "<|7.24|>": 50727,
1366
+ "<|7.26|>": 50728,
1367
+ "<|7.28|>": 50729,
1368
+ "<|7.30|>": 50730,
1369
+ "<|7.32|>": 50731,
1370
+ "<|7.34|>": 50732,
1371
+ "<|7.36|>": 50733,
1372
+ "<|7.38|>": 50734,
1373
+ "<|7.40|>": 50735,
1374
+ "<|7.42|>": 50736,
1375
+ "<|7.44|>": 50737,
1376
+ "<|7.46|>": 50738,
1377
+ "<|7.48|>": 50739,
1378
+ "<|7.50|>": 50740,
1379
+ "<|7.52|>": 50741,
1380
+ "<|7.54|>": 50742,
1381
+ "<|7.56|>": 50743,
1382
+ "<|7.58|>": 50744,
1383
+ "<|7.60|>": 50745,
1384
+ "<|7.62|>": 50746,
1385
+ "<|7.64|>": 50747,
1386
+ "<|7.66|>": 50748,
1387
+ "<|7.68|>": 50749,
1388
+ "<|7.70|>": 50750,
1389
+ "<|7.72|>": 50751,
1390
+ "<|7.74|>": 50752,
1391
+ "<|7.76|>": 50753,
1392
+ "<|7.78|>": 50754,
1393
+ "<|7.80|>": 50755,
1394
+ "<|7.82|>": 50756,
1395
+ "<|7.84|>": 50757,
1396
+ "<|7.86|>": 50758,
1397
+ "<|7.88|>": 50759,
1398
+ "<|7.90|>": 50760,
1399
+ "<|7.92|>": 50761,
1400
+ "<|7.94|>": 50762,
1401
+ "<|7.96|>": 50763,
1402
+ "<|7.98|>": 50764,
1403
+ "<|8.00|>": 50765,
1404
+ "<|8.02|>": 50766,
1405
+ "<|8.04|>": 50767,
1406
+ "<|8.06|>": 50768,
1407
+ "<|8.08|>": 50769,
1408
+ "<|8.10|>": 50770,
1409
+ "<|8.12|>": 50771,
1410
+ "<|8.14|>": 50772,
1411
+ "<|8.16|>": 50773,
1412
+ "<|8.18|>": 50774,
1413
+ "<|8.20|>": 50775,
1414
+ "<|8.22|>": 50776,
1415
+ "<|8.24|>": 50777,
1416
+ "<|8.26|>": 50778,
1417
+ "<|8.28|>": 50779,
1418
+ "<|8.30|>": 50780,
1419
+ "<|8.32|>": 50781,
1420
+ "<|8.34|>": 50782,
1421
+ "<|8.36|>": 50783,
1422
+ "<|8.38|>": 50784,
1423
+ "<|8.40|>": 50785,
1424
+ "<|8.42|>": 50786,
1425
+ "<|8.44|>": 50787,
1426
+ "<|8.46|>": 50788,
1427
+ "<|8.48|>": 50789,
1428
+ "<|8.50|>": 50790,
1429
+ "<|8.52|>": 50791,
1430
+ "<|8.54|>": 50792,
1431
+ "<|8.56|>": 50793,
1432
+ "<|8.58|>": 50794,
1433
+ "<|8.60|>": 50795,
1434
+ "<|8.62|>": 50796,
1435
+ "<|8.64|>": 50797,
1436
+ "<|8.66|>": 50798,
1437
+ "<|8.68|>": 50799,
1438
+ "<|8.70|>": 50800,
1439
+ "<|8.72|>": 50801,
1440
+ "<|8.74|>": 50802,
1441
+ "<|8.76|>": 50803,
1442
+ "<|8.78|>": 50804,
1443
+ "<|8.80|>": 50805,
1444
+ "<|8.82|>": 50806,
1445
+ "<|8.84|>": 50807,
1446
+ "<|8.86|>": 50808,
1447
+ "<|8.88|>": 50809,
1448
+ "<|8.90|>": 50810,
1449
+ "<|8.92|>": 50811,
1450
+ "<|8.94|>": 50812,
1451
+ "<|8.96|>": 50813,
1452
+ "<|8.98|>": 50814,
1453
+ "<|9.00|>": 50815,
1454
+ "<|9.02|>": 50816,
1455
+ "<|9.04|>": 50817,
1456
+ "<|9.06|>": 50818,
1457
+ "<|9.08|>": 50819,
1458
+ "<|9.10|>": 50820,
1459
+ "<|9.12|>": 50821,
1460
+ "<|9.14|>": 50822,
1461
+ "<|9.16|>": 50823,
1462
+ "<|9.18|>": 50824,
1463
+ "<|9.20|>": 50825,
1464
+ "<|9.22|>": 50826,
1465
+ "<|9.24|>": 50827,
1466
+ "<|9.26|>": 50828,
1467
+ "<|9.28|>": 50829,
1468
+ "<|9.30|>": 50830,
1469
+ "<|9.32|>": 50831,
1470
+ "<|9.34|>": 50832,
1471
+ "<|9.36|>": 50833,
1472
+ "<|9.38|>": 50834,
1473
+ "<|9.40|>": 50835,
1474
+ "<|9.42|>": 50836,
1475
+ "<|9.44|>": 50837,
1476
+ "<|9.46|>": 50838,
1477
+ "<|9.48|>": 50839,
1478
+ "<|9.50|>": 50840,
1479
+ "<|9.52|>": 50841,
1480
+ "<|9.54|>": 50842,
1481
+ "<|9.56|>": 50843,
1482
+ "<|9.58|>": 50844,
1483
+ "<|9.60|>": 50845,
1484
+ "<|9.62|>": 50846,
1485
+ "<|9.64|>": 50847,
1486
+ "<|9.66|>": 50848,
1487
+ "<|9.68|>": 50849,
1488
+ "<|9.70|>": 50850,
1489
+ "<|9.72|>": 50851,
1490
+ "<|9.74|>": 50852,
1491
+ "<|9.76|>": 50853,
1492
+ "<|9.78|>": 50854,
1493
+ "<|9.80|>": 50855,
1494
+ "<|9.82|>": 50856,
1495
+ "<|9.84|>": 50857,
1496
+ "<|9.86|>": 50858,
1497
+ "<|9.88|>": 50859,
1498
+ "<|9.90|>": 50860,
1499
+ "<|9.92|>": 50861,
1500
+ "<|9.94|>": 50862,
1501
+ "<|9.96|>": 50863,
1502
+ "<|9.98|>": 50864,
1503
+ "<|af|>": 50327,
1504
+ "<|am|>": 50334,
1505
+ "<|ar|>": 50272,
1506
+ "<|as|>": 50350,
1507
+ "<|az|>": 50304,
1508
+ "<|ba|>": 50355,
1509
+ "<|be|>": 50330,
1510
+ "<|bg|>": 50292,
1511
+ "<|bn|>": 50302,
1512
+ "<|bo|>": 50347,
1513
+ "<|br|>": 50309,
1514
+ "<|bs|>": 50315,
1515
+ "<|ca|>": 50270,
1516
+ "<|cs|>": 50283,
1517
+ "<|cy|>": 50297,
1518
+ "<|da|>": 50285,
1519
+ "<|de|>": 50261,
1520
+ "<|el|>": 50281,
1521
+ "<|endoftext|>": 50257,
1522
+ "<|en|>": 50259,
1523
+ "<|es|>": 50262,
1524
+ "<|et|>": 50307,
1525
+ "<|eu|>": 50310,
1526
+ "<|fa|>": 50300,
1527
+ "<|fi|>": 50277,
1528
+ "<|fo|>": 50338,
1529
+ "<|fr|>": 50265,
1530
+ "<|gl|>": 50319,
1531
+ "<|gu|>": 50333,
1532
+ "<|haw|>": 50352,
1533
+ "<|ha|>": 50354,
1534
+ "<|he|>": 50279,
1535
+ "<|hi|>": 50276,
1536
+ "<|hr|>": 50291,
1537
+ "<|ht|>": 50339,
1538
+ "<|hu|>": 50286,
1539
+ "<|hy|>": 50312,
1540
+ "<|id|>": 50275,
1541
+ "<|is|>": 50311,
1542
+ "<|it|>": 50274,
1543
+ "<|ja|>": 50266,
1544
+ "<|jw|>": 50356,
1545
+ "<|ka|>": 50329,
1546
+ "<|kk|>": 50316,
1547
+ "<|km|>": 50323,
1548
+ "<|kn|>": 50306,
1549
+ "<|ko|>": 50264,
1550
+ "<|la|>": 50294,
1551
+ "<|lb|>": 50345,
1552
+ "<|ln|>": 50353,
1553
+ "<|lo|>": 50336,
1554
+ "<|lt|>": 50293,
1555
+ "<|lv|>": 50301,
1556
+ "<|mg|>": 50349,
1557
+ "<|mi|>": 50295,
1558
+ "<|mk|>": 50308,
1559
+ "<|ml|>": 50296,
1560
+ "<|mn|>": 50314,
1561
+ "<|mr|>": 50320,
1562
+ "<|ms|>": 50282,
1563
+ "<|mt|>": 50343,
1564
+ "<|my|>": 50346,
1565
+ "<|ne|>": 50313,
1566
+ "<|nl|>": 50271,
1567
+ "<|nn|>": 50342,
1568
+ "<|nospeech|>": 50363,
1569
+ "<|notimestamps|>": 50364,
1570
+ "<|no|>": 50288,
1571
+ "<|oc|>": 50328,
1572
+ "<|pa|>": 50321,
1573
+ "<|pl|>": 50269,
1574
+ "<|ps|>": 50340,
1575
+ "<|pt|>": 50267,
1576
+ "<|ro|>": 50284,
1577
+ "<|ru|>": 50263,
1578
+ "<|sa|>": 50344,
1579
+ "<|sd|>": 50332,
1580
+ "<|si|>": 50322,
1581
+ "<|sk|>": 50298,
1582
+ "<|sl|>": 50305,
1583
+ "<|sn|>": 50324,
1584
+ "<|so|>": 50326,
1585
+ "<|sq|>": 50317,
1586
+ "<|sr|>": 50303,
1587
+ "<|startoflm|>": 50361,
1588
+ "<|startofprev|>": 50362,
1589
+ "<|startoftranscript|>": 50258,
1590
+ "<|su|>": 50357,
1591
+ "<|sv|>": 50273,
1592
+ "<|sw|>": 50318,
1593
+ "<|ta|>": 50287,
1594
+ "<|te|>": 50299,
1595
+ "<|tg|>": 50331,
1596
+ "<|th|>": 50289,
1597
+ "<|tk|>": 50341,
1598
+ "<|tl|>": 50348,
1599
+ "<|transcribe|>": 50360,
1600
+ "<|translate|>": 50359,
1601
+ "<|tr|>": 50268,
1602
+ "<|tt|>": 50351,
1603
+ "<|uk|>": 50280,
1604
+ "<|ur|>": 50290,
1605
+ "<|uz|>": 50337,
1606
+ "<|vi|>": 50278,
1607
+ "<|yi|>": 50335,
1608
+ "<|yo|>": 50325,
1609
+ "<|yue|>": 50358,
1610
+ "<|zh|>": 50260
1611
+ }
checkpoint-500-epoch-0-val-wer-96.036/config.json ADDED
@@ -0,0 +1,285 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "NbAiLab/nb-whisper-large",
3
+ "activation_dropout": 0.1,
4
+ "activation_function": "gelu",
5
+ "alignment_heads": [
6
+ [
7
+ 7,
8
+ 0
9
+ ],
10
+ [
11
+ 10,
12
+ 17
13
+ ],
14
+ [
15
+ 12,
16
+ 18
17
+ ],
18
+ [
19
+ 13,
20
+ 12
21
+ ],
22
+ [
23
+ 16,
24
+ 1
25
+ ],
26
+ [
27
+ 17,
28
+ 14
29
+ ],
30
+ [
31
+ 19,
32
+ 11
33
+ ],
34
+ [
35
+ 21,
36
+ 4
37
+ ],
38
+ [
39
+ 24,
40
+ 1
41
+ ],
42
+ [
43
+ 25,
44
+ 6
45
+ ]
46
+ ],
47
+ "apply_spec_augment": false,
48
+ "architectures": [
49
+ "WhisperForConditionalGeneration"
50
+ ],
51
+ "attention_dropout": 0,
52
+ "begin_suppress_tokens": null,
53
+ "bos_token_id": 50257,
54
+ "classifier_proj_size": 256,
55
+ "d_model": 1280,
56
+ "decoder_attention_heads": 20,
57
+ "decoder_ffn_dim": 5120,
58
+ "decoder_layerdrop": 0,
59
+ "decoder_layers": 2,
60
+ "decoder_start_token_id": 50258,
61
+ "dropout": 0,
62
+ "encoder_attention_heads": 20,
63
+ "encoder_ffn_dim": 5120,
64
+ "encoder_layerdrop": 0,
65
+ "encoder_layers": 32,
66
+ "eos_token_id": 50257,
67
+ "init_std": 0.02,
68
+ "is_encoder_decoder": true,
69
+ "lang_ids": [
70
+ 50259,
71
+ 50260,
72
+ 50261,
73
+ 50262,
74
+ 50263,
75
+ 50264,
76
+ 50265,
77
+ 50266,
78
+ 50267,
79
+ 50268,
80
+ 50269,
81
+ 50270,
82
+ 50271,
83
+ 50272,
84
+ 50273,
85
+ 50274,
86
+ 50275,
87
+ 50276,
88
+ 50277,
89
+ 50278,
90
+ 50279,
91
+ 50280,
92
+ 50281,
93
+ 50282,
94
+ 50283,
95
+ 50284,
96
+ 50285,
97
+ 50286,
98
+ 50287,
99
+ 50288,
100
+ 50289,
101
+ 50290,
102
+ 50291,
103
+ 50292,
104
+ 50293,
105
+ 50294,
106
+ 50295,
107
+ 50296,
108
+ 50297,
109
+ 50298,
110
+ 50299,
111
+ 50300,
112
+ 50301,
113
+ 50302,
114
+ 50303,
115
+ 50304,
116
+ 50305,
117
+ 50306,
118
+ 50307,
119
+ 50308,
120
+ 50309,
121
+ 50310,
122
+ 50311,
123
+ 50312,
124
+ 50313,
125
+ 50314,
126
+ 50315,
127
+ 50316,
128
+ 50317,
129
+ 50318,
130
+ 50319,
131
+ 50320,
132
+ 50321,
133
+ 50322,
134
+ 50323,
135
+ 50324,
136
+ 50325,
137
+ 50326,
138
+ 50327,
139
+ 50328,
140
+ 50329,
141
+ 50330,
142
+ 50331,
143
+ 50332,
144
+ 50333,
145
+ 50334,
146
+ 50335,
147
+ 50336,
148
+ 50337,
149
+ 50338,
150
+ 50339,
151
+ 50340,
152
+ 50341,
153
+ 50342,
154
+ 50343,
155
+ 50344,
156
+ 50345,
157
+ 50346,
158
+ 50347,
159
+ 50348,
160
+ 50349,
161
+ 50350,
162
+ 50351,
163
+ 50352,
164
+ 50353,
165
+ 50354,
166
+ 50355,
167
+ 50356,
168
+ 50357,
169
+ 50358
170
+ ],
171
+ "mask_feature_length": 10,
172
+ "mask_feature_min_masks": 0,
173
+ "mask_feature_prob": 0,
174
+ "mask_time_length": 10,
175
+ "mask_time_min_masks": 2,
176
+ "mask_time_prob": 0.05,
177
+ "max_length": null,
178
+ "max_source_positions": 1500,
179
+ "max_target_positions": 448,
180
+ "median_filter_width": 7,
181
+ "model_type": "whisper",
182
+ "num_hidden_layers": 32,
183
+ "num_mel_bins": 128,
184
+ "pad_token_id": 50256,
185
+ "scale_embedding": false,
186
+ "suppress_ids": [
187
+ 1,
188
+ 2,
189
+ 7,
190
+ 8,
191
+ 9,
192
+ 10,
193
+ 14,
194
+ 25,
195
+ 26,
196
+ 27,
197
+ 28,
198
+ 29,
199
+ 31,
200
+ 58,
201
+ 59,
202
+ 60,
203
+ 61,
204
+ 62,
205
+ 63,
206
+ 90,
207
+ 91,
208
+ 92,
209
+ 93,
210
+ 359,
211
+ 503,
212
+ 522,
213
+ 542,
214
+ 873,
215
+ 893,
216
+ 902,
217
+ 918,
218
+ 922,
219
+ 931,
220
+ 1350,
221
+ 1853,
222
+ 1982,
223
+ 2460,
224
+ 2627,
225
+ 3246,
226
+ 3253,
227
+ 3268,
228
+ 3536,
229
+ 3846,
230
+ 3961,
231
+ 4183,
232
+ 4667,
233
+ 6585,
234
+ 6647,
235
+ 7273,
236
+ 9061,
237
+ 9383,
238
+ 10428,
239
+ 10929,
240
+ 11938,
241
+ 12033,
242
+ 12331,
243
+ 12562,
244
+ 13793,
245
+ 14157,
246
+ 14635,
247
+ 15265,
248
+ 15618,
249
+ 16553,
250
+ 16604,
251
+ 18362,
252
+ 18956,
253
+ 20075,
254
+ 21675,
255
+ 22520,
256
+ 26130,
257
+ 26161,
258
+ 26435,
259
+ 28279,
260
+ 29464,
261
+ 31650,
262
+ 32302,
263
+ 32470,
264
+ 36865,
265
+ 42863,
266
+ 47425,
267
+ 49870,
268
+ 50254,
269
+ 50258,
270
+ 50359,
271
+ 50360,
272
+ 50361,
273
+ 50362,
274
+ 50363
275
+ ],
276
+ "suppress_ids_begin": [
277
+ 220,
278
+ 50257
279
+ ],
280
+ "torch_dtype": "float32",
281
+ "transformers_version": "4.46.2",
282
+ "use_cache": true,
283
+ "use_weighted_layer_sum": false,
284
+ "vocab_size": 51866
285
+ }
checkpoint-500-epoch-0-val-wer-96.036/generation_config.json ADDED
@@ -0,0 +1,256 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "alignment_heads": [
3
+ [
4
+ 7,
5
+ 0
6
+ ],
7
+ [
8
+ 10,
9
+ 17
10
+ ],
11
+ [
12
+ 12,
13
+ 18
14
+ ],
15
+ [
16
+ 13,
17
+ 12
18
+ ],
19
+ [
20
+ 16,
21
+ 1
22
+ ],
23
+ [
24
+ 17,
25
+ 14
26
+ ],
27
+ [
28
+ 19,
29
+ 11
30
+ ],
31
+ [
32
+ 21,
33
+ 4
34
+ ],
35
+ [
36
+ 24,
37
+ 1
38
+ ],
39
+ [
40
+ 25,
41
+ 6
42
+ ]
43
+ ],
44
+ "begin_suppress_tokens": [
45
+ 220,
46
+ 50257
47
+ ],
48
+ "bos_token_id": 50257,
49
+ "decoder_start_token_id": 50258,
50
+ "eos_token_id": 50257,
51
+ "is_multilingual": true,
52
+ "lang_to_id": {
53
+ "<|af|>": 50327,
54
+ "<|am|>": 50334,
55
+ "<|ar|>": 50272,
56
+ "<|as|>": 50350,
57
+ "<|az|>": 50304,
58
+ "<|ba|>": 50355,
59
+ "<|be|>": 50330,
60
+ "<|bg|>": 50292,
61
+ "<|bn|>": 50302,
62
+ "<|bo|>": 50347,
63
+ "<|br|>": 50309,
64
+ "<|bs|>": 50315,
65
+ "<|ca|>": 50270,
66
+ "<|cs|>": 50283,
67
+ "<|cy|>": 50297,
68
+ "<|da|>": 50285,
69
+ "<|de|>": 50261,
70
+ "<|el|>": 50281,
71
+ "<|en|>": 50259,
72
+ "<|es|>": 50262,
73
+ "<|et|>": 50307,
74
+ "<|eu|>": 50310,
75
+ "<|fa|>": 50300,
76
+ "<|fi|>": 50277,
77
+ "<|fo|>": 50338,
78
+ "<|fr|>": 50265,
79
+ "<|gl|>": 50319,
80
+ "<|gu|>": 50333,
81
+ "<|haw|>": 50352,
82
+ "<|ha|>": 50354,
83
+ "<|he|>": 50279,
84
+ "<|hi|>": 50276,
85
+ "<|hr|>": 50291,
86
+ "<|ht|>": 50339,
87
+ "<|hu|>": 50286,
88
+ "<|hy|>": 50312,
89
+ "<|id|>": 50275,
90
+ "<|is|>": 50311,
91
+ "<|it|>": 50274,
92
+ "<|ja|>": 50266,
93
+ "<|jw|>": 50356,
94
+ "<|ka|>": 50329,
95
+ "<|kk|>": 50316,
96
+ "<|km|>": 50323,
97
+ "<|kn|>": 50306,
98
+ "<|ko|>": 50264,
99
+ "<|la|>": 50294,
100
+ "<|lb|>": 50345,
101
+ "<|ln|>": 50353,
102
+ "<|lo|>": 50336,
103
+ "<|lt|>": 50293,
104
+ "<|lv|>": 50301,
105
+ "<|mg|>": 50349,
106
+ "<|mi|>": 50295,
107
+ "<|mk|>": 50308,
108
+ "<|ml|>": 50296,
109
+ "<|mn|>": 50314,
110
+ "<|mr|>": 50320,
111
+ "<|ms|>": 50282,
112
+ "<|mt|>": 50343,
113
+ "<|my|>": 50346,
114
+ "<|ne|>": 50313,
115
+ "<|nl|>": 50271,
116
+ "<|nn|>": 50342,
117
+ "<|no|>": 50288,
118
+ "<|oc|>": 50328,
119
+ "<|pa|>": 50321,
120
+ "<|pl|>": 50269,
121
+ "<|ps|>": 50340,
122
+ "<|pt|>": 50267,
123
+ "<|ro|>": 50284,
124
+ "<|ru|>": 50263,
125
+ "<|sa|>": 50344,
126
+ "<|sd|>": 50332,
127
+ "<|si|>": 50322,
128
+ "<|sk|>": 50298,
129
+ "<|sl|>": 50305,
130
+ "<|sn|>": 50324,
131
+ "<|so|>": 50326,
132
+ "<|sq|>": 50317,
133
+ "<|sr|>": 50303,
134
+ "<|su|>": 50357,
135
+ "<|sv|>": 50273,
136
+ "<|sw|>": 50318,
137
+ "<|ta|>": 50287,
138
+ "<|te|>": 50299,
139
+ "<|tg|>": 50331,
140
+ "<|th|>": 50289,
141
+ "<|tk|>": 50341,
142
+ "<|tl|>": 50348,
143
+ "<|tr|>": 50268,
144
+ "<|tt|>": 50351,
145
+ "<|uk|>": 50280,
146
+ "<|ur|>": 50290,
147
+ "<|uz|>": 50337,
148
+ "<|vi|>": 50278,
149
+ "<|yi|>": 50335,
150
+ "<|yo|>": 50325,
151
+ "<|yue|>": 50358,
152
+ "<|zh|>": 50260
153
+ },
154
+ "language": "no",
155
+ "max_initial_timestamp_index": 1,
156
+ "max_length": 448,
157
+ "no_timestamps_token_id": 50364,
158
+ "pad_token_id": 50257,
159
+ "return_timestamps": false,
160
+ "suppress_tokens": [
161
+ 1,
162
+ 2,
163
+ 7,
164
+ 8,
165
+ 9,
166
+ 10,
167
+ 14,
168
+ 25,
169
+ 26,
170
+ 27,
171
+ 28,
172
+ 29,
173
+ 31,
174
+ 58,
175
+ 59,
176
+ 60,
177
+ 61,
178
+ 62,
179
+ 63,
180
+ 90,
181
+ 91,
182
+ 92,
183
+ 93,
184
+ 359,
185
+ 503,
186
+ 522,
187
+ 542,
188
+ 873,
189
+ 893,
190
+ 902,
191
+ 918,
192
+ 922,
193
+ 931,
194
+ 1350,
195
+ 1853,
196
+ 1982,
197
+ 2460,
198
+ 2627,
199
+ 3246,
200
+ 3253,
201
+ 3268,
202
+ 3536,
203
+ 3846,
204
+ 3961,
205
+ 4183,
206
+ 4667,
207
+ 6585,
208
+ 6647,
209
+ 7273,
210
+ 9061,
211
+ 9383,
212
+ 10428,
213
+ 10929,
214
+ 11938,
215
+ 12033,
216
+ 12331,
217
+ 12562,
218
+ 13793,
219
+ 14157,
220
+ 14635,
221
+ 15265,
222
+ 15618,
223
+ 16553,
224
+ 16604,
225
+ 18362,
226
+ 18956,
227
+ 20075,
228
+ 21675,
229
+ 22520,
230
+ 26130,
231
+ 26161,
232
+ 26435,
233
+ 28279,
234
+ 29464,
235
+ 31650,
236
+ 32302,
237
+ 32470,
238
+ 36865,
239
+ 42863,
240
+ 47425,
241
+ 49870,
242
+ 50254,
243
+ 50258,
244
+ 50359,
245
+ 50360,
246
+ 50361,
247
+ 50362,
248
+ 50363
249
+ ],
250
+ "task": "transcribe",
251
+ "task_to_id": {
252
+ "transcribe": 50360,
253
+ "translate": 50359
254
+ },
255
+ "transformers_version": "4.46.2"
256
+ }
checkpoint-500-epoch-0-val-wer-96.036/merges.txt ADDED
The diff for this file is too large to render. See raw diff
 
checkpoint-500-epoch-0-val-wer-96.036/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9dceef1c98c82eee48a3a948d1ca88682b48946a66a0d989d6be8c1c49205bed
3
+ size 3025686376
checkpoint-500-epoch-0-val-wer-96.036/model_1.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:28897ec4b789c0dc382a6975366fcb16206be64b6b691a60b218831c8f6af1ea
3
+ size 4361070048
checkpoint-500-epoch-0-val-wer-96.036/optimizer.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:aadb2de09477d862c183c69ec3806328585b886a2e268366d2ef3cbfccb89257
3
+ size 950951226
checkpoint-500-epoch-0-val-wer-96.036/preprocessor_config.json ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "chunk_length": 30,
3
+ "feature_extractor_type": "WhisperFeatureExtractor",
4
+ "feature_size": 128,
5
+ "hop_length": 160,
6
+ "n_fft": 400,
7
+ "n_samples": 480000,
8
+ "nb_max_frames": 3000,
9
+ "padding_side": "right",
10
+ "padding_value": 0.0,
11
+ "processor_class": "WhisperProcessor",
12
+ "return_attention_mask": false,
13
+ "sampling_rate": 16000
14
+ }
checkpoint-500-epoch-0-val-wer-96.036/random_states_0.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f618092df3bd36d40b256e361e5acec6a0f94cbd96621ff64d347a801af7f553
3
+ size 14408
checkpoint-500-epoch-0-val-wer-96.036/scheduler.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2627ccec0bb9a51b7d9d753a9441035aec88f305994eb3b5ccbb3e0571f519d6
3
+ size 1064
checkpoint-500-epoch-0-val-wer-96.036/special_tokens_map.json ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "additional_special_tokens": [
3
+ "<|startoftranscript|>",
4
+ "<|en|>",
5
+ "<|zh|>",
6
+ "<|de|>",
7
+ "<|es|>",
8
+ "<|ru|>",
9
+ "<|ko|>",
10
+ "<|fr|>",
11
+ "<|ja|>",
12
+ "<|pt|>",
13
+ "<|tr|>",
14
+ "<|pl|>",
15
+ "<|ca|>",
16
+ "<|nl|>",
17
+ "<|ar|>",
18
+ "<|sv|>",
19
+ "<|it|>",
20
+ "<|id|>",
21
+ "<|hi|>",
22
+ "<|fi|>",
23
+ "<|vi|>",
24
+ "<|he|>",
25
+ "<|uk|>",
26
+ "<|el|>",
27
+ "<|ms|>",
28
+ "<|cs|>",
29
+ "<|ro|>",
30
+ "<|da|>",
31
+ "<|hu|>",
32
+ "<|ta|>",
33
+ "<|no|>",
34
+ "<|th|>",
35
+ "<|ur|>",
36
+ "<|hr|>",
37
+ "<|bg|>",
38
+ "<|lt|>",
39
+ "<|la|>",
40
+ "<|mi|>",
41
+ "<|ml|>",
42
+ "<|cy|>",
43
+ "<|sk|>",
44
+ "<|te|>",
45
+ "<|fa|>",
46
+ "<|lv|>",
47
+ "<|bn|>",
48
+ "<|sr|>",
49
+ "<|az|>",
50
+ "<|sl|>",
51
+ "<|kn|>",
52
+ "<|et|>",
53
+ "<|mk|>",
54
+ "<|br|>",
55
+ "<|eu|>",
56
+ "<|is|>",
57
+ "<|hy|>",
58
+ "<|ne|>",
59
+ "<|mn|>",
60
+ "<|bs|>",
61
+ "<|kk|>",
62
+ "<|sq|>",
63
+ "<|sw|>",
64
+ "<|gl|>",
65
+ "<|mr|>",
66
+ "<|pa|>",
67
+ "<|si|>",
68
+ "<|km|>",
69
+ "<|sn|>",
70
+ "<|yo|>",
71
+ "<|so|>",
72
+ "<|af|>",
73
+ "<|oc|>",
74
+ "<|ka|>",
75
+ "<|be|>",
76
+ "<|tg|>",
77
+ "<|sd|>",
78
+ "<|gu|>",
79
+ "<|am|>",
80
+ "<|yi|>",
81
+ "<|lo|>",
82
+ "<|uz|>",
83
+ "<|fo|>",
84
+ "<|ht|>",
85
+ "<|ps|>",
86
+ "<|tk|>",
87
+ "<|nn|>",
88
+ "<|mt|>",
89
+ "<|sa|>",
90
+ "<|lb|>",
91
+ "<|my|>",
92
+ "<|bo|>",
93
+ "<|tl|>",
94
+ "<|mg|>",
95
+ "<|as|>",
96
+ "<|tt|>",
97
+ "<|haw|>",
98
+ "<|ln|>",
99
+ "<|ha|>",
100
+ "<|ba|>",
101
+ "<|jw|>",
102
+ "<|su|>",
103
+ "<|yue|>",
104
+ "<|translate|>",
105
+ "<|transcribe|>",
106
+ "<|startoflm|>",
107
+ "<|startofprev|>",
108
+ "<|nospeech|>",
109
+ "<|notimestamps|>"
110
+ ],
111
+ "bos_token": {
112
+ "content": "<|endoftext|>",
113
+ "lstrip": false,
114
+ "normalized": false,
115
+ "rstrip": false,
116
+ "single_word": false
117
+ },
118
+ "eos_token": {
119
+ "content": "<|endoftext|>",
120
+ "lstrip": false,
121
+ "normalized": false,
122
+ "rstrip": false,
123
+ "single_word": false
124
+ },
125
+ "pad_token": {
126
+ "content": "<|endoftext|>",
127
+ "lstrip": false,
128
+ "normalized": false,
129
+ "rstrip": false,
130
+ "single_word": false
131
+ },
132
+ "unk_token": {
133
+ "content": "<|endoftext|>",
134
+ "lstrip": false,
135
+ "normalized": false,
136
+ "rstrip": false,
137
+ "single_word": false
138
+ }
139
+ }
checkpoint-500-epoch-0-val-wer-96.036/tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
checkpoint-500-epoch-0-val-wer-96.036/tokenizer_config.json ADDED
The diff for this file is too large to render. See raw diff
 
checkpoint-500-epoch-0-val-wer-96.036/vocab.json ADDED
The diff for this file is too large to render. See raw diff
 
config.json ADDED
@@ -0,0 +1,285 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "./nb-distil-large-init",
3
+ "activation_dropout": 0.1,
4
+ "activation_function": "gelu",
5
+ "alignment_heads": [
6
+ [
7
+ 7,
8
+ 0
9
+ ],
10
+ [
11
+ 10,
12
+ 17
13
+ ],
14
+ [
15
+ 12,
16
+ 18
17
+ ],
18
+ [
19
+ 13,
20
+ 12
21
+ ],
22
+ [
23
+ 16,
24
+ 1
25
+ ],
26
+ [
27
+ 17,
28
+ 14
29
+ ],
30
+ [
31
+ 19,
32
+ 11
33
+ ],
34
+ [
35
+ 21,
36
+ 4
37
+ ],
38
+ [
39
+ 24,
40
+ 1
41
+ ],
42
+ [
43
+ 25,
44
+ 6
45
+ ]
46
+ ],
47
+ "apply_spec_augment": false,
48
+ "architectures": [
49
+ "WhisperForConditionalGeneration"
50
+ ],
51
+ "attention_dropout": 0,
52
+ "begin_suppress_tokens": null,
53
+ "bos_token_id": 50257,
54
+ "classifier_proj_size": 256,
55
+ "d_model": 1280,
56
+ "decoder_attention_heads": 20,
57
+ "decoder_ffn_dim": 5120,
58
+ "decoder_layerdrop": 0,
59
+ "decoder_layers": 2,
60
+ "decoder_start_token_id": 50258,
61
+ "dropout": 0,
62
+ "encoder_attention_heads": 20,
63
+ "encoder_ffn_dim": 5120,
64
+ "encoder_layerdrop": 0,
65
+ "encoder_layers": 32,
66
+ "eos_token_id": 50257,
67
+ "init_std": 0.02,
68
+ "is_encoder_decoder": true,
69
+ "lang_ids": [
70
+ 50259,
71
+ 50260,
72
+ 50261,
73
+ 50262,
74
+ 50263,
75
+ 50264,
76
+ 50265,
77
+ 50266,
78
+ 50267,
79
+ 50268,
80
+ 50269,
81
+ 50270,
82
+ 50271,
83
+ 50272,
84
+ 50273,
85
+ 50274,
86
+ 50275,
87
+ 50276,
88
+ 50277,
89
+ 50278,
90
+ 50279,
91
+ 50280,
92
+ 50281,
93
+ 50282,
94
+ 50283,
95
+ 50284,
96
+ 50285,
97
+ 50286,
98
+ 50287,
99
+ 50288,
100
+ 50289,
101
+ 50290,
102
+ 50291,
103
+ 50292,
104
+ 50293,
105
+ 50294,
106
+ 50295,
107
+ 50296,
108
+ 50297,
109
+ 50298,
110
+ 50299,
111
+ 50300,
112
+ 50301,
113
+ 50302,
114
+ 50303,
115
+ 50304,
116
+ 50305,
117
+ 50306,
118
+ 50307,
119
+ 50308,
120
+ 50309,
121
+ 50310,
122
+ 50311,
123
+ 50312,
124
+ 50313,
125
+ 50314,
126
+ 50315,
127
+ 50316,
128
+ 50317,
129
+ 50318,
130
+ 50319,
131
+ 50320,
132
+ 50321,
133
+ 50322,
134
+ 50323,
135
+ 50324,
136
+ 50325,
137
+ 50326,
138
+ 50327,
139
+ 50328,
140
+ 50329,
141
+ 50330,
142
+ 50331,
143
+ 50332,
144
+ 50333,
145
+ 50334,
146
+ 50335,
147
+ 50336,
148
+ 50337,
149
+ 50338,
150
+ 50339,
151
+ 50340,
152
+ 50341,
153
+ 50342,
154
+ 50343,
155
+ 50344,
156
+ 50345,
157
+ 50346,
158
+ 50347,
159
+ 50348,
160
+ 50349,
161
+ 50350,
162
+ 50351,
163
+ 50352,
164
+ 50353,
165
+ 50354,
166
+ 50355,
167
+ 50356,
168
+ 50357,
169
+ 50358
170
+ ],
171
+ "mask_feature_length": 10,
172
+ "mask_feature_min_masks": 0,
173
+ "mask_feature_prob": 0,
174
+ "mask_time_length": 10,
175
+ "mask_time_min_masks": 2,
176
+ "mask_time_prob": 0.05,
177
+ "max_length": null,
178
+ "max_source_positions": 1500,
179
+ "max_target_positions": 448,
180
+ "median_filter_width": 7,
181
+ "model_type": "whisper",
182
+ "num_hidden_layers": 32,
183
+ "num_mel_bins": 128,
184
+ "pad_token_id": 50256,
185
+ "scale_embedding": false,
186
+ "suppress_ids": [
187
+ 1,
188
+ 2,
189
+ 7,
190
+ 8,
191
+ 9,
192
+ 10,
193
+ 14,
194
+ 25,
195
+ 26,
196
+ 27,
197
+ 28,
198
+ 29,
199
+ 31,
200
+ 58,
201
+ 59,
202
+ 60,
203
+ 61,
204
+ 62,
205
+ 63,
206
+ 90,
207
+ 91,
208
+ 92,
209
+ 93,
210
+ 359,
211
+ 503,
212
+ 522,
213
+ 542,
214
+ 873,
215
+ 893,
216
+ 902,
217
+ 918,
218
+ 922,
219
+ 931,
220
+ 1350,
221
+ 1853,
222
+ 1982,
223
+ 2460,
224
+ 2627,
225
+ 3246,
226
+ 3253,
227
+ 3268,
228
+ 3536,
229
+ 3846,
230
+ 3961,
231
+ 4183,
232
+ 4667,
233
+ 6585,
234
+ 6647,
235
+ 7273,
236
+ 9061,
237
+ 9383,
238
+ 10428,
239
+ 10929,
240
+ 11938,
241
+ 12033,
242
+ 12331,
243
+ 12562,
244
+ 13793,
245
+ 14157,
246
+ 14635,
247
+ 15265,
248
+ 15618,
249
+ 16553,
250
+ 16604,
251
+ 18362,
252
+ 18956,
253
+ 20075,
254
+ 21675,
255
+ 22520,
256
+ 26130,
257
+ 26161,
258
+ 26435,
259
+ 28279,
260
+ 29464,
261
+ 31650,
262
+ 32302,
263
+ 32470,
264
+ 36865,
265
+ 42863,
266
+ 47425,
267
+ 49870,
268
+ 50254,
269
+ 50258,
270
+ 50359,
271
+ 50360,
272
+ 50361,
273
+ 50362,
274
+ 50363
275
+ ],
276
+ "suppress_ids_begin": [
277
+ 220,
278
+ 50257
279
+ ],
280
+ "torch_dtype": "float32",
281
+ "transformers_version": "4.46.2",
282
+ "use_cache": true,
283
+ "use_weighted_layer_sum": false,
284
+ "vocab_size": 51866
285
+ }
create_student_model.py ADDED
@@ -0,0 +1,231 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+ # Copyright 2023 The HuggingFace Inc. team. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """
17
+ Initialise a student Whisper model from a pre-trained teacher model for
18
+ teacher-student distillation.
19
+ """
20
+
21
+ import argparse
22
+ import copy
23
+ import logging
24
+
25
+ import numpy as np
26
+ import torch
27
+ from transformers import GenerationConfig, WhisperForConditionalGeneration, WhisperProcessor
28
+
29
+
30
+ logger = logging.getLogger(__name__)
31
+
32
+
33
+ def parse_args():
34
+ parser = argparse.ArgumentParser(
35
+ description="Initialise a student Whisper model from a teacher model, copying the relevant layer weights and adjusting the processor as necessary."
36
+ )
37
+ parser.add_argument(
38
+ "--teacher_checkpoint",
39
+ type=str,
40
+ required=True,
41
+ help="The HF Hub ID of the teacher checkpoint.",
42
+ )
43
+ parser.add_argument(
44
+ "--subfolder",
45
+ type=str,
46
+ default="",
47
+ help="In case the relevant teacher weights are located inside a subfolder of the model repo on huggingface.co, you "
48
+ "can specify the folder name here.",
49
+ )
50
+ parser.add_argument(
51
+ "--encoder_layers",
52
+ type=int,
53
+ default=None,
54
+ help="Number of encoder layers to use in the student model. Defaults to all layers from the teacher.",
55
+ )
56
+ parser.add_argument(
57
+ "--decoder_layers",
58
+ type=int,
59
+ default=2,
60
+ help="Number of decoder layers to use in the student model. Defaults to 2 layers.",
61
+ )
62
+ parser.add_argument(
63
+ "--decoder_layers_numbers",
64
+ type=int,
65
+ nargs="*",
66
+ help="Layers numbers of the decoder teacher to use in the student model. Defaults to None, equivalent to taking first and last layer (and equivalent to `--decoder_layers_numbers 0 -1`).",
67
+ )
68
+ parser.add_argument(
69
+ "--save_dir",
70
+ type=str,
71
+ required=True,
72
+ help="Where to save the student weights and processor.",
73
+ )
74
+ parser.add_argument(
75
+ "--push_to_hub",
76
+ type=bool,
77
+ required=False,
78
+ default=False,
79
+ help="Whether to push the student weights and processor to the Hub.",
80
+ )
81
+ parser.add_argument(
82
+ "--cache_dir",
83
+ type=str,
84
+ default=None,
85
+ help="Where to store the pretrained models downloaded from huggingface.co",
86
+ )
87
+
88
+ args = parser.parse_args()
89
+ return args
90
+
91
+
92
+ def init_student_model_from_teacher(
93
+ teacher_checkpoint,
94
+ encoder_layers=None,
95
+ decoder_layers=2,
96
+ decoder_layers_numbers=None,
97
+ save_dir=None,
98
+ push_to_hub=None,
99
+ cache_dir=None,
100
+ subfolder="",
101
+ ):
102
+ if decoder_layers_numbers is not None and len(decoder_layers_numbers) != decoder_layers:
103
+ raise ValueError(
104
+ f"Got {len(decoder_layers_numbers)} layers number for {decoder_layers} decoder layers."
105
+ )
106
+
107
+ teacher_model = WhisperForConditionalGeneration.from_pretrained(
108
+ teacher_checkpoint,
109
+ cache_dir=cache_dir,
110
+ subfolder=subfolder,
111
+ low_cpu_mem_usage=True,
112
+ )
113
+ processor = WhisperProcessor.from_pretrained(teacher_checkpoint)
114
+ generation_config = GenerationConfig.from_pretrained(teacher_checkpoint)
115
+ generation_config.forced_decoder_ids = None
116
+
117
+ teacher_config = teacher_model.config
118
+ teacher_encoder_layers = teacher_config.encoder_layers
119
+ teacher_decoder_layers = teacher_config.decoder_layers
120
+
121
+ student_config = copy.deepcopy(teacher_config)
122
+ student_config.update(
123
+ {
124
+ "encoder_layers": encoder_layers if encoder_layers is not None else teacher_encoder_layers,
125
+ "decoder_layers": decoder_layers,
126
+ }
127
+ )
128
+
129
+ encoder_mapping = np.linspace(0, teacher_encoder_layers - 1, student_config.encoder_layers, dtype=int)
130
+ encoder_mapping[-1] = teacher_encoder_layers - 1
131
+
132
+ encoder_map = {}
133
+ for student_layer, teacher_layer in enumerate(encoder_mapping):
134
+ encoder_map[teacher_layer] = student_layer
135
+
136
+ if decoder_layers_numbers is None:
137
+ decoder_mapping = np.linspace(0, teacher_decoder_layers - 1, student_config.decoder_layers, dtype=int)
138
+ decoder_mapping[-1] = teacher_decoder_layers - 1
139
+ else:
140
+ decoder_mapping = decoder_layers_numbers
141
+
142
+ decoder_map = {}
143
+ for student_layer, teacher_layer in enumerate(decoder_mapping):
144
+ decoder_map[teacher_layer] = student_layer
145
+
146
+ # init the student params from the teacher model
147
+ student_model = WhisperForConditionalGeneration(student_config)
148
+ missing_keys, unexpected_keys = student_model.load_state_dict(teacher_model.state_dict(), strict=False)
149
+ if len(missing_keys) > 0:
150
+ raise RuntimeError(
151
+ "Error(s) in loading state_dict for WhisperForConditionalGeneration. \n"
152
+ f"Missing key(s) in state_dict: {missing_keys}"
153
+ )
154
+ if decoder_layers == teacher_decoder_layers:
155
+ decoder_keys = [key for key in unexpected_keys if "model.decoder.layers" in key]
156
+ if len(decoder_keys) > 0:
157
+ raise RuntimeError(
158
+ "Error(s) in loading state_dict for WhisperForConditionalGeneration. \n"
159
+ f"Unexpected key(s) in state_dict: {decoder_keys}"
160
+ )
161
+ if encoder_layers == teacher_encoder_layers:
162
+ encoder_keys = [key for key in unexpected_keys if "model.encoder.layers" in key]
163
+ if len(encoder_keys) > 0:
164
+ raise RuntimeError(
165
+ "Error(s) in loading state_dict for WhisperForConditionalGeneration. \n"
166
+ f"Unexpected key(s) in state_dict: {encoder_keys}"
167
+ )
168
+
169
+ for layer in range(teacher_decoder_layers):
170
+ if layer in decoder_map:
171
+ # re-introduce pre-defined layers from the teacher
172
+ student_model.model.decoder.layers[decoder_map[layer]].load_state_dict(
173
+ teacher_model.model.decoder.layers[layer].state_dict()
174
+ )
175
+
176
+ if encoder_layers is not None:
177
+ for layer in range(teacher_encoder_layers):
178
+ if layer in encoder_map:
179
+ # re-introduce pre-defined layers from the teacher
180
+ student_model.model.encoder.layers[encoder_map[layer]].load_state_dict(
181
+ teacher_model.model.encoder.layers[layer].state_dict()
182
+ )
183
+
184
+ # remove the teacher params and model
185
+ del teacher_model
186
+
187
+ # save the converted weights and model
188
+ if save_dir is not None:
189
+ student_model.save_pretrained(save_dir)
190
+ # we also need to correctly save the processor and generation config
191
+ processor.save_pretrained(save_dir)
192
+ generation_config.save_pretrained(save_dir)
193
+
194
+ # check we can do a forward pass with the saved model - first load the weights and processor
195
+ logger.info("Checking we can load the saved model...")
196
+ student_model = WhisperForConditionalGeneration.from_pretrained(
197
+ save_dir,
198
+ low_cpu_mem_usage=True,
199
+ )
200
+ processor = WhisperProcessor.from_pretrained(save_dir)
201
+
202
+ # define some random inputs
203
+ input_features = processor(np.ones(16000), sampling_rate=16000, return_tensors="pt").input_features
204
+ decoder_start_token_id = student_model.config.decoder_start_token_id
205
+ decoder_input_ids = torch.ones((input_features.shape[0], 1), dtype=torch.long) * decoder_start_token_id
206
+
207
+ # do a forward pass - outputs will be gibberish for the initialised model so we can't check them
208
+ # but we make can sure the model runs as expected
209
+ logger.info("Checking we can run the converted model forward...")
210
+ _ = student_model(input_features, decoder_input_ids=decoder_input_ids).logits
211
+ logger.info("Conversion successful!")
212
+
213
+ if push_to_hub:
214
+ student_model.push_to_hub(save_dir)
215
+ processor.push_to_hub(save_dir)
216
+ generation_config.push_to_hub(save_dir)
217
+
218
+
219
+ if __name__ == "__main__":
220
+ args = parse_args()
221
+
222
+ init_student_model_from_teacher(
223
+ teacher_checkpoint=args.teacher_checkpoint,
224
+ encoder_layers=args.encoder_layers,
225
+ decoder_layers=args.decoder_layers,
226
+ decoder_layers_numbers=args.decoder_layers_numbers,
227
+ save_dir=args.save_dir,
228
+ push_to_hub=args.push_to_hub,
229
+ cache_dir=args.cache_dir,
230
+ subfolder=args.subfolder,
231
+ )
distil-whisper/events.out.tfevents.1730988960.a100-80-west4a.48904.0 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e58b3108b1d79614ead26bc71b9b79ba8d077ea73b93f222724dde951c1e8ab6
3
+ size 88
distil-whisper/events.out.tfevents.1730989066.a100-80-west4a.49408.0 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:102f2dfcdad2383784b10a2967b71981fae5a614ad250eec9e83968e1215469e
3
+ size 88
distil-whisper/events.out.tfevents.1730989452.a100-80-west4a.68077.0 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:dd082b899310b318f1ae0d551d3be510323bf1da0270e0c3e78d7ccdecd4d696
3
+ size 88
distil-whisper/events.out.tfevents.1730990001.a100-80-west4a.87125.0 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5b9c1f6dc20114fde3decac69cde54d7a3b8d982cae625be16f3c2c2aafe78e9
3
+ size 1055
distil_whisper/__init__.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ __version__ = "0.0.1"
17
+
18
+ from .modeling_flax_whisper import FlaxWhisperForConditionalGeneration
19
+ from .partitioner import PjitPartitioner
20
+ from .pipeline import FlaxWhisperPipeline
21
+ from .train_state import InferenceState
distil_whisper/layers.py ADDED
@@ -0,0 +1,1338 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 The T5X Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Dense attention classes and mask/weighting functions."""
16
+
17
+ # pylint: disable=attribute-defined-outside-init,g-bare-generic
18
+
19
+ import dataclasses
20
+ import functools
21
+ import operator
22
+ from typing import Any, Callable, Iterable, List, Optional, Sequence, Tuple, Union
23
+
24
+ import jax
25
+ import jax.numpy as jnp
26
+ import numpy as np
27
+ from flax import linen as nn
28
+ from flax.linen import partitioning as nn_partitioning
29
+ from flax.linen.dtypes import promote_dtype
30
+ from jax import lax, random
31
+
32
+
33
+ # from flax.linen.partitioning import param_with_axes, with_sharding_constraint
34
+ param_with_axes = nn_partitioning.param_with_axes
35
+ with_sharding_constraint = nn_partitioning.with_sharding_constraint
36
+
37
+
38
+ # Type annotations
39
+ Array = jnp.ndarray
40
+ DType = jnp.dtype
41
+ PRNGKey = jnp.ndarray
42
+ Shape = Iterable[int]
43
+ Activation = Callable[..., Array]
44
+ PrecisionLike = Union[None, str, lax.Precision, Tuple[str, str], Tuple[lax.Precision, lax.Precision]]
45
+ DotGeneralT = Callable[..., Array]
46
+ ConvGeneralDilatedT = Callable[..., Array]
47
+ PaddingLike = Union[str, int, Sequence[Union[int, Tuple[int, int]]]]
48
+ LaxPadding = Union[str, Sequence[Tuple[int, int]]]
49
+
50
+ # Parameter initializers.
51
+ Initializer = Callable[[PRNGKey, Shape, DType], Array]
52
+ InitializerAxis = Union[int, Tuple[int, ...]]
53
+ NdInitializer = Callable[[PRNGKey, Shape, DType, InitializerAxis, InitializerAxis], Array]
54
+
55
+ default_embed_init = nn.initializers.variance_scaling(1.0, "fan_in", "normal", out_axis=0)
56
+
57
+
58
+ # ------------------------------------------------------------------------------
59
+ # Temporary inlined JAX N-d initializer code
60
+ # TODO(levskaya): remove once new JAX release is out.
61
+ # ------------------------------------------------------------------------------
62
+ def _compute_fans(shape: jax.core.NamedShape, in_axis=-2, out_axis=-1):
63
+ """Inlined JAX `nn.initializer._compute_fans`."""
64
+ if isinstance(in_axis, int):
65
+ in_size = shape[in_axis]
66
+ else:
67
+ in_size = int(np.prod([shape[i] for i in in_axis]))
68
+ if isinstance(out_axis, int):
69
+ out_size = shape[out_axis]
70
+ else:
71
+ out_size = int(np.prod([shape[i] for i in out_axis]))
72
+ receptive_field_size = shape.total / in_size / out_size
73
+ fan_in = in_size * receptive_field_size
74
+ fan_out = out_size * receptive_field_size
75
+ return fan_in, fan_out
76
+
77
+
78
+ def variance_scaling(scale, mode, distribution, in_axis=-2, out_axis=-1, dtype=jnp.float_):
79
+ """Inlined JAX `nn.initializer.variance_scaling`."""
80
+
81
+ def init(key, shape, dtype=dtype):
82
+ return jnp.zeros(shape, dtype=dtype)
83
+ dtype = jax.dtypes.canonicalize_dtype(dtype)
84
+ shape = jax.core.as_named_shape(shape)
85
+ fan_in, fan_out = _compute_fans(shape, in_axis, out_axis)
86
+ if mode == "fan_in":
87
+ denominator = fan_in
88
+ elif mode == "fan_out":
89
+ denominator = fan_out
90
+ elif mode == "fan_avg":
91
+ denominator = (fan_in + fan_out) / 2
92
+ else:
93
+ raise ValueError("invalid mode for variance scaling initializer: {}".format(mode))
94
+ variance = jnp.array(scale / denominator, dtype=dtype)
95
+
96
+ if distribution == "truncated_normal":
97
+ # constant is stddev of standard normal truncated to (-2, 2)
98
+ stddev = jnp.sqrt(variance) / jnp.array(0.87962566103423978, dtype)
99
+ return random.truncated_normal(key, -2, 2, shape, dtype) * stddev
100
+ elif distribution == "normal":
101
+ return random.normal(key, shape, dtype) * jnp.sqrt(variance)
102
+ elif distribution == "uniform":
103
+ return random.uniform(key, shape, dtype, -1) * jnp.sqrt(3 * variance)
104
+ else:
105
+ raise ValueError("invalid distribution for variance scaling initializer: {}".format(distribution))
106
+
107
+ return init
108
+
109
+
110
+ # ------------------------------------------------------------------------------
111
+
112
+
113
+ def nd_dense_init(scale, mode, distribution):
114
+ """Initializer with in_axis, out_axis set at call time."""
115
+
116
+ def init_fn(key, shape, dtype, in_axis, out_axis):
117
+ fn = variance_scaling(scale, mode, distribution, in_axis, out_axis)
118
+ return fn(key, shape, dtype)
119
+
120
+ return init_fn
121
+
122
+
123
+ def dot_product_attention(
124
+ query: Array,
125
+ key: Array,
126
+ value: Array,
127
+ bias: Optional[Array] = None,
128
+ dropout_rng: Optional[PRNGKey] = None,
129
+ dropout_rate: float = 0.0,
130
+ deterministic: bool = False,
131
+ dtype: DType = jnp.float32,
132
+ float32_logits: bool = False,
133
+ ):
134
+ """Computes dot-product attention given query, key, and value.
135
+
136
+ This is the core function for applying attention based on
137
+ https://arxiv.org/abs/1706.03762. It calculates the attention weights given
138
+ query and key and combines the values using the attention weights.
139
+
140
+ Args:
141
+ query: queries for calculating attention with shape of `[batch, q_length,
142
+ num_heads, qk_depth_per_head]`.
143
+ key: keys for calculating attention with shape of `[batch, kv_length,
144
+ num_heads, qk_depth_per_head]`.
145
+ value: values to be used in attention with shape of `[batch, kv_length,
146
+ num_heads, v_depth_per_head]`.
147
+ bias: bias for the attention weights. This should be broadcastable to the
148
+ shape `[batch, num_heads, q_length, kv_length]` This can be used for
149
+ incorporating causal masks, padding masks, proximity bias, etc.
150
+ dropout_rng: JAX PRNGKey: to be used for dropout
151
+ dropout_rate: dropout rate
152
+ deterministic: bool, deterministic or not (to apply dropout)
153
+ dtype: the dtype of the computation (default: float32)
154
+ float32_logits: bool, if True then compute logits in float32 to avoid
155
+ numerical issues with bfloat16.
156
+
157
+ Returns:
158
+ Output of shape `[batch, length, num_heads, v_depth_per_head]`.
159
+ """
160
+ assert key.ndim == query.ndim == value.ndim, "q, k, v must have same rank."
161
+ assert query.shape[:-3] == key.shape[:-3] == value.shape[:-3], "q, k, v batch dims must match."
162
+ assert query.shape[-2] == key.shape[-2] == value.shape[-2], "q, k, v num_heads must match."
163
+ assert key.shape[-3] == value.shape[-3], "k, v lengths must match."
164
+ assert query.shape[-1] == key.shape[-1], "q, k depths must match."
165
+
166
+ # Casting logits and softmax computation for float32 for model stability.
167
+ if float32_logits:
168
+ query = query.astype(jnp.float32)
169
+ key = key.astype(jnp.float32)
170
+
171
+ # `attn_weights`: [batch, num_heads, q_length, kv_length]
172
+ attn_weights = jnp.einsum("bqhd,bkhd->bhqk", query, key)
173
+
174
+ # Apply attention bias: masking, dropout, proximity bias, etc.
175
+ if bias is not None:
176
+ attn_weights = attn_weights + bias.astype(attn_weights.dtype)
177
+
178
+ # Normalize the attention weights across `kv_length` dimension.
179
+ attn_weights = jax.nn.softmax(attn_weights).astype(dtype)
180
+
181
+ # Apply attention dropout.
182
+ if not deterministic and dropout_rate > 0.0:
183
+ keep_prob = 1.0 - dropout_rate
184
+ # T5 broadcasts along the "length" dim, but unclear which one that
185
+ # corresponds to in positional dimensions here, assuming query dim.
186
+ dropout_shape = list(attn_weights.shape)
187
+ dropout_shape[-2] = 1
188
+ keep = random.bernoulli(dropout_rng, keep_prob, dropout_shape)
189
+ keep = jnp.broadcast_to(keep, attn_weights.shape)
190
+ multiplier = keep.astype(attn_weights.dtype) / jnp.asarray(keep_prob, dtype=dtype)
191
+ attn_weights = attn_weights * multiplier
192
+
193
+ # Take the linear combination of `value`.
194
+ return jnp.einsum("bhqk,bkhd->bqhd", attn_weights, value)
195
+
196
+
197
+ dynamic_vector_slice_in_dim = jax.vmap(lax.dynamic_slice_in_dim, in_axes=(None, 0, None, None))
198
+
199
+
200
+ class MultiHeadDotProductAttention(nn.Module):
201
+ """Multi-head dot-product attention.
202
+
203
+ Attributes:
204
+ num_heads: number of attention heads. Features (i.e. inputs_q.shape[-1])
205
+ should be divisible by the number of heads.
206
+ head_dim: dimension of each head.
207
+ dtype: the dtype of the computation.
208
+ dropout_rate: dropout rate
209
+ kernel_init: initializer for the kernel of the Dense layers.
210
+ float32_logits: bool, if True then compute logits in float32 to avoid
211
+ numerical issues with bfloat16.
212
+ """
213
+
214
+ num_heads: int
215
+ head_dim: int
216
+ dtype: DType = jnp.float32
217
+ dropout_rate: float = 0.0
218
+ kernel_init: NdInitializer = nd_dense_init(1.0, "fan_in", "normal")
219
+ float32_logits: bool = False # computes logits in float32 for stability.
220
+
221
+ @nn.compact
222
+ def __call__(
223
+ self,
224
+ inputs_q: Array,
225
+ inputs_kv: Array,
226
+ mask: Optional[Array] = None,
227
+ bias: Optional[Array] = None,
228
+ *,
229
+ decode: bool = False,
230
+ deterministic: bool = False,
231
+ ) -> Array:
232
+ """Applies multi-head dot product attention on the input data.
233
+
234
+ Projects the inputs into multi-headed query, key, and value vectors,
235
+ applies dot-product attention and project the results to an output vector.
236
+
237
+ There are two modes: decoding and non-decoding (e.g., training). The mode is
238
+ determined by `decode` argument. For decoding, this method is called twice,
239
+ first to initialize the cache and then for an actual decoding process. The
240
+ two calls are differentiated by the presence of 'cached_key' in the variable
241
+ dict. In the cache initialization stage, the cache variables are initialized
242
+ as zeros and will be filled in the subsequent decoding process.
243
+
244
+ In the cache initialization call, `inputs_q` has a shape [batch, length,
245
+ q_features] and `inputs_kv`: [batch, length, kv_features]. During the
246
+ incremental decoding stage, query, key and value all have the shape [batch,
247
+ 1, qkv_features] corresponding to a single step.
248
+
249
+ Args:
250
+ inputs_q: input queries of shape `[batch, q_length, q_features]`.
251
+ inputs_kv: key/values of shape `[batch, kv_length, kv_features]`.
252
+ mask: attention mask of shape `[batch, num_heads, q_length, kv_length]`.
253
+ bias: attention bias of shape `[batch, num_heads, q_length, kv_length]`.
254
+ decode: Whether to prepare and use an autoregressive cache.
255
+ deterministic: Disables dropout if set to True.
256
+
257
+ Returns:
258
+ output of shape `[batch, length, q_features]`.
259
+ """
260
+ projection = functools.partial(
261
+ DenseGeneral,
262
+ axis=-1,
263
+ features=(self.num_heads, self.head_dim),
264
+ kernel_axes=("embed", "heads", "kv"),
265
+ dtype=self.dtype,
266
+ )
267
+
268
+ # NOTE: T5 does not explicitly rescale the attention logits by
269
+ # 1/sqrt(depth_kq)! This is folded into the initializers of the
270
+ # linear transformations, which is equivalent under Adafactor.
271
+ depth_scaling = jnp.sqrt(self.head_dim).astype(self.dtype)
272
+
273
+ def query_init(*args):
274
+ return self.kernel_init(*args) / depth_scaling
275
+
276
+ # Project inputs_q to multi-headed q/k/v
277
+ # dimensions are then [batch, length, num_heads, head_dim]
278
+ query = projection(kernel_init=query_init, name="query")(inputs_q)
279
+ key = projection(kernel_init=self.kernel_init, name="key")(inputs_kv)
280
+ value = projection(kernel_init=self.kernel_init, name="value")(inputs_kv)
281
+
282
+ query = with_sharding_constraint(query, ("batch", "length", "heads", "kv"))
283
+ key = with_sharding_constraint(key, ("batch", "length", "heads", "kv"))
284
+ value = with_sharding_constraint(value, ("batch", "length", "heads", "kv"))
285
+
286
+ if decode:
287
+ # Detect if we're initializing by absence of existing cache data.
288
+ is_initialized = self.has_variable("cache", "cached_key")
289
+
290
+ # The key and value have dimension [batch, length, num_heads, head_dim],
291
+ # but we cache them as [batch, num_heads, head_dim, length] as a TPU
292
+ # fusion optimization. This also enables the "scatter via one-hot
293
+ # broadcast" trick, which means we do a one-hot broadcast instead of a
294
+ # scatter/gather operations, resulting in a 3-4x speedup in practice.
295
+ def swap_dims(x):
296
+ return x[:-3] + tuple(x[i] for i in [-2, -1, -3])
297
+
298
+ cached_key = self.variable("cache", "cached_key", jnp.zeros, swap_dims(key.shape), key.dtype)
299
+ cached_value = self.variable("cache", "cached_value", jnp.zeros, swap_dims(value.shape), value.dtype)
300
+ cache_index = self.variable("cache", "cache_index", lambda: jnp.array(0, dtype=jnp.int32))
301
+ if is_initialized:
302
+ batch, num_heads, head_dim, length = cached_key.value.shape
303
+ # During fast autoregressive decoding, we feed one position at a time,
304
+ # and cache the keys and values step by step.
305
+ # Sanity shape check of cached key against input query.
306
+ expected_shape = (batch, 1, num_heads, head_dim)
307
+ if expected_shape != query.shape:
308
+ raise ValueError(
309
+ "Autoregressive cache shape error, "
310
+ "expected query shape %s instead got %s." % (expected_shape, query.shape)
311
+ )
312
+
313
+ # Create a OHE of the current index. NOTE: the index is increased below.
314
+ cur_index = cache_index.value
315
+ one_hot_indices = jax.nn.one_hot(cur_index, length, dtype=key.dtype)
316
+ # In order to update the key, value caches with the current key and
317
+ # value, we move the length axis to the back, similar to what we did for
318
+ # the cached ones above.
319
+ # Note these are currently the key and value of a single position, since
320
+ # we feed one position at a time.
321
+ one_token_key = jnp.moveaxis(key, -3, -1)
322
+ one_token_value = jnp.moveaxis(value, -3, -1)
323
+ # Update key, value caches with our new 1d spatial slices.
324
+ # We implement an efficient scatter into the cache via one-hot
325
+ # broadcast and addition.
326
+ key = cached_key.value + one_token_key * one_hot_indices
327
+ value = cached_value.value + one_token_value * one_hot_indices
328
+ cached_key.value = key
329
+ cached_value.value = value
330
+ cache_index.value = cache_index.value + 1
331
+ # Move the keys and values back to their original shapes.
332
+ key = jnp.moveaxis(key, -1, -3)
333
+ value = jnp.moveaxis(value, -1, -3)
334
+
335
+ # Causal mask for cached decoder self-attention: our single query
336
+ # position should only attend to those key positions that have already
337
+ # been generated and cached, not the remaining zero elements.
338
+ mask = combine_masks(
339
+ mask,
340
+ jnp.broadcast_to(
341
+ jnp.arange(length) <= cur_index,
342
+ # (1, 1, length) represent (head dim, query length, key length)
343
+ # query length is 1 because during decoding we deal with one
344
+ # index.
345
+ # The same mask is applied to all batch elements and heads.
346
+ (batch, 1, 1, length),
347
+ ),
348
+ )
349
+
350
+ # Grab the correct relative attention bias during decoding. This is
351
+ # only required during single step decoding.
352
+ if bias is not None:
353
+ # The bias is a full attention matrix, but during decoding we only
354
+ # have to take a slice of it.
355
+ # This is equivalent to bias[..., cur_index:cur_index+1, :].
356
+ bias = dynamic_vector_slice_in_dim(jnp.squeeze(bias, axis=0), jnp.reshape(cur_index, (-1)), 1, -2)
357
+
358
+ # Convert the boolean attention mask to an attention bias.
359
+ if mask is not None:
360
+ # attention mask in the form of attention bias
361
+ attention_bias = lax.select(
362
+ mask > 0,
363
+ jnp.full(mask.shape, 0.0).astype(self.dtype),
364
+ jnp.full(mask.shape, -1e10).astype(self.dtype),
365
+ )
366
+ else:
367
+ attention_bias = None
368
+
369
+ # Add provided bias term (e.g. relative position embedding).
370
+ if bias is not None:
371
+ attention_bias = combine_biases(attention_bias, bias)
372
+
373
+ dropout_rng = None
374
+ if not deterministic and self.dropout_rate > 0.0:
375
+ dropout_rng = self.make_rng("dropout")
376
+
377
+ # Apply attention.
378
+ x = dot_product_attention(
379
+ query,
380
+ key,
381
+ value,
382
+ bias=attention_bias,
383
+ dropout_rng=dropout_rng,
384
+ dropout_rate=self.dropout_rate,
385
+ deterministic=deterministic,
386
+ dtype=self.dtype,
387
+ float32_logits=self.float32_logits,
388
+ )
389
+
390
+ # Back to the original inputs dimensions.
391
+ out = DenseGeneral(
392
+ features=inputs_q.shape[-1], # output dim is set to the input dim.
393
+ axis=(-2, -1),
394
+ kernel_init=self.kernel_init,
395
+ kernel_axes=("heads", "kv", "embed"),
396
+ dtype=self.dtype,
397
+ name="out",
398
+ )(x)
399
+ return out
400
+
401
+
402
+ def _normalize_axes(axes: Iterable[int], ndim: int) -> Tuple[int]:
403
+ # A tuple by convention. len(axes_tuple) then also gives the rank efficiently.
404
+ return tuple([ax if ax >= 0 else ndim + ax for ax in axes])
405
+
406
+
407
+ def _canonicalize_tuple(x):
408
+ if isinstance(x, Iterable):
409
+ return tuple(x)
410
+ else:
411
+ return (x,)
412
+
413
+
414
+ # ------------------------------------------------------------------------------
415
+ # DenseGeneral for attention layers.
416
+ # ------------------------------------------------------------------------------
417
+ class DenseGeneral(nn.Module):
418
+ """A linear transformation (without bias) with flexible axes.
419
+
420
+ Attributes:
421
+ features: tuple with numbers of output features.
422
+ axis: tuple with axes to apply the transformation on.
423
+ dtype: the dtype of the computation (default: float32).
424
+ kernel_init: initializer function for the weight matrix.
425
+ """
426
+
427
+ features: Union[Iterable[int], int]
428
+ axis: Union[Iterable[int], int] = -1
429
+ dtype: DType = jnp.float32
430
+ params_dtype: DType = jnp.float32
431
+ kernel_init: NdInitializer = nd_dense_init(1.0, "fan_in", "normal")
432
+ kernel_axes: Tuple[str, ...] = ()
433
+ use_bias: bool = True
434
+ bias_init: Any = nn.initializers.zeros
435
+
436
+ @nn.compact
437
+ def __call__(self, inputs: Array) -> Array:
438
+ """Applies a linear transformation to the inputs along multiple dimensions.
439
+
440
+ Args:
441
+ inputs: The nd-array to be transformed.
442
+
443
+ Returns:
444
+ The transformed input.
445
+ """
446
+ features = _canonicalize_tuple(self.features)
447
+ axis = _canonicalize_tuple(self.axis)
448
+
449
+ inputs = jnp.asarray(inputs, self.dtype)
450
+ axis = _normalize_axes(axis, inputs.ndim)
451
+
452
+ kernel_shape = tuple([inputs.shape[ax] for ax in axis]) + features
453
+ kernel_in_axis = np.arange(len(axis))
454
+ kernel_out_axis = np.arange(len(axis), len(axis) + len(features))
455
+ kernel = param_with_axes(
456
+ "kernel",
457
+ self.kernel_init,
458
+ kernel_shape,
459
+ self.params_dtype,
460
+ kernel_in_axis,
461
+ kernel_out_axis,
462
+ axes=self.kernel_axes,
463
+ )
464
+ if self.use_bias:
465
+ bias = param_with_axes(
466
+ "bias",
467
+ self.bias_init,
468
+ features,
469
+ self.params_dtype,
470
+ axes=(self.kernel_axes[-1],),
471
+ )
472
+ kernel = jnp.asarray(kernel, self.dtype)
473
+
474
+ contract_ind = tuple(range(0, len(axis)))
475
+ y = lax.dot_general(inputs, kernel, ((axis, contract_ind), ((), ())))
476
+ if self.use_bias:
477
+ bias = jnp.asarray(bias, self.dtype)
478
+ # y += jnp.reshape(bias, (1,) * (y.ndim - 1) + (-1,))
479
+ y += jnp.reshape(bias, (1,) * (len(features) - y.ndim) + bias.shape[:])
480
+ return y
481
+
482
+
483
+ def _convert_to_activation_function(fn_or_string: Union[str, Callable]) -> Callable:
484
+ """Convert a string to an activation function."""
485
+ if fn_or_string == "linear":
486
+ return lambda x: x
487
+ elif isinstance(fn_or_string, str):
488
+ return getattr(nn, fn_or_string)
489
+ elif callable(fn_or_string):
490
+ return fn_or_string
491
+ else:
492
+ raise ValueError("don't know how to convert %s to an activation function" % (fn_or_string,))
493
+
494
+
495
+ class MlpBlock(nn.Module):
496
+ """Transformer MLP / feed-forward block.
497
+
498
+ Attributes:
499
+ intermediate_dim: Shared dimension of hidden layers.
500
+ activations: Type of activations for each layer. Each element is either
501
+ 'linear', a string function name in flax.linen, or a function.
502
+ kernel_init: Kernel function, passed to the dense layers.
503
+ deterministic: Whether the dropout layers should be deterministic.
504
+ intermediate_dropout_rate: Dropout rate used after the intermediate layers.
505
+ dtype: Type for the dense layer.
506
+ """
507
+
508
+ intermediate_dim: int = 2048
509
+ activations: Sequence[Union[str, Callable]] = ("relu",)
510
+ kernel_init: NdInitializer = nd_dense_init(1.0, "fan_in", "truncated_normal")
511
+ intermediate_dropout_rate: float = 0.1
512
+ dtype: Any = jnp.float32
513
+
514
+ @nn.compact
515
+ def __call__(self, inputs, decode: bool = False, deterministic: bool = False):
516
+ """Applies Transformer MlpBlock module."""
517
+ # Iterate over specified MLP input activation functions.
518
+ # e.g. ('relu',) or ('gelu', 'linear') for gated-gelu.
519
+ activations = []
520
+ for idx, act_fn in enumerate(self.activations):
521
+ dense_name = "wi" if len(self.activations) == 1 else f"wi_{idx}"
522
+ x = DenseGeneral(
523
+ self.intermediate_dim,
524
+ dtype=self.dtype,
525
+ kernel_init=self.kernel_init,
526
+ kernel_axes=("embed", "mlp"),
527
+ name=dense_name,
528
+ )(inputs)
529
+ x = _convert_to_activation_function(act_fn)(x)
530
+ activations.append(x)
531
+
532
+ # Take elementwise product of above intermediate activations.
533
+ x = functools.reduce(operator.mul, activations)
534
+ # Apply dropout and final dense output projection.
535
+ x = nn.Dropout(rate=self.intermediate_dropout_rate, broadcast_dims=(-2,))(
536
+ x, deterministic=deterministic
537
+ ) # Broadcast along length.
538
+ x = with_sharding_constraint(x, ("batch", "length", "mlp"))
539
+ output = DenseGeneral(
540
+ inputs.shape[-1],
541
+ dtype=self.dtype,
542
+ kernel_init=self.kernel_init,
543
+ kernel_axes=("mlp", "embed"),
544
+ name="wo",
545
+ )(x)
546
+ return output
547
+
548
+
549
+ class Embed(nn.Module):
550
+ """A parameterized function from integers [0, n) to d-dimensional vectors.
551
+
552
+ Attributes:
553
+ num_embeddings: number of embeddings.
554
+ features: number of feature dimensions for each embedding.
555
+ dtype: the dtype of the embedding vectors (default: float32).
556
+ embedding_init: embedding initializer.
557
+ one_hot: performs the gather with a one-hot contraction rather than a true
558
+ gather. This is currently needed for SPMD partitioning.
559
+ """
560
+
561
+ num_embeddings: int
562
+ features: int
563
+ cast_input_dtype: Optional[DType] = None
564
+ dtype: DType = jnp.float32
565
+ params_dtype: DType = jnp.float32
566
+ attend_dtype: Optional[DType] = None
567
+ embedding_init: Initializer = default_embed_init
568
+ one_hot: bool = True
569
+ embedding: Array = dataclasses.field(init=False)
570
+
571
+ def setup(self):
572
+ self.embedding = param_with_axes(
573
+ "embedding",
574
+ self.embedding_init,
575
+ (self.num_embeddings, self.features),
576
+ self.params_dtype,
577
+ axes=("vocab", "embed"),
578
+ )
579
+
580
+ def __call__(self, inputs: Array) -> Array:
581
+ """Embeds the inputs along the last dimension.
582
+
583
+ Args:
584
+ inputs: input data, all dimensions are considered batch dimensions.
585
+
586
+ Returns:
587
+ Output which is embedded input data. The output shape follows the input,
588
+ with an additional `features` dimension appended.
589
+ """
590
+ if self.cast_input_dtype:
591
+ inputs = inputs.astype(self.cast_input_dtype)
592
+ if not jnp.issubdtype(inputs.dtype, jnp.integer):
593
+ raise ValueError("Input type must be an integer or unsigned integer.")
594
+ if self.one_hot:
595
+ iota = lax.iota(jnp.int32, self.num_embeddings)
596
+ one_hot = jnp.array(inputs[..., jnp.newaxis] == iota, dtype=self.dtype)
597
+ output = jnp.dot(one_hot, jnp.asarray(self.embedding, self.dtype))
598
+ else:
599
+ output = jnp.asarray(self.embedding, self.dtype)[inputs]
600
+ output = with_sharding_constraint(output, ("batch", "length", "embed"))
601
+ return output
602
+
603
+ def attend(self, query: Array) -> Array:
604
+ """Attend over the embedding using a query array.
605
+
606
+ Args:
607
+ query: array with last dimension equal the feature depth `features` of the
608
+ embedding.
609
+
610
+ Returns:
611
+ An array with final dim `num_embeddings` corresponding to the batched
612
+ inner-product of the array of query vectors against each embedding.
613
+ Commonly used for weight-sharing between embeddings and logit transform
614
+ in NLP models.
615
+ """
616
+ dtype = self.attend_dtype if self.attend_dtype is not None else self.dtype
617
+ return jnp.dot(query, jnp.asarray(self.embedding, dtype).T)
618
+
619
+
620
+ class RelativePositionBiases(nn.Module):
621
+ """Adds T5-style relative positional embeddings to the attention logits.
622
+
623
+ Attributes:
624
+ num_buckets: Number of buckets to bucket distances between key and query
625
+ positions into.
626
+ max_distance: Maximum distance before everything is lumped into the last
627
+ distance bucket.
628
+ num_heads: Number of heads in the attention layer. Each head will get a
629
+ different relative position weighting.
630
+ dtype: Type of arrays through this module.
631
+ embedding_init: initializer for relative embedding table.
632
+ """
633
+
634
+ num_buckets: int
635
+ max_distance: int
636
+ num_heads: int
637
+ dtype: Any
638
+ embedding_init: Callable[..., Array] = nn.linear.default_embed_init
639
+
640
+ @staticmethod
641
+ def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128):
642
+ """Translate relative position to a bucket number for relative attention.
643
+
644
+ The relative position is defined as memory_position - query_position, i.e.
645
+ the distance in tokens from the attending position to the attended-to
646
+ position. If bidirectional=False, then positive relative positions are
647
+ invalid.
648
+ We use smaller buckets for small absolute relative_position and larger
649
+ buckets for larger absolute relative_positions. All relative
650
+ positions >=max_distance map to the same bucket. All relative
651
+ positions <=-max_distance map to the same bucket. This should allow for
652
+ more graceful generalization to longer sequences than the model has been
653
+ trained on.
654
+
655
+ Args:
656
+ relative_position: an int32 array
657
+ bidirectional: a boolean - whether the attention is bidirectional
658
+ num_buckets: an integer
659
+ max_distance: an integer
660
+
661
+ Returns:
662
+ a Tensor with the same shape as relative_position, containing int32
663
+ values in the range [0, num_buckets)
664
+ """
665
+ ret = 0
666
+ n = -relative_position
667
+ if bidirectional:
668
+ num_buckets //= 2
669
+ ret += (n < 0).astype(np.int32) * num_buckets
670
+ n = np.abs(n)
671
+ else:
672
+ n = np.maximum(n, 0)
673
+ # now n is in the range [0, inf)
674
+ max_exact = num_buckets // 2
675
+ is_small = n < max_exact
676
+ val_if_large = max_exact + (
677
+ np.log(n.astype(np.float32) / max_exact + np.finfo(np.float32).eps)
678
+ / np.log(max_distance / max_exact)
679
+ * (num_buckets - max_exact)
680
+ ).astype(np.int32)
681
+ val_if_large = np.minimum(val_if_large, num_buckets - 1)
682
+ ret += np.where(is_small, n, val_if_large)
683
+ return ret
684
+
685
+ @nn.compact
686
+ def __call__(self, qlen, klen, bidirectional=True):
687
+ """Produce relative position embedding attention biases.
688
+
689
+ Args:
690
+ qlen: attention query length.
691
+ klen: attention key length.
692
+ bidirectional: whether to allow positive memory-query relative position
693
+ embeddings.
694
+
695
+ Returns:
696
+ output: `(1, len, q_len, k_len)` attention bias
697
+ """
698
+ # TODO(levskaya): should we be computing this w. numpy as a program
699
+ # constant?
700
+ context_position = np.arange(qlen, dtype=jnp.int32)[:, None]
701
+ memory_position = np.arange(klen, dtype=jnp.int32)[None, :]
702
+ relative_position = memory_position - context_position # shape (qlen, klen)
703
+ rp_bucket = self._relative_position_bucket(
704
+ relative_position,
705
+ bidirectional=bidirectional,
706
+ num_buckets=self.num_buckets,
707
+ max_distance=self.max_distance,
708
+ )
709
+ relative_attention_bias = param_with_axes(
710
+ "rel_embedding",
711
+ self.embedding_init,
712
+ (self.num_heads, self.num_buckets),
713
+ jnp.float32,
714
+ axes=("heads", "relpos_buckets"),
715
+ )
716
+
717
+ relative_attention_bias = jnp.asarray(relative_attention_bias, self.dtype)
718
+ # Instead of using a slow gather, we create a leading-dimension one-hot
719
+ # array from rp_bucket and use it to perform the gather-equivalent via a
720
+ # contraction, i.e.:
721
+ # (num_head, num_buckets) x (num_buckets one-hot, qlen, klen).
722
+ # This is equivalent to relative_attention_bias[:, rp_bucket]
723
+ bcast_iota = lax.broadcasted_iota(jnp.int32, (self.num_buckets, 1, 1), 0)
724
+ rp_bucket_one_hot = jnp.array(rp_bucket[jnp.newaxis, ...] == bcast_iota, dtype=self.dtype)
725
+ # --> shape (qlen, klen, num_heads)
726
+ values = lax.dot_general(
727
+ relative_attention_bias,
728
+ rp_bucket_one_hot,
729
+ (((1,), (0,)), ((), ())), # rhs, lhs contracting dims
730
+ ) # no batched dims
731
+ # Add a singleton batch dimension.
732
+ # --> shape (1, num_heads, qlen, klen)
733
+ return values[jnp.newaxis, ...]
734
+
735
+
736
+ # ------------------------------------------------------------------------------
737
+ # T5 Layernorm - no subtraction of mean or bias.
738
+ # ------------------------------------------------------------------------------
739
+ # class LayerNorm(nn.Module):
740
+ # """T5 Layer normalization operating on the last axis of the input data."""
741
+ # epsilon: float = 1e-6
742
+ # dtype: Any = jnp.float32
743
+ # scale_init: Initializer = nn.initializers.ones
744
+
745
+ # @nn.compact
746
+ # def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
747
+ # """Applies layer normalization on the input."""
748
+ # x = jnp.asarray(x, jnp.float32)
749
+ # features = x.shape[-1]
750
+ # mean2 = jnp.mean(lax.square(x), axis=-1, keepdims=True)
751
+ # y = jnp.asarray(x * lax.rsqrt(mean2 + self.epsilon), self.dtype)
752
+ # scale = param_with_axes(
753
+ # 'scale', self.scale_init, (features,), jnp.float32, axes=('embed',))
754
+
755
+ # scale = jnp.asarray(scale, self.dtype)
756
+ # return y * scale
757
+
758
+
759
+ class LayerNorm(nn.Module):
760
+ """Layer normalization (https://arxiv.org/abs/1607.06450).
761
+ Operates on the last axis of the input data.
762
+ It normalizes the activations of the layer for each given example in a
763
+ batch independently, rather than across a batch like Batch Normalization.
764
+ i.e. applies a transformation that maintains the mean activation within
765
+ each example close to 0 and the activation standard deviation close to 1.
766
+ Attributes:
767
+ epsilon: A small float added to variance to avoid dividing by zero.
768
+ dtype: the dtype of the computation (default: float32).
769
+ use_bias: If True, bias (beta) is added.
770
+ use_scale: If True, multiply by scale (gamma). When the next layer is linear
771
+ (also e.g. nn.relu), this can be disabled since the scaling will be done
772
+ by the next layer.
773
+ bias_init: Initializer for bias, by default, zero.
774
+ scale_init: Initializer for scale, by default, one.
775
+ """
776
+
777
+ epsilon: float = 1e-6
778
+ dtype: Any = jnp.float32
779
+ params_dtype: DType = jnp.float32
780
+ use_bias: bool = True
781
+ use_scale: bool = True
782
+ bias_init: Callable[[PRNGKey, Shape, Any], Array] = nn.initializers.zeros
783
+ scale_init: Callable[[PRNGKey, Shape, Any], Array] = nn.initializers.ones
784
+
785
+ @nn.compact
786
+ def __call__(self, x):
787
+ """Applies layer normalization on the input.
788
+ Args:
789
+ x: the inputs
790
+ Returns:
791
+ Normalized inputs (the same shape as inputs).
792
+ """
793
+ x = jnp.asarray(x, jnp.float32)
794
+ features = x.shape[-1]
795
+ mean = jnp.mean(x, axis=-1, keepdims=True)
796
+ mean2 = jnp.mean(lax.square(x), axis=-1, keepdims=True)
797
+ var = mean2 - lax.square(mean)
798
+ mul = lax.rsqrt(var + self.epsilon)
799
+ if self.use_scale:
800
+ scale = param_with_axes(
801
+ "scale",
802
+ self.scale_init,
803
+ (features,),
804
+ self.params_dtype,
805
+ axes=("embed",),
806
+ )
807
+ mul = mul * jnp.asarray(scale, self.dtype)
808
+ y = (x - mean) * mul
809
+ if self.use_bias:
810
+ bias = param_with_axes("bias", self.bias_init, (features,), self.params_dtype, axes=("embed",))
811
+ y = y + jnp.asarray(bias, self.dtype)
812
+ return jnp.asarray(y, self.dtype)
813
+
814
+
815
+ # ------------------------------------------------------------------------------
816
+ # Mask-making utility functions.
817
+ # ------------------------------------------------------------------------------
818
+ def make_attention_mask(
819
+ query_input: Array,
820
+ key_input: Array,
821
+ pairwise_fn: Callable = jnp.multiply,
822
+ extra_batch_dims: int = 0,
823
+ dtype: DType = jnp.float32,
824
+ ) -> Array:
825
+ """Mask-making helper for attention weights.
826
+
827
+ In case of 1d inputs (i.e., `[batch, len_q]`, `[batch, len_kv]`, the
828
+ attention weights will be `[batch, heads, len_q, len_kv]` and this
829
+ function will produce `[batch, 1, len_q, len_kv]`.
830
+
831
+ Args:
832
+ query_input: a batched, flat input of query_length size
833
+ key_input: a batched, flat input of key_length size
834
+ pairwise_fn: broadcasting elementwise comparison function
835
+ extra_batch_dims: number of extra batch dims to add singleton axes for, none
836
+ by default
837
+ dtype: mask return dtype
838
+
839
+ Returns:
840
+ A `[batch, 1, len_q, len_kv]` shaped mask for 1d attention.
841
+ """
842
+ # [batch, len_q, len_kv]
843
+ mask = pairwise_fn(
844
+ # [batch, len_q] -> [batch, len_q, 1]
845
+ jnp.expand_dims(query_input, axis=-1),
846
+ # [batch, len_q] -> [batch, 1, len_kv]
847
+ jnp.expand_dims(key_input, axis=-2),
848
+ )
849
+
850
+ # [batch, 1, len_q, len_kv]. This creates the head dim.
851
+ mask = jnp.expand_dims(mask, axis=-3)
852
+ mask = jnp.expand_dims(mask, axis=tuple(range(extra_batch_dims)))
853
+ return mask.astype(dtype)
854
+
855
+
856
+ def make_causal_mask(x: Array, extra_batch_dims: int = 0, dtype: DType = jnp.float32) -> Array:
857
+ """Make a causal mask for self-attention.
858
+
859
+ In case of 1d inputs (i.e., `[batch, len]`, the self-attention weights
860
+ will be `[batch, heads, len, len]` and this function will produce a
861
+ causal mask of shape `[batch, 1, len, len]`.
862
+
863
+ Note that a causal mask does not depend on the values of x; it only depends on
864
+ the shape. If x has padding elements, they will not be treated in a special
865
+ manner.
866
+
867
+ Args:
868
+ x: input array of shape `[batch, len]`
869
+ extra_batch_dims: number of batch dims to add singleton axes for, none by
870
+ default
871
+ dtype: mask return dtype
872
+
873
+ Returns:
874
+ A `[batch, 1, len, len]` shaped causal mask for 1d attention.
875
+ """
876
+ idxs = jnp.broadcast_to(jnp.arange(x.shape[-1], dtype=jnp.int32), x.shape)
877
+ return make_attention_mask(idxs, idxs, jnp.greater_equal, extra_batch_dims=extra_batch_dims, dtype=dtype)
878
+
879
+
880
+ def combine_masks(*masks: Optional[Array], dtype: DType = jnp.float32):
881
+ """Combine attention masks.
882
+
883
+ Args:
884
+ *masks: set of attention mask arguments to combine, some can be None.
885
+ dtype: final mask dtype
886
+
887
+ Returns:
888
+ Combined mask, reduced by logical and, returns None if no masks given.
889
+ """
890
+ masks = [m for m in masks if m is not None]
891
+ if not masks:
892
+ return None
893
+ assert all(
894
+ (x.ndim == masks[0].ndim for x in masks)
895
+ ), f"masks must have same rank: {tuple((x.ndim for x in masks))}"
896
+ mask, *other_masks = masks
897
+ for other_mask in other_masks:
898
+ mask = jnp.logical_and(mask, other_mask)
899
+ return mask.astype(dtype)
900
+
901
+
902
+ def combine_biases(*masks: Optional[Array]):
903
+ """Combine attention biases.
904
+
905
+ Args:
906
+ *masks: set of attention bias arguments to combine, some can be None.
907
+
908
+ Returns:
909
+ Combined mask, reduced by summation, returns None if no masks given.
910
+ """
911
+ masks = [m for m in masks if m is not None]
912
+ if not masks:
913
+ return None
914
+ assert all(
915
+ (x.ndim == masks[0].ndim for x in masks)
916
+ ), f"masks must have same rank: {tuple((x.ndim for x in masks))}"
917
+ mask, *other_masks = masks
918
+ for other_mask in other_masks:
919
+ mask = mask + other_mask
920
+ return mask
921
+
922
+
923
+ def make_decoder_mask(
924
+ decoder_target_tokens: Array,
925
+ dtype: DType,
926
+ decoder_causal_attention: Optional[Array] = None,
927
+ decoder_segment_ids: Optional[Array] = None,
928
+ ) -> Array:
929
+ """Compute the self-attention mask for a decoder.
930
+
931
+ Decoder mask is formed by combining a causal mask, a padding mask and an
932
+ optional packing mask. If decoder_causal_attention is passed, it makes the
933
+ masking non-causal for positions that have value of 1.
934
+
935
+ A prefix LM is applied to a dataset which has a notion of "inputs" and
936
+ "targets", e.g., a machine translation task. The inputs and targets are
937
+ concatenated to form a new target. `decoder_target_tokens` is the concatenated
938
+ decoder output tokens.
939
+
940
+ The "inputs" portion of the concatenated sequence can attend to other "inputs"
941
+ tokens even for those at a later time steps. In order to control this
942
+ behavior, `decoder_causal_attention` is necessary. This is a binary mask with
943
+ a value of 1 indicating that the position belonged to "inputs" portion of the
944
+ original dataset.
945
+
946
+ Example:
947
+
948
+ Suppose we have a dataset with two examples.
949
+
950
+ ds = [{"inputs": [6, 7], "targets": [8]},
951
+ {"inputs": [3, 4], "targets": [5]}]
952
+
953
+ After the data preprocessing with packing, the two examples are packed into
954
+ one example with the following three fields (some fields are skipped for
955
+ simplicity).
956
+
957
+ decoder_target_tokens = [[6, 7, 8, 3, 4, 5, 0]]
958
+ decoder_segment_ids = [[1, 1, 1, 2, 2, 2, 0]]
959
+ decoder_causal_attention = [[1, 1, 0, 1, 1, 0, 0]]
960
+
961
+ where each array has [batch, length] shape with batch size being 1. Then,
962
+ this function computes the following mask.
963
+
964
+ mask = [[[[1, 1, 0, 0, 0, 0, 0],
965
+ [1, 1, 0, 0, 0, 0, 0],
966
+ [1, 1, 1, 0, 0, 0, 0],
967
+ [0, 0, 0, 1, 1, 0, 0],
968
+ [0, 0, 0, 1, 1, 0, 0],
969
+ [0, 0, 0, 1, 1, 1, 0],
970
+ [0, 0, 0, 0, 0, 0, 0]]]]
971
+
972
+ mask[b, 1, :, :] represents the mask for the example `b` in the batch.
973
+ Because mask is for a self-attention layer, the mask's shape is a square of
974
+ shape [query length, key length].
975
+
976
+ mask[b, 1, i, j] = 1 means that the query token at position i can attend to
977
+ the key token at position j.
978
+
979
+ Args:
980
+ decoder_target_tokens: decoder output tokens. [batch, length]
981
+ dtype: dtype of the output mask.
982
+ decoder_causal_attention: a binary mask indicating which position should
983
+ only attend to earlier positions in the sequence. Others will attend
984
+ bidirectionally. [batch, length]
985
+ decoder_segment_ids: decoder segmentation info for packed examples. [batch,
986
+ length]
987
+
988
+ Returns:
989
+ the combined decoder mask.
990
+ """
991
+ masks = []
992
+ # The same mask is applied to all attention heads. So the head dimension is 1,
993
+ # i.e., the mask will be broadcast along the heads dim.
994
+ # [batch, 1, length, length]
995
+ causal_mask = make_causal_mask(decoder_target_tokens, dtype=dtype)
996
+
997
+ # Positions with value 1 in `decoder_causal_attneition` can attend
998
+ # bidirectionally.
999
+ if decoder_causal_attention is not None:
1000
+ # [batch, 1, length, length]
1001
+ inputs_mask = make_attention_mask(
1002
+ decoder_causal_attention,
1003
+ decoder_causal_attention,
1004
+ jnp.logical_and,
1005
+ dtype=dtype,
1006
+ )
1007
+ masks.append(jnp.logical_or(causal_mask, inputs_mask).astype(dtype))
1008
+ else:
1009
+ masks.append(causal_mask)
1010
+
1011
+ # Padding mask.
1012
+ masks.append(make_attention_mask(decoder_target_tokens > 0, decoder_target_tokens > 0, dtype=dtype))
1013
+
1014
+ # Packing mask
1015
+ if decoder_segment_ids is not None:
1016
+ masks.append(make_attention_mask(decoder_segment_ids, decoder_segment_ids, jnp.equal, dtype=dtype))
1017
+
1018
+ return combine_masks(*masks, dtype=dtype)
1019
+
1020
+
1021
+ def canonicalize_padding(padding: PaddingLike, rank: int) -> LaxPadding:
1022
+ """ "Canonicalizes conv padding to a jax.lax supported format."""
1023
+ if isinstance(padding, str):
1024
+ return padding
1025
+ if isinstance(padding, int):
1026
+ return [(padding, padding)] * rank
1027
+ if isinstance(padding, Sequence) and len(padding) == rank:
1028
+ new_pad = []
1029
+ for p in padding:
1030
+ if isinstance(p, int):
1031
+ new_pad.append((p, p))
1032
+ elif isinstance(p, tuple) and len(p) == 2:
1033
+ new_pad.append(p)
1034
+ else:
1035
+ break
1036
+ if len(new_pad) == rank:
1037
+ return new_pad
1038
+ raise ValueError(
1039
+ f"Invalid padding format: {padding}, should be str, int,"
1040
+ f" or a sequence of len {rank} where each element is an"
1041
+ " int or pair of ints."
1042
+ )
1043
+
1044
+
1045
+ def _conv_dimension_numbers(input_shape):
1046
+ """Computes the dimension numbers based on the input shape."""
1047
+ ndim = len(input_shape)
1048
+ lhs_spec = (0, ndim - 1) + tuple(range(1, ndim - 1))
1049
+ rhs_spec = (ndim - 1, ndim - 2) + tuple(range(0, ndim - 2))
1050
+ out_spec = lhs_spec
1051
+ return lax.ConvDimensionNumbers(lhs_spec, rhs_spec, out_spec)
1052
+
1053
+
1054
+ class _Conv(nn.Module):
1055
+ """Convolution Module wrapping `lax.conv_general_dilated[_local]`.
1056
+
1057
+ Attributes:
1058
+ features: number of convolution filters.
1059
+ kernel_size: shape of the convolutional kernel. For 1D convolution,
1060
+ the kernel size can be passed as an integer. For all other cases, it must
1061
+ be a sequence of integers.
1062
+ strides: an integer or a sequence of `n` integers, representing the
1063
+ inter-window strides (default: 1).
1064
+ padding: either the string `'SAME'`, the string `'VALID'`, the string
1065
+ `'CIRCULAR'` (periodic boundary conditions), or a sequence of `n` `(low,
1066
+ high)` integer pairs that give the padding to apply before and after each
1067
+ spatial dimension. A single int is interpeted as applying the same padding
1068
+ in all dims and passign a single int in a sequence causes the same padding
1069
+ to be used on both sides. `'CAUSAL'` padding for a 1D convolution will
1070
+ left-pad the convolution axis, resulting in same-sized output.
1071
+ input_dilation: an integer or a sequence of `n` integers, giving the
1072
+ dilation factor to apply in each spatial dimension of `inputs`
1073
+ (default: 1). Convolution with input dilation `d` is equivalent to
1074
+ transposed convolution with stride `d`.
1075
+ kernel_dilation: an integer or a sequence of `n` integers, giving the
1076
+ dilation factor to apply in each spatial dimension of the convolution
1077
+ kernel (default: 1). Convolution with kernel dilation
1078
+ is also known as 'atrous convolution'.
1079
+ feature_group_count: integer, default 1. If specified divides the input
1080
+ features into groups.
1081
+ use_bias: whether to add a bias to the output (default: True).
1082
+ mask: Optional mask for the weights during masked convolution. The mask must
1083
+ be the same shape as the convolution weight matrix.
1084
+ dtype: the dtype of the computation (default: infer from input and params).
1085
+ params_dtype: the dtype passed to parameter initializers (default: float32).
1086
+ precision: numerical precision of the computation see `jax.lax.Precision`
1087
+ for details.
1088
+ kernel_init: initializer for the convolutional kernel.
1089
+ bias_init: initializer for the bias.
1090
+ """
1091
+
1092
+ features: int
1093
+ kernel_size: Sequence[int]
1094
+ strides: Union[None, int, Sequence[int]] = 1
1095
+ padding: PaddingLike = "SAME"
1096
+ input_dilation: Union[None, int, Sequence[int]] = 1
1097
+ kernel_dilation: Union[None, int, Sequence[int]] = 1
1098
+ feature_group_count: int = 1
1099
+ use_bias: bool = True
1100
+ mask: Optional[Array] = None
1101
+ dtype: Optional[DType] = None
1102
+ params_dtype: DType = jnp.float32
1103
+ precision: PrecisionLike = None
1104
+ kernel_init: Callable[[PRNGKey, Shape, DType], Array] = nn.initializers.lecun_normal()
1105
+ bias_init: Callable[[PRNGKey, Shape, DType], Array] = nn.initializers.zeros
1106
+ conv_general_dilated: ConvGeneralDilatedT = lax.conv_general_dilated
1107
+ kernel_axes: Tuple[str, ...] = ()
1108
+
1109
+ @property
1110
+ def shared_weights(self) -> bool: # type: ignore
1111
+ """Defines whether weights are shared or not between different pixels.
1112
+
1113
+ Returns:
1114
+ `True` to use shared weights in convolution (regular convolution).
1115
+ `False` to use different weights at different pixels, a.k.a.
1116
+ "locally connected layer", "unshared convolution", or "local convolution".
1117
+
1118
+ """
1119
+ ...
1120
+
1121
+ @nn.compact
1122
+ def __call__(self, inputs: Array) -> Array:
1123
+ """Applies a (potentially unshared) convolution to the inputs.
1124
+
1125
+ Args:
1126
+ inputs: input data with dimensions (*batch_dims, spatial_dims...,
1127
+ features). This is the channels-last convention, i.e. NHWC for a 2d
1128
+ convolution and NDHWC for a 3D convolution. Note: this is different from
1129
+ the input convention used by `lax.conv_general_dilated`, which puts the
1130
+ spatial dimensions last.
1131
+ Note: If the input has more than 1 batch dimension, all batch dimensions
1132
+ are flattened into a single dimension for the convolution and restored
1133
+ before returning. In some cases directly vmap'ing the layer may yield
1134
+ better performance than this default flattening approach. If the input
1135
+ lacks a batch dimension it will be added for the convolution and removed
1136
+ n return, an allowance made to enable writing single-example code.
1137
+
1138
+ Returns:
1139
+ The convolved data.
1140
+ """
1141
+
1142
+ if isinstance(self.kernel_size, int):
1143
+ raise TypeError(
1144
+ "Expected Conv kernel_size to be a"
1145
+ " tuple/list of integers (eg.: [3, 3]) but got"
1146
+ f" {self.kernel_size}."
1147
+ )
1148
+ else:
1149
+ kernel_size = tuple(self.kernel_size)
1150
+
1151
+ def maybe_broadcast(x: Optional[Union[int, Sequence[int]]]) -> Tuple[int, ...]:
1152
+ if x is None:
1153
+ # backward compatibility with using None as sentinel for
1154
+ # broadcast 1
1155
+ x = 1
1156
+ if isinstance(x, int):
1157
+ return (x,) * len(kernel_size)
1158
+ return tuple(x)
1159
+
1160
+ # Combine all input batch dimensions into a single leading batch axis.
1161
+ num_batch_dimensions = inputs.ndim - (len(kernel_size) + 1)
1162
+ if num_batch_dimensions != 1:
1163
+ input_batch_shape = inputs.shape[:num_batch_dimensions]
1164
+ total_batch_size = int(np.prod(input_batch_shape))
1165
+ flat_input_shape = (total_batch_size,) + inputs.shape[num_batch_dimensions:]
1166
+ inputs = jnp.reshape(inputs, flat_input_shape)
1167
+
1168
+ # self.strides or (1,) * (inputs.ndim - 2)
1169
+ strides = maybe_broadcast(self.strides)
1170
+ input_dilation = maybe_broadcast(self.input_dilation)
1171
+ kernel_dilation = maybe_broadcast(self.kernel_dilation)
1172
+
1173
+ padding_lax = canonicalize_padding(self.padding, len(kernel_size))
1174
+ if padding_lax == "CIRCULAR":
1175
+ kernel_size_dilated = [(k - 1) * d + 1 for k, d in zip(kernel_size, kernel_dilation)]
1176
+ zero_pad: List[Tuple[int, int]] = [(0, 0)]
1177
+ pads = zero_pad + [((k - 1) // 2, k // 2) for k in kernel_size_dilated] + [(0, 0)]
1178
+ inputs = jnp.pad(inputs, pads, mode="wrap")
1179
+ padding_lax = "VALID"
1180
+ elif padding_lax == "CAUSAL":
1181
+ if len(kernel_size) != 1:
1182
+ raise ValueError("Causal padding is only implemented for 1D convolutions.")
1183
+ left_pad = kernel_dilation[0] * (kernel_size[0] - 1)
1184
+ pads = [(0, 0), (left_pad, 0), (0, 0)]
1185
+ inputs = jnp.pad(inputs, pads)
1186
+ padding_lax = "VALID"
1187
+
1188
+ dimension_numbers = _conv_dimension_numbers(inputs.shape)
1189
+ in_features = jnp.shape(inputs)[-1]
1190
+
1191
+ if self.shared_weights:
1192
+ # One shared convolutional kernel for all pixels in the output.
1193
+ assert in_features % self.feature_group_count == 0
1194
+ kernel_shape = kernel_size + (
1195
+ in_features // self.feature_group_count,
1196
+ self.features,
1197
+ )
1198
+
1199
+ else:
1200
+ if self.feature_group_count != 1:
1201
+ raise NotImplementedError(
1202
+ "`lax.conv_general_dilated_local` does not support "
1203
+ f"`feature_group_count != 1`, got `{self.feature_group_count}`."
1204
+ )
1205
+
1206
+ # Need to know the spatial output shape of a standard convolution to
1207
+ # create the unshared convolution kernel.
1208
+ conv_output_shape = jax.eval_shape(
1209
+ lambda lhs, rhs: self.conv_general_dilated( # pylint: disable=g-long-lambda
1210
+ lhs=lhs,
1211
+ rhs=rhs,
1212
+ window_strides=strides,
1213
+ padding=padding_lax,
1214
+ dimension_numbers=dimension_numbers,
1215
+ lhs_dilation=input_dilation,
1216
+ rhs_dilation=kernel_dilation,
1217
+ ),
1218
+ inputs,
1219
+ jax.ShapedArray(kernel_size + (in_features, self.features), inputs.dtype),
1220
+ ).shape
1221
+
1222
+ # One (unshared) convolutional kernel per each pixel in the output.
1223
+ kernel_shape = conv_output_shape[1:-1] + (
1224
+ np.prod(kernel_size) * in_features,
1225
+ self.features,
1226
+ )
1227
+
1228
+ if self.mask is not None and self.mask.shape != kernel_shape:
1229
+ raise ValueError(
1230
+ "Mask needs to have the same shape as weights. " f"Shapes are: {self.mask.shape}, {kernel_shape}"
1231
+ )
1232
+
1233
+ kernel = param_with_axes(
1234
+ "kernel",
1235
+ self.kernel_init,
1236
+ kernel_shape,
1237
+ self.params_dtype,
1238
+ axes=self.kernel_axes,
1239
+ )
1240
+
1241
+ if self.mask is not None:
1242
+ kernel *= self.mask
1243
+
1244
+ if self.use_bias:
1245
+ if self.shared_weights:
1246
+ # One bias weight per output channel, shared between pixels.
1247
+ bias_shape = (self.features,)
1248
+ else:
1249
+ # One bias weight per output entry, unshared betwen pixels.
1250
+ bias_shape = conv_output_shape[1:]
1251
+
1252
+ bias = param_with_axes(
1253
+ "bias",
1254
+ self.bias_init,
1255
+ bias_shape,
1256
+ self.params_dtype,
1257
+ axes=(self.kernel_axes[-1],),
1258
+ )
1259
+ else:
1260
+ bias = None
1261
+
1262
+ inputs, kernel, bias = promote_dtype(inputs, kernel, bias, dtype=self.dtype)
1263
+ if self.shared_weights:
1264
+ y = self.conv_general_dilated(
1265
+ inputs,
1266
+ kernel,
1267
+ strides,
1268
+ padding_lax,
1269
+ lhs_dilation=input_dilation,
1270
+ rhs_dilation=kernel_dilation,
1271
+ dimension_numbers=dimension_numbers,
1272
+ feature_group_count=self.feature_group_count,
1273
+ precision=self.precision,
1274
+ )
1275
+ else:
1276
+ y = lax.conv_general_dilated_local(
1277
+ lhs=inputs,
1278
+ rhs=kernel,
1279
+ window_strides=strides,
1280
+ padding=padding_lax,
1281
+ filter_shape=kernel_size,
1282
+ lhs_dilation=input_dilation,
1283
+ rhs_dilation=kernel_dilation,
1284
+ dimension_numbers=dimension_numbers,
1285
+ precision=self.precision,
1286
+ )
1287
+
1288
+ if self.use_bias:
1289
+ bias = bias.reshape((1,) * (y.ndim - bias.ndim) + bias.shape)
1290
+ y += bias
1291
+
1292
+ if num_batch_dimensions != 1:
1293
+ output_shape = input_batch_shape + y.shape[1:]
1294
+ y = jnp.reshape(y, output_shape)
1295
+ return y
1296
+
1297
+
1298
+ class Conv(_Conv):
1299
+ """Convolution Module wrapping `lax.conv_general_dilated`.
1300
+
1301
+ Attributes:
1302
+ features: number of convolution filters.
1303
+ kernel_size: shape of the convolutional kernel. For 1D convolution,
1304
+ the kernel size can be passed as an integer. For all other cases, it must
1305
+ be a sequence of integers.
1306
+ strides: an integer or a sequence of `n` integers, representing the
1307
+ inter-window strides (default: 1).
1308
+ padding: either the string `'SAME'`, the string `'VALID'`, the string
1309
+ `'CIRCULAR'` (periodic boundary conditions), or a sequence of `n` `(low,
1310
+ high)` integer pairs that give the padding to apply before and after each
1311
+ spatial dimension. A single int is interpeted as applying the same padding
1312
+ in all dims and passign a single int in a sequence causes the same padding
1313
+ to be used on both sides. `'CAUSAL'` padding for a 1D convolution will
1314
+ left-pad the convolution axis, resulting in same-sized output.
1315
+ input_dilation: an integer or a sequence of `n` integers, giving the
1316
+ dilation factor to apply in each spatial dimension of `inputs`
1317
+ (default: 1). Convolution with input dilation `d` is equivalent to
1318
+ transposed convolution with stride `d`.
1319
+ kernel_dilation: an integer or a sequence of `n` integers, giving the
1320
+ dilation factor to apply in each spatial dimension of the convolution
1321
+ kernel (default: 1). Convolution with kernel dilation
1322
+ is also known as 'atrous convolution'.
1323
+ feature_group_count: integer, default 1. If specified divides the input
1324
+ features into groups.
1325
+ use_bias: whether to add a bias to the output (default: True).
1326
+ mask: Optional mask for the weights during masked convolution. The mask must
1327
+ be the same shape as the convolution weight matrix.
1328
+ dtype: the dtype of the computation (default: infer from input and params).
1329
+ params_dtype: the dtype passed to parameter initializers (default: float32).
1330
+ precision: numerical precision of the computation see `jax.lax.Precision`
1331
+ for details.
1332
+ kernel_init: initializer for the convolutional kernel.
1333
+ bias_init: initializer for the bias.
1334
+ """
1335
+
1336
+ @property
1337
+ def shared_weights(self) -> bool:
1338
+ return True
distil_whisper/modeling_flax_whisper.py ADDED
@@ -0,0 +1,2135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 The OpenAI Authors and The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """ Flax whisper model."""
16
+
17
+ import random
18
+ from functools import partial
19
+ from typing import Dict, Optional, Tuple, Union
20
+
21
+ import flax.linen as nn
22
+ import jax
23
+ import jax.numpy as jnp
24
+ from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
25
+ from flax.linen import combine_masks, make_causal_mask
26
+ from flax.linen.attention import dot_product_attention_weights
27
+ from flax.linen.partitioning import remat, scan_with_axes
28
+ from flax.traverse_util import flatten_dict, unflatten_dict
29
+ from jax import lax
30
+ from jax.random import PRNGKey
31
+ from transformers import WhisperConfig
32
+ from transformers.generation.flax_logits_process import (
33
+ FlaxLogitsProcessor,
34
+ FlaxLogitsProcessorList,
35
+ FlaxWhisperTimeStampLogitsProcessor,
36
+ )
37
+ from transformers.modeling_flax_outputs import (
38
+ FlaxBaseModelOutput,
39
+ FlaxBaseModelOutputWithPastAndCrossAttentions,
40
+ FlaxCausalLMOutputWithCrossAttentions,
41
+ FlaxSeq2SeqLMOutput,
42
+ FlaxSeq2SeqModelOutput,
43
+ )
44
+ from transformers.modeling_flax_utils import (
45
+ ACT2FN,
46
+ FlaxPreTrainedModel,
47
+ append_call_sample_docstring,
48
+ append_replace_return_docstrings,
49
+ overwrite_call_docstring,
50
+ )
51
+ from transformers.utils import (
52
+ add_start_docstrings,
53
+ add_start_docstrings_to_model_forward,
54
+ logging,
55
+ replace_return_docstrings,
56
+ )
57
+
58
+ from .layers import Conv, DenseGeneral, Embed, LayerNorm, with_sharding_constraint
59
+
60
+
61
+ logger = logging.get_logger(__name__)
62
+
63
+
64
+ _CHECKPOINT_FOR_DOC = "openai/whisper-tiny"
65
+ _CONFIG_FOR_DOC = "WhisperConfig"
66
+
67
+
68
+ WHISPER_START_DOCSTRING = r"""
69
+ This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the
70
+ library implements for all its models (such as downloading or saving, resizing the input embeddings, pruning heads
71
+ etc.) This model is also a Flax Linen
72
+ [flax.nn.Module](https://flax.readthedocs.io/en/latest/_autosummary/flax.nn.module.html) subclass. Use it as a
73
+ regular Flax Module and refer to the Flax documentation for all matter related to general usage and behavior.
74
+ Finally, this model supports inherent JAX features such as:
75
+ - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit)
76
+ - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation)
77
+ - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap)
78
+ - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap)
79
+
80
+ Parameters:
81
+ config ([`WhisperConfig`]): Model configuration class with all the parameters of the model.
82
+ Initializing with a config file does not load the weights associated with the model, only the
83
+ configuration. Check out the [`~FlaxPreTrainedModel.from_pretrained`] method to load the model weights.
84
+ dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`):
85
+ The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on GPUs) and
86
+ `jax.numpy.bfloat16` (on TPUs). This can be used to enable mixed-precision training or half-precision
87
+ inference on GPUs or TPUs. If specified all the computation will be performed with the given `dtype`.
88
+ **Note that this only specifies the dtype of the computation and does not influence the dtype of model
89
+ parameters.** If you wish to change the dtype of the model parameters, see [`~FlaxPreTrainedModel.to_fp16`]
90
+ and [`~FlaxPreTrainedModel.to_bf16`].
91
+ """
92
+
93
+ WHISPER_INPUTS_DOCSTRING = r"""
94
+ Args:
95
+ input_features (`numpy.ndarray` of shape `(batch_size, feature_size, sequence_length)`):
96
+ Float values mel features extracted from the raw speech waveform. Raw speech waveform can be obtained by
97
+ loading a `.flac` or `.wav` audio file into an array of type `List[float]` or a `numpy.ndarray`, *e.g.* via
98
+ the soundfile library (`pip install soundfile`). To prepare the array into `input_features`, the
99
+ [`WhisperFeatureExtractor`] should be used for extracting the features, padding and conversion into a
100
+ tensor of type `numpy.ndarray`. See [`~WhisperFeatureExtractor.__call__`]
101
+ attention_mask (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
102
+ Whisper does not support masking of the `input_features`, this argument is preserved for compatibility, but
103
+ is not used. By default the silence in the input log mel spectrogram are ignored.
104
+ decoder_input_ids (`numpy.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*):
105
+ Indices of decoder input sequence tokens in the vocabulary. Indices can be obtained using
106
+ [`WhisperTokenizer`]. See [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details.
107
+ [What are decoder input IDs?](../glossary#decoder-input-ids) Whisper uses the `decoder_start_token_id` as
108
+ the starting token for `decoder_input_ids` generation.
109
+ decoder_attention_mask (`numpy.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*):
110
+ Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
111
+ be used by default. If you want to change padding behavior, you should modify to your needs. See diagram 1
112
+ in [the paper](https://arxiv.org/abs/1910.13461) for more information on the default strategy.
113
+ position_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
114
+ Whisper does not use `position_ids` in the encoder as `input_features` is always the same size and doesn't
115
+ use masking, but this argument is preserved for compatibility. By default the silence in the input log mel
116
+ spectrogram are ignored.
117
+ decoder_position_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
118
+ Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the
119
+ range `[0, config.max_position_embeddings - 1]`.
120
+ output_attentions (`bool`, *optional*):
121
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
122
+ tensors for more detail.
123
+ output_hidden_states (`bool`, *optional*):
124
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
125
+ more detail.
126
+ return_dict (`bool`, *optional*):
127
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
128
+ """
129
+
130
+ WHISPER_ENCODE_INPUTS_DOCSTRING = r"""
131
+ Args:
132
+ input_features (`numpy.ndarray` of shape `(batch_size, feature_size, sequence_length)`):
133
+ Float values mel features extracted from the raw speech waveform. Raw speech waveform can be obtained by
134
+ loading a `.flac` or `.wav` audio file into an array of type `List[float]` or a `numpy.ndarray`, *e.g.* via
135
+ the soundfile library (`pip install soundfile`). To prepare the array into `input_features`, the
136
+ [`WhisperFeatureExtractor`] should be used for extracting the mel features, padding and conversion into a
137
+ tensor of type `numpy.ndarray`. See [`~WhisperFeatureExtractor.__call__`].
138
+ attention_mask (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
139
+ Whisper does not support masking of the `input_features`, this argument is preserved for compatibility, but
140
+ is not used. By default the silence in the input log mel spectrogram are ignored.
141
+ output_attentions (`bool`, *optional*):
142
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
143
+ tensors for more detail.
144
+ output_hidden_states (`bool`, *optional*):
145
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
146
+ more detail.
147
+ return_dict (`bool`, *optional*):
148
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
149
+ """
150
+
151
+ WHISPER_DECODE_INPUTS_DOCSTRING = r"""
152
+ Args:
153
+ decoder_input_ids (`numpy.ndarray` of shape `(batch_size, target_sequence_length)`):
154
+ Indices of decoder input sequence tokens in the vocabulary. Indices can be obtained using
155
+ [`WhisperTokenizer`]. See [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details.
156
+ [What are decoder input IDs?](../glossary#decoder-input-ids)
157
+ encoder_outputs (`tuple(tuple(numpy.ndarray)`):
158
+ Tuple consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: `attentions`)
159
+ `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) is a sequence of
160
+ hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder.
161
+ encoder_attention_mask (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
162
+ Whisper does not support masking of the `input_features`, this argument is preserved for compatibility,
163
+ but it is not used. By default the silence in the input log mel spectrogram are ignored.
164
+ decoder_attention_mask (`numpy.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*):
165
+ Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
166
+ be used by default. If you want to change padding behavior, you should modify to your needs. See diagram 1
167
+ in [the paper](https://arxiv.org/abs/1910.13461) for more information on the default strategy.
168
+ decoder_position_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
169
+ Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the
170
+ range `[0, config.max_position_embeddings - 1]`.
171
+ past_key_values (`Dict[str, numpy.ndarray]`, *optional*, returned by `init_cache` or when passing previous `past_key_values`):
172
+ Dictionary of pre-computed hidden-states (key and values in the attention blocks) that can be used for fast
173
+ auto-regressive decoding. Pre-computed key and value hidden-states are of shape *[batch_size, max_length]*.
174
+ output_attentions (`bool`, *optional*):
175
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
176
+ tensors for more detail.
177
+ output_hidden_states (`bool`, *optional*):
178
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
179
+ more detail.
180
+ return_dict (`bool`, *optional*):
181
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
182
+ """
183
+
184
+
185
+ class FlaxStaticForceTokensLogitsProcessor(FlaxLogitsProcessor):
186
+ r"""
187
+ [`FlaxLogitsProcessor`] that takes a list of pairs of integers which indicates a mapping from generation indices to
188
+ token indices that will be forced before sampling. The processor will set their log probs to 0 and all other tokens
189
+ to `-inf` so that they are sampled at their corresponding index. This is a static version of the `transformers` logit
190
+ processor [`FlaxForceTokensLogitsProcessor`] that is compatible with sharded forced tokens.
191
+
192
+ Args:
193
+ force_token_map (`list`):
194
+ Map giving token ids and indices where they will be forced to be sampled.
195
+ """
196
+
197
+ def __init__(self, force_token_map):
198
+ # The generic `transformers` logit processor builds `force_token_array` as a dictionary - this is not a valid
199
+ # JAX type, and so we switch to using a JAX array instead
200
+ force_token_map = jnp.array(force_token_map)
201
+ # Converts the array of format [[index, token]] containing the tokens to be forced to an array, where the
202
+ # index of the array corresponds to the index of the token to be forced. For XLA compatibility,
203
+ # indexes without forced tokens will have a negative value. Note that the last token we ever need to force in
204
+ # Whisper is at position 3, so we only construct an array up to this index. The native version constructs a tensor
205
+ # dynamically according to the length of the `force_token_map`. Array shapes need to be concrete for XLA compatibility,
206
+ # so this is not permitted here.
207
+ force_token_array = jnp.ones(3, dtype=jnp.int32) * -1
208
+ for index, token in force_token_map:
209
+ force_token_array = force_token_array.at[index].set(token)
210
+ self.force_token_array = jnp.int32(force_token_array)
211
+
212
+ def __call__(self, input_ids: jnp.ndarray, scores: jnp.ndarray, cur_len: int) -> jnp.ndarray:
213
+ def _force_token(generation_idx):
214
+ batch_size = scores.shape[0]
215
+ current_token = self.force_token_array[generation_idx]
216
+
217
+ new_scores = jnp.ones_like(scores, dtype=scores.dtype) * -float("inf")
218
+ updates = jnp.zeros((batch_size, 1), dtype=scores.dtype)
219
+ new_scores = lax.dynamic_update_slice(new_scores, updates, (0, current_token))
220
+ return new_scores
221
+
222
+ scores = lax.cond(
223
+ cur_len >= self.force_token_array.shape[0],
224
+ # If the current length is geq than the length of force_token_array, the processor does nothing.
225
+ lambda: scores,
226
+ # Otherwise, it may force a certain token.
227
+ lambda: lax.cond(
228
+ self.force_token_array[cur_len] >= 0,
229
+ # Only valid (positive) tokens are forced
230
+ lambda: _force_token(cur_len),
231
+ # Otherwise, the processor does nothing.
232
+ lambda: scores,
233
+ ),
234
+ )
235
+ return scores
236
+
237
+
238
+ class FlaxWhisperAttention(nn.Module):
239
+ config: WhisperConfig
240
+ embed_dim: int
241
+ num_heads: int
242
+ dropout: float = 0.0
243
+ causal: bool = False
244
+ bias: bool = True
245
+ dtype: jnp.dtype = jnp.float32
246
+ params_dtype: jnp.dtype = jnp.float32
247
+
248
+ def setup(self) -> None:
249
+ self.head_dim = self.embed_dim // self.num_heads
250
+ if self.head_dim * self.num_heads != self.embed_dim:
251
+ raise ValueError(
252
+ "embed_dim must be divisible by num_heads (got `embed_dim`:"
253
+ f" {self.embed_dim} and `num_heads`: {self.num_heads})."
254
+ )
255
+
256
+ dense = partial(
257
+ DenseGeneral,
258
+ self.embed_dim,
259
+ axis=-1,
260
+ dtype=self.dtype,
261
+ params_dtype=self.params_dtype,
262
+ kernel_axes=("embed", "joined_kv"),
263
+ )
264
+
265
+ self.q_proj = dense(use_bias=self.bias)
266
+ self.k_proj = dense(use_bias=False)
267
+ self.v_proj = dense(use_bias=self.bias)
268
+
269
+ self.out_proj = DenseGeneral(
270
+ self.embed_dim,
271
+ axis=-1,
272
+ dtype=self.dtype,
273
+ params_dtype=self.params_dtype,
274
+ kernel_axes=("joined_kv", "embed"),
275
+ use_bias=self.bias,
276
+ )
277
+
278
+ if self.causal:
279
+ self.causal_mask = make_causal_mask(
280
+ jnp.ones((1, self.config.max_target_positions), dtype="bool"),
281
+ dtype="bool",
282
+ )
283
+
284
+ def __call__(
285
+ self,
286
+ hidden_states: jnp.ndarray,
287
+ key_value_states: Optional[jnp.ndarray] = None,
288
+ attention_mask: Optional[jnp.ndarray] = None,
289
+ init_cache: bool = False,
290
+ deterministic: bool = True,
291
+ ) -> Tuple[jnp.ndarray]:
292
+ is_cross_attention = key_value_states is not None
293
+ batch_size = hidden_states.shape[0]
294
+
295
+ query_states = self.q_proj(hidden_states)
296
+
297
+ if is_cross_attention:
298
+ key_states = self.k_proj(key_value_states)
299
+ value_states = self.v_proj(key_value_states)
300
+ else:
301
+ key_states = self.k_proj(hidden_states)
302
+ value_states = self.v_proj(hidden_states)
303
+
304
+ query_states = self._split_heads(query_states)
305
+ key_states = self._split_heads(key_states)
306
+ value_states = self._split_heads(value_states)
307
+
308
+ query_states = with_sharding_constraint(query_states, ("batch", "length", "heads", "kv"))
309
+ key_states = with_sharding_constraint(key_states, ("batch", "length", "heads", "kv"))
310
+ value_states = with_sharding_constraint(value_states, ("batch", "length", "heads", "kv"))
311
+
312
+ if self.causal:
313
+ query_length, key_length = query_states.shape[1], key_states.shape[1]
314
+ if self.has_variable("cache", "cached_key"):
315
+ mask_shift = self.variables["cache"]["cache_index"]
316
+ # max_length of cached_key is last dim
317
+ max_decoder_length = self.variables["cache"]["cached_key"].shape[-1]
318
+ causal_mask = lax.dynamic_slice(
319
+ self.causal_mask,
320
+ (0, 0, mask_shift, 0),
321
+ (1, 1, query_length, max_decoder_length),
322
+ )
323
+ else:
324
+ causal_mask = self.causal_mask[:, :, :query_length, :key_length]
325
+ causal_mask = jnp.broadcast_to(causal_mask, (batch_size,) + causal_mask.shape[1:])
326
+
327
+ # combine masks if needed
328
+ if attention_mask is not None and self.causal:
329
+ attention_mask = jnp.broadcast_to(jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_mask.shape)
330
+ attention_mask = combine_masks(attention_mask, causal_mask)
331
+ elif self.causal:
332
+ attention_mask = causal_mask
333
+ elif attention_mask is not None:
334
+ attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2))
335
+
336
+ # During fast autoregressive decoding, we feed one position at a time,
337
+ # and cache the keys and values step by step.
338
+
339
+ if self.causal and (self.has_variable("cache", "cached_key") or init_cache):
340
+ key_states, value_states, attention_mask = self._concatenate_to_cache(
341
+ key_states, value_states, query_states, attention_mask
342
+ )
343
+
344
+ # Convert the boolean attention mask to an attention bias.
345
+ if attention_mask is not None:
346
+ # attention mask in the form of attention bias
347
+ attention_bias = lax.select(
348
+ attention_mask > 0,
349
+ jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
350
+ jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype),
351
+ )
352
+ else:
353
+ attention_bias = None
354
+
355
+ dropout_rng = None
356
+ if not deterministic and self.dropout > 0.0:
357
+ dropout_rng = self.make_rng("dropout")
358
+
359
+ attn_weights = dot_product_attention_weights(
360
+ query_states,
361
+ key_states,
362
+ bias=attention_bias,
363
+ dropout_rng=dropout_rng,
364
+ dropout_rate=self.dropout,
365
+ broadcast_dropout=True,
366
+ deterministic=deterministic,
367
+ dtype=self.dtype,
368
+ precision=None,
369
+ )
370
+
371
+ attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value_states)
372
+ attn_output = self._merge_heads(attn_output)
373
+ attn_output = self.out_proj(attn_output)
374
+
375
+ return attn_output, attn_weights
376
+
377
+ def _split_heads(self, hidden_state) -> jnp.ndarray:
378
+ return hidden_state.reshape(hidden_state.shape[:2] + (self.num_heads, self.head_dim))
379
+
380
+ def _merge_heads(self, hidden_state) -> jnp.ndarray:
381
+ return hidden_state.reshape(hidden_state.shape[:2] + (self.embed_dim,))
382
+
383
+ @nn.compact
384
+ def _concatenate_to_cache(self, key, value, query, attention_mask):
385
+ # The following code is largely copied from: https://github.com/google-research/t5x/blob/63d9addf628c6d8c547a407a32095fcb527bb20b/t5x/examples/scalable_t5/layers.py#L280-L284
386
+ is_initialized = self.has_variable("cache", "cached_key")
387
+
388
+ # The key and value have dimension [batch_size, seq_length, num_heads, head_dim],
389
+ # but we cache them as [batch_size, num_heads, head_dim, seq_length] as a TPU
390
+ # fusion optimization. This also enables the "scatter via one-hot
391
+ # broadcast" trick, which means we do a one-hot broadcast instead of a
392
+ # scatter/gather operations, resulting in a 3-4x speedup in practice.
393
+ def swap_dims(x):
394
+ return x[:-3] + tuple(x[i] for i in [-2, -1, -3])
395
+
396
+ cached_key = self.variable("cache", "cached_key", jnp.zeros, swap_dims(key.shape), key.dtype)
397
+ cached_value = self.variable("cache", "cached_value", jnp.zeros, swap_dims(value.shape), value.dtype)
398
+ cache_index = self.variable("cache", "cache_index", lambda: jnp.array(0, dtype=jnp.int32))
399
+
400
+ if is_initialized:
401
+ batch_size, num_heads, head_dim, seq_length = cached_key.value.shape
402
+ # During fast autoregressive decoding, we feed one position at a time,
403
+ # and cache the keys and values step by step.
404
+ # Sanity shape check of cached key against input query.
405
+ num_updated_cache_vectors = query.shape[1]
406
+ expected_shape = (batch_size, 1, num_heads, head_dim)
407
+ if num_updated_cache_vectors == 1 and expected_shape != query.shape:
408
+ raise ValueError(
409
+ "Autoregressive cache shape error, expected query shape"
410
+ f" {expected_shape} instead got {query.shape}"
411
+ )
412
+
413
+ # Create a OHE of the current index. NOTE: the index is increased below.
414
+ cur_index = cache_index.value
415
+
416
+ # In order to update the key, value caches with the current key and
417
+ # value, we move the seq_length axis to the back, similar to what we did for
418
+ # the cached ones above.
419
+ # Note these are currently the key and value of a single position, since
420
+ # we feed one position at a time.
421
+ one_token_key = jnp.moveaxis(key, -3, -1)
422
+ one_token_value = jnp.moveaxis(value, -3, -1)
423
+
424
+ # Update key, value caches with our new 1d spatial slices.
425
+ # We implement an efficient scatter into the cache via one-hot
426
+ # broadcast and addition.
427
+ if num_updated_cache_vectors > 1:
428
+ indices = jnp.eye(num_updated_cache_vectors, seq_length)[None, None]
429
+ key = cached_key.value + jnp.matmul(one_token_key, indices)
430
+ value = cached_value.value + jnp.matmul(one_token_value, indices)
431
+ else:
432
+ one_hot_indices = jax.nn.one_hot(cur_index, seq_length, dtype=key.dtype)
433
+ key = cached_key.value + one_token_key * one_hot_indices
434
+ value = cached_value.value + one_token_value * one_hot_indices
435
+
436
+ cached_key.value = key
437
+ cached_value.value = value
438
+ cache_index.value = cache_index.value + num_updated_cache_vectors
439
+
440
+ # Move the keys and values back to their original shapes.
441
+ key = jnp.moveaxis(key, -1, -3)
442
+ value = jnp.moveaxis(value, -1, -3)
443
+
444
+ # causal mask for cached decoder self-attention: our single query position should only
445
+ # attend to those key positions that have already been generated and cached, not the
446
+ # remaining zero elements.
447
+ pad_mask = jnp.broadcast_to(
448
+ jnp.arange(seq_length) < cur_index + num_updated_cache_vectors,
449
+ (batch_size,) + (1, num_updated_cache_vectors, seq_length),
450
+ )
451
+ attention_mask = combine_masks(pad_mask, attention_mask)
452
+
453
+ return key, value, attention_mask
454
+
455
+
456
+ class FlaxWhisperEncoderLayer(nn.Module):
457
+ config: WhisperConfig
458
+ dtype: jnp.dtype = jnp.float32
459
+ params_dtype: jnp.dtype = jnp.float32
460
+ use_scan: bool = False
461
+
462
+ def setup(self) -> None:
463
+ self.embed_dim = self.config.d_model
464
+ self.self_attn = FlaxWhisperAttention(
465
+ config=self.config,
466
+ embed_dim=self.embed_dim,
467
+ num_heads=self.config.encoder_attention_heads,
468
+ dropout=self.config.attention_dropout,
469
+ dtype=self.dtype,
470
+ params_dtype=self.params_dtype,
471
+ )
472
+ self.self_attn_layer_norm = LayerNorm(dtype=self.dtype, epsilon=1e-05, params_dtype=self.params_dtype)
473
+ self.dropout_layer = nn.Dropout(rate=self.config.dropout)
474
+ self.activation_fn = ACT2FN[self.config.activation_function]
475
+ self.activation_dropout_layer = nn.Dropout(rate=self.config.activation_dropout)
476
+ self.fc1 = DenseGeneral(
477
+ self.config.encoder_ffn_dim,
478
+ dtype=self.dtype,
479
+ params_dtype=self.params_dtype,
480
+ kernel_axes=("embed", "mlp"),
481
+ )
482
+ self.fc2 = DenseGeneral(
483
+ self.embed_dim,
484
+ dtype=self.dtype,
485
+ params_dtype=self.params_dtype,
486
+ kernel_axes=("mlp", "embed"),
487
+ )
488
+ self.final_layer_norm = LayerNorm(dtype=self.dtype, epsilon=1e-05, params_dtype=self.params_dtype)
489
+
490
+ def __call__(
491
+ self,
492
+ hidden_states: jnp.ndarray,
493
+ attention_mask: jnp.ndarray,
494
+ output_attentions: bool = True,
495
+ deterministic: bool = True,
496
+ all_hidden_states=None, # only used when `use_scan=True` -> we have to fetch the hidden states from within the layer
497
+ ) -> Tuple[jnp.ndarray]:
498
+ if self.use_scan:
499
+ hidden_states = hidden_states[0]
500
+
501
+ hidden_states = with_sharding_constraint(hidden_states, ("batch", "length", "embed"))
502
+
503
+ residual = hidden_states
504
+
505
+ layernorm_output = self.self_attn_layer_norm(hidden_states)
506
+ layernorm_output = with_sharding_constraint(layernorm_output, ("batch", "length", "embed"))
507
+
508
+ attn_output, attn_weights = self.self_attn(hidden_states=layernorm_output, attention_mask=attention_mask)
509
+ attn_output = self.dropout_layer(attn_output, deterministic=deterministic)
510
+ attn_output = residual + attn_output
511
+ attn_output = with_sharding_constraint(attn_output, ("batch", "length", "embed"))
512
+
513
+ residual = attn_output
514
+
515
+ post_layer_norm = self.final_layer_norm(attn_output)
516
+ post_layer_norm = with_sharding_constraint(post_layer_norm, ("batch", "length", "embed"))
517
+
518
+ fc1_output = self.activation_fn(self.fc1(post_layer_norm))
519
+ fc1_output = self.activation_dropout_layer(fc1_output, deterministic=deterministic)
520
+ fc1_output = with_sharding_constraint(fc1_output, ("batch", "length", "mlp"))
521
+
522
+ hidden_states = self.fc2(fc1_output)
523
+ hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
524
+ hidden_states = residual + hidden_states
525
+ hidden_states = with_sharding_constraint(hidden_states, ("batch", "length", "embed"))
526
+
527
+ outputs = (hidden_states,)
528
+
529
+ if output_attentions:
530
+ outputs += (attn_weights,)
531
+
532
+ if self.use_scan:
533
+ if all_hidden_states is not None:
534
+ all_hidden_states = all_hidden_states + (hidden_states,)
535
+ outputs = (
536
+ outputs,
537
+ all_hidden_states,
538
+ )
539
+
540
+ return outputs
541
+
542
+
543
+ class FlaxWhisperEncoderLayerCollection(nn.Module):
544
+ config: WhisperConfig
545
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
546
+ params_dtype: jnp.dtype = jnp.float32
547
+ use_scan: bool = False
548
+ gradient_checkpointing: bool = False
549
+
550
+ @nn.compact
551
+ def __call__(
552
+ self,
553
+ hidden_states,
554
+ attention_mask,
555
+ deterministic: bool = True,
556
+ output_attentions: bool = False,
557
+ output_hidden_states: bool = False,
558
+ return_dict: bool = True,
559
+ ):
560
+ all_attentions = () if output_attentions else None
561
+ all_hidden_states = () if output_hidden_states else None
562
+
563
+ FlaxWhisperEncoderCheckpointLayer = (
564
+ remat(
565
+ FlaxWhisperEncoderLayer,
566
+ static_argnums=(2, 3),
567
+ prevent_cse=not self.use_scan,
568
+ )
569
+ if self.gradient_checkpointing
570
+ else FlaxWhisperEncoderLayer
571
+ )
572
+
573
+ if self.use_scan:
574
+ if output_attentions:
575
+ raise ValueError("Cannot use `scan` with `output_attentions` set to True")
576
+
577
+ # nicest behaviour for scan is to let the compiler figure out the correct shapes for the hidden states
578
+ # so we'll just pass an empty tuple as the carry initializer and hold on to the first hidden states for later
579
+ input_hidden_states = hidden_states
580
+ hidden_states = (hidden_states,)
581
+
582
+ hidden_states, all_hidden_states = scan_with_axes(
583
+ FlaxWhisperEncoderCheckpointLayer,
584
+ variable_axes={"params": 0, "cache": 0},
585
+ split_rngs={"params": True, "dropout": True},
586
+ in_axes=(
587
+ nn.broadcast,
588
+ nn.broadcast,
589
+ nn.broadcast,
590
+ nn.broadcast,
591
+ ),
592
+ variable_carry="all_hidden_states",
593
+ length=self.config.encoder_layers,
594
+ )(
595
+ self.config,
596
+ dtype=self.dtype,
597
+ params_dtype=self.params_dtype,
598
+ use_scan=True,
599
+ name="FlaxEncoderScanLayers",
600
+ )(
601
+ hidden_states,
602
+ attention_mask,
603
+ output_attentions,
604
+ deterministic,
605
+ all_hidden_states, # tuple intializer (or None if not using output_hidden_states)
606
+ )
607
+
608
+ # remove the scan dimension
609
+ hidden_states = hidden_states[0]
610
+
611
+ if output_hidden_states:
612
+ # if we're using scan we'll surely be training -> return hidden states as a tensor rather than tuple
613
+ all_hidden_states = jnp.vstack([input_hidden_states[None, ...], all_hidden_states[0]])
614
+
615
+ else:
616
+ for layer_idx in range(self.config.encoder_layers):
617
+ if output_hidden_states:
618
+ all_hidden_states = all_hidden_states + (hidden_states,)
619
+ # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
620
+ dropout_probability = random.uniform(0, 1)
621
+ if not deterministic and (dropout_probability < self.config.encoder_layerdrop): # skip the layer
622
+ layer_outputs = (None, None)
623
+ else:
624
+ layer_outputs = FlaxWhisperEncoderCheckpointLayer(
625
+ self.config,
626
+ dtype=self.dtype,
627
+ params_dtype=self.params_dtype,
628
+ name=str(layer_idx),
629
+ )(
630
+ hidden_states,
631
+ attention_mask,
632
+ output_attentions,
633
+ deterministic,
634
+ )
635
+ hidden_states = layer_outputs[0]
636
+ if output_attentions:
637
+ all_attentions = all_attentions + (layer_outputs[1],)
638
+
639
+ if output_hidden_states:
640
+ all_hidden_states += (hidden_states,)
641
+
642
+ outputs = (hidden_states, all_hidden_states, all_attentions)
643
+
644
+ if not return_dict:
645
+ return tuple(v for v in outputs if v is not None)
646
+
647
+ return FlaxBaseModelOutput(
648
+ last_hidden_state=hidden_states,
649
+ hidden_states=all_hidden_states,
650
+ attentions=all_attentions,
651
+ )
652
+
653
+
654
+ class FlaxWhisperDecoderLayer(nn.Module):
655
+ config: WhisperConfig
656
+ dtype: jnp.dtype = jnp.float32
657
+ params_dtype: jnp.dtype = jnp.float32
658
+ use_scan: bool = False
659
+
660
+ def setup(self) -> None:
661
+ self.embed_dim = self.config.d_model
662
+ self.self_attn = FlaxWhisperAttention(
663
+ config=self.config,
664
+ embed_dim=self.embed_dim,
665
+ num_heads=self.config.decoder_attention_heads,
666
+ dropout=self.config.attention_dropout,
667
+ causal=True,
668
+ dtype=self.dtype,
669
+ params_dtype=self.params_dtype,
670
+ )
671
+ self.dropout_layer = nn.Dropout(rate=self.config.dropout)
672
+ self.activation_fn = ACT2FN[self.config.activation_function]
673
+ self.activation_dropout_layer = nn.Dropout(rate=self.config.activation_dropout)
674
+
675
+ self.self_attn_layer_norm = LayerNorm(dtype=self.dtype, epsilon=1e-05, params_dtype=self.params_dtype)
676
+ self.encoder_attn = FlaxWhisperAttention(
677
+ config=self.config,
678
+ embed_dim=self.embed_dim,
679
+ num_heads=self.config.decoder_attention_heads,
680
+ dropout=self.config.attention_dropout,
681
+ dtype=self.dtype,
682
+ params_dtype=self.params_dtype,
683
+ )
684
+ self.encoder_attn_layer_norm = LayerNorm(dtype=self.dtype, epsilon=1e-05, params_dtype=self.params_dtype)
685
+ self.fc1 = DenseGeneral(
686
+ self.config.decoder_ffn_dim,
687
+ dtype=self.dtype,
688
+ params_dtype=self.params_dtype,
689
+ kernel_axes=("embed", "mlp"),
690
+ )
691
+ self.fc2 = DenseGeneral(
692
+ self.embed_dim,
693
+ dtype=self.dtype,
694
+ params_dtype=self.params_dtype,
695
+ kernel_axes=("mlp", "embed"),
696
+ )
697
+ self.final_layer_norm = LayerNorm(dtype=self.dtype, epsilon=1e-05, params_dtype=self.params_dtype)
698
+
699
+ def __call__(
700
+ self,
701
+ hidden_states: jnp.ndarray,
702
+ attention_mask: jnp.ndarray,
703
+ encoder_hidden_states: Optional[jnp.ndarray] = None,
704
+ encoder_attention_mask: Optional[jnp.ndarray] = None,
705
+ init_cache: bool = False,
706
+ output_attentions: bool = True,
707
+ deterministic: bool = True,
708
+ all_hidden_states=None, # only used when `use_scan=True` -> we have to fetch the hidden states from within the layer
709
+ ) -> Tuple[jnp.ndarray]:
710
+ if self.use_scan:
711
+ hidden_states = hidden_states[0]
712
+
713
+ hidden_states = with_sharding_constraint(hidden_states, ("batch", "length", "embed"))
714
+
715
+ residual = hidden_states
716
+
717
+ layer_norm_output = self.self_attn_layer_norm(hidden_states)
718
+ layer_norm_output = with_sharding_constraint(layer_norm_output, ("batch", "length", "embed"))
719
+
720
+ # Self Attention
721
+ self_attn_output, self_attn_weights = self.self_attn(
722
+ hidden_states=layer_norm_output,
723
+ attention_mask=attention_mask,
724
+ init_cache=init_cache,
725
+ )
726
+ self_attn_output = self.dropout_layer(self_attn_output, deterministic=deterministic)
727
+ self_attn_output = residual + self_attn_output
728
+ self_attn_output = with_sharding_constraint(self_attn_output, ("batch", "length", "embed"))
729
+
730
+ # Cross-Attention Block
731
+ cross_attn_weights = None
732
+ if encoder_hidden_states is not None:
733
+ residual = self_attn_output
734
+
735
+ encoder_layer_norm_output = self.encoder_attn_layer_norm(self_attn_output)
736
+ encoder_layer_norm_output = with_sharding_constraint(
737
+ encoder_layer_norm_output, ("batch", "length", "embed")
738
+ )
739
+
740
+ cross_attn_output, cross_attn_weights = self.encoder_attn(
741
+ hidden_states=encoder_layer_norm_output,
742
+ key_value_states=encoder_hidden_states,
743
+ attention_mask=encoder_attention_mask,
744
+ )
745
+ cross_attn_output = self.dropout_layer(cross_attn_output, deterministic=deterministic)
746
+ cross_attn_output = residual + cross_attn_output
747
+ cross_attn_output = with_sharding_constraint(cross_attn_output, ("batch", "length", "embed"))
748
+
749
+ # Fully Connected
750
+ residual = cross_attn_output
751
+
752
+ post_layer_norm = self.final_layer_norm(cross_attn_output)
753
+ post_layer_norm = with_sharding_constraint(post_layer_norm, ("batch", "length", "embed"))
754
+
755
+ fc1_output = self.activation_fn(self.fc1(post_layer_norm))
756
+ fc1_output = self.activation_dropout_layer(fc1_output, deterministic=deterministic)
757
+ fc1_output = with_sharding_constraint(fc1_output, ("batch", "length", "mlp"))
758
+
759
+ hidden_states = self.fc2(fc1_output)
760
+ hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
761
+ hidden_states = residual + hidden_states
762
+ hidden_states = with_sharding_constraint(hidden_states, ("batch", "length", "embed"))
763
+
764
+ outputs = (hidden_states,)
765
+
766
+ if output_attentions:
767
+ outputs += (self_attn_weights, cross_attn_weights)
768
+
769
+ if self.use_scan:
770
+ if all_hidden_states is not None:
771
+ all_hidden_states = all_hidden_states + (hidden_states,)
772
+ outputs = (
773
+ outputs,
774
+ all_hidden_states,
775
+ )
776
+
777
+ return outputs
778
+
779
+
780
+ class FlaxWhisperDecoderLayerCollection(nn.Module):
781
+ config: WhisperConfig
782
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
783
+ params_dtype: jnp.dtype = jnp.float32
784
+ use_scan: bool = False
785
+ gradient_checkpointing: bool = False
786
+
787
+ @nn.compact
788
+ def __call__(
789
+ self,
790
+ hidden_states,
791
+ attention_mask,
792
+ encoder_hidden_states: Optional[jnp.ndarray] = None,
793
+ encoder_attention_mask: Optional[jnp.ndarray] = None,
794
+ deterministic: bool = True,
795
+ init_cache: bool = False,
796
+ output_attentions: bool = False,
797
+ output_hidden_states: bool = False,
798
+ return_dict: bool = True,
799
+ ):
800
+ # decoder layers
801
+ all_hidden_states = () if output_hidden_states else None
802
+ all_self_attns = () if output_attentions else None
803
+ all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
804
+
805
+ FlaxWhisperDecoderCheckpointLayer = (
806
+ remat(
807
+ FlaxWhisperDecoderLayer,
808
+ static_argnums=(4, 5, 6),
809
+ prevent_cse=not self.use_scan,
810
+ )
811
+ if self.gradient_checkpointing
812
+ else FlaxWhisperDecoderLayer
813
+ )
814
+
815
+ if self.use_scan:
816
+ if output_attentions:
817
+ raise ValueError("Cannot use `scan` with `output_attentions` set to True")
818
+
819
+ input_hidden_states = hidden_states
820
+ hidden_states = (hidden_states,)
821
+
822
+ hidden_states, all_hidden_states = scan_with_axes(
823
+ FlaxWhisperDecoderCheckpointLayer,
824
+ variable_axes={"params": 0, "cache": 0},
825
+ split_rngs={"params": True, "dropout": True},
826
+ in_axes=(
827
+ nn.broadcast,
828
+ nn.broadcast,
829
+ nn.broadcast,
830
+ nn.broadcast,
831
+ nn.broadcast,
832
+ nn.broadcast,
833
+ nn.broadcast,
834
+ ),
835
+ variable_carry="all_hidden_states",
836
+ length=self.config.decoder_layers,
837
+ )(
838
+ self.config,
839
+ dtype=self.dtype,
840
+ params_dtype=self.params_dtype,
841
+ use_scan=True,
842
+ name="FlaxDecoderScanLayers",
843
+ )(
844
+ hidden_states,
845
+ attention_mask,
846
+ encoder_hidden_states,
847
+ encoder_attention_mask,
848
+ init_cache,
849
+ output_attentions,
850
+ deterministic,
851
+ all_hidden_states,
852
+ )
853
+ hidden_states = hidden_states[0]
854
+
855
+ if output_hidden_states:
856
+ # if we're using scan we'll surely be training -> return hidden states as a tensor rather than tuple
857
+ all_hidden_states = jnp.vstack([input_hidden_states[None, ...], all_hidden_states[0]])
858
+
859
+ else:
860
+ for layer_idx in range(self.config.decoder_layers):
861
+ if output_hidden_states:
862
+ all_hidden_states += (hidden_states,)
863
+ # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
864
+ dropout_probability = random.uniform(0, 1)
865
+ if not deterministic and (dropout_probability < self.config.decoder_layerdrop):
866
+ layer_outputs = (None, None, None)
867
+ else:
868
+ layer_outputs = FlaxWhisperDecoderCheckpointLayer(
869
+ self.config,
870
+ dtype=self.dtype,
871
+ params_dtype=self.params_dtype,
872
+ name=str(layer_idx),
873
+ )(
874
+ hidden_states,
875
+ attention_mask,
876
+ encoder_hidden_states,
877
+ encoder_attention_mask,
878
+ init_cache,
879
+ output_attentions,
880
+ deterministic,
881
+ )
882
+
883
+ hidden_states = layer_outputs[0]
884
+ if output_attentions:
885
+ all_self_attns += (layer_outputs[1],)
886
+
887
+ if encoder_hidden_states is not None:
888
+ all_cross_attentions += (layer_outputs[2],)
889
+
890
+ # add hidden states from the last decoder layer
891
+ if output_hidden_states:
892
+ all_hidden_states += (hidden_states,)
893
+
894
+ outputs = [
895
+ hidden_states,
896
+ all_hidden_states,
897
+ all_self_attns,
898
+ all_cross_attentions,
899
+ ]
900
+
901
+ if not return_dict:
902
+ return tuple(v for v in outputs if v is not None)
903
+
904
+ return FlaxBaseModelOutputWithPastAndCrossAttentions(
905
+ last_hidden_state=hidden_states,
906
+ hidden_states=all_hidden_states,
907
+ attentions=all_self_attns,
908
+ cross_attentions=all_cross_attentions,
909
+ )
910
+
911
+
912
+ class FlaxWhisperEncoder(nn.Module):
913
+ config: WhisperConfig
914
+ dtype: jnp.dtype = jnp.float32
915
+ params_dtype: jnp.dtype = jnp.float32
916
+ use_scan: bool = False
917
+ gradient_checkpointing: bool = False
918
+
919
+ def setup(self) -> None:
920
+ self.conv1 = Conv(
921
+ self.config.d_model,
922
+ kernel_size=(3,),
923
+ padding=1,
924
+ dtype=self.dtype,
925
+ params_dtype=self.params_dtype,
926
+ kernel_axes=("channels", "num_mel", "embed"),
927
+ )
928
+ self.conv2 = Conv(
929
+ self.config.d_model,
930
+ kernel_size=(3,),
931
+ strides=2,
932
+ padding=1,
933
+ dtype=self.dtype,
934
+ params_dtype=self.params_dtype,
935
+ kernel_axes=("channels", "embed", "num_mel"),
936
+ )
937
+
938
+ self.dropout_layer = nn.Dropout(rate=self.config.dropout)
939
+
940
+ self.layers = FlaxWhisperEncoderLayerCollection(
941
+ self.config,
942
+ dtype=self.dtype,
943
+ params_dtype=self.params_dtype,
944
+ use_scan=self.use_scan,
945
+ gradient_checkpointing=self.gradient_checkpointing,
946
+ )
947
+ self.embed_positions = Embed(
948
+ self.config.max_source_positions,
949
+ self.config.d_model,
950
+ dtype=self.dtype,
951
+ params_dtype=self.params_dtype,
952
+ )
953
+
954
+ self.layer_norm = LayerNorm(dtype=self.dtype, epsilon=1e-05, params_dtype=self.params_dtype)
955
+
956
+ def __call__(
957
+ self,
958
+ input_features: jnp.ndarray,
959
+ output_attentions: bool = False,
960
+ output_hidden_states: bool = False,
961
+ return_dict: bool = True,
962
+ deterministic: bool = True,
963
+ ) -> Tuple[jnp.ndarray]:
964
+ if input_features.shape[1:] != (
965
+ self.config.num_mel_bins,
966
+ self.config.max_source_positions * 2,
967
+ ):
968
+ raise ValueError(
969
+ "input_features.shape[1:], must be equal to (self.config.num_mel_bins,"
970
+ " self.config.max_source_positions * 2) (got"
971
+ f" {input_features.shape[1:]}, but should be"
972
+ f" ({self.config.num_mel_bins},"
973
+ f" {self.config.max_source_positions * 2}))"
974
+ )
975
+
976
+ input_features = input_features.transpose(0, 2, 1)
977
+ hidden_states = jax.nn.gelu(self.conv1(input_features), approximate=False)
978
+ hidden_states = with_sharding_constraint(hidden_states, ("batch", "embed", "num_mel"))
979
+ hidden_states = jax.nn.gelu(self.conv2(hidden_states), approximate=False)
980
+ hidden_states = with_sharding_constraint(hidden_states, ("batch", "length", "embed"))
981
+
982
+ embed_positions = self.embed_positions(jnp.arange(self.config.max_source_positions))
983
+ # sinusoidal positional embeddings should not be trained
984
+ embed_positions = jax.lax.stop_gradient(embed_positions)
985
+ hidden_states = hidden_states + embed_positions
986
+
987
+ hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
988
+
989
+ outputs = self.layers(
990
+ hidden_states,
991
+ attention_mask=None,
992
+ deterministic=deterministic,
993
+ output_attentions=output_attentions,
994
+ output_hidden_states=output_hidden_states,
995
+ return_dict=return_dict,
996
+ )
997
+
998
+ last_hidden_states = outputs[0]
999
+ last_hidden_states = self.layer_norm(last_hidden_states)
1000
+
1001
+ # update the last element in `hidden_states` after applying `layernorm` above
1002
+ hidden_states = None
1003
+ if output_hidden_states:
1004
+ hidden_states = outputs[1]
1005
+ if self.use_scan:
1006
+ hidden_states = jnp.vstack([hidden_states[:-1], last_hidden_states[None, ...]])
1007
+ else:
1008
+ hidden_states = hidden_states[:-1] + (last_hidden_states,)
1009
+
1010
+ if not return_dict:
1011
+ outputs = (last_hidden_states, hidden_states) + (outputs[2:] if output_hidden_states else outputs[1:])
1012
+ return tuple(v for v in outputs if v is not None)
1013
+
1014
+ return FlaxBaseModelOutput(
1015
+ last_hidden_state=last_hidden_states,
1016
+ hidden_states=hidden_states,
1017
+ attentions=outputs.attentions,
1018
+ )
1019
+
1020
+
1021
+ class FlaxWhisperDecoder(nn.Module):
1022
+ config: WhisperConfig
1023
+ dtype: jnp.dtype = jnp.float32
1024
+ params_dtype: jnp.dtype = jnp.float32
1025
+ use_scan: bool = False
1026
+ gradient_checkpointing: bool = False
1027
+
1028
+ def setup(self) -> None:
1029
+ self.embed_tokens = Embed(
1030
+ self.config.vocab_size,
1031
+ self.config.d_model,
1032
+ dtype=self.dtype,
1033
+ params_dtype=self.params_dtype,
1034
+ )
1035
+ self.embed_positions = Embed(
1036
+ self.config.max_target_positions,
1037
+ self.config.d_model,
1038
+ dtype=self.dtype,
1039
+ params_dtype=self.params_dtype,
1040
+ )
1041
+
1042
+ self.layers = FlaxWhisperDecoderLayerCollection(
1043
+ self.config,
1044
+ dtype=self.dtype,
1045
+ params_dtype=self.params_dtype,
1046
+ use_scan=self.use_scan,
1047
+ gradient_checkpointing=self.gradient_checkpointing,
1048
+ )
1049
+
1050
+ self.dropout_layer = nn.Dropout(rate=self.config.dropout)
1051
+
1052
+ self.layer_norm = LayerNorm(dtype=self.dtype, epsilon=1e-5, params_dtype=self.params_dtype)
1053
+
1054
+ def __call__(
1055
+ self,
1056
+ input_ids: jnp.ndarray,
1057
+ attention_mask: jnp.ndarray,
1058
+ position_ids: jnp.ndarray,
1059
+ encoder_hidden_states: Optional[jnp.ndarray] = None,
1060
+ init_cache: bool = False,
1061
+ output_attentions: bool = False,
1062
+ output_hidden_states: bool = False,
1063
+ return_dict: bool = True,
1064
+ deterministic: bool = True,
1065
+ ) -> Tuple[jnp.ndarray]:
1066
+ input_embeds = self.embed_tokens(input_ids)
1067
+ position_embeds = self.embed_positions(position_ids)
1068
+
1069
+ hidden_states = input_embeds + position_embeds
1070
+ hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
1071
+
1072
+ outputs = self.layers(
1073
+ hidden_states,
1074
+ attention_mask=attention_mask,
1075
+ encoder_hidden_states=encoder_hidden_states,
1076
+ deterministic=deterministic,
1077
+ init_cache=init_cache,
1078
+ output_attentions=output_attentions,
1079
+ output_hidden_states=output_hidden_states,
1080
+ return_dict=return_dict,
1081
+ )
1082
+
1083
+ last_hidden_states = outputs[0]
1084
+ last_hidden_states = self.layer_norm(last_hidden_states)
1085
+
1086
+ # update the last element in `hidden_states` after applying `layernorm` above
1087
+ hidden_states = None
1088
+ if output_hidden_states:
1089
+ hidden_states = outputs[1]
1090
+ if self.use_scan:
1091
+ hidden_states = jnp.vstack([hidden_states[:-1], last_hidden_states[None, ...]])
1092
+ else:
1093
+ hidden_states = hidden_states[:-1] + (last_hidden_states,)
1094
+
1095
+ if not return_dict:
1096
+ outputs = (last_hidden_states, hidden_states) + (outputs[2:] if output_hidden_states else outputs[1:])
1097
+ return tuple(v for v in outputs if v is not None)
1098
+
1099
+ return FlaxBaseModelOutputWithPastAndCrossAttentions(
1100
+ last_hidden_state=last_hidden_states,
1101
+ hidden_states=hidden_states,
1102
+ attentions=outputs.attentions,
1103
+ cross_attentions=outputs.cross_attentions,
1104
+ )
1105
+
1106
+
1107
+ class FlaxWhisperModule(nn.Module):
1108
+ config: WhisperConfig
1109
+ dtype: jnp.dtype = jnp.float32
1110
+ params_dtype: jnp.dtype = jnp.float32
1111
+ use_scan: bool = False
1112
+ gradient_checkpointing: bool = False
1113
+
1114
+ def setup(self) -> None:
1115
+ self.encoder = FlaxWhisperEncoder(
1116
+ self.config,
1117
+ dtype=self.dtype,
1118
+ params_dtype=self.params_dtype,
1119
+ use_scan=self.use_scan,
1120
+ gradient_checkpointing=self.gradient_checkpointing,
1121
+ )
1122
+ self.decoder = FlaxWhisperDecoder(
1123
+ self.config,
1124
+ dtype=self.dtype,
1125
+ params_dtype=self.params_dtype,
1126
+ use_scan=self.use_scan,
1127
+ gradient_checkpointing=self.gradient_checkpointing,
1128
+ )
1129
+
1130
+ def __call__(
1131
+ self,
1132
+ input_features: jnp.ndarray,
1133
+ decoder_input_ids: jnp.ndarray,
1134
+ decoder_attention_mask: jnp.ndarray,
1135
+ decoder_position_ids: jnp.ndarray,
1136
+ output_attentions: bool = False,
1137
+ output_hidden_states: bool = False,
1138
+ freeze_encoder: bool = False,
1139
+ return_dict: bool = True,
1140
+ deterministic: bool = True,
1141
+ ):
1142
+ encoder_outputs = self.encoder(
1143
+ input_features,
1144
+ output_attentions=output_attentions,
1145
+ output_hidden_states=output_hidden_states,
1146
+ return_dict=return_dict,
1147
+ deterministic=deterministic,
1148
+ )
1149
+
1150
+ encoder_hidden_states = encoder_outputs[0]
1151
+
1152
+ if freeze_encoder:
1153
+ encoder_hidden_states = jax.lax.stop_gradient(encoder_hidden_states)
1154
+
1155
+ decoder_outputs = self.decoder(
1156
+ input_ids=decoder_input_ids,
1157
+ attention_mask=decoder_attention_mask,
1158
+ position_ids=decoder_position_ids,
1159
+ encoder_hidden_states=encoder_hidden_states,
1160
+ output_attentions=output_attentions,
1161
+ output_hidden_states=output_hidden_states,
1162
+ return_dict=return_dict,
1163
+ deterministic=deterministic,
1164
+ )
1165
+
1166
+ if not return_dict:
1167
+ return decoder_outputs + encoder_outputs
1168
+
1169
+ return FlaxSeq2SeqModelOutput(
1170
+ last_hidden_state=decoder_outputs.last_hidden_state,
1171
+ decoder_hidden_states=decoder_outputs.hidden_states,
1172
+ decoder_attentions=decoder_outputs.attentions,
1173
+ cross_attentions=decoder_outputs.cross_attentions,
1174
+ encoder_last_hidden_state=encoder_outputs.last_hidden_state,
1175
+ encoder_hidden_states=encoder_outputs.hidden_states,
1176
+ encoder_attentions=encoder_outputs.attentions,
1177
+ )
1178
+
1179
+ def _get_encoder_module(self):
1180
+ return self.encoder
1181
+
1182
+ def _get_decoder_module(self):
1183
+ return self.decoder
1184
+
1185
+
1186
+ class FlaxWhisperPreTrainedModel(FlaxPreTrainedModel):
1187
+ config_class = WhisperConfig
1188
+ base_model_prefix: str = "model"
1189
+ main_input_name = "input_features"
1190
+ module_class: nn.Module = None
1191
+
1192
+ def __init__(
1193
+ self,
1194
+ config: WhisperConfig,
1195
+ input_shape: Tuple[int, int, int] = None,
1196
+ seed: int = 0,
1197
+ dtype: jnp.dtype = jnp.float32,
1198
+ params_dtype: jnp.dtype = jnp.float32,
1199
+ _do_init: bool = True,
1200
+ # Can only use_scan=True in init if loading scanned weights -> need to handle use_scan=True and unrolled weights
1201
+ use_scan: bool = False,
1202
+ gradient_checkpointing: bool = False,
1203
+ **kwargs,
1204
+ ):
1205
+ self.use_scan = use_scan
1206
+ self.gradient_checkpointing = gradient_checkpointing
1207
+
1208
+ module = self.module_class(
1209
+ config=config,
1210
+ dtype=dtype,
1211
+ params_dtype=params_dtype,
1212
+ use_scan=use_scan,
1213
+ gradient_checkpointing=gradient_checkpointing,
1214
+ **kwargs,
1215
+ )
1216
+
1217
+ if input_shape is None:
1218
+ input_shape = (1, config.num_mel_bins, 2 * config.max_source_positions)
1219
+
1220
+ super().__init__(
1221
+ config,
1222
+ module,
1223
+ input_shape=input_shape,
1224
+ seed=seed,
1225
+ dtype=dtype,
1226
+ _do_init=_do_init,
1227
+ )
1228
+
1229
+ def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
1230
+ # init input tensors
1231
+ input_features = jnp.zeros(input_shape, dtype="f4")
1232
+ input_features = input_features.at[(..., -1)].set(self.config.eos_token_id)
1233
+
1234
+ decoder_input_ids = jnp.zeros((input_shape[0], 1), dtype="i4")
1235
+ decoder_attention_mask = jnp.ones_like(decoder_input_ids)
1236
+
1237
+ batch_size, sequence_length = decoder_input_ids.shape
1238
+ decoder_position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length))
1239
+
1240
+ params_rng, dropout_rng = jax.random.split(rng)
1241
+ rngs = {"params": params_rng, "dropout": dropout_rng}
1242
+
1243
+ random_params = self.module.init(
1244
+ rngs,
1245
+ input_features=input_features,
1246
+ decoder_input_ids=decoder_input_ids,
1247
+ decoder_attention_mask=decoder_attention_mask,
1248
+ decoder_position_ids=decoder_position_ids,
1249
+ )["params"]
1250
+
1251
+ if params is not None:
1252
+ random_params = flatten_dict(unfreeze(random_params))
1253
+ params = flatten_dict(unfreeze(params))
1254
+ for missing_key in self._missing_keys:
1255
+ params[missing_key] = random_params[missing_key]
1256
+ self._missing_keys = set()
1257
+ return freeze(unflatten_dict(params))
1258
+ else:
1259
+ return random_params
1260
+
1261
+ def enable_gradient_checkpointing(self):
1262
+ self.gradient_checkpointing = True
1263
+ self._module = self.module_class(
1264
+ config=self.config,
1265
+ dtype=self.dtype,
1266
+ use_scan=self.use_scan,
1267
+ gradient_checkpointing=self.gradient_checkpointing,
1268
+ )
1269
+
1270
+ def enable_scan(self):
1271
+ self.use_scan = True
1272
+ self._module = self.module_class(
1273
+ config=self.config,
1274
+ dtype=self.dtype,
1275
+ use_scan=self.use_scan,
1276
+ gradient_checkpointing=self.gradient_checkpointing,
1277
+ )
1278
+ init_fn = partial(self.init_weights, input_shape=self.input_shape)
1279
+ params_shape_tree = jax.eval_shape(init_fn, self.key)
1280
+
1281
+ # get the shape of the parameters
1282
+ self._params_shape_tree = params_shape_tree
1283
+
1284
+ # save required_params as set
1285
+ self._required_params = set(flatten_dict(unfreeze(params_shape_tree)).keys())
1286
+
1287
+ # initialize the parameters
1288
+ if self._is_initialized:
1289
+ self.params = self.convert_unroll_to_scan(self.params)
1290
+
1291
+ def disable_scan(self):
1292
+ self.use_scan = False
1293
+ self._module = self.module_class(
1294
+ config=self.config,
1295
+ dtype=self.dtype,
1296
+ use_scan=self.use_scan,
1297
+ gradient_checkpointing=self.gradient_checkpointing,
1298
+ )
1299
+ init_fn = partial(self.init_weights, input_shape=self.input_shape)
1300
+ params_shape_tree = jax.eval_shape(init_fn, self.key)
1301
+
1302
+ # get the shape of the parameters
1303
+ self._params_shape_tree = params_shape_tree
1304
+
1305
+ # save required_params as set
1306
+ self._required_params = set(flatten_dict(unfreeze(params_shape_tree)).keys())
1307
+
1308
+ # initialize the parameters
1309
+ if self._is_initialized:
1310
+ self.params = self.convert_scan_to_unroll(self.params)
1311
+
1312
+ def convert_unroll_to_scan(self, params: Union[Dict, FrozenDict]):
1313
+ r"""
1314
+ Convert a `PyTree` of unrolled model parameters to a scanned block of model parameters. This method can be used
1315
+ to explicitly convert the model parameters to scanned format. This returns a new `params` tree and does not
1316
+ convert the `params` in place.
1317
+
1318
+ To illustrate the workings of this method, take the Flax BERT model. The unrolled structure for the query
1319
+ projection params is as follows:
1320
+ ('bert', 'encoder', 'layer', '0', 'self_attn', 'q_proj') ('bert', 'encoder', 'layer', '1', 'self_attn',
1321
+ 'q_proj') ... ('bert', 'encoder', 'layer', '23', 'self_attn', 'q_proj')
1322
+ This method takes each of the `q_proj` matrices for layers (0, ..., 23) and stacks them into a single 'super'
1323
+ matrix, giving a *single* block of weights for all 24 layers compatible with the scanned model:
1324
+ ('bert', 'encoder', 'layer', 'ScanLayers', 'self_attn', 'q_proj')
1325
+
1326
+ When enabling scan with _do_init=True (default), this method will be called automatically under the hood. With
1327
+ _do_init=False, it will have to be called explicitly (see example below).
1328
+
1329
+ Arguments:
1330
+ params (`Union[Dict, FrozenDict]`):
1331
+ A `PyTree` of model parameters.
1332
+
1333
+ Examples:
1334
+
1335
+ ```python
1336
+ >>> from distil_whisper import FlaxWhisperForConditionalGeneration
1337
+
1338
+ >>> # Download model and configuration from huggingface.co
1339
+ >>> model, params = FlaxWhisperModel.from_pretrained("openai/whisper-tiny.en", _do_init=False)
1340
+ >>> # By default, the model params will be in unrolled format. To illustrate the use of this method,
1341
+ >>> # we'll first convert to scan format and then back to unrolled
1342
+ >>> model.enable_scan()
1343
+ >>> params = model.convert_unroll_to_scan(params)
1344
+ >>> # now convert back to unrolled
1345
+ >>> model.disable_scan()
1346
+ >>> params = model.convert_scan_to_unroll(params)
1347
+ ```"""
1348
+ if isinstance(params, FrozenDict):
1349
+ params = unfreeze(params)
1350
+
1351
+ params = flatten_dict(params, sep="/")
1352
+ keys = list(params.keys())
1353
+
1354
+ for k in keys:
1355
+ # Identify all "unrolled" layers formed as part of the FlaxBertLayerCollection
1356
+ # These params contain the identifier `layer` in their key
1357
+ if "layers/0" in k:
1358
+ if "decoder" in k:
1359
+ block_prefix = "Decoder"
1360
+ num_hidden_layers = self.config.decoder_layers
1361
+ else:
1362
+ block_prefix = "Encoder"
1363
+ num_hidden_layers = self.config.encoder_layers
1364
+
1365
+ # Squash the keys for the N unrolled layers into one single key:
1366
+ # (layer/0, ..., layer/N) -> layer/FlaxScanLayers
1367
+ scan_key = k.replace("0", f"Flax{block_prefix}ScanLayers")
1368
+ stacked_params = []
1369
+
1370
+ # Iterate over the unrolled layers (1,...,N)
1371
+ for i in range(num_hidden_layers):
1372
+ # Stack the params for the N layers into one super block
1373
+ # and remove the unrolled layer params on the fly
1374
+ # -> no memory overhead for conversion!
1375
+ unrolled_layer = params.pop(k.replace("0", str(i)))
1376
+ stacked_params.append(unrolled_layer)
1377
+
1378
+ params[scan_key] = jnp.stack(stacked_params)
1379
+
1380
+ # Finally, unflatten the dict to restore the nested pytree structure
1381
+ params = unflatten_dict(params, sep="/")
1382
+ return params
1383
+
1384
+ def convert_scan_to_unroll(self, params: Union[Dict, FrozenDict]):
1385
+ r"""
1386
+ Convert a `PyTree` of scanned model parameters to an unrolled stack of model parameters. This method can be
1387
+ used to explicitly convert the model parameters to unrolled format. This returns a new `params` tree and does
1388
+ not convert the `params` in place.
1389
+
1390
+ To illustrate the workings of this method, take the Flax BERT model. The scanned structure for the query
1391
+ projection (`q_proj`) params is a single, stacked matrix of parameters over all N layers:
1392
+ ('bert', 'encoder', 'layer', 'FlaxScanLayers', 'self_attn', 'q_proj')
1393
+
1394
+ This method slices each layer of the `q_proj` scanned matrix into single, standalone layers, and replaces the
1395
+ scanned matrix of parameteres on the fly:
1396
+ ('bert', 'encoder', 'layer', '0', 'self_attn', 'q_proj') ('bert', 'encoder', 'layer', '1', 'self_attn',
1397
+ 'q_proj') ... ('bert', 'encoder', 'layer', 'N', 'self_attn', 'q_proj')
1398
+
1399
+ When enabling scan with _do_init=True (default), this method will be called automatically under the hood. With
1400
+ _do_init=False, it will have to be called explicitly (see example below).
1401
+
1402
+ Arguments:
1403
+ params (`Union[Dict, FrozenDict]`):
1404
+ A `PyTree` of model parameters.
1405
+
1406
+ Examples:
1407
+
1408
+ ```python
1409
+ >>> from distil_whisper import FlaxWhisperForConditionalGeneration
1410
+
1411
+ >>> # Download model and configuration from huggingface.co
1412
+ >>> model, params = FlaxWhisperModel.from_pretrained("openai/whisper-tiny.en", _do_init=False)
1413
+ >>> # By default, the model params will be in unrolled format. To illustrate the use of this method,
1414
+ >>> # we'll first convert to scan format and then back to unrolled
1415
+ >>> model.enable_scan()
1416
+ >>> params = model.convert_unroll_to_scan(params)
1417
+ >>> # now convert back to unrolled
1418
+ >>> model.disable_scan()
1419
+ >>> params = model.convert_scan_to_unroll(params)
1420
+ ```"""
1421
+
1422
+ if isinstance(params, FrozenDict):
1423
+ params = unfreeze(params)
1424
+
1425
+ params = flatten_dict(params, sep="/")
1426
+ keys = list(params.keys())
1427
+
1428
+ for k in keys:
1429
+ # Identify all "scan" layers formed as part of the FlaxBertLayerCollection
1430
+ # These params contain the identifier `FlaxScanLayers` in their key
1431
+ if "FlaxEncoderScanLayers" in k:
1432
+ # Remove the scan layer from the PyTree of params
1433
+ scan_layer = params.pop(k)
1434
+
1435
+ # Unroll the key for the stacked scan matrix into N separate keys, indexed by layer number
1436
+ # layer/FlaxScanLayers -> (layer/0, ..., layer/N)
1437
+ for i in range(self.config.encoder_layers):
1438
+ # Unstack the params for the i-th scan layer to unrolled
1439
+ # and remove corresponding scan params on the fly
1440
+ # -> no memory overhead for conversion!
1441
+ unrolled_key = k.replace("FlaxEncoderScanLayers", str(i))
1442
+ params[unrolled_key], scan_layer = scan_layer[0], scan_layer[1:]
1443
+
1444
+ elif "FlaxDecoderScanLayers" in k:
1445
+ # Remove the scan layer from the PyTree of params
1446
+ scan_layer = params.pop(k)
1447
+
1448
+ # Unroll the key for the stacked scan matrix into N separate keys, indexed by layer number
1449
+ # layer/FlaxScanLayers -> (layer/0, ..., layer/N)
1450
+ for i in range(self.config.decoder_layers):
1451
+ # Unstack the params for the i-th scan layer to unrolled
1452
+ # and remove corresponding scan params on the fly
1453
+ # -> no memory overhead for conversion!
1454
+ unrolled_key = k.replace("FlaxDecoderScanLayers", str(i))
1455
+ params[unrolled_key], scan_layer = scan_layer[0], scan_layer[1:]
1456
+
1457
+ params = unflatten_dict(params, sep="/")
1458
+ return params
1459
+
1460
+ # Copied from transformers.models.whisper.modeling_flax_whisper.FlaxWhisperPreTrainedModel.init_cache
1461
+ def init_cache(self, batch_size, max_length, encoder_outputs):
1462
+ r"""
1463
+ Args:
1464
+ batch_size (`int`):
1465
+ batch_size used for fast auto-regressive decoding. Defines the batch size of the initialized cache.
1466
+ max_length (`int`):
1467
+ maximum possible length for auto-regressive decoding. Defines the sequence length of the initialized
1468
+ cache.
1469
+ encoder_outputs (`Union[FlaxBaseModelOutput, tuple(tuple(jnp.ndarray)]`):
1470
+ `encoder_outputs` consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*:
1471
+ `attentions`). `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*)
1472
+ is a sequence of hidden-states at the output of the last layer of the encoder. Used in the
1473
+ cross-attention of the decoder.
1474
+ """
1475
+ # init input variables to retrieve cache
1476
+ decoder_input_ids = jnp.ones((batch_size, max_length), dtype="i4")
1477
+ decoder_attention_mask = jnp.ones_like(decoder_input_ids)
1478
+ decoder_position_ids = jnp.broadcast_to(
1479
+ jnp.arange(jnp.atleast_2d(decoder_input_ids).shape[-1]),
1480
+ decoder_input_ids.shape,
1481
+ )
1482
+
1483
+ def _decoder_forward(
1484
+ module,
1485
+ decoder_input_ids,
1486
+ decoder_attention_mask,
1487
+ decoder_position_ids,
1488
+ **kwargs,
1489
+ ):
1490
+ decoder_module = module._get_decoder_module()
1491
+ return decoder_module(
1492
+ decoder_input_ids,
1493
+ decoder_attention_mask,
1494
+ decoder_position_ids,
1495
+ **kwargs,
1496
+ )
1497
+
1498
+ init_variables = self.module.init(
1499
+ jax.random.PRNGKey(0),
1500
+ decoder_input_ids=decoder_input_ids,
1501
+ decoder_attention_mask=decoder_attention_mask,
1502
+ decoder_position_ids=decoder_position_ids,
1503
+ encoder_hidden_states=encoder_outputs[0],
1504
+ init_cache=True,
1505
+ method=_decoder_forward, # we only need to call the decoder to init the cache
1506
+ )
1507
+ return unfreeze(init_variables["cache"])
1508
+
1509
+ @add_start_docstrings(WHISPER_ENCODE_INPUTS_DOCSTRING)
1510
+ @replace_return_docstrings(output_type=FlaxBaseModelOutput, config_class=WhisperConfig)
1511
+ def encode(
1512
+ self,
1513
+ input_features: jnp.ndarray,
1514
+ attention_mask: Optional[jnp.ndarray] = None,
1515
+ output_attentions: Optional[bool] = None,
1516
+ output_hidden_states: Optional[bool] = None,
1517
+ return_dict: Optional[bool] = None,
1518
+ train: bool = False,
1519
+ params: dict = None,
1520
+ dropout_rng: PRNGKey = None,
1521
+ **kwargs,
1522
+ ):
1523
+ r"""
1524
+ Returns:
1525
+
1526
+ Example:
1527
+
1528
+ ```python
1529
+ >>> from transformers import WhisperProcessor, FlaxWhisperForConditionalGeneration
1530
+ >>> from datasets import load_dataset
1531
+
1532
+ >>> processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en")
1533
+ >>> model = FlaxWhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en", from_pt=True)
1534
+ >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
1535
+ >>> inputs = processor(ds[0]["audio"]["array"], return_tensors="np")
1536
+ >>> input_features = inputs.input_features
1537
+ >>> encoder_outputs = model.encode(input_features=input_features)
1538
+ ```"""
1539
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1540
+ output_hidden_states = (
1541
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1542
+ )
1543
+ return_dict = return_dict if return_dict is not None else self.config.return_dict
1544
+
1545
+ # Handle any PRNG if needed
1546
+ rngs = {}
1547
+ if dropout_rng is not None:
1548
+ rngs["dropout"] = dropout_rng
1549
+
1550
+ def _encoder_forward(module, input_features, **kwargs):
1551
+ encode_module = module._get_encoder_module()
1552
+ return encode_module(input_features, **kwargs)
1553
+
1554
+ return self.module.apply(
1555
+ {"params": params or self.params},
1556
+ input_features=jnp.array(input_features, dtype="f4"),
1557
+ output_attentions=output_attentions,
1558
+ output_hidden_states=output_hidden_states,
1559
+ return_dict=return_dict,
1560
+ deterministic=not train,
1561
+ rngs=rngs,
1562
+ method=_encoder_forward,
1563
+ )
1564
+
1565
+ @add_start_docstrings(WHISPER_DECODE_INPUTS_DOCSTRING)
1566
+ @replace_return_docstrings(
1567
+ output_type=FlaxBaseModelOutputWithPastAndCrossAttentions,
1568
+ config_class=WhisperConfig,
1569
+ )
1570
+ def decode(
1571
+ self,
1572
+ decoder_input_ids,
1573
+ encoder_outputs,
1574
+ encoder_attention_mask: Optional[jnp.ndarray] = None,
1575
+ decoder_attention_mask: Optional[jnp.ndarray] = None,
1576
+ decoder_position_ids: Optional[jnp.ndarray] = None,
1577
+ past_key_values: dict = None,
1578
+ output_attentions: Optional[bool] = None,
1579
+ output_hidden_states: Optional[bool] = None,
1580
+ return_dict: Optional[bool] = None,
1581
+ train: bool = False,
1582
+ params: dict = None,
1583
+ dropout_rng: PRNGKey = None,
1584
+ ):
1585
+ r"""
1586
+ Returns:
1587
+
1588
+ Example:
1589
+
1590
+ ```python
1591
+ >>> from transformers import WhisperProcessor, FlaxWhisperForConditionalGeneration
1592
+ >>> from datasets import load_dataset
1593
+
1594
+ >>> processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en")
1595
+ >>> model = FlaxWhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en", from_pt=True)
1596
+ >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
1597
+ >>> inputs = processor(ds[0]["audio"]["array"], return_tensors="np")
1598
+ >>> input_features = inputs.input_features
1599
+ >>> encoder_outputs = model.encode(input_features=input_features)
1600
+ >>> decoder_start_token_id = model.config.decoder_start_token_id
1601
+
1602
+ >>> decoder_input_ids = jnp.ones((inputs.input_ids.shape[0], 1), dtype="i4") * decoder_start_token_id
1603
+
1604
+ >>> outputs = model.decode(decoder_input_ids, encoder_outputs)
1605
+ >>> last_decoder_hidden_states = outputs.last_hidden_state
1606
+ ```"""
1607
+
1608
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1609
+ output_hidden_states = (
1610
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1611
+ )
1612
+ return_dict = return_dict if return_dict is not None else self.config.return_dict
1613
+
1614
+ encoder_hidden_states = encoder_outputs[0]
1615
+
1616
+ batch_size, sequence_length = decoder_input_ids.shape
1617
+ if decoder_position_ids is None:
1618
+ if past_key_values is not None:
1619
+ raise ValueError("Make sure to provide `decoder_position_ids` when passing `past_key_values`.")
1620
+
1621
+ if decoder_attention_mask is not None:
1622
+ decoder_position_ids = (decoder_attention_mask.cumsum(-1) * decoder_attention_mask) - 1
1623
+ else:
1624
+ decoder_position_ids = jnp.broadcast_to(
1625
+ jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)
1626
+ )
1627
+
1628
+ if decoder_attention_mask is None:
1629
+ decoder_attention_mask = jnp.ones((batch_size, sequence_length))
1630
+
1631
+ # Handle any PRNG if needed
1632
+ rngs = {}
1633
+ if dropout_rng is not None:
1634
+ rngs["dropout"] = dropout_rng
1635
+
1636
+ inputs = {"params": params or self.params}
1637
+
1638
+ # if past_key_values are passed then cache is already initialized a private flag init_cache has to be
1639
+ # passed down to ensure cache is used. It has to be made sure that cache is marked as mutable so that
1640
+ # it can be changed by FlaxWhisperAttention module
1641
+ if past_key_values:
1642
+ inputs["cache"] = past_key_values
1643
+ mutable = ["cache"]
1644
+ else:
1645
+ mutable = False
1646
+
1647
+ def _decoder_forward(
1648
+ module,
1649
+ decoder_input_ids,
1650
+ decoder_attention_mask,
1651
+ decoder_position_ids,
1652
+ **kwargs,
1653
+ ):
1654
+ decoder_module = module._get_decoder_module()
1655
+ return decoder_module(
1656
+ input_ids=decoder_input_ids,
1657
+ attention_mask=decoder_attention_mask,
1658
+ position_ids=decoder_position_ids,
1659
+ **kwargs,
1660
+ )
1661
+
1662
+ outputs = self.module.apply(
1663
+ inputs,
1664
+ decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"),
1665
+ decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"),
1666
+ decoder_position_ids=jnp.array(decoder_position_ids, dtype="i4"),
1667
+ encoder_hidden_states=encoder_hidden_states,
1668
+ output_attentions=output_attentions,
1669
+ output_hidden_states=output_hidden_states,
1670
+ return_dict=return_dict,
1671
+ deterministic=not train,
1672
+ rngs=rngs,
1673
+ mutable=mutable,
1674
+ method=_decoder_forward,
1675
+ )
1676
+
1677
+ # add updated cache to model output
1678
+ if past_key_values is not None and return_dict:
1679
+ outputs, past = outputs
1680
+ outputs["past_key_values"] = unfreeze(past["cache"])
1681
+ return outputs
1682
+ elif past_key_values is not None and not return_dict:
1683
+ outputs, past = outputs
1684
+ outputs = outputs[:1] + (unfreeze(past["cache"]),) + outputs[1:]
1685
+
1686
+ return outputs
1687
+
1688
+ @add_start_docstrings_to_model_forward(WHISPER_INPUTS_DOCSTRING)
1689
+ def __call__(
1690
+ self,
1691
+ input_features: jnp.ndarray,
1692
+ decoder_input_ids: jnp.ndarray,
1693
+ attention_mask: Optional[jnp.ndarray] = None,
1694
+ decoder_attention_mask: Optional[jnp.ndarray] = None,
1695
+ position_ids: Optional[jnp.ndarray] = None,
1696
+ decoder_position_ids: Optional[jnp.ndarray] = None,
1697
+ output_attentions: Optional[bool] = None,
1698
+ output_hidden_states: Optional[bool] = None,
1699
+ freeze_encoder: Optional[bool] = None,
1700
+ return_dict: Optional[bool] = None,
1701
+ train: bool = False,
1702
+ params: dict = None,
1703
+ dropout_rng: PRNGKey = None,
1704
+ ):
1705
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1706
+ output_hidden_states = (
1707
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1708
+ )
1709
+ return_dict = return_dict if return_dict is not None else self.config.return_dict
1710
+
1711
+ # prepare decoder inputs
1712
+ if decoder_position_ids is None:
1713
+ if decoder_attention_mask is not None:
1714
+ decoder_position_ids = (decoder_attention_mask.cumsum(-1) * decoder_attention_mask) - 1
1715
+ else:
1716
+ batch_size, sequence_length = decoder_input_ids.shape
1717
+ decoder_position_ids = jnp.broadcast_to(
1718
+ jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)
1719
+ )
1720
+ if decoder_attention_mask is None:
1721
+ decoder_attention_mask = jnp.ones_like(decoder_input_ids)
1722
+
1723
+ # Handle any PRNG if needed
1724
+ rngs = {"dropout": dropout_rng} if dropout_rng is not None else {}
1725
+
1726
+ return self.module.apply(
1727
+ {"params": params or self.params},
1728
+ input_features=jnp.array(input_features, dtype="f4"),
1729
+ decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"),
1730
+ decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"),
1731
+ decoder_position_ids=jnp.array(decoder_position_ids, dtype="i4"),
1732
+ output_attentions=output_attentions,
1733
+ output_hidden_states=output_hidden_states,
1734
+ freeze_encoder=freeze_encoder,
1735
+ return_dict=return_dict,
1736
+ deterministic=not train,
1737
+ rngs=rngs,
1738
+ )
1739
+
1740
+
1741
+ @add_start_docstrings(
1742
+ ("The bare Whisper Model transformer outputting raw hidden-states without any specific head on top."),
1743
+ WHISPER_START_DOCSTRING,
1744
+ )
1745
+ class FlaxWhisperModel(FlaxWhisperPreTrainedModel):
1746
+ config: WhisperConfig
1747
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
1748
+ params_dtype: jnp.dtype = jnp.float32
1749
+ module_class = FlaxWhisperModule
1750
+
1751
+
1752
+ append_call_sample_docstring(FlaxWhisperModel, _CHECKPOINT_FOR_DOC, FlaxSeq2SeqModelOutput, _CONFIG_FOR_DOC)
1753
+
1754
+
1755
+ class FlaxWhisperForConditionalGenerationModule(nn.Module):
1756
+ config: WhisperConfig
1757
+ dtype: jnp.dtype = jnp.float32
1758
+ params_dtype: jnp.dtype = jnp.float32
1759
+ use_scan: bool = False
1760
+ gradient_checkpointing: bool = False
1761
+
1762
+ def setup(self) -> None:
1763
+ self.model = FlaxWhisperModule(
1764
+ config=self.config,
1765
+ dtype=self.dtype,
1766
+ params_dtype=self.params_dtype,
1767
+ use_scan=self.use_scan,
1768
+ gradient_checkpointing=self.gradient_checkpointing,
1769
+ )
1770
+ self.lm_head = DenseGeneral(
1771
+ self.config.vocab_size,
1772
+ use_bias=False,
1773
+ dtype=self.dtype,
1774
+ params_dtype=self.params_dtype,
1775
+ kernel_axes=("embed", "vocab"),
1776
+ )
1777
+
1778
+ def _get_encoder_module(self):
1779
+ return self.model.encoder
1780
+
1781
+ def _get_decoder_module(self):
1782
+ return self.model.decoder
1783
+
1784
+ def __call__(
1785
+ self,
1786
+ input_features,
1787
+ decoder_input_ids,
1788
+ decoder_attention_mask: jnp.ndarray = None,
1789
+ decoder_position_ids: jnp.ndarray = None,
1790
+ position_ids: jnp.ndarray = None,
1791
+ attention_mask: jnp.ndarray = None,
1792
+ output_attentions: bool = False,
1793
+ output_hidden_states: bool = False,
1794
+ freeze_encoder: bool = False,
1795
+ return_dict: bool = True,
1796
+ deterministic: bool = True,
1797
+ ):
1798
+ outputs = self.model(
1799
+ input_features=input_features,
1800
+ decoder_input_ids=decoder_input_ids,
1801
+ decoder_attention_mask=decoder_attention_mask,
1802
+ decoder_position_ids=decoder_position_ids,
1803
+ output_attentions=output_attentions,
1804
+ output_hidden_states=output_hidden_states,
1805
+ freeze_encoder=freeze_encoder,
1806
+ return_dict=return_dict,
1807
+ deterministic=deterministic,
1808
+ )
1809
+
1810
+ hidden_states = outputs[0]
1811
+
1812
+ if self.config.tie_word_embeddings:
1813
+ shared_embedding = self.model.decoder.embed_tokens.variables["params"]["embedding"]
1814
+ lm_logits = self.lm_head.apply({"params": {"kernel": shared_embedding.T}}, hidden_states)
1815
+ else:
1816
+ lm_logits = self.lm_head(hidden_states)
1817
+
1818
+ if not return_dict:
1819
+ output = (lm_logits,) + outputs[1:]
1820
+ return output
1821
+
1822
+ return FlaxSeq2SeqLMOutput(
1823
+ logits=lm_logits,
1824
+ decoder_hidden_states=outputs.decoder_hidden_states,
1825
+ decoder_attentions=outputs.decoder_attentions,
1826
+ cross_attentions=outputs.cross_attentions,
1827
+ encoder_last_hidden_state=outputs.encoder_last_hidden_state,
1828
+ encoder_hidden_states=outputs.encoder_hidden_states,
1829
+ encoder_attentions=outputs.encoder_attentions,
1830
+ )
1831
+
1832
+
1833
+ @add_start_docstrings("The Whisper Model with a language modeling head.", WHISPER_START_DOCSTRING)
1834
+ class FlaxWhisperForConditionalGeneration(FlaxWhisperPreTrainedModel):
1835
+ module_class = FlaxWhisperForConditionalGenerationModule
1836
+
1837
+ @add_start_docstrings(WHISPER_DECODE_INPUTS_DOCSTRING)
1838
+ @replace_return_docstrings(output_type=FlaxCausalLMOutputWithCrossAttentions, config_class=WhisperConfig)
1839
+ def decode(
1840
+ self,
1841
+ decoder_input_ids,
1842
+ encoder_outputs,
1843
+ encoder_attention_mask: Optional[jnp.ndarray] = None,
1844
+ decoder_attention_mask: Optional[jnp.ndarray] = None,
1845
+ decoder_position_ids: Optional[jnp.ndarray] = None,
1846
+ past_key_values: dict = None,
1847
+ output_attentions: Optional[bool] = None,
1848
+ output_hidden_states: Optional[bool] = None,
1849
+ return_dict: Optional[bool] = None,
1850
+ train: bool = False,
1851
+ params: dict = None,
1852
+ dropout_rng: PRNGKey = None,
1853
+ ):
1854
+ r"""
1855
+ Returns:
1856
+
1857
+ Example:
1858
+
1859
+ ```python
1860
+ >>> from transformers import WhisperProcessor, FlaxWhisperForConditionalGeneration
1861
+ >>> from datasets import load_dataset
1862
+
1863
+ >>> processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en")
1864
+ >>> model = FlaxWhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en", from_pt=True)
1865
+ >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
1866
+ >>> inputs = processor(ds[0]["audio"]["array"], return_tensors="np")
1867
+ >>> input_features = inputs.input_features
1868
+ >>> encoder_outputs = model.encode(input_features=input_features)
1869
+ >>> decoder_start_token_id = model.config.decoder_start_token_id
1870
+
1871
+ >>> decoder_input_ids = jnp.ones((inputs.input_ids.shape[0], 1), dtype="i4") * decoder_start_token_id
1872
+
1873
+ >>> outputs = model.decode(decoder_input_ids, encoder_outputs)
1874
+ >>> last_decoder_hidden_states = outputs.last_hidden_state
1875
+ ```"""
1876
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1877
+ output_hidden_states = (
1878
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1879
+ )
1880
+ return_dict = return_dict if return_dict is not None else self.config.return_dict
1881
+
1882
+ encoder_hidden_states = encoder_outputs[0]
1883
+
1884
+ batch_size, sequence_length = decoder_input_ids.shape
1885
+ if decoder_position_ids is None:
1886
+ if past_key_values is not None:
1887
+ raise ValueError("Make sure to provide `decoder_position_ids` when passing `past_key_values`.")
1888
+
1889
+ if decoder_attention_mask is not None:
1890
+ decoder_position_ids = (decoder_attention_mask.cumsum(-1) * decoder_attention_mask) - 1
1891
+ else:
1892
+ decoder_position_ids = jnp.broadcast_to(
1893
+ jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)
1894
+ )
1895
+ if decoder_attention_mask is None:
1896
+ decoder_attention_mask = jnp.ones((batch_size, sequence_length), dtype="i4")
1897
+
1898
+ # Handle any PRNG if needed
1899
+ rngs = {}
1900
+ if dropout_rng is not None:
1901
+ rngs["dropout"] = dropout_rng
1902
+
1903
+ inputs = {"params": params or self.params}
1904
+
1905
+ # if past_key_values are passed then cache is already initialized a private flag init_cache has to be
1906
+ # passed down to ensure cache is used. It has to be made sure that cache is marked as mutable so that
1907
+ # it can be changed by FlaxWhisperAttention module
1908
+ if past_key_values:
1909
+ inputs["cache"] = past_key_values
1910
+ mutable = ["cache"]
1911
+ else:
1912
+ mutable = False
1913
+
1914
+ def _decoder_forward(
1915
+ module,
1916
+ decoder_input_ids,
1917
+ decoder_attention_mask,
1918
+ decoder_position_ids,
1919
+ **kwargs,
1920
+ ):
1921
+ decoder_module = module._get_decoder_module()
1922
+ outputs = decoder_module(
1923
+ input_ids=decoder_input_ids,
1924
+ attention_mask=decoder_attention_mask,
1925
+ position_ids=decoder_position_ids,
1926
+ **kwargs,
1927
+ )
1928
+ hidden_states = outputs[0]
1929
+
1930
+ if self.config.tie_word_embeddings:
1931
+ shared_embedding = module.model.decoder.embed_tokens.variables["params"]["embedding"]
1932
+ lm_logits = module.lm_head.apply({"params": {"kernel": shared_embedding.T}}, hidden_states)
1933
+ else:
1934
+ lm_logits = module.lm_head(hidden_states)
1935
+
1936
+ return lm_logits, outputs
1937
+
1938
+ outputs = self.module.apply(
1939
+ inputs,
1940
+ decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"),
1941
+ decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"),
1942
+ decoder_position_ids=jnp.array(decoder_position_ids, dtype="i4"),
1943
+ encoder_hidden_states=encoder_hidden_states,
1944
+ output_attentions=output_attentions,
1945
+ output_hidden_states=output_hidden_states,
1946
+ return_dict=return_dict,
1947
+ deterministic=not train,
1948
+ rngs=rngs,
1949
+ mutable=mutable,
1950
+ method=_decoder_forward,
1951
+ )
1952
+
1953
+ if past_key_values is None:
1954
+ lm_logits, decoder_outputs = outputs
1955
+ else:
1956
+ (lm_logits, decoder_outputs), past = outputs
1957
+
1958
+ if return_dict:
1959
+ outputs = FlaxCausalLMOutputWithCrossAttentions(
1960
+ logits=lm_logits,
1961
+ hidden_states=decoder_outputs.hidden_states,
1962
+ attentions=decoder_outputs.attentions,
1963
+ cross_attentions=decoder_outputs.cross_attentions,
1964
+ )
1965
+ else:
1966
+ outputs = (lm_logits,) + decoder_outputs[1:]
1967
+
1968
+ # add updated cache to model output
1969
+ if past_key_values is not None and return_dict:
1970
+ outputs["past_key_values"] = unfreeze(past["cache"])
1971
+ return outputs
1972
+ elif past_key_values is not None and not return_dict:
1973
+ outputs = outputs[:1] + (unfreeze(past["cache"]),) + outputs[1:]
1974
+
1975
+ return outputs
1976
+
1977
+ def generate(
1978
+ self,
1979
+ input_features,
1980
+ generation_config=None,
1981
+ logits_processor=None,
1982
+ return_timestamps=None,
1983
+ task=None,
1984
+ language=None,
1985
+ is_multilingual=None,
1986
+ **kwargs,
1987
+ ):
1988
+ if generation_config is None:
1989
+ generation_config = self.generation_config
1990
+
1991
+ if return_timestamps is not None:
1992
+ generation_config.return_timestamps = return_timestamps
1993
+
1994
+ if task is not None:
1995
+ generation_config.task = task
1996
+
1997
+ if is_multilingual is not None:
1998
+ generation_config.is_multilingual = is_multilingual
1999
+
2000
+ if language is not None:
2001
+ generation_config.language = language
2002
+
2003
+ if kwargs is not None and "decoder_input_ids" in kwargs:
2004
+ decoder_input_length = len(kwargs["decoder_input_ids"])
2005
+ else:
2006
+ decoder_input_length = 1
2007
+
2008
+ forced_decoder_ids = []
2009
+
2010
+ if hasattr(generation_config, "is_multilingual") and generation_config.is_multilingual:
2011
+ if hasattr(generation_config, "language"):
2012
+ forced_decoder_ids.append((1, generation_config.lang_to_id[generation_config.language]))
2013
+ else:
2014
+ forced_decoder_ids.append((1, None))
2015
+
2016
+ if hasattr(generation_config, "task"):
2017
+ forced_decoder_ids.append((2, generation_config.task_to_id[generation_config.task]))
2018
+ else:
2019
+ forced_decoder_ids.append((2, generation_config.task_to_id["transcribe"]))
2020
+
2021
+ if (
2022
+ hasattr(generation_config, "return_timestamps") and generation_config.return_timestamps
2023
+ ) or return_timestamps:
2024
+ logits_processor = [
2025
+ FlaxWhisperTimeStampLogitsProcessor(generation_config, self.config, decoder_input_length)
2026
+ ]
2027
+ else:
2028
+ if forced_decoder_ids and forced_decoder_ids[-1][0] != generation_config.no_timestamps_token_id:
2029
+ idx = forced_decoder_ids[-1][0] + 1 if forced_decoder_ids else 1
2030
+ forced_decoder_ids.append((idx, generation_config.no_timestamps_token_id))
2031
+
2032
+ if len(forced_decoder_ids) > 0:
2033
+ generation_config.forced_decoder_ids = forced_decoder_ids
2034
+
2035
+ return super().generate(
2036
+ input_features,
2037
+ generation_config,
2038
+ logits_processor=logits_processor,
2039
+ **kwargs,
2040
+ )
2041
+
2042
+ def pipeline_generate(
2043
+ self,
2044
+ input_features,
2045
+ forced_decoder_ids,
2046
+ return_timestamps=False,
2047
+ generation_config=None,
2048
+ **kwargs,
2049
+ ):
2050
+ if generation_config is None:
2051
+ generation_config = self.generation_config
2052
+
2053
+ # override the generation config forced decoder ids in preference of the ones we have set
2054
+ generation_config.forced_decoder_ids = None
2055
+
2056
+ logits_processor = FlaxLogitsProcessorList()
2057
+ logits_processor.append(FlaxStaticForceTokensLogitsProcessor(forced_decoder_ids))
2058
+
2059
+ if hasattr(generation_config, "return_timestamps") and return_timestamps:
2060
+ logits_processor.append(FlaxWhisperTimeStampLogitsProcessor(generation_config, self.config, 1))
2061
+
2062
+ return super().generate(
2063
+ input_features,
2064
+ generation_config,
2065
+ logits_processor=logits_processor,
2066
+ **kwargs,
2067
+ )
2068
+
2069
+ def prepare_inputs_for_generation(
2070
+ self,
2071
+ decoder_input_ids,
2072
+ max_length,
2073
+ attention_mask: Optional[jax.Array] = None,
2074
+ decoder_attention_mask: Optional[jax.Array] = None,
2075
+ encoder_outputs=None,
2076
+ **kwargs,
2077
+ ):
2078
+ # initializing the cache
2079
+ batch_size, seq_length = decoder_input_ids.shape
2080
+
2081
+ past_key_values = self.init_cache(batch_size, max_length, encoder_outputs)
2082
+ # Note that usually one would have to put 0's in the attention_mask for x > input_ids.shape[-1] and x < cache_length.
2083
+ # But since the decoder uses a causal mask, those positions are masked anyways.
2084
+ # Thus we can create a single static attention_mask here, which is more efficient for compilation
2085
+ extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4")
2086
+ if decoder_attention_mask is not None:
2087
+ position_ids = decoder_attention_mask.cumsum(-1) - 1
2088
+ extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, decoder_attention_mask, (0, 0))
2089
+ else:
2090
+ position_ids = jnp.broadcast_to(jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length))
2091
+
2092
+ return {
2093
+ "past_key_values": past_key_values,
2094
+ "encoder_outputs": encoder_outputs,
2095
+ "encoder_attention_mask": attention_mask,
2096
+ "decoder_attention_mask": extended_attention_mask,
2097
+ "decoder_position_ids": position_ids,
2098
+ }
2099
+
2100
+ def update_inputs_for_generation(self, model_outputs, model_kwargs):
2101
+ model_kwargs["past_key_values"] = model_outputs.past_key_values
2102
+ model_kwargs["decoder_position_ids"] = model_kwargs["decoder_position_ids"][:, -1:] + 1
2103
+ return model_kwargs
2104
+
2105
+
2106
+ FLAX_WHISPER_CONDITIONAL_GENERATION_DOCSTRING = r"""
2107
+ Returns:
2108
+
2109
+ Transcription example:
2110
+
2111
+ ```python
2112
+ >>> from transformers import WhisperProcessor, FlaxWhisperForConditionalGeneration
2113
+ >>> from datasets import load_dataset
2114
+
2115
+ >>> processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en")
2116
+ >>> model = FlaxWhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en", from_pt=True)
2117
+ >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
2118
+ >>> inputs = processor(ds[0]["audio"]["array"], return_tensors="np")
2119
+ >>> input_features = inputs.input_features
2120
+ >>> generated_ids = model.generate(input_ids=input_features)
2121
+ >>> transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
2122
+ >>> transcription
2123
+ ' Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel.'
2124
+ ```
2125
+ """
2126
+
2127
+ overwrite_call_docstring(
2128
+ FlaxWhisperForConditionalGeneration,
2129
+ WHISPER_INPUTS_DOCSTRING + FLAX_WHISPER_CONDITIONAL_GENERATION_DOCSTRING,
2130
+ )
2131
+ append_replace_return_docstrings(
2132
+ FlaxWhisperForConditionalGeneration,
2133
+ output_type=FlaxSeq2SeqLMOutput,
2134
+ config_class=_CONFIG_FOR_DOC,
2135
+ )
distil_whisper/partitioner.py ADDED
@@ -0,0 +1,965 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 The T5X Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Utilities for partitioning."""
16
+
17
+ import abc
18
+ import collections
19
+ import dataclasses
20
+ import typing
21
+ from typing import Any, Callable, Optional, Sequence, Tuple, Union
22
+
23
+ import cached_property
24
+ import jax
25
+ import numpy as np
26
+ from absl import logging
27
+ from flax import traverse_util
28
+ from flax.linen import partitioning as flax_partitioning
29
+ from jax import numpy as jnp
30
+ from jax import random
31
+ from jax.experimental import multihost_utils
32
+ from jax.experimental.mesh_utils import create_hybrid_device_mesh
33
+ from jax.experimental.pjit import pjit as jax_pjit
34
+ from jax.sharding import Mesh, PartitionSpec
35
+
36
+
37
+ JaxDevice = Any
38
+ TpuMesh = Tuple[int, int, int, int] # (x, y, z, num_cores).
39
+ OtherMesh = Tuple[int, int]
40
+ HardwareMesh = Union[TpuMesh, OtherMesh]
41
+ PyTreeDef = type(jax.tree_util.tree_structure(None))
42
+ TrainState = Any
43
+ LogicalAxisRules = Sequence[Tuple[str, Optional[str]]]
44
+
45
+ if typing.TYPE_CHECKING: # See b/163639353
46
+ cached_property = property # pylint: disable=invalid-name
47
+ else:
48
+ cached_property = cached_property.cached_property
49
+
50
+
51
+ class AxisNames(tuple):
52
+ """Tuple of strings specifying name for each axis.
53
+
54
+ We create a separate class for this so JAX's pytree utilities can distinguish
55
+ it from a tuple that should be treated as a pytree, instead treating it as a
56
+ leaf.
57
+ """
58
+
59
+ def __new__(cls, *names):
60
+ return tuple.__new__(AxisNames, names)
61
+
62
+ def __repr__(self):
63
+ return "AxisNames%s" % tuple.__repr__(self)
64
+
65
+
66
+ # pjit wrappers for cpu fallback.
67
+ # ----------------------------------------------------------------------------
68
+ # TODO(levskaya): This function is now no different than jax_pjit, but callers
69
+ # currently depend on `backend` argument
70
+ def pjit(
71
+ fun: Callable, # pylint: disable=g-bare-generic
72
+ in_axis_resources,
73
+ out_axis_resources,
74
+ static_argnums: Union[int, Sequence[int]] = (),
75
+ donate_argnums: Union[int, Sequence[int]] = (),
76
+ backend: Optional[str] = None,
77
+ ):
78
+ """Wrapper for pjit."""
79
+ del backend
80
+ return jax_pjit(
81
+ fun,
82
+ in_axis_resources,
83
+ out_axis_resources,
84
+ static_argnums=static_argnums,
85
+ donate_argnums=donate_argnums,
86
+ )
87
+
88
+
89
+ # pjit wrappers for cpu fallback.
90
+ # -----------------------------------------------------------------------------
91
+ # TODO(levskaya): upstream this fallback behavior to jax pjit.
92
+ def pjit_with_cpu_fallback(
93
+ fun: Callable, # pylint: disable=g-bare-generic
94
+ in_axis_resources,
95
+ out_axis_resources,
96
+ static_argnums: Union[int, Sequence[int]] = (),
97
+ donate_argnums: Union[int, Sequence[int]] = (),
98
+ backend: Optional[str] = None,
99
+ ):
100
+ """Wrapper for pjit that calls normal jit on cpu."""
101
+ if jax.devices(backend)[0].platform == "cpu":
102
+ return jax.jit(fun, static_argnums=static_argnums, donate_argnums=donate_argnums)
103
+ else:
104
+ return jax_pjit(
105
+ fun,
106
+ in_axis_resources,
107
+ out_axis_resources,
108
+ static_argnums=static_argnums,
109
+ donate_argnums=donate_argnums,
110
+ )
111
+
112
+
113
+ def with_sharding_constraint(x, axis_resources):
114
+ """Wrapper for pjit with_sharding_constraint, no-op on cpu or outside pjit."""
115
+ if jax.devices()[0].platform == "cpu" or not global_mesh_defined():
116
+ return x
117
+ else:
118
+ return jax.experimental.pjit.with_sharding_constraint(x, axis_resources)
119
+
120
+
121
+ # pjit Mesh creation functions.
122
+ # -----------------------------------------------------------------------------
123
+ def bounds_from_last_device(last_device: JaxDevice) -> HardwareMesh:
124
+ """Get the bound from the given last device."""
125
+ # Must be passed the device at the highest-coordinate corner of the
126
+ # relevant mesh, which is a requirement we know is satisfied by the last
127
+ # device in jax.devices().
128
+ if hasattr(last_device, "coords"):
129
+ x, y, z = last_device.coords
130
+ return x + 1, y + 1, z + 1, last_device.core_on_chip + 1
131
+ else:
132
+ # On non-TPU platforms, the "mesh" is hosts x devices per host in order
133
+ # to take advantage of faster within-host interconnect.
134
+ return jax.host_count(), jax.local_device_count()
135
+
136
+
137
+ def get_coords(device: JaxDevice) -> HardwareMesh:
138
+ """Returns the coordinates of the given device."""
139
+ if hasattr(device, "coords"):
140
+ return (*device.coords, device.core_on_chip)
141
+ return (device.process_index, device.id % jax.local_device_count())
142
+
143
+
144
+ def global_mesh_defined():
145
+ """Checks if global xmap/pjit mesh resource environment is defined."""
146
+ maps_env = jax.experimental.maps.thread_resources.env
147
+ return maps_env.physical_mesh.devices.shape != () # pylint: disable=g-explicit-bool-comparison
148
+
149
+
150
+ def get_mesh(
151
+ model_parallel_submesh: HardwareMesh,
152
+ input_devices: Sequence[JaxDevice] = (),
153
+ input_local_devices: Sequence[JaxDevice] = (),
154
+ tile_by_host_if_needed: bool = True,
155
+ backend: Optional[str] = None,
156
+ ) -> Mesh:
157
+ """Construct an xmap/pjit Mesh for the given model-parallel submesh.
158
+
159
+ The resulting mesh has two resource axes: 'model', with the provided submesh
160
+ shape, and 'data', which covers the rest of the mesh.
161
+
162
+ Args:
163
+ model_parallel_submesh: a HardwareMesh spec, namely (x,y,z,core) on TPU for
164
+ a single model-parallel replica's "tile" in the physical device mesh. The
165
+ first three elements (`x`, `y`, and `z`) should be factors of the pod
166
+ slice; e.g., if you are using df_4x8, then `x` should be a factor of 4
167
+ (one of 1, 2, 4), `y` should be a factor of 8 (one of 1, 2, 4, 8), and `z`
168
+ must be 1, because TPU v3 slices are only 2D. `z` can be >1 for TPU v4
169
+ (and maybe later TPUs) that allow 3D slices. `core` is the number of cores
170
+ to use from each TPU node. As communication is usually fastest inside the
171
+ same node, if you need a tile of more than 1 core, then
172
+ you should first increase `core`: e.g., for TPU v3, (1,1,1,2) is better
173
+ than (2,1,1,1). To pick a good spec, try a few possible values until you
174
+ get high TPU utilization.
175
+ input_devices: the devices to use, will use jax.devices() if this is not
176
+ set.
177
+ input_local_devices: the local devices to use, will use jax.local_devices()
178
+ if this is not set.
179
+ tile_by_host_if_needed: JAX currently requires that the parts of any sharded
180
+ array that are located on one host's local devices form a single
181
+ contiguous slice. A best effort will be made to achieve this without
182
+ "tiling" the device assignment over hosts (which can reduce XLA collective
183
+ performance). If this flag is True, then the device assignment will be
184
+ tiled over hosts if necessary to satisfy this constraint and create a
185
+ buildable mesh; if false, mesh construction will fail instead.
186
+ backend: get devices from the pinned backend, if specified. This is
187
+ useful for explicitly specifying the devices other than relying on
188
+ jax_platform_name.
189
+
190
+ Returns:
191
+ A xmap / pjit Mesh containing the virtual device mesh with data, model axes.
192
+ """
193
+ input_devices = input_devices or jax.devices(backend)
194
+ input_local_devices = input_local_devices or jax.local_devices(0, backend)
195
+ # Sort input_devices based on coords, as backends might not return devices
196
+ # in order.
197
+ last_device = sorted(input_devices, key=get_coords)[-1]
198
+ last_input_local_devices = sorted(input_local_devices, key=get_coords)[-1]
199
+ logging.info(
200
+ "last device coords : %r\nlast local device coords: %r",
201
+ get_coords(last_device),
202
+ get_coords(last_input_local_devices),
203
+ )
204
+ global_hardware_mesh = bounds_from_last_device(last_device)
205
+ mesh_ndim = len(global_hardware_mesh)
206
+ local_hardware_mesh = bounds_from_last_device(last_input_local_devices)
207
+ mesh_err = (
208
+ f"each dimension of the model parallel submesh {model_parallel_submesh} "
209
+ "must be a factor of the corresponding dimension of the global device "
210
+ f"mesh {global_hardware_mesh}"
211
+ )
212
+ assert not any(g % m for g, m in zip(global_hardware_mesh, model_parallel_submesh)), mesh_err
213
+ assert not any(g % l for g, l in zip(global_hardware_mesh, local_hardware_mesh))
214
+ devices = np.empty(global_hardware_mesh, dtype=object)
215
+ for device in input_devices:
216
+ device_coords = get_coords(device)
217
+ devices[device_coords] = device
218
+ tile_by_host = tile_by_host_if_needed
219
+ if len(global_hardware_mesh) == 4:
220
+ # enable contiguous local chunks without host tiling by making Z major
221
+ global_hardware_mesh = typing.cast(Tuple[int, int, int, int], global_hardware_mesh)
222
+ model_parallel_submesh = typing.cast(Tuple[int, int, int, int], model_parallel_submesh)
223
+ gx, gy, gz, gc = global_hardware_mesh
224
+ mx, my, mz, mc = model_parallel_submesh
225
+ if (mx == gx > 1 and my == mz == 1) or (mx == 1 and my == gy > 1 and mz == gz > 1):
226
+ logging.info("ensuring YZ plane has a Z-major device order")
227
+ # YZ should be ZY
228
+ assert mc == gc, (mc, gc)
229
+ global_hardware_mesh = gx, gz, gy, gc
230
+ model_parallel_submesh = mx, mz, my, mc
231
+ devices = devices.swapaxes(1, 2)
232
+ tile_by_host = False
233
+ if (my == gy > 1 and mx == mz == 1) or (my == 1 and mx == gx > 1 and mz == gz > 1):
234
+ logging.info("ensuring XZ plane has a Z-major device order")
235
+ # XZ should be ZX
236
+ assert mc == gc, (mc, gc)
237
+ global_hardware_mesh = gz, gy, gx, gc
238
+ model_parallel_submesh = mz, my, mx, mc
239
+ devices = devices.swapaxes(0, 2)
240
+ tile_by_host = False
241
+ if tile_by_host:
242
+ logging.warning(
243
+ "Tiling device assignment mesh by hosts, which may lead to "
244
+ "reduced XLA collective performance. To avoid this, modify "
245
+ "the model parallel submesh or run with more tasks per host."
246
+ )
247
+ tile_err = (
248
+ "to tile the mesh by hosts, each dimension of the model parallel "
249
+ "submesh must be either a factor or a multiple of the corresponding "
250
+ "dimension of the per-host submesh"
251
+ )
252
+
253
+ def dh_dd_mh_md(g: int, m: int, l: int) -> Tuple[int, int, int, int]:
254
+ """Split a global mesh dimension into four tiling components.
255
+
256
+ Args:
257
+ g: global mesh bounds dimension size
258
+ m: model-parallel submesh bounds dimension size
259
+ l: local submesh bounds dimension size
260
+
261
+ Returns:
262
+ The resulting tuple divides the dimension into the hosts component of
263
+ the data-parallel submesh, the devices component of the data-parallel
264
+ submesh, the hosts component of the model-parallel submesh, and the
265
+ devices component of the model-parallel submesh.
266
+ """
267
+ d = g // m
268
+ if m >= l:
269
+ assert not m % l, tile_err
270
+ return (d, 1, m // l, l)
271
+ else:
272
+ assert not l % m, tile_err
273
+ return (d // (l // m), l // m, 1, m)
274
+
275
+ # e.g. [(x_data_hosts, x_data_devs, x_model_hosts, x_model_devs), ...]
276
+ dh_dd_mh_md_tups = map(
277
+ dh_dd_mh_md,
278
+ global_hardware_mesh,
279
+ model_parallel_submesh,
280
+ local_hardware_mesh,
281
+ )
282
+ # reshape to e.g. (x_dh, x_dd, x_mh, x_md, y_dh, ...)
283
+ devices = devices.reshape(*(s for t in dh_dd_mh_md_tups for s in t)) # pylint: disable=g-complex-comprehension
284
+ # TODO(jekbradbury): reorder local subgroups for ring locality
285
+ # Transpose to [data_host], [data_device], [model_host], [model_device]
286
+ # block ordering e.g. (x_dh, y_dh, ..., x_dd, y_dd, ...)
287
+ devices = devices.transpose(
288
+ *(4 * i for i in range(mesh_ndim)),
289
+ *(4 * i + 1 for i in range(mesh_ndim)),
290
+ *(4 * i + 2 for i in range(mesh_ndim)),
291
+ *(4 * i + 3 for i in range(mesh_ndim)),
292
+ )
293
+ else:
294
+ # e.g. [(x_data, x_model), (y_data, y_model), ...]
295
+ model_data_tups = [(g // m, m) for g, m in zip(global_hardware_mesh, model_parallel_submesh)]
296
+ # reshape to e.g. (x_data, x_model, y_data, y_model...)
297
+ devices = devices.reshape(*(s for t in model_data_tups for s in t)) # pylint: disable=g-complex-comprehension
298
+ # TODO(jekbradbury): reorder small subgroups for ring locality
299
+ # transpose to e.g. (x_data, y_data, ..., x_model, ...)
300
+ devices = devices.transpose(*(2 * i for i in range(mesh_ndim)), *(2 * i + 1 for i in range(mesh_ndim)))
301
+ # reshape to (data, model)
302
+ devices = devices.reshape(-1, np.prod(model_parallel_submesh))
303
+ global_mesh = Mesh(devices, ["data", "model"])
304
+ logging.info("global_mesh axis_names: %s", global_mesh.axis_names)
305
+ logging.info("global_mesh devices: %s", global_mesh.devices)
306
+ logging.info("global_mesh devices shape: %s", global_mesh.devices.shape)
307
+ return global_mesh
308
+
309
+
310
+ def get_cpu_mesh() -> Mesh:
311
+ """Trivial mesh for CPU Testing."""
312
+ devices = np.empty((jax.host_count(), jax.local_device_count()), dtype=object)
313
+ for device in jax.devices():
314
+ devices[device.process_index, device.id % jax.local_device_count()] = device
315
+ return Mesh(devices, ["data", "model"])
316
+
317
+
318
+ def get_gpu_mesh(num_partitions: int) -> Mesh:
319
+ """Mesh for GPUs that preferentially places 'model' on NVLink."""
320
+ nvlink_size = jax.local_device_count()
321
+ dcn_size = jax.process_count()
322
+ nvlink_mp = min(num_partitions, nvlink_size)
323
+ nvlink_dp, extra1 = divmod(nvlink_size, nvlink_mp)
324
+ dcn_mp, extra2 = divmod(num_partitions, nvlink_mp)
325
+ assert not (
326
+ extra1 or extra2
327
+ ), "number of partitions on GPU must be a factor or multiple of the number of local devices"
328
+ dcn_dp = dcn_size // dcn_mp
329
+
330
+ devices = create_hybrid_device_mesh(
331
+ mesh_shape=[nvlink_dp, nvlink_mp],
332
+ dcn_mesh_shape=[dcn_dp, dcn_mp],
333
+ process_is_granule=True,
334
+ )
335
+
336
+ global_mesh = Mesh(devices, ["data", "model"])
337
+ logging.info("global_mesh axis_names: %s", global_mesh.axis_names)
338
+ logging.info("global_mesh devices: %s", global_mesh.devices)
339
+ return global_mesh
340
+
341
+
342
+ def default_mesh(
343
+ num_partitions: int,
344
+ model_parallel_submesh: Optional[HardwareMesh] = None,
345
+ backend: Optional[str] = None,
346
+ ) -> Mesh:
347
+ """Attempt to return a default mesh for simple cases.
348
+
349
+ Args:
350
+ num_partitions: number of partitions to use, will be ignored if
351
+ model_parallel_submesh is provided.
352
+ model_parallel_submesh: 4-tuple that specifies the x,y,z,c submesh to use as
353
+ the model-parallel device tile.
354
+ backend: get devices from the pinned backend, if specified. This is useful
355
+ for explicitly specifying the devices other than relying on
356
+ jax_platform_name.
357
+
358
+ Returns:
359
+ xmap/pjit 2D Mesh with 'data', 'model' mesh axes.
360
+ """
361
+ last_device = jax.devices(backend)[-1]
362
+ platform = last_device.platform
363
+ device_kind = last_device.device_kind
364
+ bounds = bounds_from_last_device(last_device)
365
+
366
+ if model_parallel_submesh:
367
+ return get_mesh(model_parallel_submesh, backend=backend)
368
+
369
+ if platform == "cpu":
370
+ return get_cpu_mesh()
371
+ elif platform == "gpu":
372
+ return get_gpu_mesh(num_partitions)
373
+
374
+ mps = None
375
+ if device_kind in ("TPU v2", "TPU v3"):
376
+ if num_partitions == 1:
377
+ mps = (1, 1, 1, 1)
378
+ elif num_partitions == 2:
379
+ mps = (1, 1, 1, 2)
380
+ elif num_partitions == 4:
381
+ mps = (2, 1, 1, 2)
382
+ elif num_partitions == 8:
383
+ mps = (2, 2, 1, 2)
384
+ elif num_partitions == 16:
385
+ mps = (4, 2, 1, 2)
386
+ # assume the use of megacore on TPU v4
387
+ elif (device_kind == "TPU v4" or device_kind == "TPU v4 lite") and bounds[3] == 1:
388
+ if num_partitions == 1:
389
+ mps = (1, 1, 1, 1)
390
+ elif num_partitions == 2:
391
+ mps = (1, 2, 1, 1)
392
+ elif num_partitions == 4:
393
+ if bounds[0] >= 4:
394
+ mps = (4, 1, 1, 1)
395
+ else:
396
+ mps = (2, 2, 1, 1)
397
+ elif num_partitions == 8:
398
+ if bounds[2] >= 8:
399
+ mps = (1, 1, 8, 1)
400
+ else:
401
+ mps = (4, 2, 1, 1)
402
+ elif num_partitions == 16:
403
+ if bounds[2] >= 16:
404
+ mps = (1, 1, 16, 1)
405
+ elif bounds[0] >= 8:
406
+ mps = (8, 2, 1, 1)
407
+ elif bounds[0] >= 4:
408
+ mps = (4, 4, 1, 1)
409
+ else:
410
+ mps = (2, 2, 4, 1)
411
+
412
+ if mps is None:
413
+ raise ValueError(
414
+ "No default mesh for this configuration: specify " "config.model_parallel_submesh explicitly."
415
+ )
416
+ return get_mesh(mps, backend=backend)
417
+
418
+
419
+ # Data chunking helper.
420
+ # -----------------------------------------------------------------------------
421
+ @dataclasses.dataclass
422
+ class LocalChunkInfo:
423
+ # The logical slice of an array located on this host's local devices.
424
+ slice: Tuple[slice, ...]
425
+ # A unique index for this host/local chunk among chunks with the same slice.
426
+ replica_id: int
427
+
428
+
429
+ class LocalChunker:
430
+ """Utility class to aid chunking of sharded arrays in multihost settings."""
431
+
432
+ def __init__(self, global_mesh: Mesh):
433
+ self.global_mesh = global_mesh
434
+ local_mesh = global_mesh.local_mesh
435
+ first_local_device = local_mesh.devices.reshape(-1)[0]
436
+ host_location = collections.OrderedDict(
437
+ zip(
438
+ global_mesh.shape.keys(),
439
+ list(zip(*np.nonzero(global_mesh.devices == first_local_device)))[0],
440
+ )
441
+ )
442
+ self.num_chunks = collections.OrderedDict()
443
+ self.chunk_ids = collections.OrderedDict()
444
+ self.mesh_axes = list(global_mesh.shape.keys())
445
+ for mesh_axis in self.mesh_axes:
446
+ num_devices_per_chunk = local_mesh.shape[mesh_axis]
447
+ self.num_chunks[mesh_axis] = global_mesh.shape[mesh_axis] // num_devices_per_chunk
448
+ self.chunk_ids[mesh_axis] = host_location[mesh_axis] // num_devices_per_chunk
449
+
450
+ def get_local_chunk_info(
451
+ self, global_shape: Tuple[int, ...], mesh_axes: Sequence[Optional[str]]
452
+ ) -> LocalChunkInfo:
453
+ """Get the local chunk info for a given array shape and sharded axes.
454
+
455
+ Args:
456
+ global_shape: the global, unsharded shape of the array to chunk.
457
+ mesh_axes: a sequence of names (or None) of equal rank to `global_shape`
458
+ that specifies which mesh dimensions the array is sharded along.
459
+
460
+ Returns:
461
+ LocalChunkInfo containing the logical slices of the array found on this
462
+ host's local devices, as well as the replica index for this chunk among
463
+ chunks with the same slice. The latter is used to determine which
464
+ host should write this chunk during checkpointing.
465
+ """
466
+ local_slice = [slice(None) for dim in global_shape]
467
+ sharded_mesh_axes = set()
468
+ for i, (mesh_axis, size) in enumerate(zip(mesh_axes, global_shape)):
469
+ if not mesh_axis:
470
+ continue
471
+ sharded_mesh_axes.add(mesh_axis)
472
+ if not isinstance(mesh_axis, str):
473
+ raise NotImplementedError("TODO(jekbradbury)")
474
+ chunk_id = self.chunk_ids[mesh_axis]
475
+ chunk_size = size // self.num_chunks[mesh_axis]
476
+ local_slice[i] = slice(chunk_id * chunk_size, (chunk_id + 1) * chunk_size)
477
+
478
+ replicated_mesh_axes = [mesh_axis for mesh_axis in self.mesh_axes if mesh_axis not in sharded_mesh_axes]
479
+ replica_id = 0
480
+ for mesh_axis in replicated_mesh_axes:
481
+ chunk_id = self.chunk_ids[mesh_axis]
482
+ replica_id = replica_id * self.num_chunks[mesh_axis] + chunk_id
483
+
484
+ return LocalChunkInfo(tuple(local_slice), replica_id)
485
+
486
+
487
+ def standard_logical_axis_rules(
488
+ activation_partitioning_dims: int = 1,
489
+ parameter_partitioning_dims: int = 1,
490
+ additional_rules: Optional[LogicalAxisRules] = None,
491
+ ) -> LogicalAxisRules:
492
+ """Default sharding rules for T5X model in terms of logical axis names.
493
+
494
+ Args:
495
+ activation_partitioning_dims: enables 2-D activation sharding when set to 2.
496
+ parameter_partitioning_dims: enables 2-D parameter sharding when set to 2.
497
+ additional_rules: additional rules (a sequence of tuples) that will be
498
+ appended to the standard rules.
499
+
500
+ Returns:
501
+ Sequence of logical axis rules
502
+ """
503
+ logging.info(
504
+ "`activation_partitioning_dims` = %d, `parameter_partitioning_dims` = %d",
505
+ activation_partitioning_dims,
506
+ parameter_partitioning_dims,
507
+ )
508
+
509
+ if activation_partitioning_dims == 1 and parameter_partitioning_dims == 1:
510
+ rules = [
511
+ ("batch", "data"),
512
+ ("vocab", "model"),
513
+ ("embed", None),
514
+ ("mlp", "model"),
515
+ ("heads", "model"),
516
+ ("kv", None),
517
+ ("joined_kv", "model"), # joined heads+kv dim in 2D attn param layouts
518
+ ]
519
+ elif activation_partitioning_dims == 2 and parameter_partitioning_dims == 1:
520
+ rules = [
521
+ ("batch", "data"),
522
+ ("vocab", "model"),
523
+ ("mlp", "model"),
524
+ ("heads", "model"),
525
+ ("kv", None),
526
+ ("joined_kv", "model"),
527
+ ("embed", "model"),
528
+ ]
529
+ elif activation_partitioning_dims == 1 and parameter_partitioning_dims == 2:
530
+ rules = [
531
+ ("batch", "data"),
532
+ ("vocab", "model"),
533
+ ("mlp", "model"),
534
+ ("heads", "model"),
535
+ ("kv", None),
536
+ ("joined_kv", "model"),
537
+ ("embed", "data"),
538
+ ]
539
+ elif activation_partitioning_dims == 2 and parameter_partitioning_dims == 2:
540
+ rules = [
541
+ ("batch", "data"),
542
+ ("vocab", "model"),
543
+ ("mlp", "model"),
544
+ ("heads", "model"),
545
+ ("kv", None),
546
+ ("joined_kv", "model"),
547
+ ("embed", "model"),
548
+ ("embed", "data"),
549
+ ]
550
+ else:
551
+ raise ValueError(
552
+ f"`activation_partitioning_dims` = {activation_partitioning_dims} "
553
+ f"`parameter_partitioning_dims` = {parameter_partitioning_dims} "
554
+ "is not supported."
555
+ )
556
+
557
+ # Add the common rules for the replicated logical axes names.
558
+ replicated_rules = [
559
+ ("relpos_buckets", None),
560
+ ("abspos_buckets", None),
561
+ ("length", None),
562
+ ("layers", None),
563
+ ("stack", None),
564
+ ("mlp_activations", None),
565
+ ]
566
+ rules.extend(replicated_rules)
567
+
568
+ if additional_rules:
569
+ rules.extend(additional_rules)
570
+
571
+ return rules
572
+
573
+
574
+ # NB: This needs to be top-level for the jax compilation cache.
575
+ def _id_fn(x, ix):
576
+ """Identity function for copying parameters to the devices, sharded."""
577
+ # A pure identity such as `lambda x, *: x` can get optimized away, so we
578
+ # include a random.split as a cheap function that cannot be optimized away.
579
+ y = random.split(random.PRNGKey(jnp.array(ix, dtype=jnp.uint32)))
580
+ return x, y
581
+
582
+
583
+ @dataclasses.dataclass
584
+ class DataLayout:
585
+ """Represents data layout for the partitioned model."""
586
+
587
+ batch_size: int
588
+ shard_id: int
589
+ num_shards: int
590
+ is_first_host_in_replica_set: bool
591
+
592
+
593
+ PartitionedCallable = Callable[..., Any]
594
+ CompiledPartitionedCallable = Callable[..., Any]
595
+
596
+
597
+ class BasePartitioner(metaclass=abc.ABCMeta):
598
+ """Interface for partitioning computations across hardware devices."""
599
+
600
+ def __init__(
601
+ self,
602
+ num_partitions: Optional[int] = None,
603
+ model_parallel_submesh: Optional[HardwareMesh] = None,
604
+ params_on_devices: bool = True,
605
+ backend: Optional[str] = None,
606
+ ):
607
+ """Configures the partitioner.
608
+
609
+ Args:
610
+ num_partitions: the number of partitions to use. Ignored if
611
+ `model_parallel_submesh` is provided.
612
+ model_parallel_submesh: 4-tuple that specifies the x,y,z,c submesh to use
613
+ as the model-parallel device tile. This submesh is used for the larger
614
+ of the two parameter dimensions, and, if 2-D activation sharding is
615
+ enabled, for the model dimension of activations. The rest of the mesh is
616
+ used for data parallelism and, if 2-D parameter sharding is enabled, the
617
+ other parameter dimension.
618
+ params_on_devices: whether to keep the params on devices, if False -
619
+ params stay in the host memory. Note that some partitioners might ignore
620
+ this setting, for example if they don't support storing all params on
621
+ device memory.
622
+ backend: get devices from the pinned backend, if specified. This is useful
623
+ for explicitly specifying the devices other than relying on
624
+ jax_platform_name.
625
+ """
626
+
627
+ if not num_partitions and not model_parallel_submesh:
628
+ raise ValueError("At least one of `num_partitions` or " "`model_parallel_submesh` must be set.")
629
+
630
+ if model_parallel_submesh is not None and len(model_parallel_submesh) != 4:
631
+ logging.error(
632
+ (
633
+ "`model_parallel_submesh` must be either None or a 4-tuple. Got"
634
+ " `model_parallel_submesh`=%s. A ValueError will be raised"
635
+ " beginning March 1, 2022."
636
+ ),
637
+ model_parallel_submesh,
638
+ )
639
+
640
+ if bool(num_partitions) and bool(model_parallel_submesh):
641
+ logging.error(
642
+ (
643
+ "At most one of `num_partitions` or `model_parallel_submesh` can be"
644
+ " set. Got `num_partitions=%s` and `model_parallel_submesh`=%s. A"
645
+ " ValueError will be raised beginning March 21, 2022."
646
+ ),
647
+ num_partitions,
648
+ model_parallel_submesh,
649
+ )
650
+
651
+ self._num_partitions = num_partitions
652
+ self._model_parallel_submesh = model_parallel_submesh
653
+ self._params_on_devices = params_on_devices
654
+ self._data_axis = "data"
655
+ self._backend = backend
656
+
657
+ @property
658
+ def mesh(self) -> Mesh:
659
+ raise NotImplementedError
660
+
661
+ @property
662
+ def data_partition_spec(self) -> PartitionSpec:
663
+ return PartitionSpec(self._data_axis)
664
+
665
+ def get_data_layout(self, batch_size: Optional[int] = None, host_index: Optional[int] = None) -> DataLayout:
666
+ """Returns filled `DataLayout` based on the partitioned model layout.
667
+
668
+ Args:
669
+ batch_size: if set, indicates the requested batch size. The exception will
670
+ be raised if this batch size is not compatible with the layout. If not
671
+ set, the batch size is inferred from the layout.
672
+ host_index: indicates the host index to use for the calculations, if not
673
+ set - use JAX-provided one. Should be in [0, num_hosts) interval and the
674
+ order should match the order of corresponding CPU devices in
675
+ `jax.devices()`.
676
+
677
+ Returns:
678
+ Filled `DataLayout` structure.
679
+ """
680
+ if host_index is not None:
681
+ raise NotImplementedError("Explicit host_index is not yet implemented.")
682
+ if self._data_axis is None:
683
+ return DataLayout(
684
+ batch_size=batch_size,
685
+ shard_id=0,
686
+ num_shards=1,
687
+ is_first_host_in_replica_set=(jax.process_index() == 0),
688
+ )
689
+ mesh_size = self._local_chunker.global_mesh.shape[self._data_axis]
690
+ batch_size = batch_size or mesh_size
691
+ if batch_size % mesh_size:
692
+ raise ValueError(
693
+ f"Batch size ({batch_size}) must be divisible by corresponding " f"mesh size ({mesh_size})."
694
+ )
695
+ num_shards = self._local_chunker.num_chunks[self._data_axis]
696
+ if batch_size % num_shards:
697
+ raise ValueError(f"Batch size ({batch_size}) must be divisible by number of " f"replicas ({num_shards}).")
698
+ replica_id = self._local_chunker.get_local_chunk_info((batch_size,), [self._data_axis]).replica_id
699
+ return DataLayout(
700
+ batch_size=int(batch_size),
701
+ shard_id=int(self._local_chunker.chunk_ids[self._data_axis]),
702
+ num_shards=int(num_shards),
703
+ is_first_host_in_replica_set=(replica_id == 0),
704
+ )
705
+
706
+ def get_local_chunk_info(
707
+ self, global_shape: Tuple[int, ...], mesh_axes: Sequence[Optional[str]]
708
+ ) -> LocalChunkInfo:
709
+ """Returns the local chunk info for a given array shape and sharded axes."""
710
+ return self._local_chunker.get_local_chunk_info(global_shape, mesh_axes)
711
+
712
+ @property
713
+ def params_on_devices(self):
714
+ return self._params_on_devices
715
+
716
+ def move_params_to_devices(self, train_state: TrainState, train_state_axes: TrainState) -> TrainState:
717
+ """Moves the optimizer parameters to devices."""
718
+ p_id_fn = self.partition(
719
+ _id_fn,
720
+ in_axis_resources=(train_state_axes, None),
721
+ out_axis_resources=(train_state_axes, None),
722
+ donate_argnums=(0,),
723
+ )
724
+ if jax.config.jax_array and jax.process_count() > 1:
725
+ train_state = multihost_utils.host_local_array_to_global_array(train_state, self.mesh, train_state_axes)
726
+ train_state, _ = p_id_fn(train_state, jnp.ones((), dtype=jnp.uint32))
727
+ return train_state
728
+
729
+ @property
730
+ @abc.abstractmethod
731
+ def _local_chunker(self):
732
+ """Returns the chunker that matches the parameters of this partitioner."""
733
+ raise NotImplementedError
734
+
735
+ def get_logical_axes(self, train_state: TrainState) -> TrainState:
736
+ """Returns a copy of TrainState with Optional[AxisNames] as leaves."""
737
+ # By default, return None for the logical axes.
738
+ return train_state.restore_state(jax.tree_map(lambda x: None, train_state.state_dict()))
739
+
740
+ def get_mesh_axes(self, train_state: TrainState) -> TrainState:
741
+ """Returns a copy of TrainState with Optional[PartitionSpecs] as leaves."""
742
+ raise NotImplementedError
743
+
744
+ @abc.abstractmethod
745
+ def partition(
746
+ self,
747
+ fn: Callable, # pylint: disable=g-bare-generic
748
+ in_axis_resources,
749
+ out_axis_resources,
750
+ static_argnums: Union[int, Sequence[int]] = (),
751
+ donate_argnums: Union[int, Sequence[int]] = (),
752
+ ) -> PartitionedCallable:
753
+ """Partitions the computation using partitioner-specific implementation.
754
+
755
+ Args:
756
+ fn: the function to partition.
757
+ in_axis_resources: Pytree of structure matching that of arguments to `fn`,
758
+ with all actual arguments replaced by resource assignment
759
+ specifications. It is also valid to specify a pytree prefix (e.g. one
760
+ value in place of a whole subtree), in which case the leaves get
761
+ broadcast to all values in that subtree.
762
+ The valid resource assignment specifications are:
763
+ `None`: in which case the value will be replicated on all devices
764
+ `PartitionSpec`: a tuple of length at most equal to the rank of the
765
+ partitioned value. Each element can be a `None`, a mesh axis or a
766
+ tuple of mesh axes, and specifies the set of resources assigned to
767
+ partition the value's dimension matching its position in the spec.
768
+ out_axis_resources: Like `in_axis_resources`, but specifies resource
769
+ assignment for function outputs.
770
+ static_argnums: an optional int or collection of ints that specify which
771
+ positional arguments to treat as static (compile-time constant) in the
772
+ partitioned function.
773
+ donate_argnums: an optional int or collection of ints that specify which
774
+ argument buffers are "donated" to the computation. It is safe to donate
775
+ argument buffers if you no longer need them once the computation has
776
+ finished.
777
+
778
+ Returns:
779
+ A partitioned version of the input function.
780
+ """
781
+ raise NotImplementedError
782
+
783
+ @abc.abstractmethod
784
+ def compile(self, partitioned_fn: PartitionedCallable, *args) -> CompiledPartitionedCallable:
785
+ """Compiles and returns the partitioned function, or the original.
786
+
787
+ Args:
788
+ partitioned_fn: The partitioned function.
789
+ *args: Sample arguments to the partitioned function matching the input
790
+ shapes that will be passed to the compiled function.
791
+
792
+ Returns:
793
+ The compiled function, or the original if this partitioner does not
794
+ support compilation.
795
+ """
796
+ raise NotImplementedError
797
+
798
+
799
+ class PjittedFnWithContext(PartitionedCallable):
800
+ """Wraps pjitted function to apply the appropriate contexts."""
801
+
802
+ def __init__(
803
+ self,
804
+ pjitted_fn,
805
+ partition_mesh: Mesh,
806
+ logical_axis_rules: flax_partitioning.LogicalRules = (),
807
+ ):
808
+ self._pjitted_fn = pjitted_fn
809
+ self._mesh = partition_mesh
810
+ self._logical_axis_rules = logical_axis_rules
811
+
812
+ def __call__(self, *args):
813
+ with Mesh(self._mesh.devices, self._mesh.axis_names), flax_partitioning.axis_rules(self._logical_axis_rules):
814
+ return self._pjitted_fn(*args)
815
+
816
+ def lower(self, *args):
817
+ with Mesh(self._mesh.devices, self._mesh.axis_names), flax_partitioning.axis_rules(self._logical_axis_rules):
818
+ return self._pjitted_fn.lower(*args)
819
+
820
+
821
+ class BasePjitPartitioner(BasePartitioner):
822
+ """Partitioner that uses T5X version of jax.pjit."""
823
+
824
+ @cached_property
825
+ def _local_chunker(self) -> LocalChunker:
826
+ return LocalChunker(self.mesh)
827
+
828
+ @cached_property
829
+ def mesh(self) -> Mesh:
830
+ return default_mesh(self._num_partitions, self._model_parallel_submesh, self._backend)
831
+
832
+ def partition(
833
+ self,
834
+ fn: Callable, # pylint: disable=g-bare-generic
835
+ in_axis_resources,
836
+ out_axis_resources,
837
+ static_argnums: Union[int, Sequence[int]] = (),
838
+ donate_argnums: Union[int, Sequence[int]] = (),
839
+ ) -> PjittedFnWithContext:
840
+ pjitted = pjit(
841
+ fn,
842
+ in_axis_resources=in_axis_resources,
843
+ out_axis_resources=out_axis_resources,
844
+ static_argnums=static_argnums,
845
+ donate_argnums=donate_argnums,
846
+ backend=self._backend,
847
+ )
848
+
849
+ return PjittedFnWithContext(pjitted, self.mesh)
850
+
851
+ def compile(self, partitioned_fn: PjittedFnWithContext, *args) -> CompiledPartitionedCallable:
852
+ return partitioned_fn.lower(*args).compile()
853
+
854
+
855
+ class PjitPartitioner(BasePjitPartitioner):
856
+ """Partitioner that uses named axes and jax.pjit."""
857
+
858
+ def __init__(
859
+ self,
860
+ num_partitions: Optional[int] = None,
861
+ model_parallel_submesh: Optional[HardwareMesh] = None,
862
+ params_on_devices: bool = True,
863
+ backend: Optional[str] = None,
864
+ logical_axis_rules: Optional[LogicalAxisRules] = None,
865
+ use_cpu_pjit: Optional[bool] = False,
866
+ ):
867
+ """PjitPartitioner constructor.
868
+
869
+ See https://github.com/google-research/text-to-text-transfer-transformer/blob/main/README.mdx/usage/partitioning for details.
870
+
871
+ Args:
872
+ num_partitions: an integer that specifies the size of the model parallel
873
+ submesh to be automatically selected for the current topology. See
874
+ `model_parallel_submesh` for details on how this submesh is used.
875
+ Mutually exlusive with `model_parallel_submesh`.
876
+ model_parallel_submesh: is a 4-tuple that specifies the `(x, y, z, c)`
877
+ submesh model-parallel device tile, an axis of accelerator parallelism
878
+ orthogonal to data parallelism. Array axes in a model's parameters or
879
+ activations can be sharded over this submesh using axis rules (see
880
+ `logical_axis_rules`) that map them to 'model'. The effective number of
881
+ model sub-partitions is equal to `np.prod(model_parallel_submesh)` and
882
+ must evenly divide the total number of devices (i.e.,
883
+ `jax.device_count() % np.prod(model_parallel_submesh) == 0`). The rest
884
+ of the TPU mesh is the data parallel submesh, providing
885
+ `jax.device_count() // np.prod(model_parallel_submesh)` partitions. It
886
+ is used for data (batch) parallelism and to shard other array axes that
887
+ are mapped to 'data'. This argument is mutually exclusive with
888
+ `num_partitions`.
889
+ params_on_devices: whether to keep the params on devices, if False -
890
+ params stay in the host memory. Note that some partitioners might ignore
891
+ this setting, for example if they don't support storing all params on
892
+ device memory.
893
+ backend: get devices from the pinned backend, if specified. This is
894
+ useful for explicitly specifying the devices other than relying on
895
+ jax_platform_name.
896
+ logical_axis_rules: a priority-ordered sequence of KV tuples that maps
897
+ logical axis names to either `None` (not sharded), 'model' (to shard
898
+ across the model-parallel submesh), or 'data' (to shard across the
899
+ data-parallel submesh).
900
+ use_cpu_pjit: enables wrapper function for pjit which just jits the
901
+ function if using CPU backend.
902
+ """
903
+ super().__init__(
904
+ num_partitions=num_partitions,
905
+ model_parallel_submesh=model_parallel_submesh,
906
+ params_on_devices=params_on_devices,
907
+ backend=backend,
908
+ )
909
+ if logical_axis_rules is None:
910
+ logical_axis_rules = standard_logical_axis_rules()
911
+ self._logical_axis_rules = tuple(logical_axis_rules)
912
+ (self._data_axis,) = flax_partitioning.logical_to_mesh_axes(["batch"], logical_axis_rules)
913
+ self._use_cpu_pjit = use_cpu_pjit
914
+
915
+ def partition(
916
+ self,
917
+ fn: Callable, # pylint: disable=g-bare-generic
918
+ in_axis_resources,
919
+ out_axis_resources,
920
+ static_argnums: Union[int, Sequence[int]] = (),
921
+ donate_argnums: Union[int, Sequence[int]] = (),
922
+ ) -> PjittedFnWithContext:
923
+ """Partitions the function using jax.pjit."""
924
+ if self._use_cpu_pjit:
925
+ pjit_fn = pjit_with_cpu_fallback
926
+ else:
927
+ pjit_fn = pjit
928
+ pjitted = pjit_fn(
929
+ fn,
930
+ in_axis_resources=in_axis_resources,
931
+ out_axis_resources=out_axis_resources,
932
+ static_argnums=static_argnums,
933
+ donate_argnums=donate_argnums,
934
+ backend=self._backend,
935
+ )
936
+
937
+ return PjittedFnWithContext(pjitted, self.mesh, self._logical_axis_rules)
938
+
939
+ @property
940
+ def logical_axis_rules(self):
941
+ """Returns the logical axis rules."""
942
+ return self._logical_axis_rules
943
+
944
+ def get_logical_axes(self, train_state: TrainState) -> TrainState:
945
+ """Returns a copy of TrainState with Optional[AxisNames] as leaves."""
946
+ return train_state.as_logical_axes()
947
+
948
+ def get_mesh_axes(self, train_state: TrainState) -> TrainState:
949
+ """Returns a copy of TrainState with Optional[PartitionSpecs] as leaves."""
950
+ logical_axes = self.get_logical_axes(train_state)
951
+
952
+ def _logical_to_mesh_axes(param_name, logical_axes):
953
+ if logical_axes is None:
954
+ return None
955
+ elif logical_axes is traverse_util.empty_node:
956
+ return traverse_util.empty_node
957
+ try:
958
+ return flax_partitioning.logical_to_mesh_axes(logical_axes, self._logical_axis_rules)
959
+ except ValueError as e:
960
+ raise ValueError(f"Failed to map logical axes for {param_name}") from e
961
+
962
+ flat_logical_axes = traverse_util.flatten_dict(logical_axes.state_dict(), keep_empty_nodes=True, sep="/")
963
+ flat_mesh_axes = {k: _logical_to_mesh_axes(k, v) for k, v in flat_logical_axes.items()}
964
+
965
+ return logical_axes.restore_state(traverse_util.unflatten_dict(flat_mesh_axes, sep="/"))
distil_whisper/pipeline.py ADDED
@@ -0,0 +1,527 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """Whisper JAX pipeline compatible with Distil Whisper checkpoints. Copied from https://github.com/sanchit-gandhi/whisper-jax/blob/main/whisper_jax/pipeline.py"""
17
+
18
+ import math
19
+
20
+ import jax
21
+ import jax.numpy as jnp
22
+ import numpy as np
23
+ import requests
24
+ import torch
25
+ from flax import jax_utils
26
+ from flax.core.frozen_dict import freeze
27
+ from flax.training.common_utils import shard
28
+ from transformers import WhisperFeatureExtractor, WhisperTokenizerFast
29
+ from transformers.models.whisper.tokenization_whisper import TO_LANGUAGE_CODE
30
+ from transformers.pipelines.audio_utils import ffmpeg_read
31
+ from transformers.utils import logging
32
+
33
+ from .modeling_flax_whisper import FlaxWhisperForConditionalGeneration
34
+
35
+
36
+ logger = logging.get_logger(__name__)
37
+
38
+
39
+ class FlaxWhisperFeatureExtractor(WhisperFeatureExtractor):
40
+ def _np_extract_fbank_features(self, waveform: np.array) -> np.ndarray:
41
+ """
42
+ Compute the log-mel spectrogram of the provided audio using torch filters. Using the torch implementation
43
+ computes stft filter banks approx 5x faster than its numpy counterpart, which is the native implementation
44
+ in transformers, and matches to within 1e-5 abs tolerance.
45
+ """
46
+ waveform = torch.from_numpy(waveform).type(torch.float32)
47
+
48
+ window = torch.hann_window(self.n_fft)
49
+ stft = torch.stft(waveform, self.n_fft, self.hop_length, window=window, return_complex=True)
50
+ magnitudes = stft[..., :-1].abs() ** 2
51
+
52
+ mel_filters = torch.from_numpy(self.mel_filters).type(torch.float32)
53
+ mel_spec = mel_filters.T @ magnitudes
54
+
55
+ log_spec = torch.clamp(mel_spec, min=1e-10).log10()
56
+ log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
57
+ log_spec = (log_spec + 4.0) / 4.0
58
+ return log_spec.numpy()
59
+
60
+
61
+ class FlaxWhisperPipeline:
62
+ def __init__(
63
+ self,
64
+ checkpoint="openai/whisper-large-v2",
65
+ dtype=jnp.float32,
66
+ batch_size=None,
67
+ max_length=None,
68
+ **kwargs,
69
+ ):
70
+ """
71
+ Args
72
+ checkpoint (`str`, *optional*, defaults to `"openai/whisper-large-v2"):
73
+ The Whisper checkpoint to use with the pipeline. Must be an available checkpoint on the Hugging Face Hub
74
+ with Flax weights.
75
+ dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`):
76
+ The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on GPUs) and
77
+ `jax.numpy.bfloat16` (on TPUs). This can be used to enable half-precision inference on GPUs or TPUs.
78
+ If specified all the computation will be performed with the given `dtype`. **Note that this only
79
+ specifies the dtype of the computation and does not influence the dtype of model parameters.**
80
+ batch_size (`int`, *optional*, defaults to the minimum per-device batch size, i.e. `jax.local_device_count()`):
81
+ The batch size to be used in chunking transcription. Beneficial for transcribing long audio files. Passing
82
+ a batch size in the `__init__` method will be superseded by any batch size passed to the `__call__` method.
83
+ max_length (`int`, *optional*):
84
+ The maximum numbers of tokens to generate. Defaults to `model.config.max_length`.
85
+ """
86
+ self.checkpoint = checkpoint
87
+ self.dtype = dtype
88
+
89
+ self.feature_extractor = FlaxWhisperFeatureExtractor.from_pretrained(self.checkpoint)
90
+ self.tokenizer = WhisperTokenizerFast.from_pretrained(self.checkpoint)
91
+
92
+ self.model, self.params = FlaxWhisperForConditionalGeneration.from_pretrained(
93
+ self.checkpoint,
94
+ _do_init=False,
95
+ dtype=self.dtype,
96
+ **kwargs,
97
+ )
98
+
99
+ self.max_length = max_length if max_length is not None else self.model.generation_config.max_length
100
+ self.min_batch_size = jax.local_device_count()
101
+ self.batch_size = (
102
+ batch_size if batch_size is not None else self.min_batch_size
103
+ ) # we need a minimum of 1 batch per-device
104
+
105
+ def generate(
106
+ params,
107
+ input_features,
108
+ forced_decoder_ids,
109
+ return_timestamps,
110
+ num_beams,
111
+ length_penalty,
112
+ do_sample,
113
+ top_k,
114
+ temperature,
115
+ ):
116
+ output_ids = self.model.pipeline_generate(
117
+ input_features,
118
+ params=params,
119
+ forced_decoder_ids=forced_decoder_ids,
120
+ return_timestamps=return_timestamps,
121
+ max_length=self.max_length,
122
+ num_beams=num_beams,
123
+ length_penalty=length_penalty,
124
+ do_sample=do_sample,
125
+ top_k=top_k,
126
+ temperature=temperature,
127
+ )
128
+ return output_ids
129
+
130
+ self.params = jax_utils.replicate(self.params)
131
+ self.p_generate = jax.pmap(
132
+ generate,
133
+ "input_features",
134
+ in_axes=(0, 0, None, None, None, None, None, None, None),
135
+ static_broadcasted_argnums=(
136
+ 3,
137
+ 4,
138
+ 5,
139
+ 6,
140
+ 7,
141
+ 8,
142
+ ),
143
+ )
144
+
145
+ def generate(
146
+ self,
147
+ input_features,
148
+ language=None,
149
+ task=None,
150
+ return_timestamps=False,
151
+ num_beams=1,
152
+ length_penalty=1.0,
153
+ do_sample=False,
154
+ top_k=50,
155
+ temperature=1.0,
156
+ ):
157
+ forced_decoder_ids = self.get_forced_decoder_ids(
158
+ language=language, task=task, return_timestamps=return_timestamps
159
+ )
160
+ # if we're using pmap we need to manually replicate the input data across devices and gather the output tokens
161
+ output_ids = self.p_generate(
162
+ freeze(self.params),
163
+ shard(input_features),
164
+ forced_decoder_ids,
165
+ return_timestamps,
166
+ num_beams,
167
+ length_penalty,
168
+ do_sample,
169
+ top_k,
170
+ temperature,
171
+ ).sequences
172
+ output_ids = jax.device_get(output_ids.reshape(-1, self.max_length))
173
+ return output_ids
174
+
175
+ def get_forced_decoder_ids(self, generation_config=None, task=None, language=None, return_timestamps=False):
176
+ if generation_config is None:
177
+ generation_config = self.model.generation_config
178
+
179
+ if hasattr(generation_config, "is_multilingual"):
180
+ is_multilingual = generation_config.is_multilingual
181
+ else:
182
+ is_multilingual = None
183
+
184
+ forced_decoder_ids = []
185
+
186
+ if is_multilingual:
187
+ if language is not None:
188
+ language = language.lower()
189
+ if language in generation_config.lang_to_id.keys():
190
+ language_token = language
191
+ elif language in TO_LANGUAGE_CODE.values():
192
+ language_token = f"<|{language}|>"
193
+ elif language in TO_LANGUAGE_CODE.keys():
194
+ language_token = f"<|{TO_LANGUAGE_CODE[language]}|>"
195
+ else:
196
+ if len(language) == 2:
197
+ # ISO 639-1 language code
198
+ acceptable_languages = list(TO_LANGUAGE_CODE.values())
199
+ elif "<" in language or "|" in language or ">" in language:
200
+ # generation config language code
201
+ acceptable_languages = list(generation_config.lang_to_id.keys())
202
+ else:
203
+ # language passed as a string
204
+ acceptable_languages = list(TO_LANGUAGE_CODE.keys())
205
+ raise ValueError(
206
+ f"Unsupported language: {language}. Language should be one of:" f" {acceptable_languages}."
207
+ )
208
+ forced_decoder_ids.append((1, generation_config.lang_to_id[language_token]))
209
+
210
+ if task is not None:
211
+ forced_decoder_ids.append((2, generation_config.task_to_id[task]))
212
+ else:
213
+ forced_decoder_ids.append((2, generation_config.task_to_id["transcribe"]))
214
+
215
+ if not return_timestamps:
216
+ if forced_decoder_ids and forced_decoder_ids[-1][0] != generation_config.no_timestamps_token_id:
217
+ idx = forced_decoder_ids[-1][0] + 1 if forced_decoder_ids else 1
218
+ forced_decoder_ids.append((idx, generation_config.no_timestamps_token_id))
219
+ else:
220
+ forced_decoder_ids.append((1, generation_config.no_timestamps_token_id))
221
+
222
+ return forced_decoder_ids
223
+
224
+ def chunk_iter_with_batch(self, inputs, chunk_len, stride_left, stride_right, batch_size):
225
+ inputs_len = inputs.shape[0]
226
+ step = chunk_len - stride_left - stride_right
227
+
228
+ all_chunk_start_idx = np.arange(0, inputs_len, step)
229
+ num_samples = len(all_chunk_start_idx)
230
+
231
+ num_batches = math.ceil(num_samples / batch_size)
232
+ batch_idx = np.array_split(np.arange(num_samples), num_batches)
233
+
234
+ for idx in batch_idx:
235
+ chunk_start_idx = all_chunk_start_idx[idx]
236
+
237
+ chunk_end_idx = chunk_start_idx + chunk_len
238
+
239
+ chunks = [inputs[chunk_start:chunk_end] for chunk_start, chunk_end in zip(chunk_start_idx, chunk_end_idx)]
240
+ processed = self.feature_extractor(
241
+ chunks, sampling_rate=self.feature_extractor.sampling_rate, return_tensors="np"
242
+ )
243
+
244
+ _stride_left = np.where(chunk_start_idx == 0, 0, stride_left)
245
+ is_last = np.where(stride_right > 0, chunk_end_idx > inputs_len, chunk_end_idx >= inputs_len)
246
+ _stride_right = np.where(is_last, 0, stride_right)
247
+
248
+ chunk_lens = [chunk.shape[0] for chunk in chunks]
249
+ strides = [
250
+ (chunk_l, _stride_l, _stride_r)
251
+ for chunk_l, _stride_l, _stride_r in zip(chunk_lens, _stride_left, _stride_right)
252
+ ]
253
+
254
+ yield {"stride": strides, **processed}
255
+
256
+ def preprocess_batch(self, inputs, chunk_length_s=30.0, stride_length_s=None, batch_size=None):
257
+ if isinstance(inputs, np.ndarray):
258
+ logger.warning(
259
+ "Numpy array passed as input - no sampling rate checks will be performed."
260
+ "It is strongly recommended to pass the input as a dictionary with an 'array' key "
261
+ "containing the numpy array representing the audio, and a 'sampling_rate' key "
262
+ "containing the sampling rate associated with the audio array."
263
+ "Failing to do so can result in silent errors that might be hard to debug."
264
+ )
265
+
266
+ if isinstance(inputs, str):
267
+ if inputs.startswith("http://") or inputs.startswith("https://"):
268
+ # We need to actually check for a real protocol, otherwise it's impossible to use a local file
269
+ # like http_huggingface_co.png
270
+ inputs = requests.get(inputs).content
271
+ else:
272
+ with open(inputs, "rb") as f:
273
+ inputs = f.read()
274
+
275
+ if isinstance(inputs, bytes):
276
+ inputs = ffmpeg_read(inputs, self.feature_extractor.sampling_rate)
277
+
278
+ stride = None
279
+ if isinstance(inputs, dict):
280
+ stride = inputs.get("stride", None)
281
+ # Accepting `"array"` which is the key defined in `datasets` for
282
+ # better integration
283
+ if not ("sampling_rate" in inputs and "array" in inputs):
284
+ raise ValueError(
285
+ "When passing a dictionary to FlaxWhisperPipline, the dict needs to contain an 'array' key "
286
+ "containing the numpy array representing the audio, and a 'sampling_rate' key "
287
+ "containing the sampling rate associated with the audio array."
288
+ )
289
+
290
+ in_sampling_rate = inputs.get("sampling_rate")
291
+ inputs = inputs.get("array", None)
292
+
293
+ if in_sampling_rate != self.feature_extractor.sampling_rate:
294
+ try:
295
+ import librosa
296
+ except ImportError as err:
297
+ raise ImportError(
298
+ "To support resampling audio files, please install 'librosa' and 'soundfile'."
299
+ ) from err
300
+
301
+ inputs = librosa.resample(
302
+ inputs, orig_sr=in_sampling_rate, target_sr=self.feature_extractor.sampling_rate
303
+ )
304
+ ratio = self.feature_extractor.sampling_rate / in_sampling_rate
305
+ else:
306
+ ratio = 1
307
+
308
+ if not isinstance(inputs, np.ndarray):
309
+ raise ValueError(f"We expect a numpy ndarray as input, got `{type(inputs)}`")
310
+ if len(inputs.shape) != 1:
311
+ raise ValueError("We expect a single channel audio input for AutomaticSpeechRecognitionPipeline")
312
+
313
+ if stride is not None:
314
+ if stride[0] + stride[1] > inputs.shape[0]:
315
+ raise ValueError("Stride is too large for input")
316
+
317
+ # Stride needs to get the chunk length here, it's going to get
318
+ # swallowed by the `feature_extractor` later, and then batching
319
+ # can add extra data in the inputs, so we need to keep track
320
+ # of the original length in the stride so we can cut properly.
321
+ stride = (inputs.shape[0], int(round(stride[0] * ratio)), int(round(stride[1] * ratio)))
322
+
323
+ if chunk_length_s:
324
+ if stride_length_s is None:
325
+ stride_length_s = chunk_length_s / 6
326
+
327
+ if isinstance(stride_length_s, (int, float)):
328
+ stride_length_s = [stride_length_s, stride_length_s]
329
+
330
+ chunk_len = round(chunk_length_s * self.feature_extractor.sampling_rate)
331
+ stride_left = round(stride_length_s[0] * self.feature_extractor.sampling_rate)
332
+ stride_right = round(stride_length_s[1] * self.feature_extractor.sampling_rate)
333
+
334
+ if chunk_len < stride_left + stride_right:
335
+ raise ValueError("Chunk length must be superior to stride length")
336
+
337
+ for item in self.chunk_iter_with_batch(
338
+ inputs,
339
+ chunk_len,
340
+ stride_left,
341
+ stride_right,
342
+ batch_size,
343
+ ):
344
+ yield item
345
+ else:
346
+ processed = self.feature_extractor(
347
+ inputs, sampling_rate=self.feature_extractor.sampling_rate, return_tensors="np"
348
+ )
349
+ if stride is not None:
350
+ processed["stride"] = stride
351
+ yield processed
352
+
353
+ def postprocess(self, model_outputs, return_timestamps=None, return_language=None):
354
+ # unpack the outputs from list(dict(list)) to list(dict)
355
+ model_outputs = [dict(zip(output, t)) for output in model_outputs for t in zip(*output.values())]
356
+
357
+ time_precision = self.feature_extractor.chunk_length / self.model.config.max_source_positions
358
+ # Send the chunking back to seconds, it's easier to handle in whisper
359
+ sampling_rate = self.feature_extractor.sampling_rate
360
+ for output in model_outputs:
361
+ if "stride" in output:
362
+ chunk_len, stride_left, stride_right = output["stride"]
363
+ # Go back in seconds
364
+ chunk_len /= sampling_rate
365
+ stride_left /= sampling_rate
366
+ stride_right /= sampling_rate
367
+ output["stride"] = chunk_len, stride_left, stride_right
368
+
369
+ text, optional = self.tokenizer._decode_asr(
370
+ model_outputs,
371
+ return_timestamps=return_timestamps,
372
+ return_language=return_language,
373
+ time_precision=time_precision,
374
+ )
375
+ return {"text": text, **optional}
376
+
377
+ def forward(
378
+ self,
379
+ model_inputs,
380
+ batch_size=None,
381
+ language=None,
382
+ task=None,
383
+ return_timestamps=False,
384
+ num_beams=1,
385
+ length_penalty=1.0,
386
+ do_sample=False,
387
+ top_k=50,
388
+ temperature=1.0,
389
+ ):
390
+ # We need to keep track of some additional input arguments for post-processing so need to forward these on after running generation
391
+ input_features = model_inputs.pop("input_features")
392
+ input_batch_size = input_features.shape[0]
393
+
394
+ if input_batch_size != batch_size:
395
+ padding = np.zeros([batch_size - input_batch_size, *input_features.shape[1:]], input_features.dtype)
396
+ input_features = np.concatenate([input_features, padding])
397
+
398
+ pred_ids = self.generate(
399
+ input_features,
400
+ language=language,
401
+ task=task,
402
+ return_timestamps=return_timestamps,
403
+ num_beams=num_beams,
404
+ length_penalty=length_penalty,
405
+ do_sample=do_sample,
406
+ top_k=top_k,
407
+ temperature=temperature,
408
+ )[:input_batch_size]
409
+
410
+ # tokenizer's decode method expects an extra dim - we insert it here for convenience
411
+ out = {"tokens": pred_ids[:, None, :]}
412
+
413
+ stride = model_inputs.pop("stride", None)
414
+ if stride is not None:
415
+ out["stride"] = stride
416
+
417
+ return out
418
+
419
+ def __call__(
420
+ self,
421
+ inputs,
422
+ chunk_length_s=30.0,
423
+ stride_length_s=None,
424
+ batch_size=None,
425
+ language=None,
426
+ task=None,
427
+ return_timestamps=None,
428
+ num_beams=1,
429
+ length_penalty=1.0,
430
+ do_sample=False,
431
+ top_k=50,
432
+ temperature=1.0,
433
+ ):
434
+ """
435
+ Transcribe an audio input sequence to a text transcription, optionally with timestamps.
436
+
437
+ Args:
438
+ inputs (`np.ndarray` or `bytes` or `str` or `dict`):
439
+ The inputs is either:
440
+ - `str` that is the filename of the audio file, the file will be read at the correct sampling rate
441
+ to get the waveform using *ffmpeg*. This requires *ffmpeg* to be installed on the system.
442
+ - `bytes` is the byte content of an audio file and is interpreted by *ffmpeg* in the
443
+ same way.
444
+ - (`np.ndarray` of shape (n, ) of type `np.float32` or `np.float64`)
445
+ Raw audio assumed to be at the correct sampling rate (16kHz). Note that no further sampling
446
+ rate check will be done.
447
+ - `dict` form can be used to pass raw audio sampled at arbitrary `sampling_rate` and let this
448
+ pipeline do the resampling. The dict must be in the format `{"sampling_rate": int, "array":
449
+ np.array}`. Optionally an additional argument `"stride": (left: int, right: int)` can be used to
450
+ ask the pipeline to treat the first `left` samples and last `right` samples to be ignored in
451
+ decoding (but used at inference to provide more context to the model). In general, this additional
452
+ stride argument is not required.
453
+ chunk_length_s (`float`, *optional*, defaults to 30.0):
454
+ The input length for each chunk. If `chunk_length_s = 0` then chunking is disabled. By default, the chunk
455
+ length is set 30.0s, equal to Whisper's context window.
456
+ stride_length_s (`float`, *optional*, defaults to `chunk_length_s / 6`):
457
+ The length of stride on the left and right of each chunk. Used only with `chunk_length_s > 0`. This enables
458
+ the model to *see* more context and infer letters better than without this context but the pipeline
459
+ discards the stride bits at the end to make the final reconstitution as perfect as possible.
460
+
461
+ <Tip>
462
+
463
+ For more information on how to effectively use `stride_length_s`, refer to the [ASR chunking
464
+ blog post](https://huggingface.co/blog/asr-chunking).
465
+
466
+ </Tip>
467
+ batch_size (`int`, *optional*, defaults to the minimum per-device batch size, i.e. `jax.local_device_count()`):
468
+ The batch size to be used in chunking transcription. Beneficial for transcribing long audio files. Passing
469
+ a batch size in the `__call__` method will supersede any batch size passed to the `__init__`.
470
+ task (`str`, *optional*):
471
+ Task to use for generation, either `"transcribe"` or `"translate"`. Defaults to `"transcribe"`.
472
+ language (`str`, *optional*):
473
+ Language token to use for generation, can be either in the form of `"<|en|>"`, `"en"` or `"english"`.
474
+ Defaults to `None`, meaning the language is automatically inferred from the audio input.
475
+ return_timestamps (*optional*, `bool`):
476
+ Whether to return timestamps in the prediction. Defaults to False. If set to true, the pipeline
477
+ will return two keys in the output dictionary: `"text"` containing the text transcription, and `"chunks"`
478
+ containing the transcription segments chunked by their utterance-level timestamps.
479
+ length_penalty (*optional*, `float`):
480
+ Exponential penalty to the length that is used with beam-based generation. It is applied as an
481
+ exponent to the sequence length, which in turn is used to divide the score of the sequence. Since
482
+ the score is the log likelihood of the sequence (i.e. negative), length_penalty > 1.0 promotes
483
+ longer sequences, while length_penalty < 1.0 encourages shorter sequences.
484
+ do_sample (*optional*, `bool`):
485
+ Whether or not to use sampling ; use greedy decoding otherwise.
486
+ top_k (*optional*, `int`):
487
+ The number of the highest probability vocabulary tokens to keep for top-k-filtering.
488
+ temperature (*optional*, `float`):
489
+ The value used to modulate the next token probabilities if sampling.
490
+
491
+ Return:
492
+ `Dict`: A dictionary with the following keys:
493
+ - **text** (`str` ) -- The recognised text.
494
+ - **chunks** (*optional(, `List[Dict]`)
495
+ When using `return_timestamps`, the `chunks` will become a list containing all the various text
496
+ chunks identified by the model, *e.g.* `[{"text": "hi ", "timestamps": (0.5,0.9), {"text":
497
+ "there", "timestamps": (1.0, 1.5)}]`. The original full text can roughly be recovered by doing
498
+ `"".join(chunk["text"] for chunk in output["chunks"])`.
499
+ """
500
+ batch_size = batch_size if batch_size is not None else self.batch_size
501
+ if batch_size % self.min_batch_size != 0:
502
+ raise ValueError(
503
+ f"Batch size must be a multiple of the number of JAX devices, but got batch size {batch_size} and num devices {self.min_batch_size}."
504
+ )
505
+
506
+ dataloader = self.preprocess_batch(
507
+ inputs, chunk_length_s=chunk_length_s, stride_length_s=stride_length_s, batch_size=batch_size
508
+ )
509
+ model_outputs = []
510
+ # iterate over our chunked audio samples
511
+ for batch in dataloader:
512
+ model_outputs.append(
513
+ self.forward(
514
+ batch,
515
+ batch_size=batch_size,
516
+ language=language,
517
+ task=task,
518
+ return_timestamps=return_timestamps,
519
+ num_beams=num_beams,
520
+ length_penalty=length_penalty,
521
+ do_sample=do_sample,
522
+ top_k=top_k,
523
+ temperature=temperature,
524
+ )
525
+ )
526
+ post_processed = self.postprocess(model_outputs, return_timestamps=return_timestamps)
527
+ return post_processed
distil_whisper/train_state.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Mapping, MutableMapping, Optional, Tuple
2
+
3
+ import flax.core
4
+ import flax.serialization
5
+ import flax.struct
6
+ import jax.numpy as jnp
7
+ from flax import traverse_util
8
+ from flax.core import scope as flax_scope
9
+ from flax.linen import partitioning as flax_partitioning
10
+
11
+
12
+ EMPTY_DICT = flax.core.freeze({})
13
+ FrozenDict = flax_scope.FrozenDict
14
+ FrozenVariableDict = flax_scope.FrozenVariableDict
15
+ MutableVariableDict = flax_scope.MutableVariableDict
16
+ VariableDict = flax_scope.VariableDict
17
+
18
+
19
+ def _validate_params_axes(params_axes, params):
20
+ axis_names = flax_partitioning.get_axis_names(params_axes)
21
+ missing_params_axes = set(traverse_util.flatten_dict(params, sep="/")) - set(
22
+ traverse_util.flatten_dict(axis_names, sep="/")
23
+ )
24
+ if missing_params_axes:
25
+ raise ValueError(f"Missing axis names for parameters: {missing_params_axes}")
26
+
27
+
28
+ def _split_variables_and_axes(
29
+ variables_and_axes: FrozenVariableDict,
30
+ ) -> Tuple[FrozenVariableDict, FrozenVariableDict]:
31
+ """Splits `variables_and_axes` into two separate dicts with the same keys."""
32
+ # For each `key`, `key_axes` (if any) are its axes in `variables_and_axes`.
33
+ variables = {}
34
+ axes = {}
35
+ for k, v in variables_and_axes.items():
36
+ if k.endswith("_axes"):
37
+ axes[k[:-5]] = v # k without "_axes".
38
+ _validate_params_axes(v, variables_and_axes[k[:-5]]) # k without "_axes".
39
+ else:
40
+ variables[k] = v
41
+ return flax.core.freeze(variables), flax.core.freeze(axes)
42
+
43
+
44
+ class InferenceState(flax.struct.PyTreeNode):
45
+ """State compatible with FlaxOptimTrainState without optimizer state."""
46
+
47
+ step: jnp.ndarray
48
+ params: flax_scope.FrozenVariableDict
49
+ params_axes: Optional[flax_scope.FrozenVariableDict] = None
50
+ flax_mutables: flax_scope.FrozenDict = EMPTY_DICT
51
+ flax_mutables_axes: Optional[flax_scope.FrozenVariableDict] = None
52
+
53
+ @classmethod
54
+ def create(cls, model_variables: FrozenVariableDict) -> "InferenceState":
55
+ other_variables, params = model_variables.pop("params")
56
+ if "params_axes" in other_variables:
57
+ other_variables, params_axes = other_variables.pop("params_axes")
58
+ _validate_params_axes(params_axes, params)
59
+ else:
60
+ params_axes = None
61
+
62
+ # Split other_variables into mutables and their corresponding axes.
63
+ flax_mutables, flax_mutables_axes = _split_variables_and_axes(other_variables)
64
+ flax_mutables_axes = flax_mutables_axes or None
65
+ return InferenceState(
66
+ step=jnp.array(0),
67
+ params=params,
68
+ params_axes=params_axes,
69
+ flax_mutables=flax_mutables,
70
+ flax_mutables_axes=flax_mutables_axes,
71
+ )
72
+
73
+ @property
74
+ def param_states(self) -> FrozenVariableDict:
75
+ """The optimizer states of the parameters as a PyTree."""
76
+ raise NotImplementedError("InferenceState has no optimizer states.")
77
+
78
+ def apply_gradient(self, *args, **kwargs) -> "InferenceState":
79
+ raise NotImplementedError("InferenceState does not support `apply_gradient`.")
80
+
81
+ def state_dict(self) -> MutableMapping[str, Any]:
82
+ state_dict = {
83
+ "target": flax.core.unfreeze(self.params),
84
+ "state": {"step": self.step},
85
+ }
86
+ if self.flax_mutables:
87
+ state_dict["flax_mutables"] = flax.core.unfreeze(self.flax_mutables)
88
+ return state_dict
89
+
90
+ def replace_step(self, step: jnp.ndarray) -> "InferenceState":
91
+ return self.replace(step=step)
92
+
93
+ def replace_params(self, params: FrozenVariableDict) -> "InferenceState":
94
+ return self.replace(params=params)
95
+
96
+ def replace_flax_mutables(self, flax_mutables: FrozenDict) -> "InferenceState":
97
+ return self.replace(flax_mutables=flax_mutables)
98
+
99
+ def restore_state(self, state_dict: Mapping[str, Any]) -> "InferenceState":
100
+ return self.replace(
101
+ params=flax.core.freeze(state_dict["target"]),
102
+ step=state_dict["state"]["step"],
103
+ flax_mutables=(
104
+ flax.core.freeze(state_dict["flax_mutables"]) if "flax_mutables" in state_dict else EMPTY_DICT
105
+ ),
106
+ )
107
+
108
+ def as_logical_axes(self) -> "InferenceState":
109
+ # Set step to None so that when the logical axes are processed by the
110
+ # flax.partitioning.logical_to_mesh_axes function, it will be skipped
111
+ # because jax.tree_map will short circut and never call the function on the
112
+ # step.
113
+ flax_mutables_axes = self.flax_mutables_axes or EMPTY_DICT
114
+ return InferenceState(
115
+ step=None,
116
+ params=flax_partitioning.get_axis_names(self.params_axes),
117
+ flax_mutables=flax_partitioning.get_axis_names(flax_mutables_axes),
118
+ )
generation_config.json ADDED
@@ -0,0 +1,256 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "alignment_heads": [
3
+ [
4
+ 7,
5
+ 0
6
+ ],
7
+ [
8
+ 10,
9
+ 17
10
+ ],
11
+ [
12
+ 12,
13
+ 18
14
+ ],
15
+ [
16
+ 13,
17
+ 12
18
+ ],
19
+ [
20
+ 16,
21
+ 1
22
+ ],
23
+ [
24
+ 17,
25
+ 14
26
+ ],
27
+ [
28
+ 19,
29
+ 11
30
+ ],
31
+ [
32
+ 21,
33
+ 4
34
+ ],
35
+ [
36
+ 24,
37
+ 1
38
+ ],
39
+ [
40
+ 25,
41
+ 6
42
+ ]
43
+ ],
44
+ "begin_suppress_tokens": [
45
+ 220,
46
+ 50257
47
+ ],
48
+ "bos_token_id": 50257,
49
+ "decoder_start_token_id": 50258,
50
+ "eos_token_id": 50257,
51
+ "is_multilingual": true,
52
+ "lang_to_id": {
53
+ "<|af|>": 50327,
54
+ "<|am|>": 50334,
55
+ "<|ar|>": 50272,
56
+ "<|as|>": 50350,
57
+ "<|az|>": 50304,
58
+ "<|ba|>": 50355,
59
+ "<|be|>": 50330,
60
+ "<|bg|>": 50292,
61
+ "<|bn|>": 50302,
62
+ "<|bo|>": 50347,
63
+ "<|br|>": 50309,
64
+ "<|bs|>": 50315,
65
+ "<|ca|>": 50270,
66
+ "<|cs|>": 50283,
67
+ "<|cy|>": 50297,
68
+ "<|da|>": 50285,
69
+ "<|de|>": 50261,
70
+ "<|el|>": 50281,
71
+ "<|en|>": 50259,
72
+ "<|es|>": 50262,
73
+ "<|et|>": 50307,
74
+ "<|eu|>": 50310,
75
+ "<|fa|>": 50300,
76
+ "<|fi|>": 50277,
77
+ "<|fo|>": 50338,
78
+ "<|fr|>": 50265,
79
+ "<|gl|>": 50319,
80
+ "<|gu|>": 50333,
81
+ "<|haw|>": 50352,
82
+ "<|ha|>": 50354,
83
+ "<|he|>": 50279,
84
+ "<|hi|>": 50276,
85
+ "<|hr|>": 50291,
86
+ "<|ht|>": 50339,
87
+ "<|hu|>": 50286,
88
+ "<|hy|>": 50312,
89
+ "<|id|>": 50275,
90
+ "<|is|>": 50311,
91
+ "<|it|>": 50274,
92
+ "<|ja|>": 50266,
93
+ "<|jw|>": 50356,
94
+ "<|ka|>": 50329,
95
+ "<|kk|>": 50316,
96
+ "<|km|>": 50323,
97
+ "<|kn|>": 50306,
98
+ "<|ko|>": 50264,
99
+ "<|la|>": 50294,
100
+ "<|lb|>": 50345,
101
+ "<|ln|>": 50353,
102
+ "<|lo|>": 50336,
103
+ "<|lt|>": 50293,
104
+ "<|lv|>": 50301,
105
+ "<|mg|>": 50349,
106
+ "<|mi|>": 50295,
107
+ "<|mk|>": 50308,
108
+ "<|ml|>": 50296,
109
+ "<|mn|>": 50314,
110
+ "<|mr|>": 50320,
111
+ "<|ms|>": 50282,
112
+ "<|mt|>": 50343,
113
+ "<|my|>": 50346,
114
+ "<|ne|>": 50313,
115
+ "<|nl|>": 50271,
116
+ "<|nn|>": 50342,
117
+ "<|no|>": 50288,
118
+ "<|oc|>": 50328,
119
+ "<|pa|>": 50321,
120
+ "<|pl|>": 50269,
121
+ "<|ps|>": 50340,
122
+ "<|pt|>": 50267,
123
+ "<|ro|>": 50284,
124
+ "<|ru|>": 50263,
125
+ "<|sa|>": 50344,
126
+ "<|sd|>": 50332,
127
+ "<|si|>": 50322,
128
+ "<|sk|>": 50298,
129
+ "<|sl|>": 50305,
130
+ "<|sn|>": 50324,
131
+ "<|so|>": 50326,
132
+ "<|sq|>": 50317,
133
+ "<|sr|>": 50303,
134
+ "<|su|>": 50357,
135
+ "<|sv|>": 50273,
136
+ "<|sw|>": 50318,
137
+ "<|ta|>": 50287,
138
+ "<|te|>": 50299,
139
+ "<|tg|>": 50331,
140
+ "<|th|>": 50289,
141
+ "<|tk|>": 50341,
142
+ "<|tl|>": 50348,
143
+ "<|tr|>": 50268,
144
+ "<|tt|>": 50351,
145
+ "<|uk|>": 50280,
146
+ "<|ur|>": 50290,
147
+ "<|uz|>": 50337,
148
+ "<|vi|>": 50278,
149
+ "<|yi|>": 50335,
150
+ "<|yo|>": 50325,
151
+ "<|yue|>": 50358,
152
+ "<|zh|>": 50260
153
+ },
154
+ "language": "no",
155
+ "max_initial_timestamp_index": 1,
156
+ "max_length": 448,
157
+ "no_timestamps_token_id": 50364,
158
+ "pad_token_id": 50257,
159
+ "return_timestamps": false,
160
+ "suppress_tokens": [
161
+ 1,
162
+ 2,
163
+ 7,
164
+ 8,
165
+ 9,
166
+ 10,
167
+ 14,
168
+ 25,
169
+ 26,
170
+ 27,
171
+ 28,
172
+ 29,
173
+ 31,
174
+ 58,
175
+ 59,
176
+ 60,
177
+ 61,
178
+ 62,
179
+ 63,
180
+ 90,
181
+ 91,
182
+ 92,
183
+ 93,
184
+ 359,
185
+ 503,
186
+ 522,
187
+ 542,
188
+ 873,
189
+ 893,
190
+ 902,
191
+ 918,
192
+ 922,
193
+ 931,
194
+ 1350,
195
+ 1853,
196
+ 1982,
197
+ 2460,
198
+ 2627,
199
+ 3246,
200
+ 3253,
201
+ 3268,
202
+ 3536,
203
+ 3846,
204
+ 3961,
205
+ 4183,
206
+ 4667,
207
+ 6585,
208
+ 6647,
209
+ 7273,
210
+ 9061,
211
+ 9383,
212
+ 10428,
213
+ 10929,
214
+ 11938,
215
+ 12033,
216
+ 12331,
217
+ 12562,
218
+ 13793,
219
+ 14157,
220
+ 14635,
221
+ 15265,
222
+ 15618,
223
+ 16553,
224
+ 16604,
225
+ 18362,
226
+ 18956,
227
+ 20075,
228
+ 21675,
229
+ 22520,
230
+ 26130,
231
+ 26161,
232
+ 26435,
233
+ 28279,
234
+ 29464,
235
+ 31650,
236
+ 32302,
237
+ 32470,
238
+ 36865,
239
+ 42863,
240
+ 47425,
241
+ 49870,
242
+ 50254,
243
+ 50258,
244
+ 50359,
245
+ 50360,
246
+ 50361,
247
+ 50362,
248
+ 50363
249
+ ],
250
+ "task": "transcribe",
251
+ "task_to_id": {
252
+ "transcribe": 50360,
253
+ "translate": 50359
254
+ },
255
+ "transformers_version": "4.46.2"
256
+ }
merges.txt ADDED
The diff for this file is too large to render. See raw diff
 
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9dceef1c98c82eee48a3a948d1ca88682b48946a66a0d989d6be8c1c49205bed
3
+ size 3025686376
nb-distil-large-init/added_tokens.json ADDED
@@ -0,0 +1,1611 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "<|0.00|>": 50365,
3
+ "<|0.02|>": 50366,
4
+ "<|0.04|>": 50367,
5
+ "<|0.06|>": 50368,
6
+ "<|0.08|>": 50369,
7
+ "<|0.10|>": 50370,
8
+ "<|0.12|>": 50371,
9
+ "<|0.14|>": 50372,
10
+ "<|0.16|>": 50373,
11
+ "<|0.18|>": 50374,
12
+ "<|0.20|>": 50375,
13
+ "<|0.22|>": 50376,
14
+ "<|0.24|>": 50377,
15
+ "<|0.26|>": 50378,
16
+ "<|0.28|>": 50379,
17
+ "<|0.30|>": 50380,
18
+ "<|0.32|>": 50381,
19
+ "<|0.34|>": 50382,
20
+ "<|0.36|>": 50383,
21
+ "<|0.38|>": 50384,
22
+ "<|0.40|>": 50385,
23
+ "<|0.42|>": 50386,
24
+ "<|0.44|>": 50387,
25
+ "<|0.46|>": 50388,
26
+ "<|0.48|>": 50389,
27
+ "<|0.50|>": 50390,
28
+ "<|0.52|>": 50391,
29
+ "<|0.54|>": 50392,
30
+ "<|0.56|>": 50393,
31
+ "<|0.58|>": 50394,
32
+ "<|0.60|>": 50395,
33
+ "<|0.62|>": 50396,
34
+ "<|0.64|>": 50397,
35
+ "<|0.66|>": 50398,
36
+ "<|0.68|>": 50399,
37
+ "<|0.70|>": 50400,
38
+ "<|0.72|>": 50401,
39
+ "<|0.74|>": 50402,
40
+ "<|0.76|>": 50403,
41
+ "<|0.78|>": 50404,
42
+ "<|0.80|>": 50405,
43
+ "<|0.82|>": 50406,
44
+ "<|0.84|>": 50407,
45
+ "<|0.86|>": 50408,
46
+ "<|0.88|>": 50409,
47
+ "<|0.90|>": 50410,
48
+ "<|0.92|>": 50411,
49
+ "<|0.94|>": 50412,
50
+ "<|0.96|>": 50413,
51
+ "<|0.98|>": 50414,
52
+ "<|1.00|>": 50415,
53
+ "<|1.02|>": 50416,
54
+ "<|1.04|>": 50417,
55
+ "<|1.06|>": 50418,
56
+ "<|1.08|>": 50419,
57
+ "<|1.10|>": 50420,
58
+ "<|1.12|>": 50421,
59
+ "<|1.14|>": 50422,
60
+ "<|1.16|>": 50423,
61
+ "<|1.18|>": 50424,
62
+ "<|1.20|>": 50425,
63
+ "<|1.22|>": 50426,
64
+ "<|1.24|>": 50427,
65
+ "<|1.26|>": 50428,
66
+ "<|1.28|>": 50429,
67
+ "<|1.30|>": 50430,
68
+ "<|1.32|>": 50431,
69
+ "<|1.34|>": 50432,
70
+ "<|1.36|>": 50433,
71
+ "<|1.38|>": 50434,
72
+ "<|1.40|>": 50435,
73
+ "<|1.42|>": 50436,
74
+ "<|1.44|>": 50437,
75
+ "<|1.46|>": 50438,
76
+ "<|1.48|>": 50439,
77
+ "<|1.50|>": 50440,
78
+ "<|1.52|>": 50441,
79
+ "<|1.54|>": 50442,
80
+ "<|1.56|>": 50443,
81
+ "<|1.58|>": 50444,
82
+ "<|1.60|>": 50445,
83
+ "<|1.62|>": 50446,
84
+ "<|1.64|>": 50447,
85
+ "<|1.66|>": 50448,
86
+ "<|1.68|>": 50449,
87
+ "<|1.70|>": 50450,
88
+ "<|1.72|>": 50451,
89
+ "<|1.74|>": 50452,
90
+ "<|1.76|>": 50453,
91
+ "<|1.78|>": 50454,
92
+ "<|1.80|>": 50455,
93
+ "<|1.82|>": 50456,
94
+ "<|1.84|>": 50457,
95
+ "<|1.86|>": 50458,
96
+ "<|1.88|>": 50459,
97
+ "<|1.90|>": 50460,
98
+ "<|1.92|>": 50461,
99
+ "<|1.94|>": 50462,
100
+ "<|1.96|>": 50463,
101
+ "<|1.98|>": 50464,
102
+ "<|10.00|>": 50865,
103
+ "<|10.02|>": 50866,
104
+ "<|10.04|>": 50867,
105
+ "<|10.06|>": 50868,
106
+ "<|10.08|>": 50869,
107
+ "<|10.10|>": 50870,
108
+ "<|10.12|>": 50871,
109
+ "<|10.14|>": 50872,
110
+ "<|10.16|>": 50873,
111
+ "<|10.18|>": 50874,
112
+ "<|10.20|>": 50875,
113
+ "<|10.22|>": 50876,
114
+ "<|10.24|>": 50877,
115
+ "<|10.26|>": 50878,
116
+ "<|10.28|>": 50879,
117
+ "<|10.30|>": 50880,
118
+ "<|10.32|>": 50881,
119
+ "<|10.34|>": 50882,
120
+ "<|10.36|>": 50883,
121
+ "<|10.38|>": 50884,
122
+ "<|10.40|>": 50885,
123
+ "<|10.42|>": 50886,
124
+ "<|10.44|>": 50887,
125
+ "<|10.46|>": 50888,
126
+ "<|10.48|>": 50889,
127
+ "<|10.50|>": 50890,
128
+ "<|10.52|>": 50891,
129
+ "<|10.54|>": 50892,
130
+ "<|10.56|>": 50893,
131
+ "<|10.58|>": 50894,
132
+ "<|10.60|>": 50895,
133
+ "<|10.62|>": 50896,
134
+ "<|10.64|>": 50897,
135
+ "<|10.66|>": 50898,
136
+ "<|10.68|>": 50899,
137
+ "<|10.70|>": 50900,
138
+ "<|10.72|>": 50901,
139
+ "<|10.74|>": 50902,
140
+ "<|10.76|>": 50903,
141
+ "<|10.78|>": 50904,
142
+ "<|10.80|>": 50905,
143
+ "<|10.82|>": 50906,
144
+ "<|10.84|>": 50907,
145
+ "<|10.86|>": 50908,
146
+ "<|10.88|>": 50909,
147
+ "<|10.90|>": 50910,
148
+ "<|10.92|>": 50911,
149
+ "<|10.94|>": 50912,
150
+ "<|10.96|>": 50913,
151
+ "<|10.98|>": 50914,
152
+ "<|11.00|>": 50915,
153
+ "<|11.02|>": 50916,
154
+ "<|11.04|>": 50917,
155
+ "<|11.06|>": 50918,
156
+ "<|11.08|>": 50919,
157
+ "<|11.10|>": 50920,
158
+ "<|11.12|>": 50921,
159
+ "<|11.14|>": 50922,
160
+ "<|11.16|>": 50923,
161
+ "<|11.18|>": 50924,
162
+ "<|11.20|>": 50925,
163
+ "<|11.22|>": 50926,
164
+ "<|11.24|>": 50927,
165
+ "<|11.26|>": 50928,
166
+ "<|11.28|>": 50929,
167
+ "<|11.30|>": 50930,
168
+ "<|11.32|>": 50931,
169
+ "<|11.34|>": 50932,
170
+ "<|11.36|>": 50933,
171
+ "<|11.38|>": 50934,
172
+ "<|11.40|>": 50935,
173
+ "<|11.42|>": 50936,
174
+ "<|11.44|>": 50937,
175
+ "<|11.46|>": 50938,
176
+ "<|11.48|>": 50939,
177
+ "<|11.50|>": 50940,
178
+ "<|11.52|>": 50941,
179
+ "<|11.54|>": 50942,
180
+ "<|11.56|>": 50943,
181
+ "<|11.58|>": 50944,
182
+ "<|11.60|>": 50945,
183
+ "<|11.62|>": 50946,
184
+ "<|11.64|>": 50947,
185
+ "<|11.66|>": 50948,
186
+ "<|11.68|>": 50949,
187
+ "<|11.70|>": 50950,
188
+ "<|11.72|>": 50951,
189
+ "<|11.74|>": 50952,
190
+ "<|11.76|>": 50953,
191
+ "<|11.78|>": 50954,
192
+ "<|11.80|>": 50955,
193
+ "<|11.82|>": 50956,
194
+ "<|11.84|>": 50957,
195
+ "<|11.86|>": 50958,
196
+ "<|11.88|>": 50959,
197
+ "<|11.90|>": 50960,
198
+ "<|11.92|>": 50961,
199
+ "<|11.94|>": 50962,
200
+ "<|11.96|>": 50963,
201
+ "<|11.98|>": 50964,
202
+ "<|12.00|>": 50965,
203
+ "<|12.02|>": 50966,
204
+ "<|12.04|>": 50967,
205
+ "<|12.06|>": 50968,
206
+ "<|12.08|>": 50969,
207
+ "<|12.10|>": 50970,
208
+ "<|12.12|>": 50971,
209
+ "<|12.14|>": 50972,
210
+ "<|12.16|>": 50973,
211
+ "<|12.18|>": 50974,
212
+ "<|12.20|>": 50975,
213
+ "<|12.22|>": 50976,
214
+ "<|12.24|>": 50977,
215
+ "<|12.26|>": 50978,
216
+ "<|12.28|>": 50979,
217
+ "<|12.30|>": 50980,
218
+ "<|12.32|>": 50981,
219
+ "<|12.34|>": 50982,
220
+ "<|12.36|>": 50983,
221
+ "<|12.38|>": 50984,
222
+ "<|12.40|>": 50985,
223
+ "<|12.42|>": 50986,
224
+ "<|12.44|>": 50987,
225
+ "<|12.46|>": 50988,
226
+ "<|12.48|>": 50989,
227
+ "<|12.50|>": 50990,
228
+ "<|12.52|>": 50991,
229
+ "<|12.54|>": 50992,
230
+ "<|12.56|>": 50993,
231
+ "<|12.58|>": 50994,
232
+ "<|12.60|>": 50995,
233
+ "<|12.62|>": 50996,
234
+ "<|12.64|>": 50997,
235
+ "<|12.66|>": 50998,
236
+ "<|12.68|>": 50999,
237
+ "<|12.70|>": 51000,
238
+ "<|12.72|>": 51001,
239
+ "<|12.74|>": 51002,
240
+ "<|12.76|>": 51003,
241
+ "<|12.78|>": 51004,
242
+ "<|12.80|>": 51005,
243
+ "<|12.82|>": 51006,
244
+ "<|12.84|>": 51007,
245
+ "<|12.86|>": 51008,
246
+ "<|12.88|>": 51009,
247
+ "<|12.90|>": 51010,
248
+ "<|12.92|>": 51011,
249
+ "<|12.94|>": 51012,
250
+ "<|12.96|>": 51013,
251
+ "<|12.98|>": 51014,
252
+ "<|13.00|>": 51015,
253
+ "<|13.02|>": 51016,
254
+ "<|13.04|>": 51017,
255
+ "<|13.06|>": 51018,
256
+ "<|13.08|>": 51019,
257
+ "<|13.10|>": 51020,
258
+ "<|13.12|>": 51021,
259
+ "<|13.14|>": 51022,
260
+ "<|13.16|>": 51023,
261
+ "<|13.18|>": 51024,
262
+ "<|13.20|>": 51025,
263
+ "<|13.22|>": 51026,
264
+ "<|13.24|>": 51027,
265
+ "<|13.26|>": 51028,
266
+ "<|13.28|>": 51029,
267
+ "<|13.30|>": 51030,
268
+ "<|13.32|>": 51031,
269
+ "<|13.34|>": 51032,
270
+ "<|13.36|>": 51033,
271
+ "<|13.38|>": 51034,
272
+ "<|13.40|>": 51035,
273
+ "<|13.42|>": 51036,
274
+ "<|13.44|>": 51037,
275
+ "<|13.46|>": 51038,
276
+ "<|13.48|>": 51039,
277
+ "<|13.50|>": 51040,
278
+ "<|13.52|>": 51041,
279
+ "<|13.54|>": 51042,
280
+ "<|13.56|>": 51043,
281
+ "<|13.58|>": 51044,
282
+ "<|13.60|>": 51045,
283
+ "<|13.62|>": 51046,
284
+ "<|13.64|>": 51047,
285
+ "<|13.66|>": 51048,
286
+ "<|13.68|>": 51049,
287
+ "<|13.70|>": 51050,
288
+ "<|13.72|>": 51051,
289
+ "<|13.74|>": 51052,
290
+ "<|13.76|>": 51053,
291
+ "<|13.78|>": 51054,
292
+ "<|13.80|>": 51055,
293
+ "<|13.82|>": 51056,
294
+ "<|13.84|>": 51057,
295
+ "<|13.86|>": 51058,
296
+ "<|13.88|>": 51059,
297
+ "<|13.90|>": 51060,
298
+ "<|13.92|>": 51061,
299
+ "<|13.94|>": 51062,
300
+ "<|13.96|>": 51063,
301
+ "<|13.98|>": 51064,
302
+ "<|14.00|>": 51065,
303
+ "<|14.02|>": 51066,
304
+ "<|14.04|>": 51067,
305
+ "<|14.06|>": 51068,
306
+ "<|14.08|>": 51069,
307
+ "<|14.10|>": 51070,
308
+ "<|14.12|>": 51071,
309
+ "<|14.14|>": 51072,
310
+ "<|14.16|>": 51073,
311
+ "<|14.18|>": 51074,
312
+ "<|14.20|>": 51075,
313
+ "<|14.22|>": 51076,
314
+ "<|14.24|>": 51077,
315
+ "<|14.26|>": 51078,
316
+ "<|14.28|>": 51079,
317
+ "<|14.30|>": 51080,
318
+ "<|14.32|>": 51081,
319
+ "<|14.34|>": 51082,
320
+ "<|14.36|>": 51083,
321
+ "<|14.38|>": 51084,
322
+ "<|14.40|>": 51085,
323
+ "<|14.42|>": 51086,
324
+ "<|14.44|>": 51087,
325
+ "<|14.46|>": 51088,
326
+ "<|14.48|>": 51089,
327
+ "<|14.50|>": 51090,
328
+ "<|14.52|>": 51091,
329
+ "<|14.54|>": 51092,
330
+ "<|14.56|>": 51093,
331
+ "<|14.58|>": 51094,
332
+ "<|14.60|>": 51095,
333
+ "<|14.62|>": 51096,
334
+ "<|14.64|>": 51097,
335
+ "<|14.66|>": 51098,
336
+ "<|14.68|>": 51099,
337
+ "<|14.70|>": 51100,
338
+ "<|14.72|>": 51101,
339
+ "<|14.74|>": 51102,
340
+ "<|14.76|>": 51103,
341
+ "<|14.78|>": 51104,
342
+ "<|14.80|>": 51105,
343
+ "<|14.82|>": 51106,
344
+ "<|14.84|>": 51107,
345
+ "<|14.86|>": 51108,
346
+ "<|14.88|>": 51109,
347
+ "<|14.90|>": 51110,
348
+ "<|14.92|>": 51111,
349
+ "<|14.94|>": 51112,
350
+ "<|14.96|>": 51113,
351
+ "<|14.98|>": 51114,
352
+ "<|15.00|>": 51115,
353
+ "<|15.02|>": 51116,
354
+ "<|15.04|>": 51117,
355
+ "<|15.06|>": 51118,
356
+ "<|15.08|>": 51119,
357
+ "<|15.10|>": 51120,
358
+ "<|15.12|>": 51121,
359
+ "<|15.14|>": 51122,
360
+ "<|15.16|>": 51123,
361
+ "<|15.18|>": 51124,
362
+ "<|15.20|>": 51125,
363
+ "<|15.22|>": 51126,
364
+ "<|15.24|>": 51127,
365
+ "<|15.26|>": 51128,
366
+ "<|15.28|>": 51129,
367
+ "<|15.30|>": 51130,
368
+ "<|15.32|>": 51131,
369
+ "<|15.34|>": 51132,
370
+ "<|15.36|>": 51133,
371
+ "<|15.38|>": 51134,
372
+ "<|15.40|>": 51135,
373
+ "<|15.42|>": 51136,
374
+ "<|15.44|>": 51137,
375
+ "<|15.46|>": 51138,
376
+ "<|15.48|>": 51139,
377
+ "<|15.50|>": 51140,
378
+ "<|15.52|>": 51141,
379
+ "<|15.54|>": 51142,
380
+ "<|15.56|>": 51143,
381
+ "<|15.58|>": 51144,
382
+ "<|15.60|>": 51145,
383
+ "<|15.62|>": 51146,
384
+ "<|15.64|>": 51147,
385
+ "<|15.66|>": 51148,
386
+ "<|15.68|>": 51149,
387
+ "<|15.70|>": 51150,
388
+ "<|15.72|>": 51151,
389
+ "<|15.74|>": 51152,
390
+ "<|15.76|>": 51153,
391
+ "<|15.78|>": 51154,
392
+ "<|15.80|>": 51155,
393
+ "<|15.82|>": 51156,
394
+ "<|15.84|>": 51157,
395
+ "<|15.86|>": 51158,
396
+ "<|15.88|>": 51159,
397
+ "<|15.90|>": 51160,
398
+ "<|15.92|>": 51161,
399
+ "<|15.94|>": 51162,
400
+ "<|15.96|>": 51163,
401
+ "<|15.98|>": 51164,
402
+ "<|16.00|>": 51165,
403
+ "<|16.02|>": 51166,
404
+ "<|16.04|>": 51167,
405
+ "<|16.06|>": 51168,
406
+ "<|16.08|>": 51169,
407
+ "<|16.10|>": 51170,
408
+ "<|16.12|>": 51171,
409
+ "<|16.14|>": 51172,
410
+ "<|16.16|>": 51173,
411
+ "<|16.18|>": 51174,
412
+ "<|16.20|>": 51175,
413
+ "<|16.22|>": 51176,
414
+ "<|16.24|>": 51177,
415
+ "<|16.26|>": 51178,
416
+ "<|16.28|>": 51179,
417
+ "<|16.30|>": 51180,
418
+ "<|16.32|>": 51181,
419
+ "<|16.34|>": 51182,
420
+ "<|16.36|>": 51183,
421
+ "<|16.38|>": 51184,
422
+ "<|16.40|>": 51185,
423
+ "<|16.42|>": 51186,
424
+ "<|16.44|>": 51187,
425
+ "<|16.46|>": 51188,
426
+ "<|16.48|>": 51189,
427
+ "<|16.50|>": 51190,
428
+ "<|16.52|>": 51191,
429
+ "<|16.54|>": 51192,
430
+ "<|16.56|>": 51193,
431
+ "<|16.58|>": 51194,
432
+ "<|16.60|>": 51195,
433
+ "<|16.62|>": 51196,
434
+ "<|16.64|>": 51197,
435
+ "<|16.66|>": 51198,
436
+ "<|16.68|>": 51199,
437
+ "<|16.70|>": 51200,
438
+ "<|16.72|>": 51201,
439
+ "<|16.74|>": 51202,
440
+ "<|16.76|>": 51203,
441
+ "<|16.78|>": 51204,
442
+ "<|16.80|>": 51205,
443
+ "<|16.82|>": 51206,
444
+ "<|16.84|>": 51207,
445
+ "<|16.86|>": 51208,
446
+ "<|16.88|>": 51209,
447
+ "<|16.90|>": 51210,
448
+ "<|16.92|>": 51211,
449
+ "<|16.94|>": 51212,
450
+ "<|16.96|>": 51213,
451
+ "<|16.98|>": 51214,
452
+ "<|17.00|>": 51215,
453
+ "<|17.02|>": 51216,
454
+ "<|17.04|>": 51217,
455
+ "<|17.06|>": 51218,
456
+ "<|17.08|>": 51219,
457
+ "<|17.10|>": 51220,
458
+ "<|17.12|>": 51221,
459
+ "<|17.14|>": 51222,
460
+ "<|17.16|>": 51223,
461
+ "<|17.18|>": 51224,
462
+ "<|17.20|>": 51225,
463
+ "<|17.22|>": 51226,
464
+ "<|17.24|>": 51227,
465
+ "<|17.26|>": 51228,
466
+ "<|17.28|>": 51229,
467
+ "<|17.30|>": 51230,
468
+ "<|17.32|>": 51231,
469
+ "<|17.34|>": 51232,
470
+ "<|17.36|>": 51233,
471
+ "<|17.38|>": 51234,
472
+ "<|17.40|>": 51235,
473
+ "<|17.42|>": 51236,
474
+ "<|17.44|>": 51237,
475
+ "<|17.46|>": 51238,
476
+ "<|17.48|>": 51239,
477
+ "<|17.50|>": 51240,
478
+ "<|17.52|>": 51241,
479
+ "<|17.54|>": 51242,
480
+ "<|17.56|>": 51243,
481
+ "<|17.58|>": 51244,
482
+ "<|17.60|>": 51245,
483
+ "<|17.62|>": 51246,
484
+ "<|17.64|>": 51247,
485
+ "<|17.66|>": 51248,
486
+ "<|17.68|>": 51249,
487
+ "<|17.70|>": 51250,
488
+ "<|17.72|>": 51251,
489
+ "<|17.74|>": 51252,
490
+ "<|17.76|>": 51253,
491
+ "<|17.78|>": 51254,
492
+ "<|17.80|>": 51255,
493
+ "<|17.82|>": 51256,
494
+ "<|17.84|>": 51257,
495
+ "<|17.86|>": 51258,
496
+ "<|17.88|>": 51259,
497
+ "<|17.90|>": 51260,
498
+ "<|17.92|>": 51261,
499
+ "<|17.94|>": 51262,
500
+ "<|17.96|>": 51263,
501
+ "<|17.98|>": 51264,
502
+ "<|18.00|>": 51265,
503
+ "<|18.02|>": 51266,
504
+ "<|18.04|>": 51267,
505
+ "<|18.06|>": 51268,
506
+ "<|18.08|>": 51269,
507
+ "<|18.10|>": 51270,
508
+ "<|18.12|>": 51271,
509
+ "<|18.14|>": 51272,
510
+ "<|18.16|>": 51273,
511
+ "<|18.18|>": 51274,
512
+ "<|18.20|>": 51275,
513
+ "<|18.22|>": 51276,
514
+ "<|18.24|>": 51277,
515
+ "<|18.26|>": 51278,
516
+ "<|18.28|>": 51279,
517
+ "<|18.30|>": 51280,
518
+ "<|18.32|>": 51281,
519
+ "<|18.34|>": 51282,
520
+ "<|18.36|>": 51283,
521
+ "<|18.38|>": 51284,
522
+ "<|18.40|>": 51285,
523
+ "<|18.42|>": 51286,
524
+ "<|18.44|>": 51287,
525
+ "<|18.46|>": 51288,
526
+ "<|18.48|>": 51289,
527
+ "<|18.50|>": 51290,
528
+ "<|18.52|>": 51291,
529
+ "<|18.54|>": 51292,
530
+ "<|18.56|>": 51293,
531
+ "<|18.58|>": 51294,
532
+ "<|18.60|>": 51295,
533
+ "<|18.62|>": 51296,
534
+ "<|18.64|>": 51297,
535
+ "<|18.66|>": 51298,
536
+ "<|18.68|>": 51299,
537
+ "<|18.70|>": 51300,
538
+ "<|18.72|>": 51301,
539
+ "<|18.74|>": 51302,
540
+ "<|18.76|>": 51303,
541
+ "<|18.78|>": 51304,
542
+ "<|18.80|>": 51305,
543
+ "<|18.82|>": 51306,
544
+ "<|18.84|>": 51307,
545
+ "<|18.86|>": 51308,
546
+ "<|18.88|>": 51309,
547
+ "<|18.90|>": 51310,
548
+ "<|18.92|>": 51311,
549
+ "<|18.94|>": 51312,
550
+ "<|18.96|>": 51313,
551
+ "<|18.98|>": 51314,
552
+ "<|19.00|>": 51315,
553
+ "<|19.02|>": 51316,
554
+ "<|19.04|>": 51317,
555
+ "<|19.06|>": 51318,
556
+ "<|19.08|>": 51319,
557
+ "<|19.10|>": 51320,
558
+ "<|19.12|>": 51321,
559
+ "<|19.14|>": 51322,
560
+ "<|19.16|>": 51323,
561
+ "<|19.18|>": 51324,
562
+ "<|19.20|>": 51325,
563
+ "<|19.22|>": 51326,
564
+ "<|19.24|>": 51327,
565
+ "<|19.26|>": 51328,
566
+ "<|19.28|>": 51329,
567
+ "<|19.30|>": 51330,
568
+ "<|19.32|>": 51331,
569
+ "<|19.34|>": 51332,
570
+ "<|19.36|>": 51333,
571
+ "<|19.38|>": 51334,
572
+ "<|19.40|>": 51335,
573
+ "<|19.42|>": 51336,
574
+ "<|19.44|>": 51337,
575
+ "<|19.46|>": 51338,
576
+ "<|19.48|>": 51339,
577
+ "<|19.50|>": 51340,
578
+ "<|19.52|>": 51341,
579
+ "<|19.54|>": 51342,
580
+ "<|19.56|>": 51343,
581
+ "<|19.58|>": 51344,
582
+ "<|19.60|>": 51345,
583
+ "<|19.62|>": 51346,
584
+ "<|19.64|>": 51347,
585
+ "<|19.66|>": 51348,
586
+ "<|19.68|>": 51349,
587
+ "<|19.70|>": 51350,
588
+ "<|19.72|>": 51351,
589
+ "<|19.74|>": 51352,
590
+ "<|19.76|>": 51353,
591
+ "<|19.78|>": 51354,
592
+ "<|19.80|>": 51355,
593
+ "<|19.82|>": 51356,
594
+ "<|19.84|>": 51357,
595
+ "<|19.86|>": 51358,
596
+ "<|19.88|>": 51359,
597
+ "<|19.90|>": 51360,
598
+ "<|19.92|>": 51361,
599
+ "<|19.94|>": 51362,
600
+ "<|19.96|>": 51363,
601
+ "<|19.98|>": 51364,
602
+ "<|2.00|>": 50465,
603
+ "<|2.02|>": 50466,
604
+ "<|2.04|>": 50467,
605
+ "<|2.06|>": 50468,
606
+ "<|2.08|>": 50469,
607
+ "<|2.10|>": 50470,
608
+ "<|2.12|>": 50471,
609
+ "<|2.14|>": 50472,
610
+ "<|2.16|>": 50473,
611
+ "<|2.18|>": 50474,
612
+ "<|2.20|>": 50475,
613
+ "<|2.22|>": 50476,
614
+ "<|2.24|>": 50477,
615
+ "<|2.26|>": 50478,
616
+ "<|2.28|>": 50479,
617
+ "<|2.30|>": 50480,
618
+ "<|2.32|>": 50481,
619
+ "<|2.34|>": 50482,
620
+ "<|2.36|>": 50483,
621
+ "<|2.38|>": 50484,
622
+ "<|2.40|>": 50485,
623
+ "<|2.42|>": 50486,
624
+ "<|2.44|>": 50487,
625
+ "<|2.46|>": 50488,
626
+ "<|2.48|>": 50489,
627
+ "<|2.50|>": 50490,
628
+ "<|2.52|>": 50491,
629
+ "<|2.54|>": 50492,
630
+ "<|2.56|>": 50493,
631
+ "<|2.58|>": 50494,
632
+ "<|2.60|>": 50495,
633
+ "<|2.62|>": 50496,
634
+ "<|2.64|>": 50497,
635
+ "<|2.66|>": 50498,
636
+ "<|2.68|>": 50499,
637
+ "<|2.70|>": 50500,
638
+ "<|2.72|>": 50501,
639
+ "<|2.74|>": 50502,
640
+ "<|2.76|>": 50503,
641
+ "<|2.78|>": 50504,
642
+ "<|2.80|>": 50505,
643
+ "<|2.82|>": 50506,
644
+ "<|2.84|>": 50507,
645
+ "<|2.86|>": 50508,
646
+ "<|2.88|>": 50509,
647
+ "<|2.90|>": 50510,
648
+ "<|2.92|>": 50511,
649
+ "<|2.94|>": 50512,
650
+ "<|2.96|>": 50513,
651
+ "<|2.98|>": 50514,
652
+ "<|20.00|>": 51365,
653
+ "<|20.02|>": 51366,
654
+ "<|20.04|>": 51367,
655
+ "<|20.06|>": 51368,
656
+ "<|20.08|>": 51369,
657
+ "<|20.10|>": 51370,
658
+ "<|20.12|>": 51371,
659
+ "<|20.14|>": 51372,
660
+ "<|20.16|>": 51373,
661
+ "<|20.18|>": 51374,
662
+ "<|20.20|>": 51375,
663
+ "<|20.22|>": 51376,
664
+ "<|20.24|>": 51377,
665
+ "<|20.26|>": 51378,
666
+ "<|20.28|>": 51379,
667
+ "<|20.30|>": 51380,
668
+ "<|20.32|>": 51381,
669
+ "<|20.34|>": 51382,
670
+ "<|20.36|>": 51383,
671
+ "<|20.38|>": 51384,
672
+ "<|20.40|>": 51385,
673
+ "<|20.42|>": 51386,
674
+ "<|20.44|>": 51387,
675
+ "<|20.46|>": 51388,
676
+ "<|20.48|>": 51389,
677
+ "<|20.50|>": 51390,
678
+ "<|20.52|>": 51391,
679
+ "<|20.54|>": 51392,
680
+ "<|20.56|>": 51393,
681
+ "<|20.58|>": 51394,
682
+ "<|20.60|>": 51395,
683
+ "<|20.62|>": 51396,
684
+ "<|20.64|>": 51397,
685
+ "<|20.66|>": 51398,
686
+ "<|20.68|>": 51399,
687
+ "<|20.70|>": 51400,
688
+ "<|20.72|>": 51401,
689
+ "<|20.74|>": 51402,
690
+ "<|20.76|>": 51403,
691
+ "<|20.78|>": 51404,
692
+ "<|20.80|>": 51405,
693
+ "<|20.82|>": 51406,
694
+ "<|20.84|>": 51407,
695
+ "<|20.86|>": 51408,
696
+ "<|20.88|>": 51409,
697
+ "<|20.90|>": 51410,
698
+ "<|20.92|>": 51411,
699
+ "<|20.94|>": 51412,
700
+ "<|20.96|>": 51413,
701
+ "<|20.98|>": 51414,
702
+ "<|21.00|>": 51415,
703
+ "<|21.02|>": 51416,
704
+ "<|21.04|>": 51417,
705
+ "<|21.06|>": 51418,
706
+ "<|21.08|>": 51419,
707
+ "<|21.10|>": 51420,
708
+ "<|21.12|>": 51421,
709
+ "<|21.14|>": 51422,
710
+ "<|21.16|>": 51423,
711
+ "<|21.18|>": 51424,
712
+ "<|21.20|>": 51425,
713
+ "<|21.22|>": 51426,
714
+ "<|21.24|>": 51427,
715
+ "<|21.26|>": 51428,
716
+ "<|21.28|>": 51429,
717
+ "<|21.30|>": 51430,
718
+ "<|21.32|>": 51431,
719
+ "<|21.34|>": 51432,
720
+ "<|21.36|>": 51433,
721
+ "<|21.38|>": 51434,
722
+ "<|21.40|>": 51435,
723
+ "<|21.42|>": 51436,
724
+ "<|21.44|>": 51437,
725
+ "<|21.46|>": 51438,
726
+ "<|21.48|>": 51439,
727
+ "<|21.50|>": 51440,
728
+ "<|21.52|>": 51441,
729
+ "<|21.54|>": 51442,
730
+ "<|21.56|>": 51443,
731
+ "<|21.58|>": 51444,
732
+ "<|21.60|>": 51445,
733
+ "<|21.62|>": 51446,
734
+ "<|21.64|>": 51447,
735
+ "<|21.66|>": 51448,
736
+ "<|21.68|>": 51449,
737
+ "<|21.70|>": 51450,
738
+ "<|21.72|>": 51451,
739
+ "<|21.74|>": 51452,
740
+ "<|21.76|>": 51453,
741
+ "<|21.78|>": 51454,
742
+ "<|21.80|>": 51455,
743
+ "<|21.82|>": 51456,
744
+ "<|21.84|>": 51457,
745
+ "<|21.86|>": 51458,
746
+ "<|21.88|>": 51459,
747
+ "<|21.90|>": 51460,
748
+ "<|21.92|>": 51461,
749
+ "<|21.94|>": 51462,
750
+ "<|21.96|>": 51463,
751
+ "<|21.98|>": 51464,
752
+ "<|22.00|>": 51465,
753
+ "<|22.02|>": 51466,
754
+ "<|22.04|>": 51467,
755
+ "<|22.06|>": 51468,
756
+ "<|22.08|>": 51469,
757
+ "<|22.10|>": 51470,
758
+ "<|22.12|>": 51471,
759
+ "<|22.14|>": 51472,
760
+ "<|22.16|>": 51473,
761
+ "<|22.18|>": 51474,
762
+ "<|22.20|>": 51475,
763
+ "<|22.22|>": 51476,
764
+ "<|22.24|>": 51477,
765
+ "<|22.26|>": 51478,
766
+ "<|22.28|>": 51479,
767
+ "<|22.30|>": 51480,
768
+ "<|22.32|>": 51481,
769
+ "<|22.34|>": 51482,
770
+ "<|22.36|>": 51483,
771
+ "<|22.38|>": 51484,
772
+ "<|22.40|>": 51485,
773
+ "<|22.42|>": 51486,
774
+ "<|22.44|>": 51487,
775
+ "<|22.46|>": 51488,
776
+ "<|22.48|>": 51489,
777
+ "<|22.50|>": 51490,
778
+ "<|22.52|>": 51491,
779
+ "<|22.54|>": 51492,
780
+ "<|22.56|>": 51493,
781
+ "<|22.58|>": 51494,
782
+ "<|22.60|>": 51495,
783
+ "<|22.62|>": 51496,
784
+ "<|22.64|>": 51497,
785
+ "<|22.66|>": 51498,
786
+ "<|22.68|>": 51499,
787
+ "<|22.70|>": 51500,
788
+ "<|22.72|>": 51501,
789
+ "<|22.74|>": 51502,
790
+ "<|22.76|>": 51503,
791
+ "<|22.78|>": 51504,
792
+ "<|22.80|>": 51505,
793
+ "<|22.82|>": 51506,
794
+ "<|22.84|>": 51507,
795
+ "<|22.86|>": 51508,
796
+ "<|22.88|>": 51509,
797
+ "<|22.90|>": 51510,
798
+ "<|22.92|>": 51511,
799
+ "<|22.94|>": 51512,
800
+ "<|22.96|>": 51513,
801
+ "<|22.98|>": 51514,
802
+ "<|23.00|>": 51515,
803
+ "<|23.02|>": 51516,
804
+ "<|23.04|>": 51517,
805
+ "<|23.06|>": 51518,
806
+ "<|23.08|>": 51519,
807
+ "<|23.10|>": 51520,
808
+ "<|23.12|>": 51521,
809
+ "<|23.14|>": 51522,
810
+ "<|23.16|>": 51523,
811
+ "<|23.18|>": 51524,
812
+ "<|23.20|>": 51525,
813
+ "<|23.22|>": 51526,
814
+ "<|23.24|>": 51527,
815
+ "<|23.26|>": 51528,
816
+ "<|23.28|>": 51529,
817
+ "<|23.30|>": 51530,
818
+ "<|23.32|>": 51531,
819
+ "<|23.34|>": 51532,
820
+ "<|23.36|>": 51533,
821
+ "<|23.38|>": 51534,
822
+ "<|23.40|>": 51535,
823
+ "<|23.42|>": 51536,
824
+ "<|23.44|>": 51537,
825
+ "<|23.46|>": 51538,
826
+ "<|23.48|>": 51539,
827
+ "<|23.50|>": 51540,
828
+ "<|23.52|>": 51541,
829
+ "<|23.54|>": 51542,
830
+ "<|23.56|>": 51543,
831
+ "<|23.58|>": 51544,
832
+ "<|23.60|>": 51545,
833
+ "<|23.62|>": 51546,
834
+ "<|23.64|>": 51547,
835
+ "<|23.66|>": 51548,
836
+ "<|23.68|>": 51549,
837
+ "<|23.70|>": 51550,
838
+ "<|23.72|>": 51551,
839
+ "<|23.74|>": 51552,
840
+ "<|23.76|>": 51553,
841
+ "<|23.78|>": 51554,
842
+ "<|23.80|>": 51555,
843
+ "<|23.82|>": 51556,
844
+ "<|23.84|>": 51557,
845
+ "<|23.86|>": 51558,
846
+ "<|23.88|>": 51559,
847
+ "<|23.90|>": 51560,
848
+ "<|23.92|>": 51561,
849
+ "<|23.94|>": 51562,
850
+ "<|23.96|>": 51563,
851
+ "<|23.98|>": 51564,
852
+ "<|24.00|>": 51565,
853
+ "<|24.02|>": 51566,
854
+ "<|24.04|>": 51567,
855
+ "<|24.06|>": 51568,
856
+ "<|24.08|>": 51569,
857
+ "<|24.10|>": 51570,
858
+ "<|24.12|>": 51571,
859
+ "<|24.14|>": 51572,
860
+ "<|24.16|>": 51573,
861
+ "<|24.18|>": 51574,
862
+ "<|24.20|>": 51575,
863
+ "<|24.22|>": 51576,
864
+ "<|24.24|>": 51577,
865
+ "<|24.26|>": 51578,
866
+ "<|24.28|>": 51579,
867
+ "<|24.30|>": 51580,
868
+ "<|24.32|>": 51581,
869
+ "<|24.34|>": 51582,
870
+ "<|24.36|>": 51583,
871
+ "<|24.38|>": 51584,
872
+ "<|24.40|>": 51585,
873
+ "<|24.42|>": 51586,
874
+ "<|24.44|>": 51587,
875
+ "<|24.46|>": 51588,
876
+ "<|24.48|>": 51589,
877
+ "<|24.50|>": 51590,
878
+ "<|24.52|>": 51591,
879
+ "<|24.54|>": 51592,
880
+ "<|24.56|>": 51593,
881
+ "<|24.58|>": 51594,
882
+ "<|24.60|>": 51595,
883
+ "<|24.62|>": 51596,
884
+ "<|24.64|>": 51597,
885
+ "<|24.66|>": 51598,
886
+ "<|24.68|>": 51599,
887
+ "<|24.70|>": 51600,
888
+ "<|24.72|>": 51601,
889
+ "<|24.74|>": 51602,
890
+ "<|24.76|>": 51603,
891
+ "<|24.78|>": 51604,
892
+ "<|24.80|>": 51605,
893
+ "<|24.82|>": 51606,
894
+ "<|24.84|>": 51607,
895
+ "<|24.86|>": 51608,
896
+ "<|24.88|>": 51609,
897
+ "<|24.90|>": 51610,
898
+ "<|24.92|>": 51611,
899
+ "<|24.94|>": 51612,
900
+ "<|24.96|>": 51613,
901
+ "<|24.98|>": 51614,
902
+ "<|25.00|>": 51615,
903
+ "<|25.02|>": 51616,
904
+ "<|25.04|>": 51617,
905
+ "<|25.06|>": 51618,
906
+ "<|25.08|>": 51619,
907
+ "<|25.10|>": 51620,
908
+ "<|25.12|>": 51621,
909
+ "<|25.14|>": 51622,
910
+ "<|25.16|>": 51623,
911
+ "<|25.18|>": 51624,
912
+ "<|25.20|>": 51625,
913
+ "<|25.22|>": 51626,
914
+ "<|25.24|>": 51627,
915
+ "<|25.26|>": 51628,
916
+ "<|25.28|>": 51629,
917
+ "<|25.30|>": 51630,
918
+ "<|25.32|>": 51631,
919
+ "<|25.34|>": 51632,
920
+ "<|25.36|>": 51633,
921
+ "<|25.38|>": 51634,
922
+ "<|25.40|>": 51635,
923
+ "<|25.42|>": 51636,
924
+ "<|25.44|>": 51637,
925
+ "<|25.46|>": 51638,
926
+ "<|25.48|>": 51639,
927
+ "<|25.50|>": 51640,
928
+ "<|25.52|>": 51641,
929
+ "<|25.54|>": 51642,
930
+ "<|25.56|>": 51643,
931
+ "<|25.58|>": 51644,
932
+ "<|25.60|>": 51645,
933
+ "<|25.62|>": 51646,
934
+ "<|25.64|>": 51647,
935
+ "<|25.66|>": 51648,
936
+ "<|25.68|>": 51649,
937
+ "<|25.70|>": 51650,
938
+ "<|25.72|>": 51651,
939
+ "<|25.74|>": 51652,
940
+ "<|25.76|>": 51653,
941
+ "<|25.78|>": 51654,
942
+ "<|25.80|>": 51655,
943
+ "<|25.82|>": 51656,
944
+ "<|25.84|>": 51657,
945
+ "<|25.86|>": 51658,
946
+ "<|25.88|>": 51659,
947
+ "<|25.90|>": 51660,
948
+ "<|25.92|>": 51661,
949
+ "<|25.94|>": 51662,
950
+ "<|25.96|>": 51663,
951
+ "<|25.98|>": 51664,
952
+ "<|26.00|>": 51665,
953
+ "<|26.02|>": 51666,
954
+ "<|26.04|>": 51667,
955
+ "<|26.06|>": 51668,
956
+ "<|26.08|>": 51669,
957
+ "<|26.10|>": 51670,
958
+ "<|26.12|>": 51671,
959
+ "<|26.14|>": 51672,
960
+ "<|26.16|>": 51673,
961
+ "<|26.18|>": 51674,
962
+ "<|26.20|>": 51675,
963
+ "<|26.22|>": 51676,
964
+ "<|26.24|>": 51677,
965
+ "<|26.26|>": 51678,
966
+ "<|26.28|>": 51679,
967
+ "<|26.30|>": 51680,
968
+ "<|26.32|>": 51681,
969
+ "<|26.34|>": 51682,
970
+ "<|26.36|>": 51683,
971
+ "<|26.38|>": 51684,
972
+ "<|26.40|>": 51685,
973
+ "<|26.42|>": 51686,
974
+ "<|26.44|>": 51687,
975
+ "<|26.46|>": 51688,
976
+ "<|26.48|>": 51689,
977
+ "<|26.50|>": 51690,
978
+ "<|26.52|>": 51691,
979
+ "<|26.54|>": 51692,
980
+ "<|26.56|>": 51693,
981
+ "<|26.58|>": 51694,
982
+ "<|26.60|>": 51695,
983
+ "<|26.62|>": 51696,
984
+ "<|26.64|>": 51697,
985
+ "<|26.66|>": 51698,
986
+ "<|26.68|>": 51699,
987
+ "<|26.70|>": 51700,
988
+ "<|26.72|>": 51701,
989
+ "<|26.74|>": 51702,
990
+ "<|26.76|>": 51703,
991
+ "<|26.78|>": 51704,
992
+ "<|26.80|>": 51705,
993
+ "<|26.82|>": 51706,
994
+ "<|26.84|>": 51707,
995
+ "<|26.86|>": 51708,
996
+ "<|26.88|>": 51709,
997
+ "<|26.90|>": 51710,
998
+ "<|26.92|>": 51711,
999
+ "<|26.94|>": 51712,
1000
+ "<|26.96|>": 51713,
1001
+ "<|26.98|>": 51714,
1002
+ "<|27.00|>": 51715,
1003
+ "<|27.02|>": 51716,
1004
+ "<|27.04|>": 51717,
1005
+ "<|27.06|>": 51718,
1006
+ "<|27.08|>": 51719,
1007
+ "<|27.10|>": 51720,
1008
+ "<|27.12|>": 51721,
1009
+ "<|27.14|>": 51722,
1010
+ "<|27.16|>": 51723,
1011
+ "<|27.18|>": 51724,
1012
+ "<|27.20|>": 51725,
1013
+ "<|27.22|>": 51726,
1014
+ "<|27.24|>": 51727,
1015
+ "<|27.26|>": 51728,
1016
+ "<|27.28|>": 51729,
1017
+ "<|27.30|>": 51730,
1018
+ "<|27.32|>": 51731,
1019
+ "<|27.34|>": 51732,
1020
+ "<|27.36|>": 51733,
1021
+ "<|27.38|>": 51734,
1022
+ "<|27.40|>": 51735,
1023
+ "<|27.42|>": 51736,
1024
+ "<|27.44|>": 51737,
1025
+ "<|27.46|>": 51738,
1026
+ "<|27.48|>": 51739,
1027
+ "<|27.50|>": 51740,
1028
+ "<|27.52|>": 51741,
1029
+ "<|27.54|>": 51742,
1030
+ "<|27.56|>": 51743,
1031
+ "<|27.58|>": 51744,
1032
+ "<|27.60|>": 51745,
1033
+ "<|27.62|>": 51746,
1034
+ "<|27.64|>": 51747,
1035
+ "<|27.66|>": 51748,
1036
+ "<|27.68|>": 51749,
1037
+ "<|27.70|>": 51750,
1038
+ "<|27.72|>": 51751,
1039
+ "<|27.74|>": 51752,
1040
+ "<|27.76|>": 51753,
1041
+ "<|27.78|>": 51754,
1042
+ "<|27.80|>": 51755,
1043
+ "<|27.82|>": 51756,
1044
+ "<|27.84|>": 51757,
1045
+ "<|27.86|>": 51758,
1046
+ "<|27.88|>": 51759,
1047
+ "<|27.90|>": 51760,
1048
+ "<|27.92|>": 51761,
1049
+ "<|27.94|>": 51762,
1050
+ "<|27.96|>": 51763,
1051
+ "<|27.98|>": 51764,
1052
+ "<|28.00|>": 51765,
1053
+ "<|28.02|>": 51766,
1054
+ "<|28.04|>": 51767,
1055
+ "<|28.06|>": 51768,
1056
+ "<|28.08|>": 51769,
1057
+ "<|28.10|>": 51770,
1058
+ "<|28.12|>": 51771,
1059
+ "<|28.14|>": 51772,
1060
+ "<|28.16|>": 51773,
1061
+ "<|28.18|>": 51774,
1062
+ "<|28.20|>": 51775,
1063
+ "<|28.22|>": 51776,
1064
+ "<|28.24|>": 51777,
1065
+ "<|28.26|>": 51778,
1066
+ "<|28.28|>": 51779,
1067
+ "<|28.30|>": 51780,
1068
+ "<|28.32|>": 51781,
1069
+ "<|28.34|>": 51782,
1070
+ "<|28.36|>": 51783,
1071
+ "<|28.38|>": 51784,
1072
+ "<|28.40|>": 51785,
1073
+ "<|28.42|>": 51786,
1074
+ "<|28.44|>": 51787,
1075
+ "<|28.46|>": 51788,
1076
+ "<|28.48|>": 51789,
1077
+ "<|28.50|>": 51790,
1078
+ "<|28.52|>": 51791,
1079
+ "<|28.54|>": 51792,
1080
+ "<|28.56|>": 51793,
1081
+ "<|28.58|>": 51794,
1082
+ "<|28.60|>": 51795,
1083
+ "<|28.62|>": 51796,
1084
+ "<|28.64|>": 51797,
1085
+ "<|28.66|>": 51798,
1086
+ "<|28.68|>": 51799,
1087
+ "<|28.70|>": 51800,
1088
+ "<|28.72|>": 51801,
1089
+ "<|28.74|>": 51802,
1090
+ "<|28.76|>": 51803,
1091
+ "<|28.78|>": 51804,
1092
+ "<|28.80|>": 51805,
1093
+ "<|28.82|>": 51806,
1094
+ "<|28.84|>": 51807,
1095
+ "<|28.86|>": 51808,
1096
+ "<|28.88|>": 51809,
1097
+ "<|28.90|>": 51810,
1098
+ "<|28.92|>": 51811,
1099
+ "<|28.94|>": 51812,
1100
+ "<|28.96|>": 51813,
1101
+ "<|28.98|>": 51814,
1102
+ "<|29.00|>": 51815,
1103
+ "<|29.02|>": 51816,
1104
+ "<|29.04|>": 51817,
1105
+ "<|29.06|>": 51818,
1106
+ "<|29.08|>": 51819,
1107
+ "<|29.10|>": 51820,
1108
+ "<|29.12|>": 51821,
1109
+ "<|29.14|>": 51822,
1110
+ "<|29.16|>": 51823,
1111
+ "<|29.18|>": 51824,
1112
+ "<|29.20|>": 51825,
1113
+ "<|29.22|>": 51826,
1114
+ "<|29.24|>": 51827,
1115
+ "<|29.26|>": 51828,
1116
+ "<|29.28|>": 51829,
1117
+ "<|29.30|>": 51830,
1118
+ "<|29.32|>": 51831,
1119
+ "<|29.34|>": 51832,
1120
+ "<|29.36|>": 51833,
1121
+ "<|29.38|>": 51834,
1122
+ "<|29.40|>": 51835,
1123
+ "<|29.42|>": 51836,
1124
+ "<|29.44|>": 51837,
1125
+ "<|29.46|>": 51838,
1126
+ "<|29.48|>": 51839,
1127
+ "<|29.50|>": 51840,
1128
+ "<|29.52|>": 51841,
1129
+ "<|29.54|>": 51842,
1130
+ "<|29.56|>": 51843,
1131
+ "<|29.58|>": 51844,
1132
+ "<|29.60|>": 51845,
1133
+ "<|29.62|>": 51846,
1134
+ "<|29.64|>": 51847,
1135
+ "<|29.66|>": 51848,
1136
+ "<|29.68|>": 51849,
1137
+ "<|29.70|>": 51850,
1138
+ "<|29.72|>": 51851,
1139
+ "<|29.74|>": 51852,
1140
+ "<|29.76|>": 51853,
1141
+ "<|29.78|>": 51854,
1142
+ "<|29.80|>": 51855,
1143
+ "<|29.82|>": 51856,
1144
+ "<|29.84|>": 51857,
1145
+ "<|29.86|>": 51858,
1146
+ "<|29.88|>": 51859,
1147
+ "<|29.90|>": 51860,
1148
+ "<|29.92|>": 51861,
1149
+ "<|29.94|>": 51862,
1150
+ "<|29.96|>": 51863,
1151
+ "<|29.98|>": 51864,
1152
+ "<|3.00|>": 50515,
1153
+ "<|3.02|>": 50516,
1154
+ "<|3.04|>": 50517,
1155
+ "<|3.06|>": 50518,
1156
+ "<|3.08|>": 50519,
1157
+ "<|3.10|>": 50520,
1158
+ "<|3.12|>": 50521,
1159
+ "<|3.14|>": 50522,
1160
+ "<|3.16|>": 50523,
1161
+ "<|3.18|>": 50524,
1162
+ "<|3.20|>": 50525,
1163
+ "<|3.22|>": 50526,
1164
+ "<|3.24|>": 50527,
1165
+ "<|3.26|>": 50528,
1166
+ "<|3.28|>": 50529,
1167
+ "<|3.30|>": 50530,
1168
+ "<|3.32|>": 50531,
1169
+ "<|3.34|>": 50532,
1170
+ "<|3.36|>": 50533,
1171
+ "<|3.38|>": 50534,
1172
+ "<|3.40|>": 50535,
1173
+ "<|3.42|>": 50536,
1174
+ "<|3.44|>": 50537,
1175
+ "<|3.46|>": 50538,
1176
+ "<|3.48|>": 50539,
1177
+ "<|3.50|>": 50540,
1178
+ "<|3.52|>": 50541,
1179
+ "<|3.54|>": 50542,
1180
+ "<|3.56|>": 50543,
1181
+ "<|3.58|>": 50544,
1182
+ "<|3.60|>": 50545,
1183
+ "<|3.62|>": 50546,
1184
+ "<|3.64|>": 50547,
1185
+ "<|3.66|>": 50548,
1186
+ "<|3.68|>": 50549,
1187
+ "<|3.70|>": 50550,
1188
+ "<|3.72|>": 50551,
1189
+ "<|3.74|>": 50552,
1190
+ "<|3.76|>": 50553,
1191
+ "<|3.78|>": 50554,
1192
+ "<|3.80|>": 50555,
1193
+ "<|3.82|>": 50556,
1194
+ "<|3.84|>": 50557,
1195
+ "<|3.86|>": 50558,
1196
+ "<|3.88|>": 50559,
1197
+ "<|3.90|>": 50560,
1198
+ "<|3.92|>": 50561,
1199
+ "<|3.94|>": 50562,
1200
+ "<|3.96|>": 50563,
1201
+ "<|3.98|>": 50564,
1202
+ "<|30.00|>": 51865,
1203
+ "<|4.00|>": 50565,
1204
+ "<|4.02|>": 50566,
1205
+ "<|4.04|>": 50567,
1206
+ "<|4.06|>": 50568,
1207
+ "<|4.08|>": 50569,
1208
+ "<|4.10|>": 50570,
1209
+ "<|4.12|>": 50571,
1210
+ "<|4.14|>": 50572,
1211
+ "<|4.16|>": 50573,
1212
+ "<|4.18|>": 50574,
1213
+ "<|4.20|>": 50575,
1214
+ "<|4.22|>": 50576,
1215
+ "<|4.24|>": 50577,
1216
+ "<|4.26|>": 50578,
1217
+ "<|4.28|>": 50579,
1218
+ "<|4.30|>": 50580,
1219
+ "<|4.32|>": 50581,
1220
+ "<|4.34|>": 50582,
1221
+ "<|4.36|>": 50583,
1222
+ "<|4.38|>": 50584,
1223
+ "<|4.40|>": 50585,
1224
+ "<|4.42|>": 50586,
1225
+ "<|4.44|>": 50587,
1226
+ "<|4.46|>": 50588,
1227
+ "<|4.48|>": 50589,
1228
+ "<|4.50|>": 50590,
1229
+ "<|4.52|>": 50591,
1230
+ "<|4.54|>": 50592,
1231
+ "<|4.56|>": 50593,
1232
+ "<|4.58|>": 50594,
1233
+ "<|4.60|>": 50595,
1234
+ "<|4.62|>": 50596,
1235
+ "<|4.64|>": 50597,
1236
+ "<|4.66|>": 50598,
1237
+ "<|4.68|>": 50599,
1238
+ "<|4.70|>": 50600,
1239
+ "<|4.72|>": 50601,
1240
+ "<|4.74|>": 50602,
1241
+ "<|4.76|>": 50603,
1242
+ "<|4.78|>": 50604,
1243
+ "<|4.80|>": 50605,
1244
+ "<|4.82|>": 50606,
1245
+ "<|4.84|>": 50607,
1246
+ "<|4.86|>": 50608,
1247
+ "<|4.88|>": 50609,
1248
+ "<|4.90|>": 50610,
1249
+ "<|4.92|>": 50611,
1250
+ "<|4.94|>": 50612,
1251
+ "<|4.96|>": 50613,
1252
+ "<|4.98|>": 50614,
1253
+ "<|5.00|>": 50615,
1254
+ "<|5.02|>": 50616,
1255
+ "<|5.04|>": 50617,
1256
+ "<|5.06|>": 50618,
1257
+ "<|5.08|>": 50619,
1258
+ "<|5.10|>": 50620,
1259
+ "<|5.12|>": 50621,
1260
+ "<|5.14|>": 50622,
1261
+ "<|5.16|>": 50623,
1262
+ "<|5.18|>": 50624,
1263
+ "<|5.20|>": 50625,
1264
+ "<|5.22|>": 50626,
1265
+ "<|5.24|>": 50627,
1266
+ "<|5.26|>": 50628,
1267
+ "<|5.28|>": 50629,
1268
+ "<|5.30|>": 50630,
1269
+ "<|5.32|>": 50631,
1270
+ "<|5.34|>": 50632,
1271
+ "<|5.36|>": 50633,
1272
+ "<|5.38|>": 50634,
1273
+ "<|5.40|>": 50635,
1274
+ "<|5.42|>": 50636,
1275
+ "<|5.44|>": 50637,
1276
+ "<|5.46|>": 50638,
1277
+ "<|5.48|>": 50639,
1278
+ "<|5.50|>": 50640,
1279
+ "<|5.52|>": 50641,
1280
+ "<|5.54|>": 50642,
1281
+ "<|5.56|>": 50643,
1282
+ "<|5.58|>": 50644,
1283
+ "<|5.60|>": 50645,
1284
+ "<|5.62|>": 50646,
1285
+ "<|5.64|>": 50647,
1286
+ "<|5.66|>": 50648,
1287
+ "<|5.68|>": 50649,
1288
+ "<|5.70|>": 50650,
1289
+ "<|5.72|>": 50651,
1290
+ "<|5.74|>": 50652,
1291
+ "<|5.76|>": 50653,
1292
+ "<|5.78|>": 50654,
1293
+ "<|5.80|>": 50655,
1294
+ "<|5.82|>": 50656,
1295
+ "<|5.84|>": 50657,
1296
+ "<|5.86|>": 50658,
1297
+ "<|5.88|>": 50659,
1298
+ "<|5.90|>": 50660,
1299
+ "<|5.92|>": 50661,
1300
+ "<|5.94|>": 50662,
1301
+ "<|5.96|>": 50663,
1302
+ "<|5.98|>": 50664,
1303
+ "<|6.00|>": 50665,
1304
+ "<|6.02|>": 50666,
1305
+ "<|6.04|>": 50667,
1306
+ "<|6.06|>": 50668,
1307
+ "<|6.08|>": 50669,
1308
+ "<|6.10|>": 50670,
1309
+ "<|6.12|>": 50671,
1310
+ "<|6.14|>": 50672,
1311
+ "<|6.16|>": 50673,
1312
+ "<|6.18|>": 50674,
1313
+ "<|6.20|>": 50675,
1314
+ "<|6.22|>": 50676,
1315
+ "<|6.24|>": 50677,
1316
+ "<|6.26|>": 50678,
1317
+ "<|6.28|>": 50679,
1318
+ "<|6.30|>": 50680,
1319
+ "<|6.32|>": 50681,
1320
+ "<|6.34|>": 50682,
1321
+ "<|6.36|>": 50683,
1322
+ "<|6.38|>": 50684,
1323
+ "<|6.40|>": 50685,
1324
+ "<|6.42|>": 50686,
1325
+ "<|6.44|>": 50687,
1326
+ "<|6.46|>": 50688,
1327
+ "<|6.48|>": 50689,
1328
+ "<|6.50|>": 50690,
1329
+ "<|6.52|>": 50691,
1330
+ "<|6.54|>": 50692,
1331
+ "<|6.56|>": 50693,
1332
+ "<|6.58|>": 50694,
1333
+ "<|6.60|>": 50695,
1334
+ "<|6.62|>": 50696,
1335
+ "<|6.64|>": 50697,
1336
+ "<|6.66|>": 50698,
1337
+ "<|6.68|>": 50699,
1338
+ "<|6.70|>": 50700,
1339
+ "<|6.72|>": 50701,
1340
+ "<|6.74|>": 50702,
1341
+ "<|6.76|>": 50703,
1342
+ "<|6.78|>": 50704,
1343
+ "<|6.80|>": 50705,
1344
+ "<|6.82|>": 50706,
1345
+ "<|6.84|>": 50707,
1346
+ "<|6.86|>": 50708,
1347
+ "<|6.88|>": 50709,
1348
+ "<|6.90|>": 50710,
1349
+ "<|6.92|>": 50711,
1350
+ "<|6.94|>": 50712,
1351
+ "<|6.96|>": 50713,
1352
+ "<|6.98|>": 50714,
1353
+ "<|7.00|>": 50715,
1354
+ "<|7.02|>": 50716,
1355
+ "<|7.04|>": 50717,
1356
+ "<|7.06|>": 50718,
1357
+ "<|7.08|>": 50719,
1358
+ "<|7.10|>": 50720,
1359
+ "<|7.12|>": 50721,
1360
+ "<|7.14|>": 50722,
1361
+ "<|7.16|>": 50723,
1362
+ "<|7.18|>": 50724,
1363
+ "<|7.20|>": 50725,
1364
+ "<|7.22|>": 50726,
1365
+ "<|7.24|>": 50727,
1366
+ "<|7.26|>": 50728,
1367
+ "<|7.28|>": 50729,
1368
+ "<|7.30|>": 50730,
1369
+ "<|7.32|>": 50731,
1370
+ "<|7.34|>": 50732,
1371
+ "<|7.36|>": 50733,
1372
+ "<|7.38|>": 50734,
1373
+ "<|7.40|>": 50735,
1374
+ "<|7.42|>": 50736,
1375
+ "<|7.44|>": 50737,
1376
+ "<|7.46|>": 50738,
1377
+ "<|7.48|>": 50739,
1378
+ "<|7.50|>": 50740,
1379
+ "<|7.52|>": 50741,
1380
+ "<|7.54|>": 50742,
1381
+ "<|7.56|>": 50743,
1382
+ "<|7.58|>": 50744,
1383
+ "<|7.60|>": 50745,
1384
+ "<|7.62|>": 50746,
1385
+ "<|7.64|>": 50747,
1386
+ "<|7.66|>": 50748,
1387
+ "<|7.68|>": 50749,
1388
+ "<|7.70|>": 50750,
1389
+ "<|7.72|>": 50751,
1390
+ "<|7.74|>": 50752,
1391
+ "<|7.76|>": 50753,
1392
+ "<|7.78|>": 50754,
1393
+ "<|7.80|>": 50755,
1394
+ "<|7.82|>": 50756,
1395
+ "<|7.84|>": 50757,
1396
+ "<|7.86|>": 50758,
1397
+ "<|7.88|>": 50759,
1398
+ "<|7.90|>": 50760,
1399
+ "<|7.92|>": 50761,
1400
+ "<|7.94|>": 50762,
1401
+ "<|7.96|>": 50763,
1402
+ "<|7.98|>": 50764,
1403
+ "<|8.00|>": 50765,
1404
+ "<|8.02|>": 50766,
1405
+ "<|8.04|>": 50767,
1406
+ "<|8.06|>": 50768,
1407
+ "<|8.08|>": 50769,
1408
+ "<|8.10|>": 50770,
1409
+ "<|8.12|>": 50771,
1410
+ "<|8.14|>": 50772,
1411
+ "<|8.16|>": 50773,
1412
+ "<|8.18|>": 50774,
1413
+ "<|8.20|>": 50775,
1414
+ "<|8.22|>": 50776,
1415
+ "<|8.24|>": 50777,
1416
+ "<|8.26|>": 50778,
1417
+ "<|8.28|>": 50779,
1418
+ "<|8.30|>": 50780,
1419
+ "<|8.32|>": 50781,
1420
+ "<|8.34|>": 50782,
1421
+ "<|8.36|>": 50783,
1422
+ "<|8.38|>": 50784,
1423
+ "<|8.40|>": 50785,
1424
+ "<|8.42|>": 50786,
1425
+ "<|8.44|>": 50787,
1426
+ "<|8.46|>": 50788,
1427
+ "<|8.48|>": 50789,
1428
+ "<|8.50|>": 50790,
1429
+ "<|8.52|>": 50791,
1430
+ "<|8.54|>": 50792,
1431
+ "<|8.56|>": 50793,
1432
+ "<|8.58|>": 50794,
1433
+ "<|8.60|>": 50795,
1434
+ "<|8.62|>": 50796,
1435
+ "<|8.64|>": 50797,
1436
+ "<|8.66|>": 50798,
1437
+ "<|8.68|>": 50799,
1438
+ "<|8.70|>": 50800,
1439
+ "<|8.72|>": 50801,
1440
+ "<|8.74|>": 50802,
1441
+ "<|8.76|>": 50803,
1442
+ "<|8.78|>": 50804,
1443
+ "<|8.80|>": 50805,
1444
+ "<|8.82|>": 50806,
1445
+ "<|8.84|>": 50807,
1446
+ "<|8.86|>": 50808,
1447
+ "<|8.88|>": 50809,
1448
+ "<|8.90|>": 50810,
1449
+ "<|8.92|>": 50811,
1450
+ "<|8.94|>": 50812,
1451
+ "<|8.96|>": 50813,
1452
+ "<|8.98|>": 50814,
1453
+ "<|9.00|>": 50815,
1454
+ "<|9.02|>": 50816,
1455
+ "<|9.04|>": 50817,
1456
+ "<|9.06|>": 50818,
1457
+ "<|9.08|>": 50819,
1458
+ "<|9.10|>": 50820,
1459
+ "<|9.12|>": 50821,
1460
+ "<|9.14|>": 50822,
1461
+ "<|9.16|>": 50823,
1462
+ "<|9.18|>": 50824,
1463
+ "<|9.20|>": 50825,
1464
+ "<|9.22|>": 50826,
1465
+ "<|9.24|>": 50827,
1466
+ "<|9.26|>": 50828,
1467
+ "<|9.28|>": 50829,
1468
+ "<|9.30|>": 50830,
1469
+ "<|9.32|>": 50831,
1470
+ "<|9.34|>": 50832,
1471
+ "<|9.36|>": 50833,
1472
+ "<|9.38|>": 50834,
1473
+ "<|9.40|>": 50835,
1474
+ "<|9.42|>": 50836,
1475
+ "<|9.44|>": 50837,
1476
+ "<|9.46|>": 50838,
1477
+ "<|9.48|>": 50839,
1478
+ "<|9.50|>": 50840,
1479
+ "<|9.52|>": 50841,
1480
+ "<|9.54|>": 50842,
1481
+ "<|9.56|>": 50843,
1482
+ "<|9.58|>": 50844,
1483
+ "<|9.60|>": 50845,
1484
+ "<|9.62|>": 50846,
1485
+ "<|9.64|>": 50847,
1486
+ "<|9.66|>": 50848,
1487
+ "<|9.68|>": 50849,
1488
+ "<|9.70|>": 50850,
1489
+ "<|9.72|>": 50851,
1490
+ "<|9.74|>": 50852,
1491
+ "<|9.76|>": 50853,
1492
+ "<|9.78|>": 50854,
1493
+ "<|9.80|>": 50855,
1494
+ "<|9.82|>": 50856,
1495
+ "<|9.84|>": 50857,
1496
+ "<|9.86|>": 50858,
1497
+ "<|9.88|>": 50859,
1498
+ "<|9.90|>": 50860,
1499
+ "<|9.92|>": 50861,
1500
+ "<|9.94|>": 50862,
1501
+ "<|9.96|>": 50863,
1502
+ "<|9.98|>": 50864,
1503
+ "<|af|>": 50327,
1504
+ "<|am|>": 50334,
1505
+ "<|ar|>": 50272,
1506
+ "<|as|>": 50350,
1507
+ "<|az|>": 50304,
1508
+ "<|ba|>": 50355,
1509
+ "<|be|>": 50330,
1510
+ "<|bg|>": 50292,
1511
+ "<|bn|>": 50302,
1512
+ "<|bo|>": 50347,
1513
+ "<|br|>": 50309,
1514
+ "<|bs|>": 50315,
1515
+ "<|ca|>": 50270,
1516
+ "<|cs|>": 50283,
1517
+ "<|cy|>": 50297,
1518
+ "<|da|>": 50285,
1519
+ "<|de|>": 50261,
1520
+ "<|el|>": 50281,
1521
+ "<|endoftext|>": 50257,
1522
+ "<|en|>": 50259,
1523
+ "<|es|>": 50262,
1524
+ "<|et|>": 50307,
1525
+ "<|eu|>": 50310,
1526
+ "<|fa|>": 50300,
1527
+ "<|fi|>": 50277,
1528
+ "<|fo|>": 50338,
1529
+ "<|fr|>": 50265,
1530
+ "<|gl|>": 50319,
1531
+ "<|gu|>": 50333,
1532
+ "<|haw|>": 50352,
1533
+ "<|ha|>": 50354,
1534
+ "<|he|>": 50279,
1535
+ "<|hi|>": 50276,
1536
+ "<|hr|>": 50291,
1537
+ "<|ht|>": 50339,
1538
+ "<|hu|>": 50286,
1539
+ "<|hy|>": 50312,
1540
+ "<|id|>": 50275,
1541
+ "<|is|>": 50311,
1542
+ "<|it|>": 50274,
1543
+ "<|ja|>": 50266,
1544
+ "<|jw|>": 50356,
1545
+ "<|ka|>": 50329,
1546
+ "<|kk|>": 50316,
1547
+ "<|km|>": 50323,
1548
+ "<|kn|>": 50306,
1549
+ "<|ko|>": 50264,
1550
+ "<|la|>": 50294,
1551
+ "<|lb|>": 50345,
1552
+ "<|ln|>": 50353,
1553
+ "<|lo|>": 50336,
1554
+ "<|lt|>": 50293,
1555
+ "<|lv|>": 50301,
1556
+ "<|mg|>": 50349,
1557
+ "<|mi|>": 50295,
1558
+ "<|mk|>": 50308,
1559
+ "<|ml|>": 50296,
1560
+ "<|mn|>": 50314,
1561
+ "<|mr|>": 50320,
1562
+ "<|ms|>": 50282,
1563
+ "<|mt|>": 50343,
1564
+ "<|my|>": 50346,
1565
+ "<|ne|>": 50313,
1566
+ "<|nl|>": 50271,
1567
+ "<|nn|>": 50342,
1568
+ "<|nospeech|>": 50363,
1569
+ "<|notimestamps|>": 50364,
1570
+ "<|no|>": 50288,
1571
+ "<|oc|>": 50328,
1572
+ "<|pa|>": 50321,
1573
+ "<|pl|>": 50269,
1574
+ "<|ps|>": 50340,
1575
+ "<|pt|>": 50267,
1576
+ "<|ro|>": 50284,
1577
+ "<|ru|>": 50263,
1578
+ "<|sa|>": 50344,
1579
+ "<|sd|>": 50332,
1580
+ "<|si|>": 50322,
1581
+ "<|sk|>": 50298,
1582
+ "<|sl|>": 50305,
1583
+ "<|sn|>": 50324,
1584
+ "<|so|>": 50326,
1585
+ "<|sq|>": 50317,
1586
+ "<|sr|>": 50303,
1587
+ "<|startoflm|>": 50361,
1588
+ "<|startofprev|>": 50362,
1589
+ "<|startoftranscript|>": 50258,
1590
+ "<|su|>": 50357,
1591
+ "<|sv|>": 50273,
1592
+ "<|sw|>": 50318,
1593
+ "<|ta|>": 50287,
1594
+ "<|te|>": 50299,
1595
+ "<|tg|>": 50331,
1596
+ "<|th|>": 50289,
1597
+ "<|tk|>": 50341,
1598
+ "<|tl|>": 50348,
1599
+ "<|transcribe|>": 50360,
1600
+ "<|translate|>": 50359,
1601
+ "<|tr|>": 50268,
1602
+ "<|tt|>": 50351,
1603
+ "<|uk|>": 50280,
1604
+ "<|ur|>": 50290,
1605
+ "<|uz|>": 50337,
1606
+ "<|vi|>": 50278,
1607
+ "<|yi|>": 50335,
1608
+ "<|yo|>": 50325,
1609
+ "<|yue|>": 50358,
1610
+ "<|zh|>": 50260
1611
+ }
nb-distil-large-init/config.json ADDED
@@ -0,0 +1,285 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "NbAiLab/nb-whisper-large",
3
+ "activation_dropout": 0.1,
4
+ "activation_function": "gelu",
5
+ "alignment_heads": [
6
+ [
7
+ 7,
8
+ 0
9
+ ],
10
+ [
11
+ 10,
12
+ 17
13
+ ],
14
+ [
15
+ 12,
16
+ 18
17
+ ],
18
+ [
19
+ 13,
20
+ 12
21
+ ],
22
+ [
23
+ 16,
24
+ 1
25
+ ],
26
+ [
27
+ 17,
28
+ 14
29
+ ],
30
+ [
31
+ 19,
32
+ 11
33
+ ],
34
+ [
35
+ 21,
36
+ 4
37
+ ],
38
+ [
39
+ 24,
40
+ 1
41
+ ],
42
+ [
43
+ 25,
44
+ 6
45
+ ]
46
+ ],
47
+ "apply_spec_augment": false,
48
+ "architectures": [
49
+ "WhisperForConditionalGeneration"
50
+ ],
51
+ "attention_dropout": 0,
52
+ "begin_suppress_tokens": null,
53
+ "bos_token_id": 50257,
54
+ "classifier_proj_size": 256,
55
+ "d_model": 1280,
56
+ "decoder_attention_heads": 20,
57
+ "decoder_ffn_dim": 5120,
58
+ "decoder_layerdrop": 0,
59
+ "decoder_layers": 2,
60
+ "decoder_start_token_id": 50258,
61
+ "dropout": 0,
62
+ "encoder_attention_heads": 20,
63
+ "encoder_ffn_dim": 5120,
64
+ "encoder_layerdrop": 0,
65
+ "encoder_layers": 32,
66
+ "eos_token_id": 50257,
67
+ "init_std": 0.02,
68
+ "is_encoder_decoder": true,
69
+ "lang_ids": [
70
+ 50259,
71
+ 50260,
72
+ 50261,
73
+ 50262,
74
+ 50263,
75
+ 50264,
76
+ 50265,
77
+ 50266,
78
+ 50267,
79
+ 50268,
80
+ 50269,
81
+ 50270,
82
+ 50271,
83
+ 50272,
84
+ 50273,
85
+ 50274,
86
+ 50275,
87
+ 50276,
88
+ 50277,
89
+ 50278,
90
+ 50279,
91
+ 50280,
92
+ 50281,
93
+ 50282,
94
+ 50283,
95
+ 50284,
96
+ 50285,
97
+ 50286,
98
+ 50287,
99
+ 50288,
100
+ 50289,
101
+ 50290,
102
+ 50291,
103
+ 50292,
104
+ 50293,
105
+ 50294,
106
+ 50295,
107
+ 50296,
108
+ 50297,
109
+ 50298,
110
+ 50299,
111
+ 50300,
112
+ 50301,
113
+ 50302,
114
+ 50303,
115
+ 50304,
116
+ 50305,
117
+ 50306,
118
+ 50307,
119
+ 50308,
120
+ 50309,
121
+ 50310,
122
+ 50311,
123
+ 50312,
124
+ 50313,
125
+ 50314,
126
+ 50315,
127
+ 50316,
128
+ 50317,
129
+ 50318,
130
+ 50319,
131
+ 50320,
132
+ 50321,
133
+ 50322,
134
+ 50323,
135
+ 50324,
136
+ 50325,
137
+ 50326,
138
+ 50327,
139
+ 50328,
140
+ 50329,
141
+ 50330,
142
+ 50331,
143
+ 50332,
144
+ 50333,
145
+ 50334,
146
+ 50335,
147
+ 50336,
148
+ 50337,
149
+ 50338,
150
+ 50339,
151
+ 50340,
152
+ 50341,
153
+ 50342,
154
+ 50343,
155
+ 50344,
156
+ 50345,
157
+ 50346,
158
+ 50347,
159
+ 50348,
160
+ 50349,
161
+ 50350,
162
+ 50351,
163
+ 50352,
164
+ 50353,
165
+ 50354,
166
+ 50355,
167
+ 50356,
168
+ 50357,
169
+ 50358
170
+ ],
171
+ "mask_feature_length": 10,
172
+ "mask_feature_min_masks": 0,
173
+ "mask_feature_prob": 0,
174
+ "mask_time_length": 10,
175
+ "mask_time_min_masks": 2,
176
+ "mask_time_prob": 0.05,
177
+ "max_length": null,
178
+ "max_source_positions": 1500,
179
+ "max_target_positions": 448,
180
+ "median_filter_width": 7,
181
+ "model_type": "whisper",
182
+ "num_hidden_layers": 32,
183
+ "num_mel_bins": 128,
184
+ "pad_token_id": 50256,
185
+ "scale_embedding": false,
186
+ "suppress_ids": [
187
+ 1,
188
+ 2,
189
+ 7,
190
+ 8,
191
+ 9,
192
+ 10,
193
+ 14,
194
+ 25,
195
+ 26,
196
+ 27,
197
+ 28,
198
+ 29,
199
+ 31,
200
+ 58,
201
+ 59,
202
+ 60,
203
+ 61,
204
+ 62,
205
+ 63,
206
+ 90,
207
+ 91,
208
+ 92,
209
+ 93,
210
+ 359,
211
+ 503,
212
+ 522,
213
+ 542,
214
+ 873,
215
+ 893,
216
+ 902,
217
+ 918,
218
+ 922,
219
+ 931,
220
+ 1350,
221
+ 1853,
222
+ 1982,
223
+ 2460,
224
+ 2627,
225
+ 3246,
226
+ 3253,
227
+ 3268,
228
+ 3536,
229
+ 3846,
230
+ 3961,
231
+ 4183,
232
+ 4667,
233
+ 6585,
234
+ 6647,
235
+ 7273,
236
+ 9061,
237
+ 9383,
238
+ 10428,
239
+ 10929,
240
+ 11938,
241
+ 12033,
242
+ 12331,
243
+ 12562,
244
+ 13793,
245
+ 14157,
246
+ 14635,
247
+ 15265,
248
+ 15618,
249
+ 16553,
250
+ 16604,
251
+ 18362,
252
+ 18956,
253
+ 20075,
254
+ 21675,
255
+ 22520,
256
+ 26130,
257
+ 26161,
258
+ 26435,
259
+ 28279,
260
+ 29464,
261
+ 31650,
262
+ 32302,
263
+ 32470,
264
+ 36865,
265
+ 42863,
266
+ 47425,
267
+ 49870,
268
+ 50254,
269
+ 50258,
270
+ 50359,
271
+ 50360,
272
+ 50361,
273
+ 50362,
274
+ 50363
275
+ ],
276
+ "suppress_ids_begin": [
277
+ 220,
278
+ 50257
279
+ ],
280
+ "torch_dtype": "float32",
281
+ "transformers_version": "4.46.2",
282
+ "use_cache": true,
283
+ "use_weighted_layer_sum": false,
284
+ "vocab_size": 51866
285
+ }
nb-distil-large-init/generation_config.json ADDED
@@ -0,0 +1,256 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "alignment_heads": [
3
+ [
4
+ 7,
5
+ 0
6
+ ],
7
+ [
8
+ 10,
9
+ 17
10
+ ],
11
+ [
12
+ 12,
13
+ 18
14
+ ],
15
+ [
16
+ 13,
17
+ 12
18
+ ],
19
+ [
20
+ 16,
21
+ 1
22
+ ],
23
+ [
24
+ 17,
25
+ 14
26
+ ],
27
+ [
28
+ 19,
29
+ 11
30
+ ],
31
+ [
32
+ 21,
33
+ 4
34
+ ],
35
+ [
36
+ 24,
37
+ 1
38
+ ],
39
+ [
40
+ 25,
41
+ 6
42
+ ]
43
+ ],
44
+ "begin_suppress_tokens": [
45
+ 220,
46
+ 50257
47
+ ],
48
+ "bos_token_id": 50257,
49
+ "decoder_start_token_id": 50258,
50
+ "eos_token_id": 50257,
51
+ "is_multilingual": true,
52
+ "lang_to_id": {
53
+ "<|af|>": 50327,
54
+ "<|am|>": 50334,
55
+ "<|ar|>": 50272,
56
+ "<|as|>": 50350,
57
+ "<|az|>": 50304,
58
+ "<|ba|>": 50355,
59
+ "<|be|>": 50330,
60
+ "<|bg|>": 50292,
61
+ "<|bn|>": 50302,
62
+ "<|bo|>": 50347,
63
+ "<|br|>": 50309,
64
+ "<|bs|>": 50315,
65
+ "<|ca|>": 50270,
66
+ "<|cs|>": 50283,
67
+ "<|cy|>": 50297,
68
+ "<|da|>": 50285,
69
+ "<|de|>": 50261,
70
+ "<|el|>": 50281,
71
+ "<|en|>": 50259,
72
+ "<|es|>": 50262,
73
+ "<|et|>": 50307,
74
+ "<|eu|>": 50310,
75
+ "<|fa|>": 50300,
76
+ "<|fi|>": 50277,
77
+ "<|fo|>": 50338,
78
+ "<|fr|>": 50265,
79
+ "<|gl|>": 50319,
80
+ "<|gu|>": 50333,
81
+ "<|haw|>": 50352,
82
+ "<|ha|>": 50354,
83
+ "<|he|>": 50279,
84
+ "<|hi|>": 50276,
85
+ "<|hr|>": 50291,
86
+ "<|ht|>": 50339,
87
+ "<|hu|>": 50286,
88
+ "<|hy|>": 50312,
89
+ "<|id|>": 50275,
90
+ "<|is|>": 50311,
91
+ "<|it|>": 50274,
92
+ "<|ja|>": 50266,
93
+ "<|jw|>": 50356,
94
+ "<|ka|>": 50329,
95
+ "<|kk|>": 50316,
96
+ "<|km|>": 50323,
97
+ "<|kn|>": 50306,
98
+ "<|ko|>": 50264,
99
+ "<|la|>": 50294,
100
+ "<|lb|>": 50345,
101
+ "<|ln|>": 50353,
102
+ "<|lo|>": 50336,
103
+ "<|lt|>": 50293,
104
+ "<|lv|>": 50301,
105
+ "<|mg|>": 50349,
106
+ "<|mi|>": 50295,
107
+ "<|mk|>": 50308,
108
+ "<|ml|>": 50296,
109
+ "<|mn|>": 50314,
110
+ "<|mr|>": 50320,
111
+ "<|ms|>": 50282,
112
+ "<|mt|>": 50343,
113
+ "<|my|>": 50346,
114
+ "<|ne|>": 50313,
115
+ "<|nl|>": 50271,
116
+ "<|nn|>": 50342,
117
+ "<|no|>": 50288,
118
+ "<|oc|>": 50328,
119
+ "<|pa|>": 50321,
120
+ "<|pl|>": 50269,
121
+ "<|ps|>": 50340,
122
+ "<|pt|>": 50267,
123
+ "<|ro|>": 50284,
124
+ "<|ru|>": 50263,
125
+ "<|sa|>": 50344,
126
+ "<|sd|>": 50332,
127
+ "<|si|>": 50322,
128
+ "<|sk|>": 50298,
129
+ "<|sl|>": 50305,
130
+ "<|sn|>": 50324,
131
+ "<|so|>": 50326,
132
+ "<|sq|>": 50317,
133
+ "<|sr|>": 50303,
134
+ "<|su|>": 50357,
135
+ "<|sv|>": 50273,
136
+ "<|sw|>": 50318,
137
+ "<|ta|>": 50287,
138
+ "<|te|>": 50299,
139
+ "<|tg|>": 50331,
140
+ "<|th|>": 50289,
141
+ "<|tk|>": 50341,
142
+ "<|tl|>": 50348,
143
+ "<|tr|>": 50268,
144
+ "<|tt|>": 50351,
145
+ "<|uk|>": 50280,
146
+ "<|ur|>": 50290,
147
+ "<|uz|>": 50337,
148
+ "<|vi|>": 50278,
149
+ "<|yi|>": 50335,
150
+ "<|yo|>": 50325,
151
+ "<|yue|>": 50358,
152
+ "<|zh|>": 50260
153
+ },
154
+ "language": "<|no|>",
155
+ "max_initial_timestamp_index": 1,
156
+ "max_length": 448,
157
+ "no_timestamps_token_id": 50364,
158
+ "pad_token_id": 50257,
159
+ "return_timestamps": false,
160
+ "suppress_tokens": [
161
+ 1,
162
+ 2,
163
+ 7,
164
+ 8,
165
+ 9,
166
+ 10,
167
+ 14,
168
+ 25,
169
+ 26,
170
+ 27,
171
+ 28,
172
+ 29,
173
+ 31,
174
+ 58,
175
+ 59,
176
+ 60,
177
+ 61,
178
+ 62,
179
+ 63,
180
+ 90,
181
+ 91,
182
+ 92,
183
+ 93,
184
+ 359,
185
+ 503,
186
+ 522,
187
+ 542,
188
+ 873,
189
+ 893,
190
+ 902,
191
+ 918,
192
+ 922,
193
+ 931,
194
+ 1350,
195
+ 1853,
196
+ 1982,
197
+ 2460,
198
+ 2627,
199
+ 3246,
200
+ 3253,
201
+ 3268,
202
+ 3536,
203
+ 3846,
204
+ 3961,
205
+ 4183,
206
+ 4667,
207
+ 6585,
208
+ 6647,
209
+ 7273,
210
+ 9061,
211
+ 9383,
212
+ 10428,
213
+ 10929,
214
+ 11938,
215
+ 12033,
216
+ 12331,
217
+ 12562,
218
+ 13793,
219
+ 14157,
220
+ 14635,
221
+ 15265,
222
+ 15618,
223
+ 16553,
224
+ 16604,
225
+ 18362,
226
+ 18956,
227
+ 20075,
228
+ 21675,
229
+ 22520,
230
+ 26130,
231
+ 26161,
232
+ 26435,
233
+ 28279,
234
+ 29464,
235
+ 31650,
236
+ 32302,
237
+ 32470,
238
+ 36865,
239
+ 42863,
240
+ 47425,
241
+ 49870,
242
+ 50254,
243
+ 50258,
244
+ 50359,
245
+ 50360,
246
+ 50361,
247
+ 50362,
248
+ 50363
249
+ ],
250
+ "task": "transcribe",
251
+ "task_to_id": {
252
+ "transcribe": 50360,
253
+ "translate": 50359
254
+ },
255
+ "transformers_version": "4.46.2"
256
+ }
nb-distil-large-init/merges.txt ADDED
The diff for this file is too large to render. See raw diff
 
nb-distil-large-init/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b5c6cce6fdc832cc805fe6ea8af8db5527ba4b5d0c3381ba404c82e3e8161db6
3
+ size 3025686376
nb-distil-large-init/preprocessor_config.json ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "chunk_length": 30,
3
+ "feature_extractor_type": "WhisperFeatureExtractor",
4
+ "feature_size": 128,
5
+ "hop_length": 160,
6
+ "n_fft": 400,
7
+ "n_samples": 480000,
8
+ "nb_max_frames": 3000,
9
+ "padding_side": "right",
10
+ "padding_value": 0.0,
11
+ "processor_class": "WhisperProcessor",
12
+ "return_attention_mask": false,
13
+ "sampling_rate": 16000
14
+ }
nb-distil-large-init/special_tokens_map.json ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "additional_special_tokens": [
3
+ "<|startoftranscript|>",
4
+ "<|en|>",
5
+ "<|zh|>",
6
+ "<|de|>",
7
+ "<|es|>",
8
+ "<|ru|>",
9
+ "<|ko|>",
10
+ "<|fr|>",
11
+ "<|ja|>",
12
+ "<|pt|>",
13
+ "<|tr|>",
14
+ "<|pl|>",
15
+ "<|ca|>",
16
+ "<|nl|>",
17
+ "<|ar|>",
18
+ "<|sv|>",
19
+ "<|it|>",
20
+ "<|id|>",
21
+ "<|hi|>",
22
+ "<|fi|>",
23
+ "<|vi|>",
24
+ "<|he|>",
25
+ "<|uk|>",
26
+ "<|el|>",
27
+ "<|ms|>",
28
+ "<|cs|>",
29
+ "<|ro|>",
30
+ "<|da|>",
31
+ "<|hu|>",
32
+ "<|ta|>",
33
+ "<|no|>",
34
+ "<|th|>",
35
+ "<|ur|>",
36
+ "<|hr|>",
37
+ "<|bg|>",
38
+ "<|lt|>",
39
+ "<|la|>",
40
+ "<|mi|>",
41
+ "<|ml|>",
42
+ "<|cy|>",
43
+ "<|sk|>",
44
+ "<|te|>",
45
+ "<|fa|>",
46
+ "<|lv|>",
47
+ "<|bn|>",
48
+ "<|sr|>",
49
+ "<|az|>",
50
+ "<|sl|>",
51
+ "<|kn|>",
52
+ "<|et|>",
53
+ "<|mk|>",
54
+ "<|br|>",
55
+ "<|eu|>",
56
+ "<|is|>",
57
+ "<|hy|>",
58
+ "<|ne|>",
59
+ "<|mn|>",
60
+ "<|bs|>",
61
+ "<|kk|>",
62
+ "<|sq|>",
63
+ "<|sw|>",
64
+ "<|gl|>",
65
+ "<|mr|>",
66
+ "<|pa|>",
67
+ "<|si|>",
68
+ "<|km|>",
69
+ "<|sn|>",
70
+ "<|yo|>",
71
+ "<|so|>",
72
+ "<|af|>",
73
+ "<|oc|>",
74
+ "<|ka|>",
75
+ "<|be|>",
76
+ "<|tg|>",
77
+ "<|sd|>",
78
+ "<|gu|>",
79
+ "<|am|>",
80
+ "<|yi|>",
81
+ "<|lo|>",
82
+ "<|uz|>",
83
+ "<|fo|>",
84
+ "<|ht|>",
85
+ "<|ps|>",
86
+ "<|tk|>",
87
+ "<|nn|>",
88
+ "<|mt|>",
89
+ "<|sa|>",
90
+ "<|lb|>",
91
+ "<|my|>",
92
+ "<|bo|>",
93
+ "<|tl|>",
94
+ "<|mg|>",
95
+ "<|as|>",
96
+ "<|tt|>",
97
+ "<|haw|>",
98
+ "<|ln|>",
99
+ "<|ha|>",
100
+ "<|ba|>",
101
+ "<|jw|>",
102
+ "<|su|>",
103
+ "<|yue|>",
104
+ "<|translate|>",
105
+ "<|transcribe|>",
106
+ "<|startoflm|>",
107
+ "<|startofprev|>",
108
+ "<|nospeech|>",
109
+ "<|notimestamps|>"
110
+ ],
111
+ "bos_token": {
112
+ "content": "<|endoftext|>",
113
+ "lstrip": false,
114
+ "normalized": false,
115
+ "rstrip": false,
116
+ "single_word": false
117
+ },
118
+ "eos_token": {
119
+ "content": "<|endoftext|>",
120
+ "lstrip": false,
121
+ "normalized": false,
122
+ "rstrip": false,
123
+ "single_word": false
124
+ },
125
+ "pad_token": {
126
+ "content": "<|endoftext|>",
127
+ "lstrip": false,
128
+ "normalized": false,
129
+ "rstrip": false,
130
+ "single_word": false
131
+ },
132
+ "unk_token": {
133
+ "content": "<|endoftext|>",
134
+ "lstrip": false,
135
+ "normalized": false,
136
+ "rstrip": false,
137
+ "single_word": false
138
+ }
139
+ }
nb-distil-large-init/tokenizer_config.json ADDED
The diff for this file is too large to render. See raw diff
 
nb-distil-large-init/vocab.json ADDED
The diff for this file is too large to render. See raw diff
 
preprocessor_config.json ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "chunk_length": 30,
3
+ "feature_extractor_type": "WhisperFeatureExtractor",
4
+ "feature_size": 128,
5
+ "hop_length": 160,
6
+ "n_fft": 400,
7
+ "n_samples": 480000,
8
+ "nb_max_frames": 3000,
9
+ "padding_side": "right",
10
+ "padding_value": 0.0,
11
+ "processor_class": "WhisperProcessor",
12
+ "return_attention_mask": false,
13
+ "sampling_rate": 16000
14
+ }
run_distillation.py ADDED
@@ -0,0 +1,1827 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+ # Copyright 2023 The HuggingFace Inc. team. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """
17
+ Training the Whisper model for sequence to sequence speech recognition via teacher-student distillation.
18
+ """
19
+ # You can also adapt this script for your own distillation tasks. Pointers for this are left as comments.
20
+
21
+ import logging
22
+ import os
23
+ import re
24
+ import shutil
25
+ import sys
26
+ import time
27
+ from dataclasses import dataclass, field
28
+ from functools import partial
29
+ from pathlib import Path
30
+ from typing import Any, Dict, List, Optional, Union
31
+
32
+ import datasets
33
+ import evaluate
34
+ import numpy as np
35
+ import torch
36
+ import torch.nn as nn
37
+ import transformers
38
+ from accelerate import Accelerator
39
+ from accelerate.logging import get_logger
40
+ from accelerate.utils import set_seed
41
+ from datasets import (
42
+ DatasetDict,
43
+ IterableDataset,
44
+ IterableDatasetDict,
45
+ concatenate_datasets,
46
+ interleave_datasets,
47
+ load_dataset,
48
+ )
49
+ from huggingface_hub import create_repo, get_full_repo_name, upload_folder
50
+ from torch.utils.data import DataLoader
51
+ from tqdm import tqdm
52
+ from transformers import (
53
+ AddedToken,
54
+ HfArgumentParser,
55
+ Seq2SeqTrainingArguments,
56
+ WhisperConfig,
57
+ WhisperFeatureExtractor,
58
+ WhisperForConditionalGeneration,
59
+ WhisperProcessor,
60
+ WhisperTokenizerFast,
61
+ get_scheduler
62
+ )
63
+ from transformers.modeling_outputs import BaseModelOutput
64
+ from transformers.models.whisper.english_normalizer import BasicTextNormalizer, EnglishTextNormalizer
65
+ from transformers.utils import check_min_version
66
+ from transformers.utils.versions import require_version
67
+
68
+
69
+ # Will error if the minimal version of Transformers is not installed. Remove at your own risks.
70
+ check_min_version("4.34.0.dev0")
71
+
72
+ require_version("datasets>=2.14.6", "To fix: `pip install --upgrade datasets`")
73
+
74
+ logger = get_logger(__name__)
75
+
76
+
77
+ @dataclass
78
+ class ModelArguments:
79
+ """
80
+ Arguments pertaining to which model/config/tokenizer we are going to distill from.
81
+ """
82
+
83
+ model_name_or_path: str = field(
84
+ metadata={"help": "Path to pretrained Whisper model or model identifier from huggingface.co/models"}
85
+ )
86
+ teacher_model_name_or_path: str = field(
87
+ metadata={"help": "Path to pretrained teacher model or model identifier from huggingface.co/models"}
88
+ )
89
+ config_name: Optional[str] = field(
90
+ default=None,
91
+ metadata={"help": "Pretrained config name or path if not the same as model_name"},
92
+ )
93
+ tokenizer_name: Optional[str] = field(
94
+ default=None,
95
+ metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"},
96
+ )
97
+ feature_extractor_name: Optional[str] = field(
98
+ default=None,
99
+ metadata={"help": "feature extractor name or path if not the same as model_name"},
100
+ )
101
+ cache_dir: Optional[str] = field(
102
+ default=None,
103
+ metadata={"help": "Where to store the pretrained models downloaded from huggingface.co"},
104
+ )
105
+ use_fast_tokenizer: bool = field(
106
+ default=True,
107
+ metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
108
+ )
109
+ model_revision: str = field(
110
+ default="main",
111
+ metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
112
+ )
113
+ subfolder: str = field(
114
+ default="",
115
+ metadata={
116
+ "help": "In case the relevant files are located inside a subfolder of the model repo on huggingface.co, you can"
117
+ "specify the folder name here."
118
+ },
119
+ )
120
+ token: str = field(
121
+ default=None,
122
+ metadata={
123
+ "help": (
124
+ "The token to use as HTTP bearer authorization for remote files. If not specified, will use the token "
125
+ "generated when running `huggingface-cli login` (stored in `~/.huggingface`)."
126
+ )
127
+ },
128
+ )
129
+ attn_implementation: Optional[str] = field(
130
+ default=None,
131
+ metadata={
132
+ "help": (
133
+ "Which attention implementation to use in the encoder and decoder attention layers. Can be one of:\n"
134
+ "1. `eager` or `None`: default Transformers attention implementation.\n"
135
+ "2. `sdpa`: Flash Attention through PyTorch SDPA. Requires `torch>=2.1`. Recommended for hardware where Flash Attention 2 is not supported, e.g. Turing GPUs, (T4, RTX 2080).\n"
136
+ "3. `flash_attn_2`: Flash Attention 2 through the Flash Attention package https://github.com/Dao-AILab/flash-attention. **Always** recommended on supported hardware (Ampere, Ada, or Hopper GPUs, e.g., A100, RTX 3090, RTX 4090, H100)."
137
+ )
138
+ },
139
+ )
140
+
141
+ def __post_init__(self):
142
+ if self.attn_implementation not in [None, "eager", "sdpa", "flash_attention_2"]:
143
+ raise ValueError(
144
+ f"Got `--attn_implementation={self.attn_implementation}`, which is an invalid attention type. Should be one of:\n"
145
+ "1. `eager` or `None`: default Transformers attention implementation.\n"
146
+ "2. `sdpa`: Flash Attention through PyTorch SDPA. Requires `torch>=2.1`. Recommended for hardware where Flash Attention 2 is not supported, e.g. Turing GPUs, (T4, RTX 2080).\n"
147
+ "3. `flash_attn_2`: Flash Attention 2 through the Flash Attention package https://github.com/Dao-AILab/flash-attention. **Always** recommended on supported hardware (Ampere, Ada, or Hopper GPUs, e.g., A100, RTX 3090, RTX 4090, H100)."
148
+ )
149
+
150
+
151
+ @dataclass
152
+ class DataTrainingArguments:
153
+ """
154
+ Arguments pertaining to what data we are going to input our model for training and eval.
155
+ """
156
+
157
+ train_dataset_name: str = field(
158
+ default=None,
159
+ metadata={
160
+ "help": "The name of the training dataset to use (via the datasets library). Load and combine "
161
+ "multiple datasets by separating dataset ids by a '+' symbol. For example, to load LibriSpeech "
162
+ "and Common Voice, set `train_dataset_name='librispeech_asr+common_voice'`."
163
+ },
164
+ )
165
+ train_dataset_config_name: Optional[str] = field(
166
+ default=None,
167
+ metadata={
168
+ "help": "The configuration name of the training dataset to use (via the datasets library). Load and combine "
169
+ "multiple datasets by separating dataset configs by a '+' symbol. Note that the order of the configs should "
170
+ "match the order of the datasets."
171
+ },
172
+ )
173
+ train_dataset_samples: str = field(
174
+ default=None,
175
+ metadata={
176
+ "help": "Number of samples in each dataset when loading multiple datasets with streaming mode. "
177
+ "Not required when using one dataset or non-streaming mode. The sample values provide the sampling "
178
+ "probability for each dataset. Setting them equal to the number of sample values ensures that every "
179
+ "sample from every dataset is used once per epoch."
180
+ },
181
+ )
182
+ eval_dataset_name: str = field(
183
+ default=None,
184
+ metadata={
185
+ "help": "The name of the evaluation dataset to use (via the datasets library). Defaults to the training "
186
+ "dataset name if unspecified. Load multiple evaluation datasets by separating dataset "
187
+ "ids by a '+' symbol."
188
+ },
189
+ )
190
+ eval_dataset_config_name: Optional[str] = field(
191
+ default=None,
192
+ metadata={
193
+ "help": "The configuration name of the evaluation dataset to use (via the datasets library). Defaults to the "
194
+ "training dataset config name if unspecified."
195
+ },
196
+ )
197
+ dataset_cache_dir: Optional[str] = field(
198
+ default=None,
199
+ metadata={"help": "Path to cache directory for saving and loading datasets"},
200
+ )
201
+ overwrite_cache: bool = field(
202
+ default=False,
203
+ metadata={"help": "Overwrite the cached training and evaluation sets"},
204
+ )
205
+ preprocessing_num_workers: Optional[int] = field(
206
+ default=None,
207
+ metadata={"help": "The number of processes to use for the preprocessing if using non-streaming mode."},
208
+ )
209
+ preprocessing_batch_size: Optional[int] = field(
210
+ default=256,
211
+ metadata={"help": "Number of examples per batch provided to the `prepare_dataset` function."},
212
+ )
213
+ max_train_samples: Optional[int] = field(
214
+ default=None,
215
+ metadata={
216
+ "help": (
217
+ "For debugging purposes or quicker training, truncate the number of training examples to this value if set."
218
+ )
219
+ },
220
+ )
221
+ max_eval_samples: Optional[int] = field(
222
+ default=None,
223
+ metadata={
224
+ "help": (
225
+ "For debugging purposes or quicker training, truncate the number of evaluation examples to this value if set."
226
+ )
227
+ },
228
+ )
229
+ audio_column_name: str = field(
230
+ default="audio",
231
+ metadata={"help": "The name of the dataset column containing the audio data. Defaults to 'audio'"},
232
+ )
233
+ text_column_name: str = field(
234
+ default=None,
235
+ metadata={"help": "The name of the dataset column containing the text data in the training set."},
236
+ )
237
+ eval_text_column_name: str = field(
238
+ default="text",
239
+ metadata={"help": ("The name of the dataset column containing the text data in the evaluation set.")},
240
+ )
241
+ max_duration_in_seconds: float = field(
242
+ default=30.0,
243
+ metadata={"help": "Filter audio files that are longer than `max_duration_in_seconds` seconds"},
244
+ )
245
+ min_duration_in_seconds: float = field(
246
+ default=0.0,
247
+ metadata={"help": "Filter audio files that are shorter than `min_duration_in_seconds` seconds"},
248
+ )
249
+ max_label_length: int = field(
250
+ default=448,
251
+ metadata={"help": "Truncate transcriptions that are longer `max_label_length` tokens."},
252
+ )
253
+ pad_target_to_multiple_of: Optional[int] = field(
254
+ default=None,
255
+ metadata={
256
+ "help": (
257
+ "If set will pad the target sequence to a multiple of the provided"
258
+ " value. This is important to avoid triggering recompilations on TPU."
259
+ " If unspecified, will default to padding the targets to max length."
260
+ )
261
+ },
262
+ )
263
+ preprocessing_only: bool = field(
264
+ default=False,
265
+ metadata={
266
+ "help": (
267
+ "Whether to only do data preprocessing and skip training. This is"
268
+ " especially useful when data preprocessing errors out in distributed"
269
+ " training due to timeout. In this case, one should run the"
270
+ " preprocessing in a non-distributed setup with"
271
+ " `preprocessing_only=True` so that the cached datasets can"
272
+ " consequently be loaded in distributed training"
273
+ )
274
+ },
275
+ )
276
+ train_split_name: str = field(
277
+ default="train",
278
+ metadata={
279
+ "help": "The name of the training data set split to use (via the datasets library). Defaults to 'train'"
280
+ },
281
+ )
282
+ eval_split_name: str = field(
283
+ default="validation",
284
+ metadata={
285
+ "help": (
286
+ "The name of the evaluation data set split to use (via the datasets library). Defaults to 'validation'"
287
+ )
288
+ },
289
+ )
290
+ streaming: bool = field(
291
+ default=True,
292
+ metadata={"help": "Whether to use Datasets' streaming mode to load and pre-process the data."},
293
+ )
294
+ wer_threshold: float = field(
295
+ default=None,
296
+ metadata={
297
+ "help": "Filter training data with Whisper transcriptions that have greater than `wer_threshold` "
298
+ "WER with the normalised transcriptions. This only takes effect if training on pseudo-labels targets."
299
+ "If `--use_pseudo_labels=False`, then no WER filtering is performed, since we train directly on the text"
300
+ "transcriptions."
301
+ },
302
+ )
303
+ use_pseudo_labels: bool = field(
304
+ default=True,
305
+ metadata={
306
+ "help": "Whether or not to use pseudo-label transcriptions as the targets. If True, the pseudo-labels "
307
+ "must be in the dataset column `whisper_transcript` from the previous pseudo-labelling step. This is "
308
+ "not currently yet configurable."
309
+ },
310
+ )
311
+ timestamp_probability: float = field(
312
+ default=0.2, metadata={"help": "Probability for training on timestamped tokens if the data contains it."}
313
+ )
314
+ condition_on_prev_probability: float = field(
315
+ default=0.2, metadata={"help": "Probability for conditioning on the previous text example."}
316
+ )
317
+ return_timestamps: bool = field(
318
+ default=False, metadata={"help": "Whether or not to predict timestamps in the generation step."}
319
+ )
320
+ language: str = field(
321
+ default=None,
322
+ metadata={
323
+ "help": (
324
+ "Language for multilingual distillation. This argument should be set for multilingual distillation "
325
+ "only. For English speech recognition, it should be left as `None`."
326
+ )
327
+ },
328
+ )
329
+ task: str = field(
330
+ default="transcribe",
331
+ metadata={
332
+ "help": "Task, either `transcribe` for speech recognition or `translate` for speech translation."
333
+ "This argument should be set for multilingual distillation only. For English speech recognition, it should be left as `None`."
334
+ },
335
+ )
336
+ wandb_project: str = field(
337
+ default="distil-whisper",
338
+ metadata={"help": "The name of the wandb project."},
339
+ )
340
+ wandb_name: str = field(
341
+ default=None,
342
+ metadata={"help": "The name of the wandb run."},
343
+ )
344
+ wandb_dir: str = field(
345
+ default="./wandb",
346
+ metadata={"help": "The dir where wandb metadata will be stored."},
347
+ )
348
+
349
+
350
+ @dataclass
351
+ class DistillationTrainingArguments(Seq2SeqTrainingArguments):
352
+ freeze_encoder: Optional[bool] = field(
353
+ default=False,
354
+ metadata={
355
+ "help": (
356
+ "Whether to freeze the entire encoder model. Only recommended when the entire encoder has been "
357
+ "copied from the teacher model."
358
+ )
359
+ },
360
+ )
361
+ freeze_decoder: Optional[bool] = field(
362
+ default=False,
363
+ metadata={
364
+ "help": (
365
+ "Whether to freeze the entire decoder model. Note that the decoder input embeddings are **not** frozen, since they are tied to the LM head."
366
+ )
367
+ },
368
+ )
369
+ freeze_embed_positions: Optional[bool] = field(
370
+ default=False,
371
+ metadata={"help": "Whether to freeze the decoder embedding positions."},
372
+ )
373
+ temperature: Optional[float] = field(
374
+ default=2.0, metadata={"help": "Temperature to anneal the logits when computing the softmax."}
375
+ )
376
+ kl_weight: Optional[float] = field(
377
+ default=1.0,
378
+ metadata={
379
+ "help": (
380
+ "Weighting assigned to the MSE loss in the KD formulation. MSE loss is "
381
+ "computed between the teacher-student hidden states and attentions."
382
+ )
383
+ },
384
+ )
385
+ dtype: Optional[str] = field(
386
+ default="float32",
387
+ metadata={
388
+ "help": (
389
+ "The data type (dtype) in which to run training. One of `float32` (full-precision), "
390
+ "`float16` or `bfloat16` (both half-precision)."
391
+ )
392
+ },
393
+ )
394
+ save_best_total_limit: Optional[int] = field(
395
+ default=1,
396
+ metadata={
397
+ "help": (
398
+ "Number of best models to be saved."
399
+ )
400
+ }
401
+ )
402
+
403
+
404
+ @dataclass
405
+ class DataCollatorSpeechSeq2SeqWithPadding:
406
+ """
407
+ Data collator that will dynamically pad the inputs received.
408
+ Args:
409
+ processor ([`Wav2Vec2Processor`])
410
+ The processor used for proccessing the data.
411
+ decoder_start_token_id (:obj: `int`)
412
+ The start-of-sequence token id of the decoder.
413
+ decoder_prev_token_id (:obj: `int`)
414
+ The start-of-prompt token id of the decoder
415
+ input_padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`):
416
+ Select a strategy to pad the returned input sequences (according to the model's padding side and padding index)
417
+ among:
418
+ * :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
419
+ sequence if provided).
420
+ * :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the
421
+ maximum acceptable input length for the model if that argument is not provided.
422
+ * :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of
423
+ different lengths).
424
+ target_padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`):
425
+ Select a strategy to pad the returned target sequences (according to the model's padding side and padding index).
426
+ See above for details.
427
+ max_target_length (:obj:`int`, `optional`):
428
+ Maximum length of the ``labels`` of the returned list and optionally padding length (see above).
429
+ """
430
+
431
+ processor: Any
432
+ decoder_start_token_id: int
433
+ decoder_prev_token_id: int
434
+ input_padding: Union[bool, str] = "max_length"
435
+ target_padding: Union[bool, str] = "max_length"
436
+ max_target_length: Optional[int] = None
437
+
438
+ def __call__(self, features: List[Dict[str, Union[List[int], np.ndarray]]]) -> Dict[str, np.ndarray]:
439
+ # split inputs and labels since they have to be of different lengths and need
440
+ # different padding methods
441
+
442
+ # dataloader returns a list of features which we convert to a dict
443
+ input_features = {"input_features": [feature["input_features"] for feature in features]}
444
+ label_features = {"input_ids": [feature["labels"] for feature in features]}
445
+
446
+ # reformat list to dict and set to pytorch format
447
+ batch = self.processor.feature_extractor.pad(
448
+ input_features,
449
+ padding=self.input_padding,
450
+ return_tensors="pt",
451
+ )
452
+
453
+ labels_batch = self.processor.tokenizer.pad(
454
+ label_features,
455
+ max_length=self.max_target_length,
456
+ padding=self.target_padding,
457
+ return_tensors="pt",
458
+ )
459
+
460
+ # shift labels to the right to get decoder input ids
461
+ labels = labels_batch["input_ids"]
462
+ decoder_input_ids = labels[:, :-1]
463
+ labels = labels[:, 1:]
464
+ labels_mask = labels_batch.attention_mask[:, 1:]
465
+
466
+ # replace padding with -100 to ignore correctly when computing the loss
467
+ labels = labels.masked_fill(labels_mask.ne(1), -100)
468
+
469
+ # replace initial prompt tokens with -100 to ignore correctly when computing the loss
470
+ bos_index = torch.argmax((labels == self.decoder_start_token_id).long(), dim=1)
471
+ bos_index = torch.where(bos_index > 0, bos_index + 1, bos_index)
472
+ prompt_mask = torch.arange(labels.shape[1]) < bos_index[:, None]
473
+ labels = torch.where(prompt_mask, -100, labels)
474
+
475
+ batch["labels"] = labels
476
+ batch["decoder_input_ids"] = decoder_input_ids
477
+
478
+ return batch
479
+
480
+
481
+ def log_metric(
482
+ accelerator,
483
+ metrics: Dict,
484
+ train_time: float,
485
+ step: int,
486
+ epoch: int,
487
+ learning_rate: float = None,
488
+ prefix: str = "train",
489
+ ):
490
+ """Helper function to log all training/evaluation metrics with the correct prefixes and styling."""
491
+ log_metrics = {}
492
+ for k, v in metrics.items():
493
+ log_metrics[f"{prefix}/{k}"] = v
494
+ log_metrics[f"{prefix}/time"] = train_time
495
+ log_metrics[f"{prefix}/epoch"] = epoch
496
+ if learning_rate is not None:
497
+ log_metrics[f"{prefix}/learning_rate"] = learning_rate
498
+ accelerator.log(log_metrics, step=step)
499
+
500
+
501
+ def log_pred(
502
+ accelerator,
503
+ pred_str: List[str],
504
+ label_str: List[str],
505
+ norm_pred_str: List[str],
506
+ norm_label_str: List[str],
507
+ step: int,
508
+ prefix: str = "eval",
509
+ num_lines: int = 200000,
510
+ ):
511
+ """Helper function to log target/predicted transcriptions to weights and biases (wandb)."""
512
+ if accelerator.is_main_process:
513
+ wandb_tracker = accelerator.get_tracker("wandb")
514
+ # pretty name for current step: step 50000 -> step 50k
515
+ cur_step_pretty = f"{int(step // 1000)}k" if step > 1000 else step
516
+ prefix_pretty = prefix.replace("/", "-")
517
+
518
+ # convert str data to a wandb compatible format
519
+ str_data = [[label_str[i], pred_str[i], norm_label_str[i], norm_pred_str[i]] for i in range(len(pred_str))]
520
+ # log as a table with the appropriate headers
521
+ wandb_tracker.log_table(
522
+ table_name=f"predictions/{prefix_pretty}-step-{cur_step_pretty}",
523
+ columns=["Target", "Pred", "Norm Target", "Norm Pred"],
524
+ data=str_data[:num_lines],
525
+ step=step,
526
+ )
527
+
528
+ # log incorrect normalised predictions
529
+ str_data = np.asarray(str_data)
530
+ str_data_incorrect = str_data[str_data[:, -2] != str_data[:, -1]]
531
+ # log as a table with the appropriate headers
532
+ wandb_tracker.log_table(
533
+ table_name=f"incorrect_predictions/{prefix_pretty}-step-{cur_step_pretty}",
534
+ columns=["Target", "Pred", "Norm Target", "Norm Pred"],
535
+ data=str_data_incorrect[:num_lines],
536
+ step=step,
537
+ )
538
+
539
+
540
+ def convert_dataset_str_to_list(
541
+ dataset_names,
542
+ dataset_config_names,
543
+ splits=None,
544
+ text_column_names=None,
545
+ dataset_samples=None,
546
+ default_split="train",
547
+ ) -> List[Dict]:
548
+ """
549
+ Given three lists of dataset names, configs and splits, this function groups the corresponding
550
+ names/configs/splits. Each dataset is assigned a unique dictionary with these metadata values, and the
551
+ function returns a list of dictionaries, one for each dataset.
552
+ """
553
+ if isinstance(dataset_names, str):
554
+ dataset_names = dataset_names.split("+")
555
+ dataset_config_names = dataset_config_names.split("+") if dataset_config_names is not None else None
556
+ splits = splits.split("+") if splits is not None else None
557
+ text_column_names = text_column_names.split("+") if text_column_names is not None else None
558
+ dataset_samples = dataset_samples.split("+") if dataset_samples is not None else None
559
+
560
+ # basic checks to ensure we've got the right number of datasets/configs/splits/columns/probs
561
+ if dataset_config_names is not None and len(dataset_names) != len(dataset_config_names):
562
+ raise ValueError(
563
+ f"Ensure one config is passed for each dataset, got {len(dataset_names)} datasets and"
564
+ f" {len(dataset_config_names)} configs."
565
+ )
566
+
567
+ if splits is not None and len(splits) != len(dataset_names):
568
+ raise ValueError(
569
+ f"Ensure one split is passed for each dataset, got {len(dataset_names)} datasets and {len(splits)} splits."
570
+ )
571
+
572
+ if text_column_names is not None and len(text_column_names) != len(dataset_names):
573
+ raise ValueError(
574
+ f"Ensure one text column name is passed for each dataset, got {len(dataset_names)} datasets and"
575
+ f" {len(text_column_names)} text column names."
576
+ )
577
+
578
+ if dataset_samples is not None:
579
+ if len(dataset_samples) != len(dataset_names):
580
+ raise ValueError(
581
+ f"Ensure one sample is passed for each dataset, got {len(dataset_names)} datasets and "
582
+ f"{len(dataset_samples)} samples."
583
+ )
584
+ dataset_samples = [float(ds_sample) for ds_sample in dataset_samples]
585
+ else:
586
+ dataset_samples = [None] * len(dataset_names)
587
+
588
+ dataset_config_names = (
589
+ dataset_config_names if dataset_config_names is not None else ["default" for _ in range(len(dataset_names))]
590
+ )
591
+ text_column_names = (
592
+ text_column_names if text_column_names is not None else ["text" for _ in range(len(dataset_names))]
593
+ )
594
+ splits = splits if splits is not None else [default_split for _ in range(len(dataset_names))]
595
+
596
+ dataset_names_dict = []
597
+ for i, ds_name in enumerate(dataset_names):
598
+ dataset_names_dict.append(
599
+ {
600
+ "name": ds_name,
601
+ "config": dataset_config_names[i],
602
+ "split": splits[i],
603
+ "text_column_name": text_column_names[i],
604
+ "samples": dataset_samples[i],
605
+ }
606
+ )
607
+ return dataset_names_dict
608
+
609
+
610
+ def load_multiple_datasets(
611
+ dataset_names: Union[List, str],
612
+ dataset_config_names: Union[List, str],
613
+ splits: Optional[Union[List, str]] = None,
614
+ text_column_names: Optional[List] = None,
615
+ sampling_rate: Optional[int] = 16000,
616
+ stopping_strategy: Optional[str] = "first_exhausted",
617
+ dataset_samples: Optional[Union[List, np.array]] = None,
618
+ streaming: Optional[bool] = True,
619
+ seed: Optional[int] = None,
620
+ accelerator: Optional[Accelerator] = None,
621
+ use_pseudo_labels: float = None,
622
+ **kwargs,
623
+ ) -> IterableDataset:
624
+ dataset_names_dict = convert_dataset_str_to_list(
625
+ dataset_names, dataset_config_names, splits, text_column_names, dataset_samples
626
+ )
627
+
628
+ if dataset_samples is not None:
629
+ dataset_samples = [ds_dict["samples"] for ds_dict in dataset_names_dict]
630
+ probabilities = np.array(dataset_samples) / np.sum(dataset_samples)
631
+ else:
632
+ probabilities = None
633
+
634
+ all_datasets = []
635
+ # iterate over the datasets we want to interleave
636
+ for dataset_dict in tqdm(
637
+ dataset_names_dict,
638
+ desc="Combining datasets...",
639
+ disable=not accelerator.is_local_main_process if accelerator is not None else False,
640
+ ):
641
+ dataset = load_dataset(
642
+ dataset_dict["name"],
643
+ dataset_dict["config"],
644
+ split=dataset_dict["split"],
645
+ streaming=streaming,
646
+ **kwargs,
647
+ )
648
+ # resample to specified sampling rate
649
+ dataset = dataset.cast_column("audio", datasets.features.Audio(sampling_rate))
650
+ dataset_features = dataset.features.keys()
651
+ columns_to_keep = {"audio", "text"}
652
+
653
+ if dataset_dict["text_column_name"] not in dataset_features:
654
+ raise ValueError(
655
+ f"Text column name {dataset_dict['text_column_name']} not found in dataset"
656
+ f" '{dataset_dict['name']}'. Make sure to set `--text_column_name` to the"
657
+ f" correct text column - one of {', '.join(dataset_features)}."
658
+ )
659
+
660
+ # blanket renaming of all transcription columns to text
661
+ if dataset_dict["text_column_name"] != "text":
662
+ dataset = dataset.rename_column(dataset_dict["text_column_name"], "text")
663
+
664
+ if use_pseudo_labels:
665
+ if "whisper_transcript" not in dataset_features:
666
+ raise ValueError(
667
+ f"Pseudo-label column `whisper_transcript` not found in dataset {dataset_dict['name']}. Ensure"
668
+ "pseudo-labels are present in the dataset under this column name, or train directly on the text "
669
+ "labels by setting `--use_pseudo_labels=False` and defining the appropriate `--text_column_name`."
670
+ )
671
+ columns_to_keep.add("whisper_transcript")
672
+
673
+ if "condition_on_prev" in dataset_features:
674
+ columns_to_keep.add("condition_on_prev")
675
+
676
+ dataset_features = dataset.features.keys()
677
+ dataset = dataset.remove_columns(set(dataset_features - columns_to_keep))
678
+ all_datasets.append(dataset)
679
+
680
+ if len(all_datasets) == 1:
681
+ # we have a single dataset so just return it as is
682
+ return all_datasets[0]
683
+
684
+ if streaming:
685
+ interleaved_dataset = interleave_datasets(
686
+ all_datasets,
687
+ stopping_strategy=stopping_strategy,
688
+ probabilities=probabilities,
689
+ seed=seed,
690
+ )
691
+ else:
692
+ interleaved_dataset = concatenate_datasets(all_datasets)
693
+
694
+ return interleaved_dataset
695
+
696
+
697
+ def sorted_checkpoints(output_dir=None, checkpoint_prefix="checkpoint") -> List[str]:
698
+ """Helper function to sort saved checkpoints from oldest to newest."""
699
+ ordering_and_checkpoint_path = []
700
+
701
+ glob_checkpoints = [str(x) for x in Path(output_dir).glob(f"{checkpoint_prefix}-*") if os.path.isdir(x)]
702
+ glob_checkpoints = [path for path in glob_checkpoints if "val-wer" not in path] # filter out best model checkpoints
703
+
704
+ for path in glob_checkpoints:
705
+ regex_match = re.match(f".*{checkpoint_prefix}-([0-9]+)", path)
706
+ if regex_match is not None and regex_match.groups() is not None:
707
+ ordering_and_checkpoint_path.append((int(regex_match.groups()[0]), path))
708
+
709
+ checkpoints_sorted = sorted(ordering_and_checkpoint_path)
710
+ checkpoints_sorted = [checkpoint[1] for checkpoint in checkpoints_sorted]
711
+ return checkpoints_sorted
712
+
713
+
714
+ def sorted_best_checkpoints(output_dir=None, checkpoint_prefix="checkpoint"):
715
+ """Helper function to sort saved best checkpoints."""
716
+ ordering_and_checkpoint_path = []
717
+
718
+ glob_checkpoints = [str(x) for x in Path(output_dir).glob(f"{checkpoint_prefix}-*") if os.path.isdir(x)]
719
+ for path in glob_checkpoints:
720
+ regex_match = re.search(r"val-wer-([0-9]+\.[0-9]+)", path)
721
+ if regex_match is not None and regex_match.groups() is not None:
722
+ ordering_and_checkpoint_path.append((regex_match.groups(1), path))
723
+
724
+ checkpoints_sorted = sorted(ordering_and_checkpoint_path, reverse=True)
725
+ checkpoints_sorted = [checkpoint[1] for checkpoint in checkpoints_sorted]
726
+ return checkpoints_sorted
727
+
728
+
729
+ def rotate_checkpoints(save_total_limit=None, output_dir=None, checkpoint_prefix="checkpoint", sorting_fn=sorted_checkpoints) -> None:
730
+ """Helper function to delete old checkpoints."""
731
+ if save_total_limit is None or save_total_limit <= 0:
732
+ return
733
+ # Check if we should delete older checkpoint(s)
734
+ checkpoints_sorted = sorting_fn(output_dir=output_dir, checkpoint_prefix=checkpoint_prefix)
735
+ if len(checkpoints_sorted) <= save_total_limit:
736
+ return
737
+
738
+ number_of_checkpoints_to_delete = max(0, len(checkpoints_sorted) - save_total_limit)
739
+ checkpoints_to_be_deleted = checkpoints_sorted[:number_of_checkpoints_to_delete]
740
+ for checkpoint in checkpoints_to_be_deleted:
741
+ logger.info(f"Deleting older checkpoint [{checkpoint}].")
742
+ shutil.rmtree(checkpoint, ignore_errors=True)
743
+
744
+
745
+ _RE_CHECKPOINT = re.compile(r"^checkpoint-(\d+)-epoch-(\d+)$")
746
+
747
+
748
+ def get_last_checkpoint(folder):
749
+ content = os.listdir(folder)
750
+ checkpoints = [
751
+ path
752
+ for path in content
753
+ if _RE_CHECKPOINT.search(path) is not None and os.path.isdir(os.path.join(folder, path))
754
+ ]
755
+ if len(checkpoints) == 0:
756
+ return
757
+ return os.path.join(folder, max(checkpoints, key=lambda x: int(_RE_CHECKPOINT.search(x).groups()[0])))
758
+
759
+
760
+ def get_parameter_names(model, forbidden_layer_types, forbidden_module=None):
761
+ """
762
+ Returns the names of the model parameters that are not inside a forbidden layer or forbidden module.
763
+ Can be used to get a subset of parameter names for decay masks, or to exclude parameters from an optimiser
764
+ (e.g. if the module is frozen).
765
+ """
766
+ result = []
767
+ for name, child in model.named_children():
768
+ result += [
769
+ f"{name}.{n}"
770
+ for n in get_parameter_names(child, forbidden_layer_types, forbidden_module)
771
+ if not (
772
+ isinstance(child, tuple(forbidden_layer_types))
773
+ or (child in tuple(forbidden_module) if forbidden_module is not None else False)
774
+ )
775
+ ]
776
+ # Add model specific parameters (defined with nn.Parameter) since they are not in any child.
777
+ result += list(model._parameters.keys())
778
+ return result
779
+
780
+
781
+ def main():
782
+ # 1. Parse input arguments
783
+ # We keep distinct sets of args, for cleaner separation of model/data/training related args
784
+ parser = HfArgumentParser((ModelArguments, DataTrainingArguments, DistillationTrainingArguments))
785
+
786
+ if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
787
+ # If we pass only one argument to the script and it's the path to a json file,
788
+ # let's parse it to get our arguments.
789
+ model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
790
+ else:
791
+ model_args, data_args, training_args = parser.parse_args_into_dataclasses()
792
+
793
+ # 2. Initialize the accelerator
794
+ # We will let the accelerator handle device placement for us in this example
795
+ # We simply have to specify the training precision and any trackers being used
796
+ # We'll use the same dtype arguments as our JAX/Flax training script and convert
797
+ # it to accelerate format
798
+ if training_args.dtype == "float16":
799
+ mixed_precision = "fp16"
800
+ teacher_dtype = torch.float16
801
+ elif training_args.dtype == "bfloat16":
802
+ mixed_precision = "bf16"
803
+ teacher_dtype = torch.bfloat16
804
+ else:
805
+ mixed_precision = "no"
806
+ teacher_dtype = torch.float32
807
+
808
+ accelerator = Accelerator(
809
+ gradient_accumulation_steps=training_args.gradient_accumulation_steps,
810
+ mixed_precision=mixed_precision,
811
+ log_with=training_args.report_to,
812
+ project_dir=training_args.output_dir,
813
+ )
814
+
815
+ accelerator.init_trackers(
816
+ project_name=data_args.wandb_project,
817
+ init_kwargs={
818
+ "wandb": {"name": data_args.wandb_name,
819
+ "dir": data_args.wandb_dir}
820
+ }
821
+
822
+ )
823
+
824
+ # 3. Set-up basic logging
825
+ # Create one log on every process with the configuration for debugging
826
+ logging.basicConfig(
827
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
828
+ datefmt="%m/%d/%Y %H:%M:%S",
829
+ level=logging.INFO,
830
+ )
831
+ # Log a small summary on each proces
832
+ logger.warning(
833
+ f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}, "
834
+ f"distributed training: {training_args.parallel_mode.value == 'distributed'}, 16-bits training: {training_args.fp16}"
835
+ )
836
+
837
+ # Set the verbosity to info of the Transformers logger (on main process only)
838
+ if accelerator.is_local_main_process:
839
+ datasets.utils.logging.set_verbosity_warning()
840
+ transformers.utils.logging.set_verbosity_info()
841
+ else:
842
+ datasets.utils.logging.set_verbosity_error()
843
+ transformers.utils.logging.set_verbosity_error()
844
+ logger.info("Training/evaluation parameters %s", training_args)
845
+
846
+ # 4. Detecting last checkpoint and eventually continue from last checkpoint
847
+ last_checkpoint = None
848
+ if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir:
849
+ last_checkpoint = get_last_checkpoint(training_args.output_dir)
850
+ if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
851
+ raise ValueError(
852
+ f"Output directory ({training_args.output_dir}) already exists and is not empty. "
853
+ "Use --overwrite_output_dir to overcome."
854
+ )
855
+ elif last_checkpoint is not None and training_args.resume_from_checkpoint is None:
856
+ logger.info(
857
+ f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
858
+ "the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
859
+ )
860
+
861
+ # 5. Handle the repository creation
862
+ if accelerator.is_main_process:
863
+ if training_args.push_to_hub:
864
+ if training_args.hub_model_id is None:
865
+ repo_name = get_full_repo_name(
866
+ Path(training_args.output_dir).absolute().name,
867
+ token=training_args.hub_token,
868
+ )
869
+ else:
870
+ repo_name = training_args.hub_model_id
871
+ create_repo(repo_name, exist_ok=True, token=training_args.hub_token)
872
+
873
+ with open(os.path.join(training_args.output_dir, ".gitignore"), "w+") as gitignore:
874
+ if "wandb" not in gitignore:
875
+ gitignore.write("wandb\n")
876
+ elif training_args.output_dir is not None:
877
+ os.makedirs(training_args.output_dir, exist_ok=True)
878
+ accelerator.wait_for_everyone()
879
+
880
+ # 6. Load dataset - either streaming or non-streaming (offline)
881
+ raw_datasets = IterableDatasetDict() if data_args.streaming else DatasetDict()
882
+
883
+ # set seed for determinism
884
+ set_seed(training_args.seed)
885
+
886
+ if training_args.do_train:
887
+ raw_datasets["train"] = load_multiple_datasets(
888
+ data_args.train_dataset_name,
889
+ data_args.train_dataset_config_name,
890
+ splits=data_args.train_split_name,
891
+ text_column_names=data_args.text_column_name,
892
+ use_pseudo_labels=data_args.use_pseudo_labels,
893
+ streaming=data_args.streaming,
894
+ dataset_samples=data_args.train_dataset_samples,
895
+ seed=training_args.seed,
896
+ accelerator=accelerator,
897
+ cache_dir=data_args.dataset_cache_dir,
898
+ token=model_args.token,
899
+ )
900
+ raw_datasets_train_features = list(raw_datasets["train"].features.keys())
901
+
902
+ if training_args.do_eval:
903
+ dataset_names_dict = convert_dataset_str_to_list(
904
+ data_args.eval_dataset_name if data_args.eval_dataset_name else data_args.train_dataset_name,
905
+ (
906
+ data_args.eval_dataset_config_name
907
+ if data_args.eval_dataset_config_name
908
+ else data_args.train_dataset_config_name
909
+ ),
910
+ splits=data_args.eval_split_name,
911
+ text_column_names=data_args.eval_text_column_name,
912
+ )
913
+ all_eval_splits = []
914
+ if len(dataset_names_dict) == 1:
915
+ # load a single eval set
916
+ dataset_dict = dataset_names_dict[0]
917
+ all_eval_splits.append("eval")
918
+ raw_datasets["eval"] = load_dataset(
919
+ dataset_dict["name"],
920
+ dataset_dict["config"],
921
+ split=dataset_dict["split"],
922
+ cache_dir=data_args.dataset_cache_dir,
923
+ token=model_args.token,
924
+ streaming=data_args.streaming,
925
+ )
926
+ if data_args.eval_text_column_name != "text":
927
+ raw_datasets["eval"] = raw_datasets["eval"].rename_column(data_args.eval_text_column_name, "text")
928
+ else:
929
+ # load multiple eval sets
930
+ for dataset_dict in dataset_names_dict:
931
+ if dataset_dict["name"] == "esb/diagnostic-dataset":
932
+ # for the ESB diagnostic dataset, the dataset name is effectively the config
933
+ pretty_name = f"{dataset_dict['config']}-diagnostic/{dataset_dict['split']}"
934
+ else:
935
+ pretty_name = f"{dataset_dict['name'].split('/')[-1]}/{dataset_dict['split'].replace('.', '-')}"
936
+ all_eval_splits.append(pretty_name)
937
+ raw_datasets[pretty_name] = load_dataset(
938
+ dataset_dict["name"],
939
+ dataset_dict["config"],
940
+ split=dataset_dict["split"],
941
+ cache_dir=data_args.dataset_cache_dir,
942
+ token=model_args.token,
943
+ streaming=data_args.streaming,
944
+ )
945
+ # make column names consistent (text, audio)
946
+ if dataset_dict["text_column_name"] != "text":
947
+ raw_datasets[pretty_name] = raw_datasets[pretty_name].rename_column(
948
+ dataset_dict["text_column_name"], "text"
949
+ )
950
+ raw_datasets[pretty_name] = raw_datasets[pretty_name].remove_columns(
951
+ set(raw_datasets[pretty_name].features.keys()) - {"audio", "text"}
952
+ )
953
+
954
+ if not training_args.do_train and not training_args.do_eval:
955
+ raise ValueError(
956
+ "Cannot not train and not do evaluation. At least one of training or evaluation has to be performed."
957
+ )
958
+
959
+ # 7. Load pretrained model, tokenizer, and feature extractor
960
+ config = WhisperConfig.from_pretrained(
961
+ (model_args.config_name if model_args.config_name else model_args.model_name_or_path),
962
+ cache_dir=model_args.cache_dir,
963
+ revision=model_args.model_revision,
964
+ token=model_args.token,
965
+ )
966
+ feature_extractor = WhisperFeatureExtractor.from_pretrained(
967
+ (model_args.feature_extractor_name if model_args.feature_extractor_name else model_args.model_name_or_path),
968
+ cache_dir=model_args.cache_dir,
969
+ revision=model_args.model_revision,
970
+ token=model_args.token,
971
+ )
972
+ tokenizer = WhisperTokenizerFast.from_pretrained(
973
+ (model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path),
974
+ cache_dir=model_args.cache_dir,
975
+ use_fast=model_args.use_fast_tokenizer,
976
+ revision=model_args.model_revision,
977
+ token=model_args.token,
978
+ )
979
+
980
+ # override timestamp tokens until tokenizer issues are fixed in transformers
981
+ timestamps = [AddedToken("<|%.2f|>" % (i * 0.02), lstrip=False, rstrip=False) for i in range(1500 + 1)]
982
+ tokenizer.add_tokens(timestamps)
983
+
984
+ # The teacher model can safely be cast to the dtype of training since we don't
985
+ # update the params
986
+ teacher_model = WhisperForConditionalGeneration.from_pretrained(
987
+ model_args.teacher_model_name_or_path,
988
+ cache_dir=model_args.cache_dir,
989
+ token=model_args.token,
990
+ low_cpu_mem_usage=True,
991
+ torch_dtype=teacher_dtype,
992
+ attn_implementation=model_args.attn_implementation,
993
+ )
994
+
995
+ student_model = WhisperForConditionalGeneration.from_pretrained(
996
+ model_args.model_name_or_path,
997
+ config=config,
998
+ cache_dir=model_args.cache_dir,
999
+ revision=model_args.model_revision,
1000
+ subfolder=model_args.subfolder,
1001
+ token=model_args.token,
1002
+ low_cpu_mem_usage=True,
1003
+ attn_implementation=model_args.attn_implementation,
1004
+ )
1005
+
1006
+ if student_model.config.decoder_start_token_id is None or teacher_model.config.decoder_start_token_id is None:
1007
+ raise ValueError(
1008
+ f"Make sure that `config.decoder_start_token_id` is correctly defined for both the "
1009
+ f"student and teacher model. Got {student_model.config.decoder_start_token_id} for the "
1010
+ f"student and {teacher_model.config.decoder_start_token_id} for the teacher."
1011
+ )
1012
+
1013
+ # enable gradient checkpointing if necessary
1014
+ if training_args.gradient_checkpointing:
1015
+ student_model.gradient_checkpointing_enable()
1016
+
1017
+ def set_trainable_parameters(module, requires_grad=False):
1018
+ for param in module.parameters():
1019
+ param.requires_grad = requires_grad
1020
+ module._requires_grad = requires_grad
1021
+
1022
+ # freeze student encoder if necessary
1023
+ if training_args.freeze_encoder:
1024
+ set_trainable_parameters(student_model.model.encoder, requires_grad=False)
1025
+ student_model.model.encoder.gradient_checkpointing = False
1026
+
1027
+ if training_args.freeze_decoder:
1028
+ set_trainable_parameters(student_model.model.decoder, requires_grad=False)
1029
+ student_model.model.decoder.gradient_checkpointing = False
1030
+ # un-freeze LM head parameters (and consequently word embeddings), frozen when frozing decoder since tied word embedding and LM head
1031
+ set_trainable_parameters(student_model.proj_out, requires_grad=True)
1032
+
1033
+
1034
+ if training_args.freeze_embed_positions:
1035
+ # set_trainable_parameters(student_model.model.decoder.embed_tokens, requires_grad=False)
1036
+ set_trainable_parameters(student_model.model.decoder.embed_positions, requires_grad=False)
1037
+ if student_model.model.decoder.gradient_checkpointing:
1038
+ logger.info(
1039
+ "Disabling gradient checkpointing in the decoder since it's incompatible with `freeze_embed_positions`."
1040
+ )
1041
+
1042
+ logger.info(
1043
+ f"Number of trainable parameters: {sum(p.numel() for p in student_model.parameters() if p.requires_grad):.3e}"
1044
+ )
1045
+
1046
+ share_hidden_states = training_args.freeze_encoder and student_model.config.d_model == teacher_model.config.d_model
1047
+ if share_hidden_states:
1048
+ # tie the weights for the teacher encoder if we're freezing the student and it's the same as the teacher
1049
+ teacher_model.model.encoder = student_model.model.encoder
1050
+
1051
+ if hasattr(teacher_model.generation_config, "is_multilingual") and teacher_model.generation_config.is_multilingual:
1052
+ # We need to set the language and task ids for previously multilingual checkpoints
1053
+ is_multilingual = True
1054
+ tokenizer.set_prefix_tokens(language=data_args.language, task=data_args.task, predict_timestamps=False)
1055
+ student_model.generation_config.update(
1056
+ **{
1057
+ "language": data_args.language,
1058
+ "task": data_args.task,
1059
+ }
1060
+ )
1061
+ elif data_args.language is not None:
1062
+ raise ValueError(
1063
+ "Setting language token for an English-only checkpoint is not permitted. The language argument should "
1064
+ "only be set for multilingual checkpoints."
1065
+ )
1066
+ else:
1067
+ is_multilingual = False
1068
+
1069
+ # 8. Create a single speech processor - make sure all processes wait until data is saved
1070
+ if accelerator.is_main_process:
1071
+ feature_extractor.save_pretrained(training_args.output_dir)
1072
+ tokenizer.save_pretrained(training_args.output_dir)
1073
+ # save the config and generation config as well
1074
+ config.save_pretrained(training_args.output_dir)
1075
+ student_model.generation_config.save_pretrained(training_args.output_dir)
1076
+
1077
+ accelerator.wait_for_everyone()
1078
+ processor = WhisperProcessor.from_pretrained(training_args.output_dir)
1079
+
1080
+ # 9. Resample speech dataset: `datasets` takes care of automatically loading and resampling the audio,
1081
+ # so we just need to set the correct target sampling rate.
1082
+ sampling_rate = feature_extractor.sampling_rate
1083
+ raw_datasets = raw_datasets.cast_column(
1084
+ data_args.audio_column_name,
1085
+ datasets.features.Audio(sampling_rate=sampling_rate),
1086
+ )
1087
+
1088
+ # 10. Preprocessing the datasets: we need to read the audio files as arrays and tokenize the targets.
1089
+ # 10.1: Define the pre-processing constants
1090
+ max_input_length = int(data_args.max_duration_in_seconds * sampling_rate)
1091
+ min_input_length = int(data_args.min_duration_in_seconds * sampling_rate)
1092
+ max_label_length = (
1093
+ data_args.max_label_length if data_args.max_label_length is not None else student_model.config.max_length
1094
+ )
1095
+
1096
+ timestamp_probability = data_args.timestamp_probability
1097
+ condition_on_prev_probability = data_args.condition_on_prev_probability
1098
+ return_timestamps = data_args.return_timestamps if timestamp_probability > 0 else False
1099
+
1100
+ timestamp_ids = tokenizer.timestamp_ids()
1101
+ timestamp_begin = tokenizer.all_special_ids[-1]
1102
+ timestamp_position = 3 if is_multilingual else 1
1103
+
1104
+ decoder_start_token_id = student_model.config.decoder_start_token_id # <|startoftranscript|>
1105
+ decoder_prev_token_id = tokenizer.all_special_ids[-3] # <|startofprev|>
1106
+ prompt_cutoff_length = max_label_length // 2
1107
+
1108
+ num_workers = data_args.preprocessing_num_workers
1109
+ dataloader_num_workers = training_args.dataloader_num_workers
1110
+ prefetch_factor = training_args.dataloader_prefetch_factor
1111
+
1112
+ metric = evaluate.load("wer")
1113
+ normalizer = (
1114
+ BasicTextNormalizer()
1115
+ if data_args.language is not None
1116
+ else EnglishTextNormalizer(tokenizer.english_spelling_normalizer)
1117
+ )
1118
+ wer_threshold = data_args.wer_threshold
1119
+ use_pseudo_labels = data_args.use_pseudo_labels
1120
+ train_text_column_name = "whisper_transcript" if use_pseudo_labels else "text"
1121
+
1122
+ # 10.2: filter based on maximum number of training/evaluation samples
1123
+ if training_args.do_train and data_args.max_train_samples is not None:
1124
+ raw_datasets["train"] = (
1125
+ raw_datasets["train"].take(data_args.max_train_samples)
1126
+ if data_args.streaming
1127
+ else raw_datasets["train"].select(range(data_args.max_train_samples))
1128
+ )
1129
+
1130
+ if training_args.do_eval and data_args.max_eval_samples is not None:
1131
+ for eval_split in all_eval_splits:
1132
+ raw_datasets[eval_split] = (
1133
+ raw_datasets[eval_split].take(data_args.max_eval_samples)
1134
+ if data_args.streaming
1135
+ else raw_datasets[eval_split].select(range(data_args.max_eval_samples))
1136
+ )
1137
+
1138
+ # 10.3: filter training data based on WER threshold -> this is KEY to good distillation performance
1139
+ def is_wer_in_range(ground_truth, whisper_transcript):
1140
+ norm_ground_truth = normalizer(ground_truth)
1141
+ if whisper_transcript is not None and whisper_transcript.upper() == whisper_transcript:
1142
+ # filter entirely upper-case transcriptions: these are erroneous generations from large-v3
1143
+ return False
1144
+ elif len(norm_ground_truth) == 0 and len(normalizer(whisper_transcript)) == 0:
1145
+ return True
1146
+ elif len(norm_ground_truth.strip()) > 0 and whisper_transcript is not None and len(normalizer(whisper_transcript).strip()) > 0:
1147
+ norm_whisper_transcript = normalizer(whisper_transcript)
1148
+ wer = 100 * metric.compute(predictions=[norm_whisper_transcript], references=[norm_ground_truth])
1149
+ return wer < wer_threshold
1150
+ else:
1151
+ # filter automatically since weR
1152
+ return False
1153
+
1154
+ filter_by_wer_threshold = partial(
1155
+ raw_datasets["train"].filter,
1156
+ function=is_wer_in_range,
1157
+ input_columns=["text", "whisper_transcript"],
1158
+ )
1159
+
1160
+ if wer_threshold is not None and use_pseudo_labels:
1161
+ with accelerator.main_process_first():
1162
+ raw_datasets["train"] = (
1163
+ filter_by_wer_threshold(num_proc=num_workers, desc="filtering train dataset by wer")
1164
+ if not data_args.streaming
1165
+ else filter_by_wer_threshold()
1166
+ )
1167
+
1168
+ # 10.4: pre-process training/evaluation datasets
1169
+ def prepare_train_dataset(batch):
1170
+ """
1171
+ Pre-process the raw dataset in a three stage process:
1172
+ 1. Convert the audio arrays to log-mel spectrogram inputs
1173
+ 2. Possibly filter the timestamp tokens from the token ids (depending on the timestamp probability)
1174
+ 3. Possibly add prompt tokens if conditioning on previous text (depending on the conditioning probability)
1175
+ """
1176
+ # process audio input
1177
+ audio = [sample["array"] for sample in batch["audio"]]
1178
+ inputs = feature_extractor(audio, sampling_rate=sampling_rate)
1179
+ batch["input_features"] = inputs.input_features
1180
+ batch["input_length"] = [len(sample) for sample in audio]
1181
+
1182
+ # process text targets - for training these are the Whisper-generated pseudo-labels
1183
+ input_str_batched = batch[train_text_column_name]
1184
+ condition_on_prev_batched = batch.get("condition_on_prev", len(input_str_batched) * [None])
1185
+
1186
+ all_token_ids = []
1187
+ all_token_ids_unprompted = []
1188
+ for prev_ids, input_str in zip(condition_on_prev_batched, input_str_batched):
1189
+ token_ids = tokenizer(input_str, add_special_tokens=not use_pseudo_labels).input_ids
1190
+
1191
+ # check whether we have timestamps in the PLs and filter if required
1192
+ has_timestamps = len(set(token_ids) & set(timestamp_ids)) > 0
1193
+ if has_timestamps:
1194
+ # sample from binomial distribution to get probability of training on timestamps
1195
+ predict_timestamps = bool(np.random.binomial(1, timestamp_probability))
1196
+ if not predict_timestamps:
1197
+ # filter timestamps and insert the <|notimestamps|> task token
1198
+ token_ids = [token for token in token_ids if token < timestamp_begin]
1199
+ token_ids.insert(timestamp_position, timestamp_begin)
1200
+
1201
+ all_token_ids_unprompted.append(token_ids)
1202
+ # check whether to condition on previous text - we do this with probability condition_on_prev_probability
1203
+ condition_on_prev = bool(np.random.binomial(1, condition_on_prev_probability))
1204
+ if not condition_on_prev:
1205
+ prev_ids = None
1206
+ elif "condition_on_prev" not in batch and len(all_token_ids_unprompted) > 1:
1207
+ # prompt ids are the penultimate token ids in the batch
1208
+ prev_ids = all_token_ids_unprompted[-2]
1209
+
1210
+ if prev_ids is not None:
1211
+ if has_timestamps and not predict_timestamps:
1212
+ # filter timestamp ids from prompt when not predicting timestamps
1213
+ prev_ids = [token for token in prev_ids if token < timestamp_begin]
1214
+
1215
+ # check that the length of the prompt does not exceed more than half the max label length (224)
1216
+ if len(prev_ids) > prompt_cutoff_length:
1217
+ prev_ids = prev_ids[-prompt_cutoff_length + 1 :]
1218
+ prev_ids = [decoder_prev_token_id] + prev_ids
1219
+
1220
+ # and that the total length of the labels does not exceed the max label length (448)
1221
+ if len(prev_ids + token_ids) > max_label_length:
1222
+ trim_length = len(prev_ids + token_ids) - max_label_length + 1
1223
+ prev_ids = prev_ids[trim_length:]
1224
+ prev_ids = [decoder_prev_token_id] + prev_ids
1225
+
1226
+ token_ids = prev_ids + token_ids
1227
+
1228
+ all_token_ids.append(token_ids)
1229
+
1230
+ batch["labels"] = all_token_ids
1231
+ return batch
1232
+
1233
+ def prepare_eval_dataset(batch):
1234
+ # process audio input
1235
+ sample = batch["audio"]
1236
+ inputs = feature_extractor(sample["array"], sampling_rate=sample["sampling_rate"])
1237
+ batch["input_features"] = inputs.input_features[0]
1238
+ batch["input_length"] = len(sample["array"])
1239
+
1240
+ # process targets - for evaluation these are the ground-truth transcriptions
1241
+ input_str = batch["text"]
1242
+ batch["labels"] = tokenizer(input_str).input_ids
1243
+ return batch
1244
+
1245
+ vectorized_datasets = IterableDatasetDict() if data_args.streaming else DatasetDict()
1246
+ if training_args.do_train:
1247
+ # with streaming mode we can only have 1 worker, whereas with non-streaming
1248
+ # we can use `num_workers` (which is much faster)
1249
+ # We gate the pre-processing function accordingly
1250
+ map_fn_train = partial(
1251
+ raw_datasets["train"].map,
1252
+ function=prepare_train_dataset,
1253
+ remove_columns=raw_datasets_train_features,
1254
+ batched=True,
1255
+ batch_size=data_args.preprocessing_batch_size,
1256
+ )
1257
+ with accelerator.main_process_first():
1258
+ vectorized_datasets["train"] = (
1259
+ map_fn_train(num_proc=num_workers, desc="preprocess train dataset")
1260
+ if not data_args.streaming
1261
+ else map_fn_train()
1262
+ )
1263
+ if training_args.do_eval:
1264
+ for eval_split in all_eval_splits:
1265
+ raw_datasets_eval_features = list(raw_datasets[eval_split].features.keys())
1266
+ map_fn_eval = partial(
1267
+ raw_datasets[eval_split].map, function=prepare_eval_dataset, remove_columns=raw_datasets_eval_features
1268
+ )
1269
+ with accelerator.main_process_first():
1270
+ vectorized_datasets[eval_split] = (
1271
+ map_fn_eval(num_proc=num_workers, desc="preprocess eval dataset")
1272
+ if not data_args.streaming
1273
+ else map_fn_eval()
1274
+ )
1275
+
1276
+ # 10.5: Filter training data with inputs longer than `max_input_length`
1277
+ def is_audio_in_length_range(length):
1278
+ return min_input_length < length < max_input_length
1279
+
1280
+ filter_by_audio_fn = partial(
1281
+ vectorized_datasets.filter, function=is_audio_in_length_range, input_columns=["input_length"]
1282
+ )
1283
+ with accelerator.main_process_first():
1284
+ vectorized_datasets = (
1285
+ filter_by_audio_fn(num_proc=num_workers, desc="filtering train dataset by audio length")
1286
+ if not data_args.streaming
1287
+ else filter_by_audio_fn()
1288
+ )
1289
+
1290
+ # 10.6: Filter training data with labels longer than `max_label_length`
1291
+ def is_labels_in_length_range(labels):
1292
+ return 0 < len(labels) <= max_label_length
1293
+
1294
+ filter_by_labels_fn = partial(
1295
+ vectorized_datasets.filter, function=is_labels_in_length_range, input_columns=["labels"]
1296
+ )
1297
+ with accelerator.main_process_first():
1298
+ vectorized_datasets = (
1299
+ filter_by_labels_fn(num_proc=num_workers, desc="filtering train dataset")
1300
+ if not data_args.streaming
1301
+ else filter_by_labels_fn()
1302
+ )
1303
+
1304
+ # Pre-processing complete!
1305
+ # For large datasets it is advised to run the preprocessing on a
1306
+ # single machine first with `--preprocessing_only` since there will mostly likely
1307
+ # be a timeout when running the script in distributed mode.
1308
+ # In a second step, `--preprocessing_only` can then be set to `False` to load the
1309
+ # cached dataset
1310
+ if data_args.preprocessing_only:
1311
+ if data_args.streaming:
1312
+ raise ValueError(
1313
+ "When using streaming mode, dataset pre-processing is performed on the fly, hence there is no notion"
1314
+ "of a cached pre-processed dataset. Remove the argument `--preprocessing_only` to run pre-processing "
1315
+ "on the fly with streaming mode."
1316
+ )
1317
+ cache = {k: v.cache_files for k, v in vectorized_datasets.items()}
1318
+ logger.info(f"Data preprocessing finished. Files cached at {cache}.")
1319
+ return
1320
+
1321
+ # 11. Define Evaluation Metrics
1322
+ def compute_metrics(preds, labels):
1323
+ # replace padded labels by the padding token
1324
+ for idx in range(len(labels)):
1325
+ labels[idx][labels[idx] == -100] = tokenizer.pad_token_id
1326
+
1327
+ pred_str = tokenizer.batch_decode(preds, skip_special_tokens=True, decode_with_timestamps=return_timestamps)
1328
+ # we do not want to group tokens when computing the metrics
1329
+ label_str = tokenizer.batch_decode(labels, skip_special_tokens=True)
1330
+ wer_ortho = 100 * metric.compute(predictions=pred_str, references=label_str)
1331
+
1332
+ # Normalize everything
1333
+ norm_pred_str = []
1334
+ norm_label_str = []
1335
+
1336
+ # Iterate through all predictions and labels
1337
+ for pred, label in zip(pred_str, label_str):
1338
+ # Normalize the prediction and label
1339
+ normalized_pred = normalizer(pred)
1340
+ normalized_label = normalizer(label)
1341
+
1342
+ # If either normalized string is empty after normalization, replace with "<|nocaptions|>"
1343
+ if not normalized_pred.strip():
1344
+ normalized_pred = "<|nocaptions|>"
1345
+ if not normalized_label.strip():
1346
+ normalized_label = "<|nocaptions|>"
1347
+
1348
+ norm_pred_str.append(normalized_pred)
1349
+ norm_label_str.append(normalized_label)
1350
+
1351
+ # Replace original strings with "<|nocaptions|>" where necessary for consistency
1352
+ pred_str = [pred if len(pred.strip()) > 0 else "<|nocaptions|>" for pred in pred_str]
1353
+ label_str = [label if len(label.strip()) > 0 else "<|nocaptions|>" for label in label_str]
1354
+
1355
+ # Compute WER using all entries, including those with "<|nocaptions|>"
1356
+ wer = 100 * metric.compute(predictions=norm_pred_str, references=norm_label_str)
1357
+ return {"wer": wer, "wer_ortho": wer_ortho}, pred_str, label_str, norm_pred_str, norm_label_str
1358
+
1359
+ # 12. Define Training Schedule
1360
+ # Store some constants
1361
+ per_device_train_batch_size = int(training_args.per_device_train_batch_size)
1362
+ train_batch_size = per_device_train_batch_size * accelerator.num_processes
1363
+ gradient_accumulation_steps = int(training_args.gradient_accumulation_steps)
1364
+ per_device_eval_batch_size = int(training_args.per_device_eval_batch_size)
1365
+
1366
+ if not data_args.streaming and training_args.max_steps < 0:
1367
+ num_epochs = int(training_args.num_train_epochs)
1368
+ steps_per_epoch = len(vectorized_datasets["train"]) // (train_batch_size * gradient_accumulation_steps)
1369
+ total_train_steps = steps_per_epoch * num_epochs
1370
+ elif training_args.max_steps > 0:
1371
+ logger.info("max_steps is given, it will override any value given in num_train_epochs")
1372
+ total_train_steps = int(training_args.max_steps)
1373
+ if not data_args.streaming:
1374
+ steps_per_epoch = len(vectorized_datasets["train"]) // (train_batch_size * gradient_accumulation_steps)
1375
+ num_epochs = int(np.ceil(total_train_steps / steps_per_epoch))
1376
+ else:
1377
+ # Setting a very large number of epochs so we go as many times as necessary over the iterator.
1378
+ num_epochs = sys.maxsize
1379
+ steps_per_epoch = total_train_steps
1380
+ else:
1381
+ raise ValueError("max_steps must be specified when training with a streaming (iterable) dataset")
1382
+
1383
+ if training_args.eval_steps is None:
1384
+ logger.info(
1385
+ f"eval_steps is not set, evaluating at the end of {'each epoch' if not data_args.streaming else 'training'}"
1386
+ )
1387
+ eval_steps = steps_per_epoch
1388
+ else:
1389
+ eval_steps = training_args.eval_steps
1390
+
1391
+ # 13. Define optimizer, LR scheduler, collator
1392
+
1393
+ forbidden_module = [
1394
+ module
1395
+ for module, flag in [
1396
+ (student_model.model.encoder, training_args.freeze_encoder),
1397
+ (student_model.model.decoder, training_args.freeze_decoder)
1398
+ ]
1399
+ if flag
1400
+ ] or None
1401
+
1402
+ decay_parameters = get_parameter_names(
1403
+ student_model,
1404
+ [nn.LayerNorm],
1405
+ forbidden_module=forbidden_module,
1406
+ )
1407
+ decay_parameters = [name for name in decay_parameters if "bias" not in name]
1408
+ optimizer_grouped_parameters = [
1409
+ {
1410
+ "params": [param for name, param in student_model.named_parameters() if name in decay_parameters],
1411
+ "weight_decay": training_args.weight_decay,
1412
+ },
1413
+ {
1414
+ "params": [param for name, param in student_model.named_parameters() if name not in decay_parameters],
1415
+ "weight_decay": 0.0,
1416
+ },
1417
+ ]
1418
+ optimizer = torch.optim.AdamW(
1419
+ params=optimizer_grouped_parameters,
1420
+ lr=training_args.learning_rate,
1421
+ betas=(training_args.adam_beta1, training_args.adam_beta2),
1422
+ eps=training_args.adam_epsilon,
1423
+ )
1424
+
1425
+ # LR scheduler gets stepped by `num_processes` each time -> account for this in warmup / total steps
1426
+ lr_scheduler = get_scheduler(
1427
+ name=training_args.lr_scheduler_type,
1428
+ optimizer=optimizer,
1429
+ num_warmup_steps=training_args.warmup_steps * accelerator.num_processes,
1430
+ num_training_steps=total_train_steps * accelerator.num_processes,
1431
+ )
1432
+
1433
+ data_collator = DataCollatorSpeechSeq2SeqWithPadding(
1434
+ processor=processor,
1435
+ decoder_start_token_id=decoder_start_token_id,
1436
+ decoder_prev_token_id=decoder_prev_token_id,
1437
+ input_padding="longest",
1438
+ target_padding="max_length",
1439
+ max_target_length=max_label_length,
1440
+ )
1441
+
1442
+ # 14. Define generation arguments - we need to do this before we wrap the models in DDP
1443
+ # so that we can still access the configs
1444
+ num_beams = (
1445
+ training_args.generation_num_beams
1446
+ if training_args.generation_num_beams is not None
1447
+ else getattr(student_model.generation_config, "num_beams", 1)
1448
+ )
1449
+
1450
+ gen_kwargs = {
1451
+ "max_length": max_label_length,
1452
+ "num_beams": num_beams,
1453
+ "return_timestamps": return_timestamps,
1454
+ }
1455
+ if is_multilingual:
1456
+ # forcing the language and task tokens helps multilingual models in their generations
1457
+ gen_kwargs.update(
1458
+ {
1459
+ "language": data_args.language,
1460
+ "task": data_args.task,
1461
+ }
1462
+ )
1463
+
1464
+ # 15. Prepare everything with accelerate
1465
+ student_model, teacher_model, optimizer, lr_scheduler = accelerator.prepare(
1466
+ student_model, teacher_model, optimizer, lr_scheduler
1467
+ )
1468
+
1469
+ def kl_divergence(target_distribution, log_predicted_distribution, labels):
1470
+ kl_loss = nn.KLDivLoss(reduction="none")
1471
+ divergence = kl_loss(log_predicted_distribution, target_distribution)
1472
+ # ignore padded tokens from divergence, i.e. where labels are not set to -100
1473
+ padding_mask = labels >= 0
1474
+ padding_mask = padding_mask.unsqueeze(-1)
1475
+ divergence = divergence * padding_mask
1476
+ # take the average over the mini-batch
1477
+ divergence = divergence.sum() / padding_mask.sum()
1478
+ return divergence
1479
+
1480
+ # Define gradient update step fn
1481
+ def train_step(
1482
+ batch,
1483
+ temperature=2.0,
1484
+ ):
1485
+ student_model.train()
1486
+ teacher_model.eval()
1487
+
1488
+ student_outputs = student_model(**batch)
1489
+ with torch.no_grad():
1490
+ if share_hidden_states:
1491
+ # if the student and teacher share the same frozen encoder then we don't have to recompute the
1492
+ # encoder hidden-states for the teacher model, we can just re-use from the student
1493
+ encoder_outputs = BaseModelOutput(student_outputs.encoder_last_hidden_state.to(dtype=teacher_dtype))
1494
+ teacher_outputs = teacher_model(encoder_outputs=encoder_outputs, labels=batch["labels"])
1495
+ else:
1496
+ # do the full forward pass for the teacher model (encoder + decoder)
1497
+ teacher_outputs = teacher_model(**batch)
1498
+
1499
+ # CE (data) loss
1500
+ ce_loss = student_outputs.loss
1501
+ # rescale distribution by temperature to ensure gradients scale correctly
1502
+ teacher_distribution = nn.functional.softmax(teacher_outputs.logits / temperature, dim=-1)
1503
+ # log softmax of student predictions for numerical stability
1504
+ student_distribution = nn.functional.log_softmax(student_outputs.logits / temperature, dim=-1)
1505
+ # KL-divergence loss (scaled by temperature)
1506
+ kl_loss = kl_divergence(teacher_distribution, student_distribution, batch["labels"]) * temperature**2
1507
+
1508
+ # use Distil-Whisper formulation (fix weight of CE loss and tune KL weight)
1509
+ loss = 0.8 * ce_loss + training_args.kl_weight * kl_loss
1510
+ metrics = {"loss": loss, "ce_loss": ce_loss, "kl_loss": kl_loss}
1511
+ return loss, metrics
1512
+
1513
+ # Define eval fn
1514
+ def eval_step(batch):
1515
+ student_model.eval()
1516
+ teacher_model.eval()
1517
+
1518
+ with torch.no_grad():
1519
+ student_outputs = student_model(**batch)
1520
+ if share_hidden_states:
1521
+ encoder_outputs = BaseModelOutput(student_outputs.encoder_last_hidden_state.to(dtype=teacher_dtype))
1522
+ teacher_outputs = teacher_model(encoder_outputs=encoder_outputs, labels=batch["labels"])
1523
+ else:
1524
+ teacher_outputs = teacher_model(**batch)
1525
+
1526
+ # CE (data) loss
1527
+ ce_loss = student_outputs.loss
1528
+
1529
+ # log softmax / softmax for numerical stability
1530
+ student_distribution = nn.functional.log_softmax(student_outputs.logits, dim=-1)
1531
+ teacher_distribution = nn.functional.softmax(teacher_outputs.logits, dim=-1)
1532
+ # temperature is always 1 for eval
1533
+ kl_loss = kl_divergence(teacher_distribution, student_distribution, batch["labels"])
1534
+
1535
+ # use Distil-Whisper formulation (fix weight of CE loss and tune KL weight)
1536
+ loss = 0.8 * ce_loss + training_args.kl_weight * kl_loss
1537
+ metrics = {"loss": loss, "ce_loss": ce_loss, "kl_loss": kl_loss}
1538
+ return metrics
1539
+
1540
+ def generate_step(batch):
1541
+ student_model.eval()
1542
+ output_ids = accelerator.unwrap_model(student_model).generate(batch["input_features"], **gen_kwargs)
1543
+ output_ids = accelerator.pad_across_processes(output_ids, dim=1, pad_index=tokenizer.pad_token_id)
1544
+ return output_ids
1545
+
1546
+ logger.info("***** Running training *****")
1547
+ logger.info(f" Num examples = {total_train_steps * train_batch_size * gradient_accumulation_steps}")
1548
+ if not data_args.streaming:
1549
+ logger.info(f" Num epochs = {num_epochs}")
1550
+ logger.info(" Instantaneous batch size per device =" f" {training_args.per_device_train_batch_size}")
1551
+ logger.info(" Gradient accumulation steps =" f" {gradient_accumulation_steps}")
1552
+ logger.info(
1553
+ f" Total train batch size (w. parallel & distributed) = {train_batch_size * gradient_accumulation_steps}"
1554
+ )
1555
+ logger.info(f" Total optimization steps = {total_train_steps}")
1556
+
1557
+ # ======================== Training ================================
1558
+ train_time = 0
1559
+ train_start = time.time()
1560
+ steps_trained_progress_bar = tqdm(
1561
+ range(total_train_steps), desc="Train steps ... ", position=0, disable=not accelerator.is_local_main_process
1562
+ )
1563
+ continue_training = True
1564
+ epochs_trained = 0
1565
+ cur_step = 0
1566
+ best_val_wer = np.inf
1567
+
1568
+ checkpoint = None
1569
+ if training_args.resume_from_checkpoint is not None:
1570
+ checkpoint = training_args.resume_from_checkpoint
1571
+ elif last_checkpoint is not None:
1572
+ checkpoint = last_checkpoint
1573
+
1574
+ if checkpoint is not None:
1575
+ accelerator.load_state(checkpoint)
1576
+ # Find num steps and epoch from saved state string pattern
1577
+ pattern = r"checkpoint-(\d+)-epoch-(\d+)"
1578
+ match = re.search(pattern, checkpoint)
1579
+ cur_step = int(match.group(1))
1580
+ epochs_trained = int(match.group(2))
1581
+
1582
+ logger.info(" Continuing training from checkpoint, will skip to saved global_step")
1583
+ logger.info(f" Continuing training from epoch {epochs_trained}")
1584
+ logger.info(f" Continuing training from global step {cur_step}")
1585
+
1586
+ steps_trained_progress_bar.update(cur_step)
1587
+
1588
+ for epoch in range(0, epochs_trained):
1589
+ vectorized_datasets["train"] = vectorized_datasets["train"].shuffle(training_args.seed)
1590
+
1591
+ if not data_args.streaming and training_args.max_steps < 0:
1592
+ # we know exactly the number of steps per epoch, so can skip through the required number of batches
1593
+ resume_step = (cur_step - epochs_trained * steps_per_epoch) * gradient_accumulation_steps
1594
+ else:
1595
+ # Currently we don't know how many steps we've taken in the current epoch
1596
+ # So we just shuffle the dataset one extra time and start from a fresh epoch
1597
+ # This is "good enough" for our purposes but not fully correct
1598
+ resume_step = None
1599
+ vectorized_datasets["train"] = vectorized_datasets["train"].shuffle(training_args.seed)
1600
+ else:
1601
+ resume_step = None
1602
+
1603
+ for epoch in range(epochs_trained, num_epochs):
1604
+ vectorized_datasets["train"] = vectorized_datasets["train"].shuffle(training_args.seed)
1605
+ train_dataloader = DataLoader(
1606
+ vectorized_datasets["train"],
1607
+ collate_fn=data_collator,
1608
+ batch_size=per_device_train_batch_size,
1609
+ num_workers=dataloader_num_workers,
1610
+ prefetch_factor=prefetch_factor,
1611
+ pin_memory=training_args.dataloader_pin_memory,
1612
+ )
1613
+ train_dataloader = accelerator.prepare(train_dataloader)
1614
+ if hasattr(train_dataloader, "dataset") and isinstance(train_dataloader.dataset, IterableDataset):
1615
+ train_dataloader.dataset.set_epoch(epoch)
1616
+
1617
+ if resume_step is not None:
1618
+ # Skip the first N batches in the dataloader when resuming from a checkpoint
1619
+ train_dataloader = accelerator.skip_first_batches(train_dataloader, resume_step)
1620
+ resume_step = None
1621
+
1622
+ for batch in train_dataloader:
1623
+ with accelerator.accumulate(student_model):
1624
+ loss, train_metric = train_step(batch, temperature=training_args.temperature)
1625
+ accelerator.backward(loss)
1626
+ if accelerator.sync_gradients:
1627
+ accelerator.clip_grad_norm_(student_model.parameters(), training_args.max_grad_norm)
1628
+ optimizer.step()
1629
+ lr_scheduler.step()
1630
+ optimizer.zero_grad()
1631
+
1632
+ # Check if the accelerator has performed an optimization step behind the scenes
1633
+ if accelerator.sync_gradients:
1634
+ steps_trained_progress_bar.update(1)
1635
+ cur_step += 1
1636
+
1637
+ if cur_step % training_args.logging_steps == 0:
1638
+ steps_trained_progress_bar.write(
1639
+ f"Step... ({cur_step} / {total_train_steps} | Loss:"
1640
+ f" {train_metric['loss']}, Learning Rate:"
1641
+ f" {lr_scheduler.get_last_lr()[0]})"
1642
+ )
1643
+ log_metric(
1644
+ accelerator,
1645
+ metrics=train_metric,
1646
+ learning_rate=lr_scheduler.get_last_lr()[0],
1647
+ train_time=train_time + time.time() - train_start,
1648
+ step=cur_step,
1649
+ epoch=epoch,
1650
+ prefix="train",
1651
+ )
1652
+
1653
+ # save checkpoint and weights after each save_steps and at the end of training
1654
+ if (cur_step % training_args.save_steps == 0) or cur_step == total_train_steps:
1655
+ intermediate_dir = os.path.join(training_args.output_dir, f"checkpoint-{cur_step}-epoch-{epoch}")
1656
+ accelerator.save_state(output_dir=intermediate_dir)
1657
+ feature_extractor.save_pretrained(intermediate_dir)
1658
+ tokenizer.save_pretrained(intermediate_dir)
1659
+ config.save_pretrained(intermediate_dir)
1660
+ student_model.generation_config.save_pretrained(intermediate_dir)
1661
+
1662
+ accelerator.wait_for_everyone()
1663
+ if accelerator.is_main_process:
1664
+ rotate_checkpoints(training_args.save_total_limit, output_dir=training_args.output_dir)
1665
+
1666
+ if training_args.push_to_hub:
1667
+ upload_folder(
1668
+ folder_path=training_args.output_dir,
1669
+ repo_id=repo_name,
1670
+ repo_type="model",
1671
+ commit_message=f"Saving train state of step {cur_step}",
1672
+ )
1673
+
1674
+ if training_args.do_eval and (cur_step % eval_steps == 0 or cur_step == total_train_steps):
1675
+ train_time += time.time() - train_start
1676
+ student_model.eval()
1677
+ wer_l, labels_l = [], []
1678
+ # ======================== Evaluating ==============================
1679
+ for eval_split in all_eval_splits:
1680
+ eval_metrics = []
1681
+ eval_preds = []
1682
+ eval_labels = []
1683
+ eval_start = time.time()
1684
+
1685
+ validation_dataloader = DataLoader(
1686
+ vectorized_datasets[eval_split],
1687
+ collate_fn=data_collator,
1688
+ batch_size=per_device_eval_batch_size,
1689
+ drop_last=False,
1690
+ num_workers=dataloader_num_workers,
1691
+ prefetch_factor=prefetch_factor,
1692
+ pin_memory=training_args.dataloader_pin_memory,
1693
+ )
1694
+ validation_dataloader = accelerator.prepare(validation_dataloader)
1695
+
1696
+ for batch in tqdm(
1697
+ validation_dataloader,
1698
+ desc=f"Evaluating {eval_split}...",
1699
+ position=2,
1700
+ disable=not accelerator.is_local_main_process,
1701
+ ):
1702
+ # Model forward
1703
+ eval_metric = eval_step(batch)
1704
+ eval_metric = accelerator.gather_for_metrics(eval_metric)
1705
+ eval_metrics.append(eval_metric)
1706
+
1707
+ # generation
1708
+ if training_args.predict_with_generate:
1709
+ generated_ids = generate_step(batch)
1710
+ # Gather all predictions and targets
1711
+ generated_ids, labels = accelerator.gather_for_metrics(
1712
+ (generated_ids, batch["labels"])
1713
+ )
1714
+ eval_preds.extend(generated_ids)
1715
+ eval_labels.extend(labels)
1716
+
1717
+ eval_time = time.time() - eval_start
1718
+ # normalize eval metrics
1719
+ eval_metrics = {
1720
+ key: torch.mean(torch.stack([d[key] for d in eval_metrics])) for key in eval_metrics[0]
1721
+ }
1722
+
1723
+ # compute WER metric
1724
+ wer_desc = ""
1725
+ if training_args.predict_with_generate:
1726
+ wer_metric, pred_str, label_str, norm_pred_str, norm_label_str = compute_metrics(
1727
+ eval_preds, eval_labels
1728
+ )
1729
+ eval_metrics.update(wer_metric)
1730
+ wer_desc = " ".join([f"Eval {key}: {value} |" for key, value in wer_metric.items()])
1731
+ log_pred(
1732
+ accelerator,
1733
+ pred_str,
1734
+ label_str,
1735
+ norm_pred_str,
1736
+ norm_label_str,
1737
+ step=cur_step,
1738
+ prefix=eval_split,
1739
+ )
1740
+
1741
+ # Print metrics and update progress bar
1742
+ steps_trained_progress_bar.write(
1743
+ f"Eval results for step ({cur_step} / {total_train_steps} | Eval Loss: {eval_metrics['loss']} |"
1744
+ f" {wer_desc})"
1745
+ )
1746
+
1747
+ wer_l.append(wer_metric)
1748
+ labels_l.append(norm_label_str)
1749
+
1750
+ log_metric(
1751
+ accelerator,
1752
+ metrics=eval_metrics,
1753
+ train_time=eval_time,
1754
+ step=cur_step,
1755
+ epoch=epoch,
1756
+ prefix=eval_split,
1757
+ )
1758
+
1759
+ # flush the train metrics
1760
+ train_start = time.time()
1761
+
1762
+ # save best checkpoint
1763
+ numerators = [wer['wer'] * len(labs) for wer, labs in zip(wer_l, labels_l)]
1764
+ val_wer = sum(numerators) / sum(len(labs) for labs in labels_l)
1765
+
1766
+ if val_wer < best_val_wer:
1767
+ intermediate_dir = os.path.join(training_args.output_dir, f"checkpoint-{cur_step}-epoch-{epoch}-val-wer-{val_wer:.3f}")
1768
+ logger.info(f"Saving new best model, validation WER: {val_wer:.3f}")
1769
+ accelerator.save_state(output_dir=intermediate_dir)
1770
+ feature_extractor.save_pretrained(intermediate_dir)
1771
+ tokenizer.save_pretrained(intermediate_dir)
1772
+ config.save_pretrained(intermediate_dir)
1773
+ student_model.generation_config.save_pretrained(intermediate_dir)
1774
+
1775
+ accelerator.wait_for_everyone()
1776
+
1777
+ # remove unnecesary checkpoints, save best model and push to hub
1778
+ if accelerator.is_main_process:
1779
+ rotate_checkpoints(training_args.save_best_total_limit, output_dir=training_args.output_dir, sorting_fn=sorted_best_checkpoints)
1780
+
1781
+ accelerator.unwrap_model(student_model).save_pretrained(training_args.output_dir)
1782
+
1783
+ if training_args.push_to_hub:
1784
+ upload_folder(
1785
+ folder_path=training_args.output_dir,
1786
+ repo_id=repo_name,
1787
+ repo_type="model",
1788
+ commit_message=f"Saving best state, step {cur_step}, val wer {val_wer:.3f}",
1789
+ )
1790
+
1791
+ best_val_wer = val_wer
1792
+
1793
+ # break condition
1794
+ if cur_step == total_train_steps:
1795
+
1796
+ # the model under training_args.output_dir is the best model, let's also save end of training weights
1797
+ final_weights_dir = os.path.join(training_args.output_dir, "end-of-training-weights")
1798
+
1799
+ feature_extractor.save_pretrained(final_weights_dir)
1800
+ tokenizer.save_pretrained(final_weights_dir)
1801
+ # save the config and generation config as well
1802
+ config.save_pretrained(final_weights_dir)
1803
+ student_model.generation_config.save_pretrained(final_weights_dir)
1804
+
1805
+ # un-wrap student model for save
1806
+ student_model = accelerator.unwrap_model(student_model)
1807
+ student_model.save_pretrained(final_weights_dir)
1808
+
1809
+ if training_args.push_to_hub:
1810
+ upload_folder(
1811
+ folder_path=training_args.output_dir,
1812
+ repo_id=repo_name,
1813
+ repo_type="model",
1814
+ commit_message=f"Saving final weights of step {cur_step}",
1815
+ )
1816
+
1817
+ continue_training = False
1818
+ break
1819
+
1820
+ if not continue_training:
1821
+ break
1822
+
1823
+ accelerator.end_training()
1824
+
1825
+
1826
+ if __name__ == "__main__":
1827
+ main()
run_large_training.sh ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ accelerate launch run_distillation.py \
3
+ --model_name_or_path "./nb-distil-large-init" \
4
+ --teacher_model_name_or_path "NbAiLab/nb-whisper-large" \
5
+ --train_dataset_name "NbAiLab/annotated_distil_raw_ncc_speech_v7_large" \
6
+ --train_dataset_config_name "" \
7
+ --train_split_name "train" \
8
+ --eval_dataset_name "NbAiLab/annotated_distil_raw_ncc_speech_v7_large" \
9
+ --eval_dataset_config_name "" \
10
+ --eval_split_name "validation" \
11
+ --eval_steps 500 \
12
+ --save_steps 1000 \
13
+ --warmup_steps 1000 \
14
+ --learning_rate 0.0003 \
15
+ --lr_scheduler_type "constant_with_warmup" \
16
+ --timestamp_probability 0.2 \
17
+ --condition_on_prev_probability 0.2 \
18
+ --language "no" \
19
+ --task "transcribe" \
20
+ --logging_steps 200 \
21
+ --save_total_limit 1 \
22
+ --max_steps 50000 \
23
+ --wer_threshold 20 \
24
+ --per_device_train_batch_size 32 \
25
+ --per_device_eval_batch_size 32 \
26
+ --dataloader_num_workers 8 \
27
+ --preprocessing_num_workers 8 \
28
+ --ddp_timeout 7200 \
29
+ --dtype "bfloat16" \
30
+ --attn_implementation "sdpa" \
31
+ --output_dir "./" \
32
+ --do_train \
33
+ --do_eval \
34
+ --gradient_checkpointing \
35
+ --overwrite_output_dir \
36
+ --predict_with_generate \
37
+ --freeze_encoder \
38
+ --freeze_embed_positions \
39
+ --streaming True \
40
+ --push_to_hub
41
+
special_tokens_map.json ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "additional_special_tokens": [
3
+ "<|startoftranscript|>",
4
+ "<|en|>",
5
+ "<|zh|>",
6
+ "<|de|>",
7
+ "<|es|>",
8
+ "<|ru|>",
9
+ "<|ko|>",
10
+ "<|fr|>",
11
+ "<|ja|>",
12
+ "<|pt|>",
13
+ "<|tr|>",
14
+ "<|pl|>",
15
+ "<|ca|>",
16
+ "<|nl|>",
17
+ "<|ar|>",
18
+ "<|sv|>",
19
+ "<|it|>",
20
+ "<|id|>",
21
+ "<|hi|>",
22
+ "<|fi|>",
23
+ "<|vi|>",
24
+ "<|he|>",
25
+ "<|uk|>",
26
+ "<|el|>",
27
+ "<|ms|>",
28
+ "<|cs|>",
29
+ "<|ro|>",
30
+ "<|da|>",
31
+ "<|hu|>",
32
+ "<|ta|>",
33
+ "<|no|>",
34
+ "<|th|>",
35
+ "<|ur|>",
36
+ "<|hr|>",
37
+ "<|bg|>",
38
+ "<|lt|>",
39
+ "<|la|>",
40
+ "<|mi|>",
41
+ "<|ml|>",
42
+ "<|cy|>",
43
+ "<|sk|>",
44
+ "<|te|>",
45
+ "<|fa|>",
46
+ "<|lv|>",
47
+ "<|bn|>",
48
+ "<|sr|>",
49
+ "<|az|>",
50
+ "<|sl|>",
51
+ "<|kn|>",
52
+ "<|et|>",
53
+ "<|mk|>",
54
+ "<|br|>",
55
+ "<|eu|>",
56
+ "<|is|>",
57
+ "<|hy|>",
58
+ "<|ne|>",
59
+ "<|mn|>",
60
+ "<|bs|>",
61
+ "<|kk|>",
62
+ "<|sq|>",
63
+ "<|sw|>",
64
+ "<|gl|>",
65
+ "<|mr|>",
66
+ "<|pa|>",
67
+ "<|si|>",
68
+ "<|km|>",
69
+ "<|sn|>",
70
+ "<|yo|>",
71
+ "<|so|>",
72
+ "<|af|>",
73
+ "<|oc|>",
74
+ "<|ka|>",
75
+ "<|be|>",
76
+ "<|tg|>",
77
+ "<|sd|>",
78
+ "<|gu|>",
79
+ "<|am|>",
80
+ "<|yi|>",
81
+ "<|lo|>",
82
+ "<|uz|>",
83
+ "<|fo|>",
84
+ "<|ht|>",
85
+ "<|ps|>",
86
+ "<|tk|>",
87
+ "<|nn|>",
88
+ "<|mt|>",
89
+ "<|sa|>",
90
+ "<|lb|>",
91
+ "<|my|>",
92
+ "<|bo|>",
93
+ "<|tl|>",
94
+ "<|mg|>",
95
+ "<|as|>",
96
+ "<|tt|>",
97
+ "<|haw|>",
98
+ "<|ln|>",
99
+ "<|ha|>",
100
+ "<|ba|>",
101
+ "<|jw|>",
102
+ "<|su|>",
103
+ "<|yue|>",
104
+ "<|translate|>",
105
+ "<|transcribe|>",
106
+ "<|startoflm|>",
107
+ "<|startofprev|>",
108
+ "<|nospeech|>",
109
+ "<|notimestamps|>"
110
+ ],
111
+ "bos_token": {
112
+ "content": "<|endoftext|>",
113
+ "lstrip": false,
114
+ "normalized": false,
115
+ "rstrip": false,
116
+ "single_word": false
117
+ },
118
+ "eos_token": {
119
+ "content": "<|endoftext|>",
120
+ "lstrip": false,
121
+ "normalized": false,
122
+ "rstrip": false,
123
+ "single_word": false
124
+ },
125
+ "pad_token": {
126
+ "content": "<|endoftext|>",
127
+ "lstrip": false,
128
+ "normalized": false,
129
+ "rstrip": false,
130
+ "single_word": false
131
+ },
132
+ "unk_token": {
133
+ "content": "<|endoftext|>",
134
+ "lstrip": false,
135
+ "normalized": false,
136
+ "rstrip": false,
137
+ "single_word": false
138
+ }
139
+ }
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_config.json ADDED
The diff for this file is too large to render. See raw diff
 
vocab.json ADDED
The diff for this file is too large to render. See raw diff