bourdoiscatie commited on
Commit
bb74a21
·
verified ·
1 Parent(s): 18bc4bc

Delete modeling_flash_t5(1).py

Browse files
Files changed (1) hide show
  1. modeling_flash_t5(1).py +0 -840
modeling_flash_t5(1).py DELETED
@@ -1,840 +0,0 @@
1
- # From: https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/modeling_t5.py
2
-
3
- from dataclasses import dataclass
4
-
5
- import copy
6
- import math
7
- from typing import Optional, Tuple, Union
8
-
9
- import torch
10
- from torch import nn
11
- from torch.nn import CrossEntropyLoss
12
- import torch.nn.functional as F
13
-
14
- from transformers.modeling_utils import ModuleUtilsMixin
15
- from transformers.modeling_outputs import ModelOutput, Seq2SeqModelOutput, BaseModelOutput
16
- from transformers import PreTrainedModel
17
-
18
- try:
19
- from .rms_norm import fast_rms_layernorm
20
- except ImportError:
21
- fast_rms_layernorm = None
22
-
23
- try:
24
- from .cross_entropy_loss import fast_cross_entropy_loss
25
- except ImportError:
26
- fast_cross_entropy_loss = None
27
-
28
- try:
29
- from .flash_attention_v2_bias import attention as flash_attention_triton
30
- except ImportError:
31
- fast_cross_entropy_loss = None
32
-
33
- try:
34
- from .gated_mlp import gated_mlp
35
- except ImportError:
36
- gated_mlp = None
37
-
38
- try:
39
- #from flash_attn import flash_attn_kvpacked_func, flash_attn_func
40
- from .fa2_compilable import flash_attn_kvpacked_func, flash_attn_func
41
- except ImportError:
42
- flash_attn_kvpacked_func, flash_attn_func = None, None
43
-
44
- from .attn_ref import attn_ref
45
-
46
- from .configuration_flash_t5 import FlashT5Config
47
- from .positional_encoding import ALiBiPositionalEncoding, RelativePositionalEncoding, RotaryPositionalEncoding
48
-
49
- @dataclass
50
- class EncoderOutput(ModelOutput):
51
- hidden_states: torch.FloatTensor = None
52
- attention_mask: torch.FloatTensor = None
53
-
54
- @dataclass
55
- class Seq2SeqLMOutput(ModelOutput):
56
- loss: torch.FloatTensor = None
57
- logits: torch.FloatTensor = None
58
- encoder_outputs: EncoderOutput = None
59
-
60
-
61
- class FlashT5CrossEntropyLoss(nn.Module):
62
- def __init__(self, z_loss_factor=0.0, label_smoothing=0.0, use_triton_crossentropy=False):
63
-
64
- super().__init__()
65
-
66
- if use_triton_crossentropy and fast_cross_entropy_loss is None:
67
- raise ImportError("fast_cross_entropy_loss is not available")
68
-
69
- self.use_triton_crossentropy = use_triton_crossentropy
70
- self.z_loss_factor = z_loss_factor
71
-
72
- self.cross_entropy_loss = nn.CrossEntropyLoss(label_smoothing=label_smoothing)
73
-
74
- def compute_zloss(self, logits: torch.Tensor, z_loss: float):
75
- logits_sum = torch.logsumexp(logits, dim=-1, keepdim=True)
76
- log_z = torch.squeeze(logits_sum, axis=-1)
77
- total_z_loss = z_loss * torch.square(log_z)
78
- return total_z_loss.mean()
79
-
80
- def forward(self, logits, labels):
81
-
82
- if self.use_triton_crossentropy:
83
- return fast_cross_entropy_loss(logits, labels, z_loss_factor=self.z_loss_factor)
84
-
85
- # use standard method
86
- batch, seq_len, d = logits.shape
87
- logits_flatten = logits.float().view(batch*seq_len, d) # Must cast to float32 for numerical stability
88
- labels_flatten = labels.view(-1)
89
- loss = self.cross_entropy_loss(logits_flatten, labels_flatten)
90
- z_loss = 0.0
91
- if self.z_loss_factor != 0.0:
92
- z_loss = self.compute_zloss(logits_flatten[labels_flatten != -100],
93
- z_loss=self.z_loss_factor)
94
- return loss, z_loss
95
-
96
- class FlashT5LayerNorm(nn.Module):
97
- def __init__(self, hidden_size, eps=1e-6, use_triton_layernorm=False):
98
- """
99
- Construct a layernorm module in the T5 style. No bias and no subtraction of mean.
100
- """
101
- super().__init__()
102
-
103
- if use_triton_layernorm and fast_rms_layernorm is None:
104
- raise ImportError("fast_rms_layernorm is not available")
105
-
106
- self.use_triton_layernorm = use_triton_layernorm
107
- self.weight = nn.Parameter(torch.ones(hidden_size))
108
- self.variance_epsilon = eps
109
-
110
- def forward(self, hidden_states):
111
-
112
- if self.use_triton_layernorm:
113
- return fast_rms_layernorm(hidden_states, self.weight, self.variance_epsilon)
114
-
115
- # T5 uses a layer_norm which only scales and doesn't shift, which is also known as Root Mean
116
- # Square Layer Normalization https://arxiv.org/abs/1910.07467 thus varience is calculated
117
- # w/o mean and there is no bias. Additionally we want to make sure that the accumulation for
118
- # half-precision inputs is done in fp32
119
-
120
- variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
121
- hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
122
-
123
- # convert into half-precision if necessary
124
- if self.weight.dtype in [torch.float16, torch.bfloat16]:
125
- hidden_states = hidden_states.to(self.weight.dtype)
126
-
127
- return self.weight * hidden_states
128
-
129
- class FlashT5DenseAct(nn.Module):
130
- def __init__(self, config: FlashT5Config):
131
- super().__init__()
132
- self.wi = nn.Linear(config.d_model, config.d_ff, bias=False)
133
- self.dropout = nn.Dropout(config.dropout_rate)
134
- self.act = torch.nn.GELU(approximate='tanh') if config.use_gelu_act else torch.nn.ReLU()
135
-
136
- def forward(self, hidden_states):
137
- hidden_states = self.wi(hidden_states)
138
- hidden_states = self.act(hidden_states)
139
- hidden_states = self.dropout(hidden_states)
140
- if (
141
- isinstance(self.wo.weight, torch.Tensor)
142
- and hidden_states.dtype != self.wo.weight.dtype
143
- and self.wo.weight.dtype != torch.int8
144
- ):
145
- hidden_states = hidden_states.to(self.wo.weight.dtype)
146
-
147
- return hidden_states
148
-
149
- class FlashT5DenseGatedAct(nn.Module):
150
- def __init__(self, config: FlashT5Config):
151
- super().__init__()
152
- self.wi_0 = nn.Linear(config.d_model, config.d_ff, bias=False)
153
- self.wi_1 = nn.Linear(config.d_model, config.d_ff, bias=False)
154
- self.dropout = nn.Dropout(config.dropout_rate)
155
- self.act = torch.nn.GELU(approximate='tanh') if config.use_gelu_act else torch.nn.ReLU()
156
-
157
- self.use_triton_gated_mlp = config.use_triton_gated_mlp
158
- if self.use_triton_gated_mlp and gated_mlp is None:
159
- raise ImportError("gated_mlp is not available")
160
- self.use_gelu_act = config.use_gelu_act
161
-
162
- def forward(self, hidden_states):
163
-
164
- if self.use_triton_gated_mlp:
165
- return gated_mlp(hidden_states, self.wi_0.weight, self.wi_1.weight, self.use_gelu_act)
166
-
167
- hidden_act = self.act(self.wi_0(hidden_states))
168
- hidden_linear = self.wi_1(hidden_states)
169
- hidden_states = hidden_act * hidden_linear
170
- hidden_states = self.dropout(hidden_states)
171
-
172
- return hidden_states
173
-
174
- class FlashT5LayerFF(nn.Module):
175
- def __init__(self, config: FlashT5Config):
176
- super().__init__()
177
- if config.use_glu_mlp:
178
- self.act = FlashT5DenseGatedAct(config)
179
- else:
180
- self.act = FlashT5DenseAct(config)
181
-
182
- self.layer_norm = FlashT5LayerNorm(config.d_model, eps=config.layer_norm_epsilon, use_triton_layernorm=config.use_triton_layernorm)
183
- self.wo = nn.Linear(config.d_ff, config.d_model, bias=False)
184
- self.dropout = nn.Dropout(config.dropout_rate)
185
-
186
- def forward(self, hidden_states):
187
- forwarded_states = self.layer_norm(hidden_states).type_as(hidden_states)
188
- forwarded_states = self.act(forwarded_states)
189
- forwarded_states = self.wo(forwarded_states)
190
- hidden_states = hidden_states + self.dropout(forwarded_states)
191
- return hidden_states
192
-
193
-
194
- class FlashT5Attention(nn.Module, ModuleUtilsMixin):
195
- def __init__(self, config: FlashT5Config, has_positional_encoding=False, is_causal=False):
196
- super().__init__()
197
- self.is_decoder = config.is_decoder
198
- self.has_positional_encoding = has_positional_encoding
199
- self.is_causal = is_causal
200
- self.relative_attention_num_buckets = config.relative_attention_num_buckets
201
- self.relative_attention_max_distance = config.relative_attention_max_distance
202
- self.d_model = config.d_model
203
- self.key_value_proj_dim = config.d_kv
204
- self.n_heads = config.num_heads
205
- self.p_dropout = config.attention_dropout_rate
206
- self.inner_dim = self.n_heads * self.key_value_proj_dim
207
- self.use_flash_attention = config.use_flash_attention
208
- self.position_encoding_type = config.position_encoding_type
209
- self.max_sequence_length = config.max_sequence_length
210
- self.softmax_scale = 1.0/math.sqrt(self.n_heads)
211
- self.use_full_bias_size = config.use_full_bias_size
212
-
213
- if self.use_flash_attention == "triton" and flash_attention_triton is None:
214
- raise ImportError("flash_attention_triton is not available")
215
- elif self.use_flash_attention == "fa2" and flash_attn_func is None:
216
- raise ImportError("Flash Attention 2 is not available")
217
-
218
- assert (self.p_dropout == 0.0) or (self.use_flash_attention != "triton"), "Triton attention does not support dropout"
219
-
220
- self.pe_encoding = None
221
- if self.position_encoding_type == "ALiBi" and has_positional_encoding:
222
- # build alibi matrix with an upper bound on seq length
223
- self.pe_encoding = ALiBiPositionalEncoding(self.max_sequence_length, self.n_heads, config.alibi_mode, config.use_randomized_position_encoding)
224
- elif self.position_encoding_type == "t5" and has_positional_encoding:
225
- self.pe_encoding = RelativePositionalEncoding(self.relative_attention_num_buckets, self.relative_attention_max_distance, self.n_heads, self.max_sequence_length, config.use_randomized_position_encoding)
226
- elif self.position_encoding_type == "RoPE":
227
- self.pe_encoding = RotaryPositionalEncoding(int(self.key_value_proj_dim * config.rotary_emb_fraction), self.max_sequence_length, config.rotary_base, config.rotary_interleaved, config.rotary_scale_base, config.use_randomized_position_encoding)
228
-
229
- self.Wq = nn.Linear(self.d_model, self.inner_dim, bias=False)
230
- self.Wk = nn.Linear(self.d_model, self.inner_dim, bias=False)
231
- self.Wv = nn.Linear(self.d_model, self.inner_dim, bias=False)
232
- self.o = nn.Linear(self.inner_dim, self.d_model, bias=False)
233
-
234
- def forward(
235
- self,
236
- hidden_states,
237
- mask=None,
238
- key_value_states=None,
239
- position_bias=None,
240
- ):
241
- """
242
- Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states).
243
- """
244
- # Input is (batch_size, seq_length, dim)
245
- # Mask is (batch_size, key_length) (non-causal) or (batch_size, key_length, key_length)
246
- batch_size, seq_length = hidden_states.shape[:2]
247
- key_length = seq_length if key_value_states is None else key_value_states.shape[1]
248
- q = self.Wq(hidden_states)
249
- if key_value_states is None:
250
- k = self.Wk(hidden_states)
251
- v = self.Wv(hidden_states)
252
- else:
253
- k = self.Wk(key_value_states)
254
- v = self.Wv(key_value_states)
255
-
256
- q = q.view(batch_size, seq_length, self.n_heads, self.key_value_proj_dim)
257
- k = k.view(batch_size, key_length, self.n_heads, self.key_value_proj_dim)
258
- v = v.view(batch_size, key_length, self.n_heads, self.key_value_proj_dim)
259
-
260
- if position_bias is None and self.pe_encoding is not None:
261
- q, k, v, position_bias = self.pe_encoding(q, k, v)
262
-
263
- if position_bias is not None and self.use_full_bias_size and (self.use_flash_attention == "fa2" or self.use_flash_attention == "triton"):
264
- position_bias = position_bias.expand(q.shape[0], q.shape[2], q.shape[1], k.shape[1]).contiguous()
265
-
266
- if self.use_flash_attention == "fa2":
267
- output = flash_attn_func(q, k, v, dropout_p=self.p_dropout, softmax_scale=self.softmax_scale, attn_bias=position_bias, causal=self.is_causal)
268
- elif self.use_flash_attention == "triton":
269
- q = q.permute(0, 2, 1, 3)
270
- k = k.permute(0, 2, 1, 3)
271
- v = v.permute(0, 2, 1, 3)
272
- output = flash_attention_triton(q, k, v, position_bias, self.is_causal, self.softmax_scale)
273
- output = output.permute(0, 2, 1, 3)
274
- else: # use flash attention
275
- q = q.permute(0, 2, 1, 3)
276
- k = k.permute(0, 2, 1, 3)
277
- v = v.permute(0, 2, 1, 3)
278
- output = attn_ref(q, k, v, position_bias, dropout_p=self.p_dropout, sm_scale=self.softmax_scale, causal=self.is_causal)
279
- output = output.permute(0, 2, 1, 3)
280
-
281
- output = self.o(output.reshape(output.shape[0], output.shape[1], self.inner_dim))
282
- return (output, position_bias)
283
-
284
-
285
- class FlashT5LayerSelfAttention(nn.Module):
286
- def __init__(self, config, has_positional_encoding=False):
287
- super().__init__()
288
- self.self_attention = FlashT5Attention(config, has_positional_encoding=has_positional_encoding, is_causal=config.is_decoder)
289
- self.layer_norm = FlashT5LayerNorm(config.d_model, eps=config.layer_norm_epsilon, use_triton_layernorm=config.use_triton_layernorm)
290
- self.dropout = nn.Dropout(config.dropout_rate)
291
-
292
- def forward(
293
- self,
294
- hidden_states,
295
- attention_mask=None,
296
- position_bias=None,
297
- ):
298
- normed_hidden_states = self.layer_norm(hidden_states).type_as(hidden_states)
299
- attention_output = self.self_attention(
300
- normed_hidden_states,
301
- mask=attention_mask,
302
- position_bias=position_bias,
303
- )
304
- hidden_states = hidden_states + self.dropout(attention_output[0])
305
- outputs = (hidden_states,) + attention_output[1:]
306
- return outputs
307
-
308
-
309
- class FlashT5LayerCrossAttention(nn.Module):
310
- def __init__(self, config):
311
- super().__init__()
312
- self.cross_attention = FlashT5Attention(config, has_positional_encoding=False)
313
- self.layer_norm = FlashT5LayerNorm(config.d_model, eps=config.layer_norm_epsilon, use_triton_layernorm=config.use_triton_layernorm)
314
- self.dropout = nn.Dropout(config.dropout_rate)
315
-
316
- def forward(
317
- self,
318
- hidden_states,
319
- key_value_states,
320
- attention_mask=None,
321
- position_bias=None,
322
- ):
323
- normed_hidden_states = self.layer_norm(hidden_states)
324
- attention_output = self.cross_attention(
325
- normed_hidden_states,
326
- mask=attention_mask,
327
- key_value_states=key_value_states,
328
- position_bias=position_bias,
329
- )
330
- layer_output = hidden_states + self.dropout(attention_output[0])
331
- outputs = (layer_output,) + attention_output[1:]
332
- return outputs
333
-
334
-
335
- class FlashT5Block(nn.Module):
336
- def __init__(self, config, has_positional_encoding=False):
337
- super().__init__()
338
- self.is_decoder = config.is_decoder
339
-
340
- self.self_attention_layer = FlashT5LayerSelfAttention(config, has_positional_encoding=has_positional_encoding)
341
-
342
- if self.is_decoder:
343
- self.cross_attention_layer = FlashT5LayerCrossAttention(config)
344
-
345
- self.ff_layer = FlashT5LayerFF(config)
346
-
347
- def forward(
348
- self,
349
- hidden_states,
350
- attention_mask=None,
351
- position_bias=None,
352
- encoder_hidden_states=None,
353
- encoder_attention_mask=None,
354
- encoder_decoder_position_bias=None,
355
- ):
356
- self_attention_outputs = self.self_attention_layer(
357
- hidden_states,
358
- attention_mask=attention_mask,
359
- position_bias=position_bias,
360
- )
361
- hidden_states = self_attention_outputs[0]
362
- attention_outputs = self_attention_outputs[1:] # Relative position weights
363
-
364
- if self.is_decoder and encoder_hidden_states is not None:
365
- cross_attention_outputs = self.cross_attention_layer(
366
- hidden_states,
367
- key_value_states=encoder_hidden_states,
368
- attention_mask=encoder_attention_mask,
369
- position_bias=encoder_decoder_position_bias,
370
- )
371
- hidden_states = cross_attention_outputs[0]
372
-
373
- # Keep relative position weights
374
- attention_outputs = attention_outputs + cross_attention_outputs[1:]
375
-
376
- # Apply Feed Forward layer
377
- hidden_states = self.ff_layer(hidden_states)
378
-
379
- outputs = (hidden_states,) + attention_outputs
380
- return outputs # hidden-states, (self-attention position bias), (cross-attention position bias)
381
-
382
- class FlashT5Stack(nn.Module, ModuleUtilsMixin):
383
- def __init__(self, config, embed_tokens):
384
- super().__init__()
385
- assert embed_tokens is not None
386
-
387
- self.config = config
388
- self.embed_tokens = embed_tokens
389
- self.is_decoder = config.is_decoder
390
- self.use_flash_attention = config.use_flash_attention
391
-
392
- self.block = nn.ModuleList(
393
- [FlashT5Block(config, has_positional_encoding=bool(i == 0)) for i in range(config.num_layers)]
394
- )
395
-
396
- self.final_layer_norm = FlashT5LayerNorm(config.d_model, eps=config.layer_norm_epsilon, use_triton_layernorm=config.use_triton_layernorm)
397
- self.dropout = nn.Dropout(config.dropout_rate)
398
-
399
- def forward(
400
- self,
401
- input_ids=None,
402
- attention_mask=None,
403
- encoder_hidden_states=None,
404
- encoder_attention_mask=None,
405
- inputs_embeds=None,
406
- head_mask=None,
407
- cross_attn_head_mask=None,
408
- past_key_values=None,
409
- use_cache=None,
410
- output_attentions=None,
411
- output_hidden_states=None,
412
- return_dict=None) -> BaseModelOutput:
413
- input_shape = input_ids.size()
414
- batch_size, seq_length = input_shape
415
-
416
- if inputs_embeds is None:
417
- inputs_embeds = self.embed_tokens(input_ids)
418
-
419
- if torch.is_autocast_enabled() and input_ids.device.type == 'cuda':
420
- inputs_embeds = inputs_embeds.to(torch.get_autocast_gpu_dtype())
421
-
422
- # Masking
423
- if attention_mask is None:
424
- attention_mask = torch.ones(batch_size, seq_length, device=inputs_embeds.device, dtype=torch.bool)
425
-
426
- if self.is_decoder and encoder_attention_mask is None and encoder_hidden_states is not None:
427
- encoder_seq_length = encoder_hidden_states.shape[1]
428
- encoder_attention_mask = torch.ones(
429
- batch_size, encoder_seq_length, device=inputs_embeds.device, dtype=torch.bool
430
- )
431
-
432
- position_bias = None
433
- encoder_decoder_position_bias = None
434
-
435
- hidden_states = self.dropout(inputs_embeds)
436
-
437
- for _, layer_module in enumerate(self.block):
438
- layer_outputs = layer_module(
439
- hidden_states,
440
- attention_mask=attention_mask,
441
- position_bias=position_bias,
442
- encoder_hidden_states=encoder_hidden_states,
443
- encoder_attention_mask=encoder_attention_mask,
444
- encoder_decoder_position_bias=encoder_decoder_position_bias,
445
- )
446
-
447
- # We share the position biases between the layers - the first layer store them
448
- position_bias = layer_outputs[1]
449
- if self.is_decoder and encoder_hidden_states is not None:
450
- encoder_decoder_position_bias = layer_outputs[2]
451
-
452
- hidden_states = layer_outputs[0]
453
-
454
- hidden_states = self.final_layer_norm(hidden_states).type_as(hidden_states)
455
- hidden_states = self.dropout(hidden_states)
456
-
457
- return BaseModelOutput(
458
- last_hidden_state=hidden_states
459
- )
460
-
461
-
462
- class FlashT5PreTrainedModel(PreTrainedModel):
463
- """
464
- An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
465
- models.
466
- """
467
-
468
- config_class = FlashT5Config
469
- base_model_prefix = "transformer"
470
- is_parallelizable = False
471
- supports_gradient_checkpointing = True
472
- _no_split_modules = ["FlashT5Block"]
473
- _keep_in_fp32_modules = []
474
-
475
- def _init_weights(self, module):
476
- factor = self.config.initializer_factor # Used for testing weights initialization
477
- if isinstance(module, FlashT5LayerNorm):
478
- module.weight.data.fill_(factor * 1.0)
479
- elif isinstance(module, (FlashT5ForConditionalGeneration)):
480
- module.shared.weight.data.normal_(mean=0.0, std=factor * 1.0)
481
- if hasattr(module, "lm_head") and not self.config.tie_word_embeddings:
482
- module.lm_head.weight.data.normal_(mean=0.0, std=factor * self.config.d_model ** -0.5)
483
- elif isinstance(module, FlashT5DenseGatedAct):
484
- d_ff, d_model = module.wi_0.weight.data.size()
485
- module.wi_0.weight.data.normal_(mean=0.0, std=factor * ((d_model) ** -0.5))
486
- module.wi_1.weight.data.normal_(mean=0.0, std=factor * ((d_model) ** -0.5))
487
- elif isinstance(module, FlashT5LayerFF):
488
- d_ff, d_model = module.wo.weight.data.size()
489
- module.wo.weight.data.normal_(mean=0.0, std=factor * ((d_ff) ** -0.5))
490
- elif isinstance(module, FlashT5Attention):
491
- d_model = self.config.d_model
492
- key_value_proj_dim = self.config.d_kv
493
- n_heads = self.config.num_heads
494
- module.Wq.weight.data.normal_(mean=0.0, std=factor * ((d_model * key_value_proj_dim) ** -0.5))
495
- module.Wk.weight.data.normal_(mean=0.0, std=factor * (d_model**-0.5))
496
- module.Wv.weight.data.normal_(mean=0.0, std=factor * (d_model**-0.5))
497
- module.o.weight.data.normal_(mean=0.0, std=factor * ((n_heads * key_value_proj_dim) ** -0.5))
498
- if module.has_positional_encoding:
499
- if hasattr(module.pe_encoding, "relative_attention_bias"):
500
- module.pe_encoding.relative_attention_bias.weight.data.normal_(mean=0.0, std=factor * ((d_model) ** -0.5))
501
-
502
- def _shift_right(self, input_ids):
503
- decoder_start_token_id = self.config.decoder_start_token_id
504
- pad_token_id = self.config.pad_token_id
505
-
506
- shifted_input_ids = input_ids.new_zeros(input_ids.shape)
507
- shifted_input_ids[..., 1:] = input_ids[..., :-1].clone()
508
- shifted_input_ids[..., 0] = decoder_start_token_id
509
-
510
- # replace possible -100 values in labels by `pad_token_id`
511
- shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
512
-
513
- return shifted_input_ids
514
-
515
-
516
- class FlashT5Model(FlashT5PreTrainedModel):
517
-
518
- def __init__(self, config: FlashT5Config):
519
- super().__init__(config)
520
- self.shared = nn.Embedding(config.vocab_size, config.d_model)
521
-
522
- encoder_config = copy.deepcopy(config)
523
- encoder_config.is_decoder = False
524
- encoder_config.use_cache = False
525
- encoder_config.is_encoder_decoder = False
526
- self.encoder = FlashT5Stack(encoder_config, self.shared)
527
-
528
- decoder_config = copy.deepcopy(config)
529
- decoder_config.is_decoder = True
530
- decoder_config.is_encoder_decoder = False
531
- decoder_config.num_layers = config.num_decoder_layers
532
- self.decoder = FlashT5Stack(decoder_config, self.shared)
533
-
534
- # Initialize weights and apply final processing
535
- self.post_init()
536
-
537
- # Model parallel
538
- self.model_parallel = False
539
- self.device_map = None
540
-
541
- def get_input_embeddings(self):
542
- return self.shared
543
-
544
- def set_input_embeddings(self, new_embeddings):
545
- self.shared = new_embeddings
546
- self.encoder.set_input_embeddings(new_embeddings)
547
- self.decoder.set_input_embeddings(new_embeddings)
548
-
549
- def get_encoder(self):
550
- return self.encoder
551
-
552
- def get_decoder(self):
553
- return self.decoder
554
-
555
- def forward(
556
- self,
557
- input_ids: Optional[torch.LongTensor] = None,
558
- attention_mask: Optional[torch.FloatTensor] = None,
559
- decoder_input_ids: Optional[torch.LongTensor] = None,
560
- decoder_attention_mask: Optional[torch.BoolTensor] = None,
561
- head_mask: Optional[torch.FloatTensor] = None,
562
- decoder_head_mask: Optional[torch.FloatTensor] = None,
563
- cross_attn_head_mask: Optional[torch.Tensor] = None,
564
- encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
565
- past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
566
- inputs_embeds: Optional[torch.Tensor] = None,
567
- decoder_inputs_embeds: Optional[torch.Tensor] = None,
568
- use_cache: Optional[bool] = None,
569
- output_attentions: Optional[bool] = None,
570
- output_hidden_states: Optional[bool] = None,
571
- return_dict: Optional[bool] = None,
572
- ) -> Union[Tuple[torch.FloatTensor], Seq2SeqModelOutput]:
573
-
574
- # Encode if needed (training, first prediction pass)
575
- if encoder_outputs is None:
576
- encoder_outputs = self.encoder(
577
- input_ids=input_ids,
578
- attention_mask=attention_mask,
579
- inputs_embeds=inputs_embeds
580
- )
581
-
582
- hidden_states = encoder_outputs[0]
583
-
584
- # Decode
585
- decoder_outputs = self.decoder(
586
- input_ids=decoder_input_ids,
587
- attention_mask=decoder_attention_mask,
588
- inputs_embeds=decoder_inputs_embeds,
589
- encoder_hidden_states=hidden_states,
590
- encoder_attention_mask=attention_mask
591
- )
592
-
593
- return Seq2SeqModelOutput(
594
- last_hidden_state=decoder_outputs.last_hidden_state,
595
- decoder_hidden_states=decoder_outputs.hidden_states,
596
- encoder_last_hidden_state=encoder_outputs.last_hidden_state,
597
- encoder_hidden_states=encoder_outputs.hidden_states,
598
- )
599
-
600
- class FlashT5ForConditionalGeneration(FlashT5PreTrainedModel):
601
-
602
- def __init__(self, config: FlashT5Config):
603
- super().__init__(config)
604
- config.is_encoder_decoder = False
605
- assert not config.tie_word_embeddings
606
-
607
- self.config = config
608
- self.model_dim = config.d_model
609
- self.shared = nn.Embedding(config.vocab_size, config.d_model)
610
-
611
- encoder_config = copy.deepcopy(config)
612
- encoder_config.is_decoder = False
613
- self.encoder = FlashT5Stack(encoder_config, self.shared)
614
-
615
- decoder_config = copy.deepcopy(config)
616
- decoder_config.is_decoder = True
617
- decoder_config.num_layers = config.num_decoder_layers
618
- self.decoder = FlashT5Stack(decoder_config, self.shared)
619
-
620
- self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
621
-
622
- self.loss_fct = FlashT5CrossEntropyLoss(z_loss_factor=config.z_loss,
623
- label_smoothing=config.label_smoothing,
624
- use_triton_crossentropy=config.use_triton_crossentropy)
625
-
626
- # Initialize weights and apply final processing
627
- self.post_init()
628
-
629
- def prepare_inputs_for_generation(
630
- self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
631
- ):
632
- # do nothing
633
- model_inputs = {"input_ids": input_ids, "attention_mask": attention_mask}
634
-
635
- return model_inputs
636
-
637
- def get_input_embeddings(self):
638
- return self.shared
639
-
640
- def set_input_embeddings(self, value):
641
- self.shared = value
642
-
643
- def generate(
644
- self,
645
- input_ids: Optional[torch.LongTensor] = None,
646
- attention_mask: Optional[torch.FloatTensor] = None,
647
- max_length = 32,
648
- **kwargs,
649
- ) -> torch.LongTensor:
650
- """
651
- input_ids: B x L_encoder, int64
652
- attention_mask: B x L_encoder, int64
653
- 1 for tokens to attend to, 0 for tokens to ignore
654
-
655
- Generation:
656
- Starts with 0, ends with 1, padding is 0
657
-
658
- # For 20 input/outputs, the diff between my implementation and HF is 9.8s vs 11.4s
659
- """
660
- B, _ = input_ids.size()
661
- labels = torch.zeros(B, 1, dtype=torch.long, device=input_ids.device)
662
- encoder_outputs = None
663
-
664
- for _ in range(max_length):
665
- out = self.forward(
666
- input_ids=input_ids,
667
- attention_mask=attention_mask,
668
- decoder_input_ids=labels,
669
- encoder_outputs=encoder_outputs,
670
- )
671
- encoder_outputs = out.encoder_outputs
672
- top_labels = out.logits[:, -1].argmax(-1).unsqueeze(-1)
673
- labels = torch.cat([labels, top_labels], dim=-1)
674
-
675
- if (labels == 1).sum(-1).clamp(min=0, max=1).sum().item() == B:
676
- break
677
-
678
- labels[:, -1] = 1
679
-
680
- # Mask out the padding, i.e., all positions after the first 1 with 0
681
- B, L = labels.size()
682
- mask = torch.arange(L, device=labels.device).unsqueeze(0) <= (labels == 1).long().argmax(-1).unsqueeze(-1)
683
- labels = labels.masked_fill(~mask, 0)
684
-
685
- return labels
686
-
687
- def forward(
688
- self,
689
- input_ids: Optional[torch.LongTensor] = None,
690
- attention_mask: Optional[torch.FloatTensor] = None,
691
- decoder_input_ids: Optional[torch.LongTensor] = None,
692
- decoder_attention_mask: Optional[torch.BoolTensor] = None,
693
- labels: Optional[torch.LongTensor] = None,
694
- encoder_outputs = None,
695
- ) -> Seq2SeqLMOutput:
696
- """
697
- input_ids: B x L_encoder, int64
698
- attention_mask: B x L_encoder, int64
699
- 1 for tokens to attend to, 0 for tokens to ignore
700
- labels: B x L_decoder, int64
701
- """
702
- if encoder_outputs is None:
703
- encoder_outputs = self.encoder(
704
- input_ids=input_ids,
705
- attention_mask=attention_mask,
706
- )
707
-
708
- hidden_states = encoder_outputs.hidden_states
709
-
710
- if labels is not None and decoder_input_ids is None:
711
- decoder_input_ids = self._shift_right(labels)
712
-
713
- decoder_outputs = self.decoder(
714
- input_ids=decoder_input_ids,
715
- attention_mask=decoder_attention_mask,
716
- encoder_hidden_states=hidden_states,
717
- encoder_attention_mask=attention_mask,
718
- )
719
-
720
- sequence_output = decoder_outputs[0]
721
- lm_logits = self.lm_head(sequence_output)
722
-
723
- loss = None
724
- if labels is not None:
725
- loss, z_loss = self.loss_fct(lm_logits, labels)
726
- loss += z_loss
727
-
728
- return Seq2SeqLMOutput(
729
- loss=loss,
730
- logits=lm_logits,
731
- encoder_outputs=encoder_outputs,
732
- )
733
-
734
-
735
-
736
- class FlashT5EncoderModel(FlashT5PreTrainedModel):
737
- _tied_weights_keys = ["encoder.embed_tokens.weight"]
738
-
739
- def __init__(self, config: FlashT5Config):
740
- super().__init__(config)
741
- self.shared = nn.Embedding(config.vocab_size, config.d_model)
742
-
743
- encoder_config = copy.deepcopy(config)
744
- encoder_config.use_cache = False
745
- encoder_config.is_encoder_decoder = False
746
- self.encoder = FlashT5Stack(encoder_config, self.shared)
747
-
748
- # Initialize weights and apply final processing
749
- self.post_init()
750
-
751
- # Model parallel
752
- self.model_parallel = False
753
- self.device_map = None
754
-
755
-
756
- def parallelize(self, device_map=None):
757
- warnings.warn(
758
- "`T5EncoderModel.parallelize` is deprecated and will be removed in v5 of Transformers, you should load"
759
- " your model with `device_map='balanced'` in the call to `from_pretrained`. You can also provide your own"
760
- " `device_map` but it needs to be a dictionary module_name to device, so for instance {'block.0': 0,"
761
- " 'block.1': 1, ...}",
762
- FutureWarning,
763
- )
764
- self.device_map = (
765
- get_device_map(len(self.encoder.block), range(torch.cuda.device_count()))
766
- if device_map is None
767
- else device_map
768
- )
769
- assert_device_map(self.device_map, len(self.encoder.block))
770
- self.encoder.parallelize(self.device_map)
771
- self.model_parallel = True
772
-
773
- def deparallelize(self):
774
- warnings.warn(
775
- "Like `parallelize`, `deparallelize` is deprecated and will be removed in v5 of Transformers.",
776
- FutureWarning,
777
- )
778
- self.encoder.deparallelize()
779
- self.encoder = self.encoder.to("cpu")
780
- self.model_parallel = False
781
- self.device_map = None
782
- torch.cuda.empty_cache()
783
-
784
- def get_input_embeddings(self):
785
- return self.shared
786
-
787
- def set_input_embeddings(self, new_embeddings):
788
- self.shared = new_embeddings
789
- self.encoder.set_input_embeddings(new_embeddings)
790
-
791
- def get_encoder(self):
792
- return self.encoder
793
-
794
- def _prune_heads(self, heads_to_prune):
795
- """
796
- Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
797
- class PreTrainedModel
798
- """
799
- for layer, heads in heads_to_prune.items():
800
- self.encoder.block[layer].layer[0].SelfAttention.prune_heads(heads)
801
-
802
- def forward(
803
- self,
804
- input_ids: Optional[torch.LongTensor] = None,
805
- attention_mask: Optional[torch.FloatTensor] = None,
806
- head_mask: Optional[torch.FloatTensor] = None,
807
- inputs_embeds: Optional[torch.FloatTensor] = None,
808
- output_attentions: Optional[bool] = None,
809
- output_hidden_states: Optional[bool] = None,
810
- return_dict: Optional[bool] = None,
811
- ) -> Union[Tuple[torch.FloatTensor], BaseModelOutput]:
812
- r"""
813
- Returns:
814
-
815
- Example:
816
-
817
- ```python
818
- >>> from transformers import AutoTokenizer, T5EncoderModel
819
-
820
- >>> tokenizer = AutoTokenizer.from_pretrained("t5-small")
821
- >>> model = T5EncoderModel.from_pretrained("t5-small")
822
- >>> input_ids = tokenizer(
823
- ... "Studies have been shown that owning a dog is good for you", return_tensors="pt"
824
- ... ).input_ids # Batch size 1
825
- >>> outputs = model(input_ids=input_ids)
826
- >>> last_hidden_states = outputs.last_hidden_state
827
- ```"""
828
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
829
-
830
- encoder_outputs = self.encoder(
831
- input_ids=input_ids,
832
- attention_mask=attention_mask,
833
- inputs_embeds=inputs_embeds,
834
- head_mask=head_mask,
835
- output_attentions=output_attentions,
836
- output_hidden_states=output_hidden_states,
837
- return_dict=return_dict,
838
- )
839
-
840
- return encoder_outputs