makitanikaze commited on
Commit
72f87d1
·
1 Parent(s): 4a17a3f

Upload modeling_p5.py

Browse files
Files changed (1) hide show
  1. modeling_p5.py +456 -0
modeling_p5.py ADDED
@@ -0,0 +1,456 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+
3
+ from transformers.models.t5.modeling_t5 import (
4
+ T5Stack, T5Block, T5LayerNorm, T5LayerSelfAttention, T5LayerFF, T5LayerCrossAttention,
5
+ T5PreTrainedModel, T5ForConditionalGeneration
6
+ )
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ from torch.nn import CrossEntropyLoss
11
+
12
+ from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple
13
+ import copy
14
+
15
+ from transformers.modeling_outputs import ModelOutput, BaseModelOutput, BaseModelOutputWithPast, BaseModelOutputWithPastAndCrossAttentions, Seq2SeqLMOutput, Seq2SeqModelOutput
16
+ from transformers.modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer
17
+ from transformers.utils import logging
18
+ from transformers import BeamScorer, BeamSearchScorer
19
+
20
+ logger = logging.get_logger(__name__)
21
+
22
+ # The encoder for input token sequence
23
+ class JointEncoder(T5Stack):
24
+ def __init__(self, config, embed_tokens=None):
25
+ super(T5Stack, self).__init__(config)
26
+ self.config = config
27
+
28
+ self.embed_tokens = embed_tokens
29
+ self.is_decoder = self.config.is_decoder
30
+ assert self.config.is_decoder is False
31
+
32
+ self.block = nn.ModuleList(
33
+ [T5Block(config, has_relative_attention_bias=(i == 0))
34
+ for i in range(config.num_layers)]
35
+ )
36
+ self.final_layer_norm = T5LayerNorm(
37
+ config.d_model, eps=config.layer_norm_epsilon)
38
+ self.dropout = nn.Dropout(config.dropout_rate)
39
+
40
+ ## Set maximum 512 whole words in a source text
41
+ self.whole_word_embeddings = nn.Embedding(
42
+ 512, config.d_model ## config.d_model is 768 for base
43
+ )
44
+ self.init_weights()
45
+ self.model_parallel = False
46
+ self.device_map = None
47
+
48
+ def set_input_embeddings(self, new_embeddings):
49
+ self.embed_tokens = new_embeddings
50
+
51
+ def forward(
52
+ self,
53
+ input_ids=None,
54
+ whole_word_ids=None,
55
+ attention_mask=None,
56
+ inputs_embeds=None,
57
+ head_mask=None,
58
+ past_key_values=None,
59
+ use_cache=None,
60
+ output_attentions=None,
61
+ output_hidden_states=None,
62
+ return_dict=None,
63
+ ):
64
+
65
+ if inputs_embeds is None:
66
+ assert self.embed_tokens is not None, "You have to initialize the model with valid token embeddings"
67
+ inputs_embeds = self.embed_tokens(input_ids) ### embedding step - add HERE ###
68
+ if whole_word_ids is not None:
69
+ whole_word_embeds = self.whole_word_embeddings(whole_word_ids)
70
+ assert whole_word_embeds.shape[-1] == inputs_embeds.shape[-1]
71
+ inputs_embeds = inputs_embeds + whole_word_embeds
72
+
73
+ B, L = inputs_embeds.size()[:-1]
74
+
75
+ if attention_mask is None:
76
+ attention_mask = input_ids.ne(self.config.pad_token_id).to(dtype=inputs_embeds.dtype, device=inputs_embeds.device)
77
+
78
+ # ourselves in which case we just need to make it broadcastable to all heads.
79
+ extended_attention_mask = self.get_extended_attention_mask(
80
+ attention_mask,
81
+ (B, L),
82
+ inputs_embeds.device)
83
+
84
+ # initialize past_key_values with `None` if past does not exist
85
+ if past_key_values is None:
86
+ past_key_values = [None] * len(self.block)
87
+
88
+ # Prepare head mask if needed
89
+ head_mask = self.get_head_mask(head_mask, self.config.num_layers)
90
+ present_key_value_states = () if use_cache else None
91
+ all_hidden_states = () if output_hidden_states else None
92
+ all_attentions = () if output_attentions else None
93
+ all_cross_attentions = () if (output_attentions and self.is_decoder) else None
94
+
95
+ hidden_states = self.dropout(inputs_embeds)
96
+
97
+ if self.config.num_layers > 0:
98
+
99
+ assert self.block[0].layer[0].SelfAttention.has_relative_attention_bias
100
+
101
+ seq_length = L
102
+ q_len = seq_length
103
+ k_len = seq_length
104
+
105
+ # [1, n_heads, Q_len, K_len]
106
+ text_position_bias = self.block[0].layer[0].SelfAttention.compute_bias(
107
+ L, L)
108
+ num_heads = text_position_bias.size(1)
109
+ position_bias = text_position_bias.new_zeros(
110
+ 1, num_heads, seq_length, seq_length)
111
+ position_bias[:, :, :L, :L] = text_position_bias
112
+
113
+ position_bias = position_bias + extended_attention_mask
114
+
115
+ for i, (layer_module, past_key_value) in enumerate(zip(self.block, past_key_values)):
116
+ layer_head_mask = head_mask[i]
117
+ layer_outputs = layer_module(
118
+ hidden_states,
119
+ attention_mask=extended_attention_mask,
120
+ position_bias=position_bias,
121
+ encoder_hidden_states=None,
122
+ encoder_attention_mask=None,
123
+ encoder_decoder_position_bias=None,
124
+ # head_mask=head_mask[i],
125
+ layer_head_mask=layer_head_mask,
126
+ past_key_value=past_key_value,
127
+ use_cache=use_cache,
128
+ output_attentions=output_attentions,
129
+ )
130
+
131
+ # layer_outputs is a tuple with:
132
+ # hidden-states, key-value-states, (self-attention weights), (self-attention position bias), (cross-attention weights), (cross-attention position bias)
133
+ hidden_states, present_key_value_state = layer_outputs[:2]
134
+
135
+ # We share the position biases between the layers - the first layer store them
136
+ # layer_outputs = hidden-states, key-value-states (self-attention weights),
137
+ # (self-attention position bias), (cross-attention weights), (cross-attention position bias)
138
+
139
+ # position_bias = layer_outputs[2]
140
+
141
+ # append next layer key value states
142
+ if use_cache:
143
+ present_key_value_states = present_key_value_states + \
144
+ (present_key_value_state,)
145
+
146
+ hidden_states = self.final_layer_norm(hidden_states)
147
+ hidden_states = self.dropout(hidden_states)
148
+
149
+ # Add last layer
150
+ if output_hidden_states:
151
+ all_hidden_states = all_hidden_states + (hidden_states,)
152
+
153
+ if not return_dict:
154
+ return tuple(
155
+ v
156
+ for v in [
157
+ hidden_states,
158
+ present_key_value_states,
159
+ all_hidden_states,
160
+ all_attentions,
161
+ all_cross_attentions,
162
+ ]
163
+ if v is not None
164
+ )
165
+ return BaseModelOutputWithPastAndCrossAttentions(
166
+ last_hidden_state=hidden_states,
167
+ past_key_values=present_key_value_states,
168
+ hidden_states=all_hidden_states,
169
+ attentions=all_attentions,
170
+ cross_attentions=all_cross_attentions,
171
+ )
172
+
173
+
174
+ class P5(T5ForConditionalGeneration):
175
+ _keys_to_ignore_on_load_missing = [
176
+ r"encoder\.embed_tokens\.weight",
177
+ r"decoder\.embed_tokens\.weight",
178
+ r"lm_head\.weight",
179
+ ]
180
+ _keys_to_ignore_on_load_unexpected = [
181
+ r"decoder\.block\.0\.layer\.1\.EncDecAttention\.relative_attention_bias\.weight",
182
+ ]
183
+
184
+ def __init__(self, config):
185
+ super(T5ForConditionalGeneration, self).__init__(config)
186
+
187
+ self.config = config
188
+
189
+ self.model_dim = config.d_model
190
+
191
+ self.shared = nn.Embedding(config.vocab_size, config.d_model)
192
+
193
+ encoder_config = copy.deepcopy(config)
194
+ encoder_config.is_decoder = False
195
+ encoder_config.use_cache = False
196
+ encoder_config.is_encoder_decoder = False
197
+
198
+ self.encoder = JointEncoder(encoder_config, self.shared)
199
+
200
+ decoder_config = copy.deepcopy(config)
201
+ decoder_config.is_decoder = True
202
+ decoder_config.is_encoder_decoder = False
203
+
204
+ self.decoder = T5Stack(decoder_config, self.shared)
205
+
206
+ self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
207
+
208
+ self.init_weights()
209
+
210
+ self.model_parallel = False
211
+ self.device_map = None
212
+
213
+ def set_input_embeddings(self, new_embeddings):
214
+ self.shared = new_embeddings
215
+ self.encoder.set_input_embeddings(new_embeddings)
216
+ self.decoder.set_input_embeddings(new_embeddings)
217
+
218
+ def extend_vocab(self, vocab_size):
219
+
220
+ new_shared = nn.Embedding(vocab_size, self.config.d_model)
221
+ old_weight = self.shared.weight.data.detach().clone()
222
+ old_vocab_size = old_weight.size(0)
223
+ new_shared.weight.data[:old_vocab_size, :] = old_weight
224
+ self.shared = new_shared
225
+
226
+ new_lm_head = nn.Linear(self.config.d_model, vocab_size, bias=False)
227
+ old_weight = self.lm_head.weight.data.detach().clone()
228
+ old_vocab_size = old_weight.size(0)
229
+ new_lm_head.weight.data[:old_vocab_size, :] = old_weight
230
+ self.lm_head = new_lm_head
231
+
232
+ self.encoder.embed_tokens = self.shared
233
+ self.decoder.embed_tokens = self.shared
234
+
235
+ self.lm_head.weight = self.shared.weight
236
+
237
+ self.config.vocab_size = vocab_size
238
+ self.encoder.config.vocab_size = vocab_size
239
+ self.decoder.config.vocab_size = vocab_size
240
+
241
+ def forward(
242
+ self,
243
+ input_ids=None,
244
+ whole_word_ids=None,
245
+ attention_mask=None,
246
+ encoder_outputs=None,
247
+ decoder_input_ids=None,
248
+ decoder_attention_mask=None,
249
+ past_key_values=None,
250
+ use_cache=None,
251
+ labels=None,
252
+ inputs_embeds=None,
253
+ decoder_inputs_embeds=None,
254
+ head_mask=None,
255
+ output_attentions=None,
256
+ output_hidden_states=None,
257
+ return_dict=None,
258
+ reduce_loss=False,
259
+
260
+ return_hidden_state=False,
261
+
262
+ **kwargs,
263
+ ):
264
+
265
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
266
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
267
+
268
+ if encoder_outputs is None:
269
+ encoder_outputs = self.encoder(
270
+ input_ids=input_ids,
271
+ whole_word_ids=whole_word_ids,
272
+ attention_mask=attention_mask,
273
+ inputs_embeds=inputs_embeds,
274
+ head_mask=head_mask,
275
+ output_attentions=output_attentions,
276
+ output_hidden_states=output_hidden_states,
277
+ return_dict=return_dict,
278
+ )
279
+ elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
280
+ encoder_outputs = BaseModelOutput(
281
+ last_hidden_state=encoder_outputs[0],
282
+ hidden_states=encoder_outputs[1] if len(
283
+ encoder_outputs) > 1 else None,
284
+ attentions=encoder_outputs[2] if len(
285
+ encoder_outputs) > 2 else None,
286
+ )
287
+
288
+ hidden_states = encoder_outputs[0]
289
+
290
+ if labels is not None and decoder_input_ids is None and decoder_inputs_embeds is None:
291
+ # get decoder inputs from shifting lm labels to the right
292
+ decoder_input_ids = self._shift_right(labels)
293
+
294
+ # If decoding with past key value states, only the last tokens
295
+ # should be given as an input
296
+ if past_key_values is not None:
297
+ assert labels is None, "Decoder should not use cached key value states when training."
298
+ if decoder_input_ids is not None:
299
+ decoder_input_ids = decoder_input_ids[:, -1:]
300
+ if decoder_inputs_embeds is not None:
301
+ decoder_inputs_embeds = decoder_inputs_embeds[:, -1:]
302
+
303
+ if attention_mask is None:
304
+ attention_mask = input_ids.ne(self.config.pad_token_id).to(dtype=hidden_states.dtype, device=hidden_states.device)
305
+ encoder_attention_mask = attention_mask
306
+
307
+ # Decode
308
+ decoder_outputs = self.decoder(
309
+ input_ids=decoder_input_ids,
310
+ attention_mask=decoder_attention_mask,
311
+ inputs_embeds=decoder_inputs_embeds,
312
+ past_key_values=past_key_values,
313
+
314
+ encoder_hidden_states=hidden_states,
315
+ encoder_attention_mask=encoder_attention_mask,
316
+
317
+ head_mask=head_mask,
318
+ use_cache=use_cache,
319
+ output_attentions=output_attentions,
320
+ output_hidden_states=output_hidden_states,
321
+ return_dict=return_dict,
322
+ )
323
+
324
+ sequence_output = decoder_outputs[0]
325
+
326
+ assert self.config.tie_word_embeddings is True
327
+
328
+ if self.config.tie_word_embeddings:
329
+ sequence_output = sequence_output * (self.model_dim ** -0.5)
330
+
331
+ if return_hidden_state:
332
+ return sequence_output
333
+
334
+ lm_logits = self.lm_head(sequence_output)
335
+
336
+ loss = None
337
+ if labels is not None:
338
+ if reduce_loss:
339
+ loss_fct = CrossEntropyLoss(ignore_index=-100)
340
+ else:
341
+ loss_fct = CrossEntropyLoss(ignore_index=-100, reduction='none')
342
+ loss = loss_fct(
343
+ lm_logits.view(-1, lm_logits.size(-1)),
344
+ labels.view(-1))
345
+
346
+ return P5Seq2SeqLMOutput(
347
+ loss=loss,
348
+ logits=lm_logits,
349
+ past_key_values=decoder_outputs.past_key_values,
350
+ decoder_last_hidden_state=decoder_outputs.last_hidden_state,
351
+ decoder_hidden_states=decoder_outputs.hidden_states,
352
+ )
353
+
354
+ def prepare_inputs_for_generation(
355
+ self, input_ids, past=None, attention_mask=None, use_cache=None,
356
+ encoder_outputs=None,
357
+ **kwargs):
358
+
359
+ if past is not None:
360
+ input_ids = input_ids[:, -1:]
361
+
362
+ output = {
363
+ "decoder_input_ids": input_ids,
364
+ "past_key_values": past,
365
+ "encoder_outputs": encoder_outputs,
366
+ "attention_mask": attention_mask,
367
+ "use_cache": use_cache,
368
+ }
369
+
370
+ return output
371
+
372
+ @staticmethod
373
+ def _expand_inputs_for_generation(
374
+ input_ids: torch.LongTensor,
375
+ expand_size: int = 1,
376
+ is_encoder_decoder: bool = False,
377
+ attention_mask: torch.LongTensor = None,
378
+ encoder_outputs: ModelOutput = None,
379
+ **model_kwargs
380
+ ) -> Tuple[torch.LongTensor, Dict[str, Any]]:
381
+ expanded_return_idx = (
382
+ torch.arange(input_ids.shape[0]).view(-1, 1).repeat(1,
383
+ expand_size).view(-1).to(input_ids.device)
384
+ )
385
+ input_ids = input_ids.index_select(0, expanded_return_idx)
386
+
387
+ if "token_type_ids" in model_kwargs:
388
+ token_type_ids = model_kwargs["token_type_ids"]
389
+ model_kwargs["token_type_ids"] = token_type_ids.index_select(
390
+ 0, expanded_return_idx)
391
+
392
+ if attention_mask is not None:
393
+ model_kwargs["attention_mask"] = attention_mask.index_select(
394
+ 0, expanded_return_idx)
395
+
396
+ if is_encoder_decoder:
397
+ assert encoder_outputs is not None
398
+ encoder_outputs["last_hidden_state"] = encoder_outputs.last_hidden_state.index_select(
399
+ 0, expanded_return_idx
400
+ )
401
+ model_kwargs["encoder_outputs"] = encoder_outputs
402
+
403
+ return input_ids, model_kwargs
404
+
405
+
406
+ @dataclass
407
+ class P5Seq2SeqLMOutput(ModelOutput):
408
+ """
409
+ Base class for sequence-to-sequence language models outputs.
410
+
411
+ Args:
412
+ loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`labels` is provided):
413
+ Languaged modeling loss.
414
+ logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`):
415
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
416
+ past_key_values (:obj:`List[torch.FloatTensor]`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``):
417
+ List of :obj:`torch.FloatTensor` of length :obj:`config.n_layers`, with each tensor of shape
418
+ :obj:`(2, batch_size, num_heads, sequence_length, embed_size_per_head)`).
419
+
420
+ Contains pre-computed hidden-states (key and values in the attention blocks) of the decoder that can be
421
+ used (see ``past_key_values`` input) to speed up sequential decoding.
422
+ decoder_hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
423
+ Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
424
+ of shape :obj:`(batch_size, sequence_length, hidden_size)`.
425
+
426
+ Hidden-states of the decoder at the output of each layer plus the initial embedding outputs.
427
+ decoder_attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
428
+ Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
429
+ :obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
430
+
431
+ Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the
432
+ self-attention heads.
433
+ encoder_last_hidden_state (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
434
+ Sequence of hidden-states at the output of the last layer of the encoder of the model.
435
+ encoder_hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
436
+ Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
437
+ of shape :obj:`(batch_size, sequence_length, hidden_size)`.
438
+
439
+ Hidden-states of the encoder at the output of each layer plus the initial embedding outputs.
440
+ encoder_attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
441
+ Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
442
+ :obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
443
+
444
+ Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the
445
+ self-attention heads.
446
+ """
447
+
448
+ loss: Optional[torch.FloatTensor] = None
449
+ logits: torch.FloatTensor = None
450
+ past_key_values: Optional[List[torch.FloatTensor]] = None
451
+ decoder_last_hidden_state: Optional[Tuple[torch.FloatTensor]] = None
452
+ decoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
453
+ decoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
454
+ encoder_last_hidden_state: Optional[torch.FloatTensor] = None
455
+ encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
456
+ encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None