amyeroberts HF staff jbochi commited on
Commit
d1017b4
0 Parent(s):

Duplicate from jbochi/madlad400-8b-lm

Browse files

Co-authored-by: J Bochi <jbochi@users.noreply.huggingface.co>

.gitattributes ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ tokenizer.json filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,462 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ language:
4
+ - en
5
+ - ru
6
+ - es
7
+ - fr
8
+ - de
9
+ - it
10
+ - pt
11
+ - pl
12
+ - nl
13
+ - vi
14
+ - tr
15
+ - sv
16
+ - id
17
+ - ro
18
+ - cs
19
+ - zh
20
+ - hu
21
+ - ja
22
+ - th
23
+ - fi
24
+ - fa
25
+ - uk
26
+ - da
27
+ - el
28
+ - "no"
29
+ - bg
30
+ - sk
31
+ - ko
32
+ - ar
33
+ - lt
34
+ - ca
35
+ - sl
36
+ - he
37
+ - et
38
+ - lv
39
+ - hi
40
+ - sq
41
+ - ms
42
+ - az
43
+ - sr
44
+ - ta
45
+ - hr
46
+ - kk
47
+ - is
48
+ - ml
49
+ - mr
50
+ - te
51
+ - af
52
+ - gl
53
+ - fil
54
+ - be
55
+ - mk
56
+ - eu
57
+ - bn
58
+ - ka
59
+ - mn
60
+ - bs
61
+ - uz
62
+ - ur
63
+ - sw
64
+ - yue
65
+ - ne
66
+ - kn
67
+ - kaa
68
+ - gu
69
+ - si
70
+ - cy
71
+ - eo
72
+ - la
73
+ - hy
74
+ - ky
75
+ - tg
76
+ - ga
77
+ - mt
78
+ - my
79
+ - km
80
+ - tt
81
+ - so
82
+ - ku
83
+ - ps
84
+ - pa
85
+ - rw
86
+ - lo
87
+ - ha
88
+ - dv
89
+ - fy
90
+ - lb
91
+ - ckb
92
+ - mg
93
+ - gd
94
+ - am
95
+ - ug
96
+ - ht
97
+ - grc
98
+ - hmn
99
+ - sd
100
+ - jv
101
+ - mi
102
+ - tk
103
+ - ceb
104
+ - yi
105
+ - ba
106
+ - fo
107
+ - or
108
+ - xh
109
+ - su
110
+ - kl
111
+ - ny
112
+ - sm
113
+ - sn
114
+ - co
115
+ - zu
116
+ - ig
117
+ - yo
118
+ - pap
119
+ - st
120
+ - haw
121
+ - as
122
+ - oc
123
+ - cv
124
+ - lus
125
+ - tet
126
+ - gsw
127
+ - sah
128
+ - br
129
+ - rm
130
+ - sa
131
+ - bo
132
+ - om
133
+ - se
134
+ - ce
135
+ - cnh
136
+ - ilo
137
+ - hil
138
+ - udm
139
+ - os
140
+ - lg
141
+ - ti
142
+ - vec
143
+ - ts
144
+ - tyv
145
+ - kbd
146
+ - ee
147
+ - iba
148
+ - av
149
+ - kha
150
+ - to
151
+ - tn
152
+ - nso
153
+ - fj
154
+ - zza
155
+ - ak
156
+ - ada
157
+ - otq
158
+ - dz
159
+ - bua
160
+ - cfm
161
+ - ln
162
+ - chm
163
+ - gn
164
+ - krc
165
+ - wa
166
+ - hif
167
+ - yua
168
+ - srn
169
+ - war
170
+ - rom
171
+ - bik
172
+ - pam
173
+ - sg
174
+ - lu
175
+ - ady
176
+ - kbp
177
+ - syr
178
+ - ltg
179
+ - myv
180
+ - iso
181
+ - kac
182
+ - bho
183
+ - ay
184
+ - kum
185
+ - qu
186
+ - za
187
+ - pag
188
+ - ngu
189
+ - ve
190
+ - pck
191
+ - zap
192
+ - tyz
193
+ - hui
194
+ - bbc
195
+ - tzo
196
+ - tiv
197
+ - ksd
198
+ - gom
199
+ - min
200
+ - ang
201
+ - nhe
202
+ - bgp
203
+ - nzi
204
+ - nnb
205
+ - nv
206
+ - zxx
207
+ - bci
208
+ - kv
209
+ - new
210
+ - mps
211
+ - alt
212
+ - meu
213
+ - bew
214
+ - fon
215
+ - iu
216
+ - abt
217
+ - mgh
218
+ - mnw
219
+ - tvl
220
+ - dov
221
+ - tlh
222
+ - ho
223
+ - kw
224
+ - mrj
225
+ - meo
226
+ - crh
227
+ - mbt
228
+ - emp
229
+ - ace
230
+ - ium
231
+ - mam
232
+ - gym
233
+ - mai
234
+ - crs
235
+ - pon
236
+ - ubu
237
+ - fip
238
+ - quc
239
+ - gv
240
+ - kj
241
+ - btx
242
+ - ape
243
+ - chk
244
+ - rcf
245
+ - shn
246
+ - tzh
247
+ - mdf
248
+ - ppk
249
+ - ss
250
+ - gag
251
+ - cab
252
+ - kri
253
+ - seh
254
+ - ibb
255
+ - tbz
256
+ - bru
257
+ - enq
258
+ - ach
259
+ - cuk
260
+ - kmb
261
+ - wo
262
+ - kek
263
+ - qub
264
+ - tab
265
+ - bts
266
+ - kos
267
+ - rwo
268
+ - cak
269
+ - tuc
270
+ - bum
271
+ - cjk
272
+ - gil
273
+ - stq
274
+ - tsg
275
+ - quh
276
+ - mak
277
+ - arn
278
+ - ban
279
+ - jiv
280
+ - sja
281
+ - yap
282
+ - tcy
283
+ - toj
284
+ - twu
285
+ - xal
286
+ - amu
287
+ - rmc
288
+ - hus
289
+ - nia
290
+ - kjh
291
+ - bm
292
+ - guh
293
+ - mas
294
+ - acf
295
+ - dtp
296
+ - ksw
297
+ - bzj
298
+ - din
299
+ - zne
300
+ - mad
301
+ - msi
302
+ - mag
303
+ - mkn
304
+ - kg
305
+ - lhu
306
+ - ch
307
+ - qvi
308
+ - mh
309
+ - djk
310
+ - sus
311
+ - mfe
312
+ - srm
313
+ - dyu
314
+ - ctu
315
+ - gui
316
+ - pau
317
+ - inb
318
+ - bi
319
+ - mni
320
+ - guc
321
+ - jam
322
+ - wal
323
+ - jac
324
+ - bas
325
+ - gor
326
+ - skr
327
+ - nyu
328
+ - noa
329
+ - sda
330
+ - gub
331
+ - nog
332
+ - cni
333
+ - teo
334
+ - tdx
335
+ - sxn
336
+ - rki
337
+ - nr
338
+ - frp
339
+ - alz
340
+ - taj
341
+ - lrc
342
+ - cce
343
+ - rn
344
+ - jvn
345
+ - hvn
346
+ - nij
347
+ - dwr
348
+ - izz
349
+ - msm
350
+ - bus
351
+ - ktu
352
+ - chr
353
+ - maz
354
+ - tzj
355
+ - suz
356
+ - knj
357
+ - bim
358
+ - gvl
359
+ - bqc
360
+ - tca
361
+ - pis
362
+ - prk
363
+ - laj
364
+ - mel
365
+ - qxr
366
+ - niq
367
+ - ahk
368
+ - shp
369
+ - hne
370
+ - spp
371
+ - koi
372
+ - krj
373
+ - quf
374
+ - luz
375
+ - agr
376
+ - tsc
377
+ - mqy
378
+ - gof
379
+ - gbm
380
+ - miq
381
+ - dje
382
+ - awa
383
+ - bjj
384
+ - qvz
385
+ - sjp
386
+ - tll
387
+ - raj
388
+ - kjg
389
+ - bgz
390
+ - quy
391
+ - cbk
392
+ - akb
393
+ - oj
394
+ - ify
395
+ - mey
396
+ - ks
397
+ - cac
398
+ - brx
399
+ - qup
400
+ - syl
401
+ - jax
402
+ - ff
403
+ - ber
404
+ - tks
405
+ - trp
406
+ - mrw
407
+ - adh
408
+ - smt
409
+ - srr
410
+ - ffm
411
+ - qvc
412
+ - mtr
413
+ - ann
414
+ - kaa
415
+ - aa
416
+ - noe
417
+ - nut
418
+ - gyn
419
+ - kwi
420
+ - xmm
421
+ - msb
422
+ library_name: transformers
423
+ tags:
424
+ - text-generation-inference
425
+ datasets:
426
+ - allenai/MADLAD-400
427
+ ---
428
+
429
+ This model has the safetensors weights for the [Madlad-400](https://github.com/google-research/google-research/tree/master/madlad_400) 8B param **language model**.
430
+
431
+ The HF transformers code to run inference is not ready yet. The [original implementation](https://github.com/google/flaxformer/blob/ea17eb012a1d340ddff017b7a534c2162aaec34c/flaxformer/architectures/t5/t5_architecture.py#L1484) is in JAX/Flaxformer.
432
+
433
+ The model architecture is the same as [Palm 8B](https://arxiv.org/pdf/2204.02311.pdf).
434
+
435
+ It's a decoder-only T5 with 32 layers, 16 query heads, 1 KV head, and 4096 embedding size.
436
+
437
+ These are the main differences relative to the original T5 architecture:
438
+
439
+ - SwiGLU Activation
440
+ - Parallel Layers
441
+ - Multi-Query Attention
442
+ - RoPE Embeddings
443
+ - Shared Input-Output Embeddings
444
+ - No biases
445
+ - Bidirectional attention
446
+ - Layer Norm with `center_scale_at_zero` and final layer with `use_scale=False`
447
+
448
+ If you are looking for the language models models, here are the available versions:
449
+ - [3B](https://huggingface.co/jbochi/madlad400-3b-mt)
450
+ - [7B](https://huggingface.co/jbochi/madlad400-7b-mt)
451
+ - [7B-BT](https://huggingface.co/jbochi/madlad400-7b-mt-bt)
452
+ - [10B](https://huggingface.co/jbochi/madlad400-10b-mt)
453
+
454
+
455
+ Article: [MADLAD-400: A Multilingual And Document-Level Large Audited Dataset](https://arxiv.org/abs/2309.04662)
456
+
457
+ Abstract:
458
+
459
+ > We introduce MADLAD-400, a manually audited, general domain 3T token monolingual dataset based on CommonCrawl, spanning 419 languages. We discuss the limitations revealed by self-auditing MADLAD-400, and the role data auditing had in the dataset creation process. We then train and release a 10.7B-parameter multilingual machine translation model on 250 billion tokens covering over 450 languages using publicly available data, and find that it is competitive with models that are significantly larger, and report the results on different domains. In addition, we train a 8B-parameter language model, and assess the results on few-shot translation. We make the baseline models available to the research community.
460
+
461
+
462
+
added_tokens.json ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ {
2
+ }
config.json ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "DecoderOnlyT5Model"
4
+ ],
5
+ "auto_map": {
6
+ "AutoConfig": "decoderonlyt5_config.DecoderOnlyT5Config",
7
+ "AutoModelForCausalLM": "decoderonlyt5_modeling.DecoderOnlyT5Model"
8
+ },
9
+ "d_ff": 16384,
10
+ "d_kv": 256,
11
+ "d_model": 4096,
12
+ "dropout_rate": 0.0,
13
+ "decoder_start_token_id": 0,
14
+ "pad_token_id": 1,
15
+ "eos_token_id": 3,
16
+ "feed_forward_proj": "gated-swish",
17
+ "initializer_factor": 1.0,
18
+ "is_encoder_decoder": false,
19
+ "is_decoder_only": true,
20
+ "layer_norm_epsilon": 1e-06,
21
+ "model_type": "t5",
22
+ "n_positions": 512,
23
+ "num_layers": 0,
24
+ "num_decoder_layers": 32,
25
+ "num_heads": 16,
26
+ "output_past": true,
27
+ "relative_attention_max_distance": 128,
28
+ "relative_attention_num_buckets": 32,
29
+ "task_specific_params": {},
30
+ "tie_word_embeddings": true,
31
+ "transformers_version": "4.23.1",
32
+ "use_cache": true,
33
+ "vocab_size": 256512,
34
+ "parallel_layers": true,
35
+ "has_relative_attention_bias": false,
36
+ "multi_query_attention": true,
37
+ "use_rotary_embedding": true,
38
+ "rotary_embedding_max_timescale": 1000
39
+ }
decoderonlyt5_config.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers.models.t5.configuration_t5 import T5Config
2
+
3
+
4
+ class DecoderOnlyT5Config(T5Config):
5
+ is_decoder_only = True
6
+ # whether to call attention and mlp in parallel.
7
+ # https://github.com/google/flaxformer/blob/ea17eb012a1d340ddff017b7a534c2162aaec34c/flaxformer/architectures/t5/t5_architecture.py#L384
8
+ parallel_layers = True
9
+ has_relative_attention_bias = False
10
+ # https://arxiv.org/abs/1911.02150
11
+ multi_query_attention = True
decoderonlyt5_modeling.py ADDED
@@ -0,0 +1,840 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import math
3
+ from typing import Optional, Tuple, Union
4
+
5
+ import torch
6
+ from torch import nn
7
+ from torch.nn import CrossEntropyLoss
8
+ from transformers.models.t5 import modeling_t5
9
+ from transformers.modeling_outputs import CausalLMOutputWithPast
10
+ from transformers.utils import (
11
+ add_start_docstrings_to_model_forward,
12
+ logging,
13
+ replace_return_docstrings,
14
+ )
15
+
16
+ from .decoderonlyt5_config import DecoderOnlyT5Config
17
+
18
+
19
+ logger = logging.get_logger(__name__)
20
+ _CONFIG_FOR_DOC = "DecoderOnlyT5Config"
21
+
22
+
23
+ class DecoderOnlyT5LayerNorm(nn.Module):
24
+ def __init__(self, hidden_size, eps=1e-6, use_scale=True, center_scale_at_zero=False):
25
+ """
26
+ Construct a layernorm module in the T5 style No bias and no subtraction of mean.
27
+ """
28
+ super().__init__()
29
+ if use_scale:
30
+ self.weight = nn.Parameter(torch.ones(hidden_size))
31
+ else:
32
+ assert not center_scale_at_zero
33
+ self.weight = None
34
+ self.center_scale_at_zero = center_scale_at_zero
35
+ self.variance_epsilon = eps
36
+
37
+ def forward(self, hidden_states):
38
+ # https://github.com/google/flaxformer/blob/ea17eb012a1d340ddff017b7a534c2162aaec34c/flaxformer/components/layer_norm.py#L30
39
+
40
+ # layer norm should always be calculated in float32
41
+ mean2 = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
42
+ hidden_states = hidden_states * torch.rsqrt(mean2 + self.variance_epsilon)
43
+
44
+ # convert into float16 if necessary
45
+ if self.weight is None:
46
+ return hidden_states
47
+ if self.weight.dtype == torch.float16:
48
+ hidden_states = hidden_states.to(torch.float16)
49
+ if self.center_scale_at_zero:
50
+ return (self.weight + 1.0) * hidden_states
51
+ else:
52
+ return self.weight * hidden_states
53
+
54
+
55
+
56
+ class DecoderOnlyT5LayerFF(modeling_t5.T5LayerFF):
57
+ def __init__(self, config: DecoderOnlyT5Config):
58
+ super(modeling_t5.T5LayerFF, self).__init__()
59
+ if config.is_gated_act:
60
+ self.DenseReluDense = modeling_t5.T5DenseGatedActDense(config)
61
+ else:
62
+ self.DenseReluDense = modeling_t5.T5DenseActDense(config)
63
+
64
+ if not config.parallel_layers:
65
+ self.layer_norm = modeling_t5.DecoderOnlyT5LayerNorm(
66
+ config.d_model, eps=config.layer_norm_epsilon
67
+ )
68
+ else:
69
+ self.layer_norm = nn.Identity()
70
+ self.dropout = nn.Dropout(config.dropout_rate)
71
+
72
+
73
+ # LlamaRotaryEmbedding
74
+ class DecoderOnlyT5RotaryEmbedding(nn.Module):
75
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
76
+ super().__init__()
77
+
78
+ self.dim = dim
79
+ self.max_position_embeddings = max_position_embeddings
80
+ self.base = base
81
+ inv_freq = 1.0 / (
82
+ self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)
83
+ )
84
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
85
+
86
+ # Build here to make `torch.jit.trace` work.
87
+ self._set_cos_sin_cache(
88
+ seq_len=max_position_embeddings,
89
+ device=self.inv_freq.device,
90
+ dtype=torch.get_default_dtype(),
91
+ )
92
+
93
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
94
+ self.max_seq_len_cached = seq_len
95
+ t = torch.arange(
96
+ self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype
97
+ )
98
+
99
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq)
100
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
101
+ emb = torch.cat((freqs, freqs), dim=-1)
102
+ self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
103
+ self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
104
+
105
+ def forward(self, x, seq_len=None):
106
+ # x: [bs, num_attention_heads, seq_len, head_size]
107
+ if seq_len > self.max_seq_len_cached:
108
+ self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
109
+
110
+ return (
111
+ self.cos_cached[:seq_len].to(dtype=x.dtype),
112
+ self.sin_cached[:seq_len].to(dtype=x.dtype),
113
+ )
114
+
115
+
116
+ def rotate_half(x):
117
+ """Rotates half the hidden dims of the input."""
118
+ x1 = x[..., : x.shape[-1] // 2]
119
+ x2 = x[..., x.shape[-1] // 2 :]
120
+ return torch.cat((-x2, x1), dim=-1)
121
+
122
+
123
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
124
+ """Applies Rotary Position Embedding to the query and key tensors.
125
+
126
+ Args:
127
+ q (`torch.Tensor`): The query tensor.
128
+ k (`torch.Tensor`): The key tensor.
129
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
130
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
131
+ position_ids (`torch.Tensor`):
132
+ The position indices of the tokens corresponding to the query and key tensors. For example, this can be
133
+ used to pass offsetted position ids when working with a KV-cache.
134
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
135
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
136
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
137
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
138
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
139
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
140
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
141
+ Returns:
142
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
143
+ """
144
+ cos = cos[position_ids].unsqueeze(unsqueeze_dim)
145
+ sin = sin[position_ids].unsqueeze(unsqueeze_dim)
146
+ q_embed = (q * cos) + (rotate_half(q) * sin)
147
+ k_embed = (k * cos) + (rotate_half(k) * sin)
148
+ return q_embed, k_embed
149
+
150
+
151
+ # https://github.com/huggingface/transformers/blob/7ee995fd9c692761c4601ddbffa2ac2ec9f27b0b/src/transformers/models/llama/modeling_llama.py#L263
152
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
153
+ """
154
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
155
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
156
+ """
157
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
158
+ if n_rep == 1:
159
+ return hidden_states
160
+ hidden_states = hidden_states[:, :, None, :, :].expand(
161
+ batch, num_key_value_heads, n_rep, slen, head_dim
162
+ )
163
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
164
+
165
+
166
+ class DecoderOnlyT5Attention(modeling_t5.T5Attention):
167
+ """
168
+ Supports both multi-head and multi-query attention.
169
+ https://arxiv.org/abs/1911.02150
170
+ https://github.com/google/flaxformer/blob/ea17eb012a1d340ddff017b7a534c2162aaec34c/flaxformer/components/attention/dense_attention.py#L292
171
+ """
172
+
173
+ def __init__(self, config: DecoderOnlyT5Config, has_relative_attention_bias=False):
174
+ super(modeling_t5.T5Attention, self).__init__()
175
+ self.is_decoder = config.is_decoder
176
+ assert not has_relative_attention_bias
177
+ assert config.use_rotary_embedding
178
+ self.d_model = config.d_model
179
+ self.head_dim = config.d_kv
180
+ self.num_heads = config.num_heads
181
+ self.num_key_value_heads = 1 if config.multi_query_attention else self.n_heads
182
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
183
+ self.attention_dropout = config.dropout_rate
184
+ self.inner_dim = self.num_heads * self.head_dim
185
+ self.kv_inner_dim = self.num_key_value_heads * self.head_dim
186
+ self.rotary_emb = DecoderOnlyT5RotaryEmbedding(
187
+ self.head_dim,
188
+ max_position_embeddings=config.relative_attention_max_distance,
189
+ base=config.rotary_embedding_max_timescale,
190
+ )
191
+
192
+ # Mesh TensorFlow initialization to avoid scaling before softmax
193
+ self.q = nn.Linear(self.d_model, self.inner_dim, bias=False)
194
+ self.k = nn.Linear(self.d_model, self.kv_inner_dim, bias=False)
195
+ self.v = nn.Linear(self.d_model, self.kv_inner_dim, bias=False)
196
+ self.o = nn.Linear(self.inner_dim, self.d_model, bias=False)
197
+
198
+ self.pruned_heads = set()
199
+ self.gradient_checkpointing = False
200
+
201
+ def forward(
202
+ self,
203
+ hidden_states: torch.Tensor,
204
+ key_value_states=None,
205
+ position_bias=None,
206
+ mask: Optional[torch.Tensor] = None,
207
+ layer_head_mask=None,
208
+ position_ids: Optional[torch.LongTensor] = None,
209
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
210
+ output_attentions: bool = False,
211
+ use_cache: bool = False,
212
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
213
+ assert key_value_states is None
214
+ assert position_bias is None
215
+ assert layer_head_mask is None
216
+
217
+ bsz, q_len, _ = hidden_states.size()
218
+
219
+ query_states = self.q(hidden_states)
220
+ key_states = self.k(hidden_states)
221
+ value_states = self.v(hidden_states)
222
+
223
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
224
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
225
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
226
+
227
+ kv_seq_len = key_states.shape[-2]
228
+ if past_key_value is not None:
229
+ kv_seq_len += past_key_value[0].shape[-2]
230
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
231
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
232
+
233
+ if past_key_value is not None:
234
+ # reuse k, v, self_attention
235
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
236
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
237
+
238
+ past_key_value = (key_states, value_states) if use_cache else None
239
+
240
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
241
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
242
+
243
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
244
+
245
+ if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
246
+ raise ValueError(
247
+ f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
248
+ f" {attn_weights.size()}"
249
+ )
250
+
251
+ if mask is not None:
252
+ if mask.size() != (bsz, 1, q_len, kv_seq_len):
253
+ raise ValueError(
254
+ f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {mask.size()}"
255
+ )
256
+ attn_weights = attn_weights + mask
257
+
258
+ # upcast attention to fp32
259
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
260
+ attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout)
261
+ attn_output = torch.matmul(attn_weights, value_states)
262
+
263
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
264
+ raise ValueError(
265
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
266
+ f" {attn_output.size()}"
267
+ )
268
+
269
+ attn_output = attn_output.transpose(1, 2).contiguous()
270
+ attn_output = attn_output.reshape(bsz, q_len, self.inner_dim)
271
+ attn_output = self.o(attn_output)
272
+
273
+ present_key_value_state = (
274
+ (key_states, value_states) if (self.is_decoder and use_cache) else None
275
+ )
276
+ outputs = (attn_output,) + (present_key_value_state,) + (position_bias,)
277
+
278
+ if output_attentions:
279
+ outputs = outputs + (attn_weights,)
280
+ return outputs
281
+
282
+
283
+ class DecoderOnlyT5LayerSelfAttention(modeling_t5.T5LayerSelfAttention):
284
+ def __init__(self, config, has_relative_attention_bias=False):
285
+ super(modeling_t5.T5LayerSelfAttention, self).__init__()
286
+ self.SelfAttention = DecoderOnlyT5Attention(
287
+ config, has_relative_attention_bias=has_relative_attention_bias
288
+ )
289
+ self.layer_norm = DecoderOnlyT5LayerNorm(
290
+ config.d_model,
291
+ eps=config.layer_norm_epsilon,
292
+ use_scale=True,
293
+ center_scale_at_zero=True,
294
+ )
295
+ self.dropout = nn.Dropout(config.dropout_rate)
296
+ self.parallel_layers = config.parallel_layers
297
+
298
+ def forward(
299
+ self,
300
+ hidden_states,
301
+ attention_mask=None,
302
+ position_bias=None,
303
+ position_ids=None,
304
+ layer_head_mask=None,
305
+ past_key_value=None,
306
+ use_cache=False,
307
+ output_attentions=False,
308
+ ):
309
+ if not self.parallel_layers:
310
+ x = self.layer_norm(hidden_states)
311
+ else:
312
+ x = hidden_states
313
+ attention_output = self.SelfAttention(
314
+ x,
315
+ mask=attention_mask,
316
+ position_bias=position_bias,
317
+ position_ids=position_ids,
318
+ layer_head_mask=layer_head_mask,
319
+ past_key_value=past_key_value,
320
+ use_cache=use_cache,
321
+ output_attentions=output_attentions,
322
+ )
323
+ if not self.parallel_layers:
324
+ # When parallel_layers is True, the residual connection is applied
325
+ # in the decoder block instead of here.
326
+ hidden_states = hidden_states + self.dropout(attention_output[0])
327
+ else:
328
+ hidden_states = attention_output[0]
329
+ outputs = (hidden_states,) + attention_output[
330
+ 1:
331
+ ] # add attentions if we output them
332
+ return outputs
333
+
334
+
335
+ class DecoderOnlyT5Block(modeling_t5.T5Block):
336
+ def __init__(self, config, has_relative_attention_bias=False):
337
+ super(modeling_t5.T5Block, self).__init__()
338
+ self.is_decoder = config.is_decoder
339
+ self.is_decoder_only = config.is_decoder_only
340
+ self.layer = nn.ModuleList()
341
+ self.layer.append(
342
+ DecoderOnlyT5LayerSelfAttention(
343
+ config, has_relative_attention_bias=has_relative_attention_bias
344
+ )
345
+ )
346
+ if self.is_decoder:
347
+ if config.is_decoder_only:
348
+ self.layer.append(nn.Identity())
349
+ else:
350
+ self.layer.append(modeling_t5.T5LayerCrossAttention(config))
351
+ self.parallel_layers = config.parallel_layers
352
+ self.layer.append(DecoderOnlyT5LayerFF(config))
353
+
354
+ def forward(
355
+ self,
356
+ hidden_states,
357
+ attention_mask=None,
358
+ position_bias=None,
359
+ position_ids=None,
360
+ encoder_hidden_states=None,
361
+ layer_head_mask=None,
362
+ past_key_value=None,
363
+ use_cache=False,
364
+ output_attentions=False,
365
+ encoder_attention_mask=None,
366
+ encoder_decoder_position_bias=None,
367
+ cross_attn_layer_head_mask=None,
368
+ return_dict=True,
369
+ ):
370
+ assert encoder_attention_mask is None
371
+ assert encoder_decoder_position_bias is None
372
+ assert cross_attn_layer_head_mask is None
373
+ if past_key_value is not None:
374
+ expected_num_past_key_values = 2 if encoder_hidden_states is None else 4
375
+
376
+ if len(past_key_value) != expected_num_past_key_values:
377
+ raise ValueError(
378
+ f"There should be {expected_num_past_key_values} past states. "
379
+ f"{'2 (past / key) for cross attention. ' if expected_num_past_key_values == 4 else ''}"
380
+ f"Got {len(past_key_value)} past key / value states"
381
+ )
382
+ self_attn_past_key_value = past_key_value[:2]
383
+ else:
384
+ self_attn_past_key_value = None
385
+
386
+ ff_layer = self.layer[-1]
387
+ if self.parallel_layers:
388
+ # https://github.com/google/flaxformer/blob/ea17eb012a1d340ddff017b7a534c2162aaec34c/flaxformer/architectures/t5/t5_architecture.py#L563-L568
389
+ x = self.layer[0].layer_norm(hidden_states)
390
+ ff_output = ff_layer(x)
391
+ else:
392
+ x = hidden_states
393
+
394
+ self_attention_outputs = self.layer[0](
395
+ x,
396
+ attention_mask=attention_mask,
397
+ position_bias=position_bias,
398
+ position_ids=position_ids,
399
+ layer_head_mask=layer_head_mask,
400
+ past_key_value=self_attn_past_key_value,
401
+ use_cache=use_cache,
402
+ output_attentions=output_attentions,
403
+ )
404
+ x, present_key_value_state = self_attention_outputs[:2]
405
+ attention_outputs = self_attention_outputs[
406
+ 2:
407
+ ] # Keep self-attention outputs and relative position weights
408
+
409
+ # clamp inf values to enable fp16 training
410
+ if x.dtype == torch.float16:
411
+ clamp_value = torch.where(
412
+ torch.isinf(x).any(),
413
+ torch.finfo(x.dtype).max - 1000,
414
+ torch.finfo(x.dtype).max,
415
+ )
416
+ x = torch.clamp(x, min=-clamp_value, max=clamp_value)
417
+
418
+ do_cross_attention = (
419
+ self.is_decoder
420
+ and not self.is_decoder_only
421
+ and encoder_hidden_states is not None
422
+ )
423
+ assert not do_cross_attention
424
+
425
+ if self.parallel_layers:
426
+ # https://github.com/google/flaxformer/blob/ea17eb012a1d340ddff017b7a534c2162aaec34c/flaxformer/architectures/t5/t5_architecture.py#L534-L578
427
+ x = x + ff_output
428
+ x *= 2**-0.5
429
+ hidden_states = hidden_states + self.layer[0].dropout(x)
430
+ else:
431
+ hidden_states = ff_layer(x)
432
+
433
+ # clamp inf values to enable fp16 training
434
+ if hidden_states.dtype == torch.float16:
435
+ clamp_value = torch.where(
436
+ torch.isinf(hidden_states).any(),
437
+ torch.finfo(hidden_states.dtype).max - 1000,
438
+ torch.finfo(hidden_states.dtype).max,
439
+ )
440
+ hidden_states = torch.clamp(
441
+ hidden_states, min=-clamp_value, max=clamp_value
442
+ )
443
+
444
+ outputs = (hidden_states,)
445
+
446
+ if use_cache:
447
+ outputs = outputs + (present_key_value_state,) + attention_outputs
448
+ else:
449
+ outputs = outputs + attention_outputs
450
+
451
+ return outputs # hidden-states, present_key_value_states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights)
452
+
453
+
454
+ class DecoderOnlyT5Stack(modeling_t5.T5Stack):
455
+ def __init__(self, config, embed_tokens=None):
456
+ super(modeling_t5.T5Stack, self).__init__(config)
457
+
458
+ self.embed_tokens = embed_tokens
459
+ self.is_decoder = config.is_decoder
460
+
461
+ self.block = nn.ModuleList(
462
+ [
463
+ DecoderOnlyT5Block(
464
+ config,
465
+ has_relative_attention_bias=(
466
+ config.has_relative_attention_bias and bool(i == 0)
467
+ ),
468
+ )
469
+ for i in range(config.num_layers)
470
+ ]
471
+ )
472
+ self.final_layer_norm = DecoderOnlyT5LayerNorm(
473
+ config.d_model,
474
+ eps=config.layer_norm_epsilon,
475
+ use_scale=False,
476
+ center_scale_at_zero=False,
477
+ )
478
+ self.dropout = nn.Dropout(config.dropout_rate)
479
+
480
+ # Initialize weights and apply final processing
481
+ self.post_init()
482
+ # Model parallel
483
+ self.model_parallel = False
484
+ self.device_map = None
485
+ self.gradient_checkpointing = False
486
+
487
+ def forward(
488
+ self,
489
+ input_ids=None,
490
+ position_ids=None,
491
+ attention_mask=None,
492
+ encoder_hidden_states=None,
493
+ encoder_attention_mask=None,
494
+ inputs_embeds=None,
495
+ head_mask=None,
496
+ cross_attn_head_mask=None,
497
+ past_key_values=None,
498
+ use_cache=None,
499
+ output_attentions=None,
500
+ output_hidden_states=None,
501
+ return_dict=None,
502
+ ):
503
+ # Model parallel
504
+ if self.model_parallel:
505
+ torch.cuda.set_device(self.first_device)
506
+ self.embed_tokens = self.embed_tokens.to(self.first_device)
507
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
508
+ output_attentions = (
509
+ output_attentions
510
+ if output_attentions is not None
511
+ else self.config.output_attentions
512
+ )
513
+ output_hidden_states = (
514
+ output_hidden_states
515
+ if output_hidden_states is not None
516
+ else self.config.output_hidden_states
517
+ )
518
+ return_dict = (
519
+ return_dict if return_dict is not None else self.config.use_return_dict
520
+ )
521
+
522
+ if input_ids is not None and inputs_embeds is not None:
523
+ err_msg_prefix = "decoder_" if self.is_decoder else ""
524
+ raise ValueError(
525
+ f"You cannot specify both {err_msg_prefix}input_ids and {err_msg_prefix}inputs_embeds at the same time"
526
+ )
527
+ elif input_ids is not None:
528
+ input_shape = input_ids.size()
529
+ input_ids = input_ids.view(-1, input_shape[-1])
530
+ elif inputs_embeds is not None:
531
+ input_shape = inputs_embeds.size()[:-1]
532
+ else:
533
+ err_msg_prefix = "decoder_" if self.is_decoder else ""
534
+ raise ValueError(
535
+ f"You have to specify either {err_msg_prefix}input_ids or {err_msg_prefix}inputs_embeds"
536
+ )
537
+
538
+ if position_ids is None:
539
+ seq_length = input_ids.shape[1]
540
+ past_key_values_length = (
541
+ 0 if past_key_values is None else past_key_values[0][0].shape[2]
542
+ )
543
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
544
+ position_ids = torch.arange(
545
+ past_key_values_length,
546
+ seq_length + past_key_values_length,
547
+ dtype=torch.long,
548
+ device=device,
549
+ ).unsqueeze(0)
550
+
551
+ if inputs_embeds is None:
552
+ if self.embed_tokens is None:
553
+ raise ValueError(
554
+ "You have to initialize the model with valid token embeddings"
555
+ )
556
+ inputs_embeds = self.embed_tokens(input_ids)
557
+
558
+ batch_size, seq_length = input_shape
559
+
560
+ # required mask seq length can be calculated via length of past
561
+ mask_seq_length = (
562
+ past_key_values[0][0].shape[2] + seq_length
563
+ if past_key_values is not None
564
+ else seq_length
565
+ )
566
+
567
+ if use_cache is True:
568
+ if not self.is_decoder:
569
+ raise ValueError(
570
+ f"`use_cache` can only be set to `True` if {self} is used as a decoder"
571
+ )
572
+
573
+ if attention_mask is None:
574
+ attention_mask = torch.ones(
575
+ batch_size, mask_seq_length, device=inputs_embeds.device
576
+ )
577
+
578
+ # initialize past_key_values with `None` if past does not exist
579
+ if past_key_values is None:
580
+ past_key_values = [None] * len(self.block)
581
+
582
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
583
+ # ourselves in which case we just need to make it broadcastable to all heads.
584
+ extended_attention_mask = self.get_extended_attention_mask(
585
+ attention_mask, input_shape
586
+ )
587
+
588
+ if self.gradient_checkpointing and self.training:
589
+ if use_cache:
590
+ logger.warning_once(
591
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
592
+ )
593
+ use_cache = False
594
+
595
+ # Prepare head mask if needed
596
+ head_mask = self.get_head_mask(head_mask, self.config.num_layers)
597
+ cross_attn_head_mask = self.get_head_mask(
598
+ cross_attn_head_mask, self.config.num_layers
599
+ )
600
+ present_key_value_states = () if use_cache else None
601
+ all_hidden_states = () if output_hidden_states else None
602
+ all_attentions = () if output_attentions else None
603
+ all_cross_attentions = () if (output_attentions and self.is_decoder) else None
604
+ position_bias = None
605
+
606
+ hidden_states = self.dropout(inputs_embeds)
607
+
608
+ for i, (layer_module, past_key_value) in enumerate(
609
+ zip(self.block, past_key_values)
610
+ ):
611
+ layer_head_mask = head_mask[i]
612
+ cross_attn_layer_head_mask = cross_attn_head_mask[i]
613
+ # Model parallel
614
+ if self.model_parallel:
615
+ torch.cuda.set_device(hidden_states.device)
616
+ # Ensure that attention_mask is always on the same device as hidden_states
617
+ if attention_mask is not None:
618
+ attention_mask = attention_mask.to(hidden_states.device)
619
+ if position_bias is not None:
620
+ position_bias = position_bias.to(hidden_states.device)
621
+ if layer_head_mask is not None:
622
+ layer_head_mask = layer_head_mask.to(hidden_states.device)
623
+
624
+ if output_hidden_states:
625
+ all_hidden_states = all_hidden_states + (hidden_states,)
626
+
627
+ if self.gradient_checkpointing and self.training:
628
+ layer_outputs = self._gradient_checkpointing_func(
629
+ layer_module.forward,
630
+ hidden_states,
631
+ extended_attention_mask,
632
+ position_bias,
633
+ None,
634
+ None,
635
+ None,
636
+ layer_head_mask,
637
+ cross_attn_layer_head_mask,
638
+ None, # past_key_value is always None with gradient checkpointing
639
+ use_cache,
640
+ output_attentions,
641
+ )
642
+ else:
643
+ layer_outputs = layer_module(
644
+ hidden_states,
645
+ attention_mask=extended_attention_mask,
646
+ position_bias=position_bias,
647
+ position_ids=position_ids,
648
+ encoder_hidden_states=None,
649
+ encoder_attention_mask=None,
650
+ encoder_decoder_position_bias=None,
651
+ layer_head_mask=layer_head_mask,
652
+ cross_attn_layer_head_mask=cross_attn_layer_head_mask,
653
+ past_key_value=past_key_value,
654
+ use_cache=use_cache,
655
+ output_attentions=output_attentions,
656
+ )
657
+
658
+ # layer_outputs is a tuple with:
659
+ # hidden-states, key-value-states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights)
660
+ if use_cache is False:
661
+ layer_outputs = layer_outputs[:1] + (None,) + layer_outputs[1:]
662
+
663
+ hidden_states, present_key_value_state = layer_outputs[:2]
664
+
665
+ # We share the position biases between the layers - the first layer store them
666
+ # layer_outputs = hidden-states, key-value-states (self-attention position bias), (self-attention weights),
667
+ # (cross-attention position bias), (cross-attention weights)
668
+ position_bias = layer_outputs[2]
669
+ # append next layer key value states
670
+ if use_cache:
671
+ present_key_value_states = present_key_value_states + (
672
+ present_key_value_state,
673
+ )
674
+
675
+ if output_attentions:
676
+ all_attentions = all_attentions + (layer_outputs[3],)
677
+ if self.is_decoder:
678
+ all_cross_attentions = all_cross_attentions + (layer_outputs[5],)
679
+
680
+ # Model Parallel: If it's the last layer for that device, put things on the next device
681
+ if self.model_parallel:
682
+ for k, v in self.device_map.items():
683
+ if i == v[-1] and "cuda:" + str(k) != self.last_device:
684
+ hidden_states = hidden_states.to("cuda:" + str(k + 1))
685
+
686
+ hidden_states = self.final_layer_norm(hidden_states)
687
+ hidden_states = self.dropout(hidden_states)
688
+
689
+ # Add last layer
690
+ if output_hidden_states:
691
+ all_hidden_states = all_hidden_states + (hidden_states,)
692
+
693
+ if not return_dict:
694
+ return tuple(
695
+ v
696
+ for v in [
697
+ hidden_states,
698
+ present_key_value_states,
699
+ all_hidden_states,
700
+ all_attentions,
701
+ all_cross_attentions,
702
+ ]
703
+ if v is not None
704
+ )
705
+ return modeling_t5.BaseModelOutputWithPastAndCrossAttentions(
706
+ last_hidden_state=hidden_states,
707
+ past_key_values=present_key_value_states,
708
+ hidden_states=all_hidden_states,
709
+ attentions=all_attentions,
710
+ cross_attentions=all_cross_attentions,
711
+ )
712
+
713
+
714
+ class DecoderOnlyT5Model(modeling_t5.T5ForConditionalGeneration):
715
+ def __init__(self, config: DecoderOnlyT5Config):
716
+ super(modeling_t5.T5ForConditionalGeneration, self).__init__(config)
717
+ self.model_dim = config.d_model
718
+
719
+ self.shared = nn.Embedding(config.vocab_size, config.d_model)
720
+ assert (
721
+ self.config.num_layers == 0
722
+ ), "Decoder only model cannot have encoder layers"
723
+ self.encoder = None
724
+
725
+ decoder_config = copy.deepcopy(config)
726
+ decoder_config.is_decoder = True
727
+ decoder_config.is_encoder_decoder = False
728
+ decoder_config.num_layers = config.num_decoder_layers
729
+ self.decoder = DecoderOnlyT5Stack(decoder_config, self.shared)
730
+
731
+ self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
732
+
733
+ # Initialize weights and apply final processing
734
+ self.post_init()
735
+
736
+ # Model parallel
737
+ self.model_parallel = False
738
+ self.device_map = None
739
+
740
+ def _tie_weights(self):
741
+ if not self.config.tie_word_embeddings:
742
+ return
743
+ if self.decoder:
744
+ self._tie_or_clone_weights(self.decoder.embed_tokens, self.shared)
745
+
746
+ @add_start_docstrings_to_model_forward(modeling_t5.T5_INPUTS_DOCSTRING)
747
+ @replace_return_docstrings(
748
+ output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
749
+ )
750
+ def forward(
751
+ self,
752
+ input_ids: Optional[torch.LongTensor] = None,
753
+ position_ids: Optional[torch.LongTensor] = None,
754
+ attention_mask: Optional[torch.FloatTensor] = None,
755
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
756
+ inputs_embeds: Optional[torch.FloatTensor] = None,
757
+ labels: Optional[torch.LongTensor] = None,
758
+ use_cache: Optional[bool] = None,
759
+ output_attentions: Optional[bool] = None,
760
+ output_hidden_states: Optional[bool] = None,
761
+ return_dict: Optional[bool] = None,
762
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
763
+ r"""
764
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
765
+ Labels for computing the sequence classification/regression loss. Indices should be in `[-100, 0, ...,
766
+ config.vocab_size - 1]`. All labels set to `-100` are ignored (masked), the loss is only computed for
767
+ labels in `[0, ..., config.vocab_size]`
768
+
769
+ Returns:
770
+
771
+ Examples:
772
+
773
+ ```"""
774
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
775
+ return_dict = (
776
+ return_dict if return_dict is not None else self.config.use_return_dict
777
+ )
778
+
779
+ if self.model_parallel:
780
+ torch.cuda.set_device(self.decoder.first_device)
781
+
782
+ # Set device for model parallelism
783
+ if self.model_parallel:
784
+ torch.cuda.set_device(self.decoder.first_device)
785
+ if input_ids is not None:
786
+ input_ids = input_ids.to(self.decoder.first_device)
787
+ if attention_mask is not None:
788
+ attention_mask = attention_mask.to(self.decoder.first_device)
789
+
790
+ # Decode
791
+ outputs = self.decoder(
792
+ input_ids=input_ids,
793
+ position_ids=position_ids,
794
+ attention_mask=attention_mask,
795
+ inputs_embeds=inputs_embeds,
796
+ past_key_values=past_key_values,
797
+ encoder_hidden_states=None,
798
+ encoder_attention_mask=None,
799
+ head_mask=None,
800
+ cross_attn_head_mask=None,
801
+ use_cache=use_cache,
802
+ output_attentions=output_attentions,
803
+ output_hidden_states=output_hidden_states,
804
+ return_dict=return_dict,
805
+ )
806
+
807
+ sequence_output = outputs[0]
808
+
809
+ # Set device for model parallelism
810
+ if self.model_parallel:
811
+ torch.cuda.set_device(self.decoder.first_device)
812
+ self.lm_head = self.lm_head.to(self.decoder.first_device)
813
+ sequence_output = sequence_output.to(self.lm_head.weight.device)
814
+
815
+ if self.config.tie_word_embeddings:
816
+ # Rescale output before projecting on vocab
817
+ # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586
818
+ sequence_output = sequence_output * (self.model_dim**-0.5)
819
+
820
+ lm_logits = self.lm_head(sequence_output)
821
+
822
+ loss = None
823
+ if labels is not None:
824
+ loss_fct = CrossEntropyLoss(ignore_index=-100)
825
+ # move labels to correct device to enable PP
826
+ labels = labels.to(lm_logits.device)
827
+ loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1))
828
+ # TODO(thom): Add z_loss https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L666
829
+
830
+ if not return_dict:
831
+ output = (lm_logits,) + outputs[1:]
832
+ return ((loss,) + output) if loss is not None else output
833
+
834
+ return CausalLMOutputWithPast(
835
+ loss=loss,
836
+ logits=lm_logits,
837
+ past_key_values=outputs.past_key_values,
838
+ hidden_states=outputs.hidden_states,
839
+ attentions=outputs.attentions,
840
+ )
model-00000-of-00007.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:eee3b0c4eef668152f9f1106f18bf0a892bd04ba8b26017d7d5865f49dec5f3c
3
+ size 5150622792
model-00001-of-00007.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:29b44a655d9261522e705d963d1fa7ca1717c3e6bcfb9402fa69cb8ee6156c6f
3
+ size 4739650416
model-00002-of-00007.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1dbd16405fa07722d953ba5c99aeb8ae05c1068cb4018a9622e5828336c1b9c8
3
+ size 4739650424
model-00003-of-00007.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7444d7670f145f12706970a191b260b609a82626c5e417f08f1039c05fbdda75
3
+ size 4739650456
model-00004-of-00007.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:471e481023c3f4622d2ff0b1031a34c55469fb10fe1252f5c5c62e8f95418b4b
3
+ size 4739650456
model-00005-of-00007.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e934d8fcf5bc54649e82b29c1322b7ce3d13e8915d2cac205e65660e0a4cdbbb
3
+ size 4739650456
model-00006-of-00007.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f3f82de95453b49cb5a78ec517fac3219556f83edd9075a20a7ac95a577b5e93
3
+ size 4739650456
model-00007-of-00007.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f816e875186339eaf9afa8964d34952c80e697ade369154512f9900ea0a33553
3
+ size 947930104
model.safetensors.index.json ADDED
@@ -0,0 +1,262 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "metadata": {},
3
+ "weight_map": {
4
+ "shared.weight": "model-00000-of-00007.safetensors",
5
+ "decoder.block.0.layer.0.layer_norm.weight": "model-00000-of-00007.safetensors",
6
+ "decoder.block.0.layer.0.SelfAttention.k.weight": "model-00000-of-00007.safetensors",
7
+ "decoder.block.0.layer.0.SelfAttention.o.weight": "model-00000-of-00007.safetensors",
8
+ "decoder.block.0.layer.0.SelfAttention.q.weight": "model-00000-of-00007.safetensors",
9
+ "decoder.block.0.layer.0.SelfAttention.v.weight": "model-00000-of-00007.safetensors",
10
+ "decoder.block.0.layer.2.DenseReluDense.wi_0.weight": "model-00000-of-00007.safetensors",
11
+ "decoder.block.0.layer.2.DenseReluDense.wi_1.weight": "model-00000-of-00007.safetensors",
12
+ "decoder.block.0.layer.2.DenseReluDense.wo.weight": "model-00000-of-00007.safetensors",
13
+ "decoder.block.1.layer.0.layer_norm.weight": "model-00001-of-00007.safetensors",
14
+ "decoder.block.1.layer.0.SelfAttention.k.weight": "model-00001-of-00007.safetensors",
15
+ "decoder.block.1.layer.0.SelfAttention.o.weight": "model-00001-of-00007.safetensors",
16
+ "decoder.block.1.layer.0.SelfAttention.q.weight": "model-00001-of-00007.safetensors",
17
+ "decoder.block.1.layer.0.SelfAttention.v.weight": "model-00001-of-00007.safetensors",
18
+ "decoder.block.1.layer.2.DenseReluDense.wi_0.weight": "model-00001-of-00007.safetensors",
19
+ "decoder.block.1.layer.2.DenseReluDense.wi_1.weight": "model-00001-of-00007.safetensors",
20
+ "decoder.block.1.layer.2.DenseReluDense.wo.weight": "model-00001-of-00007.safetensors",
21
+ "decoder.block.2.layer.0.layer_norm.weight": "model-00001-of-00007.safetensors",
22
+ "decoder.block.2.layer.0.SelfAttention.k.weight": "model-00001-of-00007.safetensors",
23
+ "decoder.block.2.layer.0.SelfAttention.o.weight": "model-00001-of-00007.safetensors",
24
+ "decoder.block.2.layer.0.SelfAttention.q.weight": "model-00001-of-00007.safetensors",
25
+ "decoder.block.2.layer.0.SelfAttention.v.weight": "model-00001-of-00007.safetensors",
26
+ "decoder.block.2.layer.2.DenseReluDense.wi_0.weight": "model-00001-of-00007.safetensors",
27
+ "decoder.block.2.layer.2.DenseReluDense.wi_1.weight": "model-00001-of-00007.safetensors",
28
+ "decoder.block.2.layer.2.DenseReluDense.wo.weight": "model-00001-of-00007.safetensors",
29
+ "decoder.block.3.layer.0.layer_norm.weight": "model-00001-of-00007.safetensors",
30
+ "decoder.block.3.layer.0.SelfAttention.k.weight": "model-00001-of-00007.safetensors",
31
+ "decoder.block.3.layer.0.SelfAttention.o.weight": "model-00001-of-00007.safetensors",
32
+ "decoder.block.3.layer.0.SelfAttention.q.weight": "model-00001-of-00007.safetensors",
33
+ "decoder.block.3.layer.0.SelfAttention.v.weight": "model-00001-of-00007.safetensors",
34
+ "decoder.block.3.layer.2.DenseReluDense.wi_0.weight": "model-00001-of-00007.safetensors",
35
+ "decoder.block.3.layer.2.DenseReluDense.wi_1.weight": "model-00001-of-00007.safetensors",
36
+ "decoder.block.3.layer.2.DenseReluDense.wo.weight": "model-00001-of-00007.safetensors",
37
+ "decoder.block.4.layer.0.layer_norm.weight": "model-00001-of-00007.safetensors",
38
+ "decoder.block.4.layer.0.SelfAttention.k.weight": "model-00001-of-00007.safetensors",
39
+ "decoder.block.4.layer.0.SelfAttention.o.weight": "model-00001-of-00007.safetensors",
40
+ "decoder.block.4.layer.0.SelfAttention.q.weight": "model-00001-of-00007.safetensors",
41
+ "decoder.block.4.layer.0.SelfAttention.v.weight": "model-00001-of-00007.safetensors",
42
+ "decoder.block.4.layer.2.DenseReluDense.wi_0.weight": "model-00001-of-00007.safetensors",
43
+ "decoder.block.4.layer.2.DenseReluDense.wi_1.weight": "model-00001-of-00007.safetensors",
44
+ "decoder.block.4.layer.2.DenseReluDense.wo.weight": "model-00001-of-00007.safetensors",
45
+ "decoder.block.5.layer.0.layer_norm.weight": "model-00001-of-00007.safetensors",
46
+ "decoder.block.5.layer.0.SelfAttention.k.weight": "model-00001-of-00007.safetensors",
47
+ "decoder.block.5.layer.0.SelfAttention.o.weight": "model-00001-of-00007.safetensors",
48
+ "decoder.block.5.layer.0.SelfAttention.q.weight": "model-00001-of-00007.safetensors",
49
+ "decoder.block.5.layer.0.SelfAttention.v.weight": "model-00001-of-00007.safetensors",
50
+ "decoder.block.5.layer.2.DenseReluDense.wi_0.weight": "model-00001-of-00007.safetensors",
51
+ "decoder.block.5.layer.2.DenseReluDense.wi_1.weight": "model-00001-of-00007.safetensors",
52
+ "decoder.block.5.layer.2.DenseReluDense.wo.weight": "model-00001-of-00007.safetensors",
53
+ "decoder.block.6.layer.0.layer_norm.weight": "model-00002-of-00007.safetensors",
54
+ "decoder.block.6.layer.0.SelfAttention.k.weight": "model-00002-of-00007.safetensors",
55
+ "decoder.block.6.layer.0.SelfAttention.o.weight": "model-00002-of-00007.safetensors",
56
+ "decoder.block.6.layer.0.SelfAttention.q.weight": "model-00002-of-00007.safetensors",
57
+ "decoder.block.6.layer.0.SelfAttention.v.weight": "model-00002-of-00007.safetensors",
58
+ "decoder.block.6.layer.2.DenseReluDense.wi_0.weight": "model-00002-of-00007.safetensors",
59
+ "decoder.block.6.layer.2.DenseReluDense.wi_1.weight": "model-00002-of-00007.safetensors",
60
+ "decoder.block.6.layer.2.DenseReluDense.wo.weight": "model-00002-of-00007.safetensors",
61
+ "decoder.block.7.layer.0.layer_norm.weight": "model-00002-of-00007.safetensors",
62
+ "decoder.block.7.layer.0.SelfAttention.k.weight": "model-00002-of-00007.safetensors",
63
+ "decoder.block.7.layer.0.SelfAttention.o.weight": "model-00002-of-00007.safetensors",
64
+ "decoder.block.7.layer.0.SelfAttention.q.weight": "model-00002-of-00007.safetensors",
65
+ "decoder.block.7.layer.0.SelfAttention.v.weight": "model-00002-of-00007.safetensors",
66
+ "decoder.block.7.layer.2.DenseReluDense.wi_0.weight": "model-00002-of-00007.safetensors",
67
+ "decoder.block.7.layer.2.DenseReluDense.wi_1.weight": "model-00002-of-00007.safetensors",
68
+ "decoder.block.7.layer.2.DenseReluDense.wo.weight": "model-00002-of-00007.safetensors",
69
+ "decoder.block.8.layer.0.layer_norm.weight": "model-00002-of-00007.safetensors",
70
+ "decoder.block.8.layer.0.SelfAttention.k.weight": "model-00002-of-00007.safetensors",
71
+ "decoder.block.8.layer.0.SelfAttention.o.weight": "model-00002-of-00007.safetensors",
72
+ "decoder.block.8.layer.0.SelfAttention.q.weight": "model-00002-of-00007.safetensors",
73
+ "decoder.block.8.layer.0.SelfAttention.v.weight": "model-00002-of-00007.safetensors",
74
+ "decoder.block.8.layer.2.DenseReluDense.wi_0.weight": "model-00002-of-00007.safetensors",
75
+ "decoder.block.8.layer.2.DenseReluDense.wi_1.weight": "model-00002-of-00007.safetensors",
76
+ "decoder.block.8.layer.2.DenseReluDense.wo.weight": "model-00002-of-00007.safetensors",
77
+ "decoder.block.9.layer.0.layer_norm.weight": "model-00002-of-00007.safetensors",
78
+ "decoder.block.9.layer.0.SelfAttention.k.weight": "model-00002-of-00007.safetensors",
79
+ "decoder.block.9.layer.0.SelfAttention.o.weight": "model-00002-of-00007.safetensors",
80
+ "decoder.block.9.layer.0.SelfAttention.q.weight": "model-00002-of-00007.safetensors",
81
+ "decoder.block.9.layer.0.SelfAttention.v.weight": "model-00002-of-00007.safetensors",
82
+ "decoder.block.9.layer.2.DenseReluDense.wi_0.weight": "model-00002-of-00007.safetensors",
83
+ "decoder.block.9.layer.2.DenseReluDense.wi_1.weight": "model-00002-of-00007.safetensors",
84
+ "decoder.block.9.layer.2.DenseReluDense.wo.weight": "model-00002-of-00007.safetensors",
85
+ "decoder.block.10.layer.0.layer_norm.weight": "model-00002-of-00007.safetensors",
86
+ "decoder.block.10.layer.0.SelfAttention.k.weight": "model-00002-of-00007.safetensors",
87
+ "decoder.block.10.layer.0.SelfAttention.o.weight": "model-00002-of-00007.safetensors",
88
+ "decoder.block.10.layer.0.SelfAttention.q.weight": "model-00002-of-00007.safetensors",
89
+ "decoder.block.10.layer.0.SelfAttention.v.weight": "model-00002-of-00007.safetensors",
90
+ "decoder.block.10.layer.2.DenseReluDense.wi_0.weight": "model-00002-of-00007.safetensors",
91
+ "decoder.block.10.layer.2.DenseReluDense.wi_1.weight": "model-00002-of-00007.safetensors",
92
+ "decoder.block.10.layer.2.DenseReluDense.wo.weight": "model-00002-of-00007.safetensors",
93
+ "decoder.block.11.layer.0.layer_norm.weight": "model-00003-of-00007.safetensors",
94
+ "decoder.block.11.layer.0.SelfAttention.k.weight": "model-00003-of-00007.safetensors",
95
+ "decoder.block.11.layer.0.SelfAttention.o.weight": "model-00003-of-00007.safetensors",
96
+ "decoder.block.11.layer.0.SelfAttention.q.weight": "model-00003-of-00007.safetensors",
97
+ "decoder.block.11.layer.0.SelfAttention.v.weight": "model-00003-of-00007.safetensors",
98
+ "decoder.block.11.layer.2.DenseReluDense.wi_0.weight": "model-00003-of-00007.safetensors",
99
+ "decoder.block.11.layer.2.DenseReluDense.wi_1.weight": "model-00003-of-00007.safetensors",
100
+ "decoder.block.11.layer.2.DenseReluDense.wo.weight": "model-00003-of-00007.safetensors",
101
+ "decoder.block.12.layer.0.layer_norm.weight": "model-00003-of-00007.safetensors",
102
+ "decoder.block.12.layer.0.SelfAttention.k.weight": "model-00003-of-00007.safetensors",
103
+ "decoder.block.12.layer.0.SelfAttention.o.weight": "model-00003-of-00007.safetensors",
104
+ "decoder.block.12.layer.0.SelfAttention.q.weight": "model-00003-of-00007.safetensors",
105
+ "decoder.block.12.layer.0.SelfAttention.v.weight": "model-00003-of-00007.safetensors",
106
+ "decoder.block.12.layer.2.DenseReluDense.wi_0.weight": "model-00003-of-00007.safetensors",
107
+ "decoder.block.12.layer.2.DenseReluDense.wi_1.weight": "model-00003-of-00007.safetensors",
108
+ "decoder.block.12.layer.2.DenseReluDense.wo.weight": "model-00003-of-00007.safetensors",
109
+ "decoder.block.13.layer.0.layer_norm.weight": "model-00003-of-00007.safetensors",
110
+ "decoder.block.13.layer.0.SelfAttention.k.weight": "model-00003-of-00007.safetensors",
111
+ "decoder.block.13.layer.0.SelfAttention.o.weight": "model-00003-of-00007.safetensors",
112
+ "decoder.block.13.layer.0.SelfAttention.q.weight": "model-00003-of-00007.safetensors",
113
+ "decoder.block.13.layer.0.SelfAttention.v.weight": "model-00003-of-00007.safetensors",
114
+ "decoder.block.13.layer.2.DenseReluDense.wi_0.weight": "model-00003-of-00007.safetensors",
115
+ "decoder.block.13.layer.2.DenseReluDense.wi_1.weight": "model-00003-of-00007.safetensors",
116
+ "decoder.block.13.layer.2.DenseReluDense.wo.weight": "model-00003-of-00007.safetensors",
117
+ "decoder.block.14.layer.0.layer_norm.weight": "model-00003-of-00007.safetensors",
118
+ "decoder.block.14.layer.0.SelfAttention.k.weight": "model-00003-of-00007.safetensors",
119
+ "decoder.block.14.layer.0.SelfAttention.o.weight": "model-00003-of-00007.safetensors",
120
+ "decoder.block.14.layer.0.SelfAttention.q.weight": "model-00003-of-00007.safetensors",
121
+ "decoder.block.14.layer.0.SelfAttention.v.weight": "model-00003-of-00007.safetensors",
122
+ "decoder.block.14.layer.2.DenseReluDense.wi_0.weight": "model-00003-of-00007.safetensors",
123
+ "decoder.block.14.layer.2.DenseReluDense.wi_1.weight": "model-00003-of-00007.safetensors",
124
+ "decoder.block.14.layer.2.DenseReluDense.wo.weight": "model-00003-of-00007.safetensors",
125
+ "decoder.block.15.layer.0.layer_norm.weight": "model-00003-of-00007.safetensors",
126
+ "decoder.block.15.layer.0.SelfAttention.k.weight": "model-00003-of-00007.safetensors",
127
+ "decoder.block.15.layer.0.SelfAttention.o.weight": "model-00003-of-00007.safetensors",
128
+ "decoder.block.15.layer.0.SelfAttention.q.weight": "model-00003-of-00007.safetensors",
129
+ "decoder.block.15.layer.0.SelfAttention.v.weight": "model-00003-of-00007.safetensors",
130
+ "decoder.block.15.layer.2.DenseReluDense.wi_0.weight": "model-00003-of-00007.safetensors",
131
+ "decoder.block.15.layer.2.DenseReluDense.wi_1.weight": "model-00003-of-00007.safetensors",
132
+ "decoder.block.15.layer.2.DenseReluDense.wo.weight": "model-00003-of-00007.safetensors",
133
+ "decoder.block.16.layer.0.layer_norm.weight": "model-00004-of-00007.safetensors",
134
+ "decoder.block.16.layer.0.SelfAttention.k.weight": "model-00004-of-00007.safetensors",
135
+ "decoder.block.16.layer.0.SelfAttention.o.weight": "model-00004-of-00007.safetensors",
136
+ "decoder.block.16.layer.0.SelfAttention.q.weight": "model-00004-of-00007.safetensors",
137
+ "decoder.block.16.layer.0.SelfAttention.v.weight": "model-00004-of-00007.safetensors",
138
+ "decoder.block.16.layer.2.DenseReluDense.wi_0.weight": "model-00004-of-00007.safetensors",
139
+ "decoder.block.16.layer.2.DenseReluDense.wi_1.weight": "model-00004-of-00007.safetensors",
140
+ "decoder.block.16.layer.2.DenseReluDense.wo.weight": "model-00004-of-00007.safetensors",
141
+ "decoder.block.17.layer.0.layer_norm.weight": "model-00004-of-00007.safetensors",
142
+ "decoder.block.17.layer.0.SelfAttention.k.weight": "model-00004-of-00007.safetensors",
143
+ "decoder.block.17.layer.0.SelfAttention.o.weight": "model-00004-of-00007.safetensors",
144
+ "decoder.block.17.layer.0.SelfAttention.q.weight": "model-00004-of-00007.safetensors",
145
+ "decoder.block.17.layer.0.SelfAttention.v.weight": "model-00004-of-00007.safetensors",
146
+ "decoder.block.17.layer.2.DenseReluDense.wi_0.weight": "model-00004-of-00007.safetensors",
147
+ "decoder.block.17.layer.2.DenseReluDense.wi_1.weight": "model-00004-of-00007.safetensors",
148
+ "decoder.block.17.layer.2.DenseReluDense.wo.weight": "model-00004-of-00007.safetensors",
149
+ "decoder.block.18.layer.0.layer_norm.weight": "model-00004-of-00007.safetensors",
150
+ "decoder.block.18.layer.0.SelfAttention.k.weight": "model-00004-of-00007.safetensors",
151
+ "decoder.block.18.layer.0.SelfAttention.o.weight": "model-00004-of-00007.safetensors",
152
+ "decoder.block.18.layer.0.SelfAttention.q.weight": "model-00004-of-00007.safetensors",
153
+ "decoder.block.18.layer.0.SelfAttention.v.weight": "model-00004-of-00007.safetensors",
154
+ "decoder.block.18.layer.2.DenseReluDense.wi_0.weight": "model-00004-of-00007.safetensors",
155
+ "decoder.block.18.layer.2.DenseReluDense.wi_1.weight": "model-00004-of-00007.safetensors",
156
+ "decoder.block.18.layer.2.DenseReluDense.wo.weight": "model-00004-of-00007.safetensors",
157
+ "decoder.block.19.layer.0.layer_norm.weight": "model-00004-of-00007.safetensors",
158
+ "decoder.block.19.layer.0.SelfAttention.k.weight": "model-00004-of-00007.safetensors",
159
+ "decoder.block.19.layer.0.SelfAttention.o.weight": "model-00004-of-00007.safetensors",
160
+ "decoder.block.19.layer.0.SelfAttention.q.weight": "model-00004-of-00007.safetensors",
161
+ "decoder.block.19.layer.0.SelfAttention.v.weight": "model-00004-of-00007.safetensors",
162
+ "decoder.block.19.layer.2.DenseReluDense.wi_0.weight": "model-00004-of-00007.safetensors",
163
+ "decoder.block.19.layer.2.DenseReluDense.wi_1.weight": "model-00004-of-00007.safetensors",
164
+ "decoder.block.19.layer.2.DenseReluDense.wo.weight": "model-00004-of-00007.safetensors",
165
+ "decoder.block.20.layer.0.layer_norm.weight": "model-00004-of-00007.safetensors",
166
+ "decoder.block.20.layer.0.SelfAttention.k.weight": "model-00004-of-00007.safetensors",
167
+ "decoder.block.20.layer.0.SelfAttention.o.weight": "model-00004-of-00007.safetensors",
168
+ "decoder.block.20.layer.0.SelfAttention.q.weight": "model-00004-of-00007.safetensors",
169
+ "decoder.block.20.layer.0.SelfAttention.v.weight": "model-00004-of-00007.safetensors",
170
+ "decoder.block.20.layer.2.DenseReluDense.wi_0.weight": "model-00004-of-00007.safetensors",
171
+ "decoder.block.20.layer.2.DenseReluDense.wi_1.weight": "model-00004-of-00007.safetensors",
172
+ "decoder.block.20.layer.2.DenseReluDense.wo.weight": "model-00004-of-00007.safetensors",
173
+ "decoder.block.21.layer.0.layer_norm.weight": "model-00005-of-00007.safetensors",
174
+ "decoder.block.21.layer.0.SelfAttention.k.weight": "model-00005-of-00007.safetensors",
175
+ "decoder.block.21.layer.0.SelfAttention.o.weight": "model-00005-of-00007.safetensors",
176
+ "decoder.block.21.layer.0.SelfAttention.q.weight": "model-00005-of-00007.safetensors",
177
+ "decoder.block.21.layer.0.SelfAttention.v.weight": "model-00005-of-00007.safetensors",
178
+ "decoder.block.21.layer.2.DenseReluDense.wi_0.weight": "model-00005-of-00007.safetensors",
179
+ "decoder.block.21.layer.2.DenseReluDense.wi_1.weight": "model-00005-of-00007.safetensors",
180
+ "decoder.block.21.layer.2.DenseReluDense.wo.weight": "model-00005-of-00007.safetensors",
181
+ "decoder.block.22.layer.0.layer_norm.weight": "model-00005-of-00007.safetensors",
182
+ "decoder.block.22.layer.0.SelfAttention.k.weight": "model-00005-of-00007.safetensors",
183
+ "decoder.block.22.layer.0.SelfAttention.o.weight": "model-00005-of-00007.safetensors",
184
+ "decoder.block.22.layer.0.SelfAttention.q.weight": "model-00005-of-00007.safetensors",
185
+ "decoder.block.22.layer.0.SelfAttention.v.weight": "model-00005-of-00007.safetensors",
186
+ "decoder.block.22.layer.2.DenseReluDense.wi_0.weight": "model-00005-of-00007.safetensors",
187
+ "decoder.block.22.layer.2.DenseReluDense.wi_1.weight": "model-00005-of-00007.safetensors",
188
+ "decoder.block.22.layer.2.DenseReluDense.wo.weight": "model-00005-of-00007.safetensors",
189
+ "decoder.block.23.layer.0.layer_norm.weight": "model-00005-of-00007.safetensors",
190
+ "decoder.block.23.layer.0.SelfAttention.k.weight": "model-00005-of-00007.safetensors",
191
+ "decoder.block.23.layer.0.SelfAttention.o.weight": "model-00005-of-00007.safetensors",
192
+ "decoder.block.23.layer.0.SelfAttention.q.weight": "model-00005-of-00007.safetensors",
193
+ "decoder.block.23.layer.0.SelfAttention.v.weight": "model-00005-of-00007.safetensors",
194
+ "decoder.block.23.layer.2.DenseReluDense.wi_0.weight": "model-00005-of-00007.safetensors",
195
+ "decoder.block.23.layer.2.DenseReluDense.wi_1.weight": "model-00005-of-00007.safetensors",
196
+ "decoder.block.23.layer.2.DenseReluDense.wo.weight": "model-00005-of-00007.safetensors",
197
+ "decoder.block.24.layer.0.layer_norm.weight": "model-00005-of-00007.safetensors",
198
+ "decoder.block.24.layer.0.SelfAttention.k.weight": "model-00005-of-00007.safetensors",
199
+ "decoder.block.24.layer.0.SelfAttention.o.weight": "model-00005-of-00007.safetensors",
200
+ "decoder.block.24.layer.0.SelfAttention.q.weight": "model-00005-of-00007.safetensors",
201
+ "decoder.block.24.layer.0.SelfAttention.v.weight": "model-00005-of-00007.safetensors",
202
+ "decoder.block.24.layer.2.DenseReluDense.wi_0.weight": "model-00005-of-00007.safetensors",
203
+ "decoder.block.24.layer.2.DenseReluDense.wi_1.weight": "model-00005-of-00007.safetensors",
204
+ "decoder.block.24.layer.2.DenseReluDense.wo.weight": "model-00005-of-00007.safetensors",
205
+ "decoder.block.25.layer.0.layer_norm.weight": "model-00005-of-00007.safetensors",
206
+ "decoder.block.25.layer.0.SelfAttention.k.weight": "model-00005-of-00007.safetensors",
207
+ "decoder.block.25.layer.0.SelfAttention.o.weight": "model-00005-of-00007.safetensors",
208
+ "decoder.block.25.layer.0.SelfAttention.q.weight": "model-00005-of-00007.safetensors",
209
+ "decoder.block.25.layer.0.SelfAttention.v.weight": "model-00005-of-00007.safetensors",
210
+ "decoder.block.25.layer.2.DenseReluDense.wi_0.weight": "model-00005-of-00007.safetensors",
211
+ "decoder.block.25.layer.2.DenseReluDense.wi_1.weight": "model-00005-of-00007.safetensors",
212
+ "decoder.block.25.layer.2.DenseReluDense.wo.weight": "model-00005-of-00007.safetensors",
213
+ "decoder.block.26.layer.0.layer_norm.weight": "model-00006-of-00007.safetensors",
214
+ "decoder.block.26.layer.0.SelfAttention.k.weight": "model-00006-of-00007.safetensors",
215
+ "decoder.block.26.layer.0.SelfAttention.o.weight": "model-00006-of-00007.safetensors",
216
+ "decoder.block.26.layer.0.SelfAttention.q.weight": "model-00006-of-00007.safetensors",
217
+ "decoder.block.26.layer.0.SelfAttention.v.weight": "model-00006-of-00007.safetensors",
218
+ "decoder.block.26.layer.2.DenseReluDense.wi_0.weight": "model-00006-of-00007.safetensors",
219
+ "decoder.block.26.layer.2.DenseReluDense.wi_1.weight": "model-00006-of-00007.safetensors",
220
+ "decoder.block.26.layer.2.DenseReluDense.wo.weight": "model-00006-of-00007.safetensors",
221
+ "decoder.block.27.layer.0.layer_norm.weight": "model-00006-of-00007.safetensors",
222
+ "decoder.block.27.layer.0.SelfAttention.k.weight": "model-00006-of-00007.safetensors",
223
+ "decoder.block.27.layer.0.SelfAttention.o.weight": "model-00006-of-00007.safetensors",
224
+ "decoder.block.27.layer.0.SelfAttention.q.weight": "model-00006-of-00007.safetensors",
225
+ "decoder.block.27.layer.0.SelfAttention.v.weight": "model-00006-of-00007.safetensors",
226
+ "decoder.block.27.layer.2.DenseReluDense.wi_0.weight": "model-00006-of-00007.safetensors",
227
+ "decoder.block.27.layer.2.DenseReluDense.wi_1.weight": "model-00006-of-00007.safetensors",
228
+ "decoder.block.27.layer.2.DenseReluDense.wo.weight": "model-00006-of-00007.safetensors",
229
+ "decoder.block.28.layer.0.layer_norm.weight": "model-00006-of-00007.safetensors",
230
+ "decoder.block.28.layer.0.SelfAttention.k.weight": "model-00006-of-00007.safetensors",
231
+ "decoder.block.28.layer.0.SelfAttention.o.weight": "model-00006-of-00007.safetensors",
232
+ "decoder.block.28.layer.0.SelfAttention.q.weight": "model-00006-of-00007.safetensors",
233
+ "decoder.block.28.layer.0.SelfAttention.v.weight": "model-00006-of-00007.safetensors",
234
+ "decoder.block.28.layer.2.DenseReluDense.wi_0.weight": "model-00006-of-00007.safetensors",
235
+ "decoder.block.28.layer.2.DenseReluDense.wi_1.weight": "model-00006-of-00007.safetensors",
236
+ "decoder.block.28.layer.2.DenseReluDense.wo.weight": "model-00006-of-00007.safetensors",
237
+ "decoder.block.29.layer.0.layer_norm.weight": "model-00006-of-00007.safetensors",
238
+ "decoder.block.29.layer.0.SelfAttention.k.weight": "model-00006-of-00007.safetensors",
239
+ "decoder.block.29.layer.0.SelfAttention.o.weight": "model-00006-of-00007.safetensors",
240
+ "decoder.block.29.layer.0.SelfAttention.q.weight": "model-00006-of-00007.safetensors",
241
+ "decoder.block.29.layer.0.SelfAttention.v.weight": "model-00006-of-00007.safetensors",
242
+ "decoder.block.29.layer.2.DenseReluDense.wi_0.weight": "model-00006-of-00007.safetensors",
243
+ "decoder.block.29.layer.2.DenseReluDense.wi_1.weight": "model-00006-of-00007.safetensors",
244
+ "decoder.block.29.layer.2.DenseReluDense.wo.weight": "model-00006-of-00007.safetensors",
245
+ "decoder.block.30.layer.0.layer_norm.weight": "model-00006-of-00007.safetensors",
246
+ "decoder.block.30.layer.0.SelfAttention.k.weight": "model-00006-of-00007.safetensors",
247
+ "decoder.block.30.layer.0.SelfAttention.o.weight": "model-00006-of-00007.safetensors",
248
+ "decoder.block.30.layer.0.SelfAttention.q.weight": "model-00006-of-00007.safetensors",
249
+ "decoder.block.30.layer.0.SelfAttention.v.weight": "model-00006-of-00007.safetensors",
250
+ "decoder.block.30.layer.2.DenseReluDense.wi_0.weight": "model-00006-of-00007.safetensors",
251
+ "decoder.block.30.layer.2.DenseReluDense.wi_1.weight": "model-00006-of-00007.safetensors",
252
+ "decoder.block.30.layer.2.DenseReluDense.wo.weight": "model-00006-of-00007.safetensors",
253
+ "decoder.block.31.layer.0.layer_norm.weight": "model-00007-of-00007.safetensors",
254
+ "decoder.block.31.layer.0.SelfAttention.k.weight": "model-00007-of-00007.safetensors",
255
+ "decoder.block.31.layer.0.SelfAttention.o.weight": "model-00007-of-00007.safetensors",
256
+ "decoder.block.31.layer.0.SelfAttention.q.weight": "model-00007-of-00007.safetensors",
257
+ "decoder.block.31.layer.0.SelfAttention.v.weight": "model-00007-of-00007.safetensors",
258
+ "decoder.block.31.layer.2.DenseReluDense.wi_0.weight": "model-00007-of-00007.safetensors",
259
+ "decoder.block.31.layer.2.DenseReluDense.wi_1.weight": "model-00007-of-00007.safetensors",
260
+ "decoder.block.31.layer.2.DenseReluDense.wo.weight": "model-00007-of-00007.safetensors"
261
+ }
262
+ }
special_tokens_map.json ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "eos_token": {
3
+ "content": "</s>",
4
+ "lstrip": false,
5
+ "normalized": false,
6
+ "rstrip": false,
7
+ "single_word": false
8
+ },
9
+ "pad_token": {
10
+ "content": "<s>",
11
+ "lstrip": false,
12
+ "normalized": false,
13
+ "rstrip": false,
14
+ "single_word": false
15
+ },
16
+ "unk_token": {
17
+ "content": "<unk>",
18
+ "lstrip": false,
19
+ "normalized": false,
20
+ "rstrip": false,
21
+ "single_word": false
22
+ }
23
+ }
spiece.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ef11ac9a22c7503492f56d48dce53be20e339b63605983e9f27d2cd0e0f3922c
3
+ size 4427844
tokenizer.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a2799ccc696b752ba00c34f58726bfe253a04921ceb6cfc620400f560474790b
3
+ size 16629031
tokenizer_config.json ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "added_tokens_decoder": {
3
+ "0": {
4
+ "content": "<unk>",
5
+ "lstrip": false,
6
+ "normalized": false,
7
+ "rstrip": false,
8
+ "single_word": false,
9
+ "special": true
10
+ },
11
+ "1": {
12
+ "content": "<s>",
13
+ "lstrip": false,
14
+ "normalized": false,
15
+ "rstrip": false,
16
+ "single_word": false,
17
+ "special": true
18
+ },
19
+ "2": {
20
+ "content": "</s>",
21
+ "lstrip": false,
22
+ "normalized": false,
23
+ "rstrip": false,
24
+ "single_word": false,
25
+ "special": true
26
+ }
27
+ },
28
+ "additional_special_tokens": [],
29
+ "clean_up_tokenization_spaces": true,
30
+ "eos_token": "</s>",
31
+ "extra_ids": 0,
32
+ "legacy": false,
33
+ "model_max_length": 1000000000000000019884624838656,
34
+ "pad_token": "<s>",
35
+ "sp_model_kwargs": {},
36
+ "tokenizer_class": "T5Tokenizer",
37
+ "unk_token": "<unk>"
38
+ }