ahmed-masry commited on
Commit
2ec6397
·
verified ·
1 Parent(s): 53fc190

Update chartinstruct_flant5_modeling.py

Browse files
Files changed (1) hide show
  1. chartinstruct_flant5_modeling.py +610 -609
chartinstruct_flant5_modeling.py CHANGED
@@ -1,609 +1,610 @@
1
- from typing import List, Optional, Tuple, Union
2
- from dataclasses import dataclass
3
- import copy, os
4
- import torch
5
- import torch.nn as nn
6
- from torch.nn import CrossEntropyLoss
7
- from transformers import AutoConfig, AutoModelForSeq2SeqLM, \
8
- T5Config, T5Model, T5ForConditionalGeneration
9
-
10
- from transformers.models.t5.modeling_t5 import T5Stack
11
- from transformers.modeling_outputs import CausalLMOutputWithPast, Seq2SeqLMOutput, BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions
12
- from transformers.utils import ModelOutput
13
- from transformers import DonutSwinModel, DonutImageProcessor, DonutSwinConfig
14
- from abc import ABC, abstractmethod
15
- import re
16
-
17
- from transformers import T5PreTrainedModel
18
- from transformers.models.t5.modeling_t5 import T5Block, T5LayerNorm
19
-
20
-
21
- @dataclass
22
- class BaseModelOutputWithPastAndCrossAttentionsWithAttentionMask(ModelOutput):
23
- """
24
- Base class for model's outputs that may also contain a past key/values (to speed up sequential decoding).
25
-
26
- Args:
27
- last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
28
- Sequence of hidden-states at the output of the last layer of the model.
29
-
30
- If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1,
31
- hidden_size)` is output.
32
- past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
33
- Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
34
- `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if
35
- `config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads,
36
- encoder_sequence_length, embed_size_per_head)`.
37
-
38
- Contains pre-computed hidden-states (key and values in the self-attention blocks and optionally if
39
- `config.is_encoder_decoder=True` in the cross-attention blocks) that can be used (see `past_key_values`
40
- input) to speed up sequential decoding.
41
- hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
42
- Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
43
- one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
44
-
45
- Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
46
- attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
47
- Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
48
- sequence_length)`.
49
-
50
- Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
51
- heads.
52
- cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` and `config.add_cross_attention=True` is passed or when `config.output_attentions=True`):
53
- Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
54
- sequence_length)`.
55
-
56
- Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the
57
- weighted average in the cross-attention heads.
58
- """
59
-
60
- last_hidden_state: torch.FloatTensor = None
61
- past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
62
- hidden_states: Optional[Tuple[torch.FloatTensor]] = None
63
- attentions: Optional[Tuple[torch.FloatTensor]] = None
64
- cross_attentions: Optional[Tuple[torch.FloatTensor]] = None
65
- attention_mask: Optional[torch.LongTensor] = None
66
-
67
- class LlavaT5Config(T5Config):
68
- model_type = "llava_t5"
69
-
70
-
71
-
72
- class LlavaT5Stack(T5PreTrainedModel):
73
- config_class = LlavaT5Config
74
-
75
- def __init__(self, config, embed_tokens=None):
76
- super().__init__(config)
77
-
78
- self.embed_tokens = embed_tokens
79
- self.is_decoder = config.is_decoder
80
-
81
- self.block = nn.ModuleList(
82
- [T5Block(config, has_relative_attention_bias=bool(i == 0)) for i in range(config.num_layers)]
83
- )
84
- self.final_layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
85
- self.dropout = nn.Dropout(config.dropout_rate)
86
-
87
- ## Vision
88
- self.vision_tower = DonutSwinModel(config=config.vision_config)
89
- self.mm_projector = nn.Linear(config.mm_hidden_size, config.hidden_size)
90
- self.pad_token_id = 0
91
- self.image_token_index = 32100
92
- ##
93
-
94
- # Initialize weights and apply final processing
95
- self.post_init()
96
- # Model parallel
97
- self.model_parallel = False
98
- self.device_map = None
99
- self.gradient_checkpointing = False
100
-
101
- def _merge_input_ids_with_image_features(self, image_features, inputs_embeds, input_ids, attention_mask):
102
- num_images, num_image_patches, embed_dim = image_features.shape
103
- batch_size, sequence_length = input_ids.shape
104
- left_padding = not torch.sum(input_ids[:, -1] == torch.tensor(self.pad_token_id))
105
- # 1. Create a mask to know where special image tokens are
106
- special_image_token_mask = input_ids == self.image_token_index
107
- num_special_image_tokens = torch.sum(special_image_token_mask, dim=-1)
108
- # Compute the maximum embed dimension
109
- max_embed_dim = (num_special_image_tokens.max() * (num_image_patches - 1)) + sequence_length
110
- batch_indices, non_image_indices = torch.where(input_ids != self.image_token_index)
111
-
112
- # 2. Compute the positions where text should be written
113
- # Calculate new positions for text tokens in merged image-text sequence.
114
- # `special_image_token_mask` identifies image tokens. Each image token will be replaced by `nb_text_tokens_per_images - 1` text tokens.
115
- # `torch.cumsum` computes how each image token shifts subsequent text token positions.
116
- # - 1 to adjust for zero-based indexing, as `cumsum` inherently increases indices by one.
117
- new_token_positions = torch.cumsum((special_image_token_mask * (num_image_patches - 1) + 1), -1) - 1
118
- nb_image_pad = max_embed_dim - 1 - new_token_positions[:, -1]
119
- if left_padding:
120
- new_token_positions += nb_image_pad[:, None] # offset for left padding
121
- text_to_overwrite = new_token_positions[batch_indices, non_image_indices]
122
-
123
- # 3. Create the full embedding, already padded to the maximum position
124
- final_embedding = torch.zeros(
125
- batch_size, max_embed_dim, embed_dim, dtype=inputs_embeds.dtype, device=inputs_embeds.device
126
- )
127
- final_attention_mask = torch.zeros(
128
- batch_size, max_embed_dim, dtype=attention_mask.dtype, device=inputs_embeds.device
129
- )
130
-
131
- # In case the Vision model or the Language model has been offloaded to CPU, we need to manually
132
- # set the corresponding tensors into their correct target device.
133
- target_device = inputs_embeds.device
134
- batch_indices, non_image_indices, text_to_overwrite = (
135
- batch_indices.to(target_device),
136
- non_image_indices.to(target_device),
137
- text_to_overwrite.to(target_device),
138
- )
139
- attention_mask = attention_mask.to(target_device)
140
-
141
- # 4. Fill the embeddings based on the mask. If we have ["hey" "<image>", "how", "are"]
142
- # we need to index copy on [0, 577, 578, 579] for the text and [1:576] for the image features
143
- final_embedding[batch_indices, text_to_overwrite] = inputs_embeds[batch_indices, non_image_indices]
144
- final_attention_mask[batch_indices, text_to_overwrite] = attention_mask[batch_indices, non_image_indices]
145
-
146
- # 5. Fill the embeddings corresponding to the images. Anything that is not `text_positions` needs filling (#29835)
147
- image_to_overwrite = torch.full(
148
- (batch_size, max_embed_dim), True, dtype=torch.bool, device=inputs_embeds.device
149
- )
150
- image_to_overwrite[batch_indices, text_to_overwrite] = False
151
- image_to_overwrite &= image_to_overwrite.cumsum(-1) - 1 >= nb_image_pad[:, None].to(target_device)
152
-
153
- if image_to_overwrite.sum() != image_features.shape[:-1].numel():
154
- raise ValueError(
155
- f"The input provided to the model are wrong. The number of image tokens is {torch.sum(special_image_token_mask)} while"
156
- f" the number of image given to the model is {num_images}. This prevents correct indexing and breaks batch generation."
157
- )
158
-
159
- final_embedding[image_to_overwrite] = image_features.contiguous().reshape(-1, embed_dim).to(target_device)
160
- final_attention_mask |= image_to_overwrite
161
-
162
- # 6. Mask out the embedding at padding positions, as we later use the past_key_value value to determine the non-attended tokens.
163
- batch_indices, pad_indices = torch.where(input_ids == self.pad_token_id)
164
- indices_to_mask = new_token_positions[batch_indices, pad_indices]
165
-
166
- final_embedding[batch_indices, indices_to_mask] = 0
167
-
168
- return final_embedding, final_attention_mask
169
-
170
- def forward(
171
- self,
172
- input_ids=None,
173
- attention_mask=None,
174
- pixel_values=None,
175
- encoder_hidden_states=None,
176
- encoder_attention_mask=None,
177
- inputs_embeds=None,
178
- head_mask=None,
179
- cross_attn_head_mask=None,
180
- past_key_values=None,
181
- use_cache=None,
182
- output_attentions=None,
183
- output_hidden_states=None,
184
- return_dict=None,
185
- ):
186
- # Model parallel
187
- if self.model_parallel:
188
- torch.cuda.set_device(self.first_device)
189
- self.embed_tokens = self.embed_tokens.to(self.first_device)
190
- use_cache = use_cache if use_cache is not None else self.config.use_cache
191
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
192
- output_hidden_states = (
193
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
194
- )
195
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
196
-
197
- if input_ids is not None and inputs_embeds is not None:
198
- err_msg_prefix = "decoder_" if self.is_decoder else ""
199
- raise ValueError(
200
- f"You cannot specify both {err_msg_prefix}input_ids and {err_msg_prefix}inputs_embeds at the same time"
201
- )
202
- elif input_ids is not None:
203
- input_shape = input_ids.size()
204
- input_ids = input_ids.view(-1, input_shape[-1])
205
- elif inputs_embeds is not None:
206
- input_shape = inputs_embeds.size()[:-1]
207
- else:
208
- err_msg_prefix = "decoder_" if self.is_decoder else ""
209
- raise ValueError(f"You have to specify either {err_msg_prefix}input_ids or {err_msg_prefix}inputs_embeds")
210
-
211
- if inputs_embeds is None:
212
- if self.embed_tokens is None:
213
- raise ValueError("You have to initialize the model with valid token embeddings")
214
- inputs_embeds = self.embed_tokens(input_ids)
215
-
216
- ### Multimodal
217
- vision_feature_layer = -1
218
- vision_feature_select_strategy = "default"
219
- image_outputs = self.vision_tower(pixel_values, output_hidden_states=True)
220
- # this is not memory efficient at all (output_hidden_states=True) will save all the hidden stated.
221
- selected_image_feature = image_outputs.hidden_states[vision_feature_layer]
222
-
223
- if vision_feature_select_strategy == "default":
224
- selected_image_feature = selected_image_feature[:, 1:]
225
- elif vision_feature_select_strategy == "full":
226
- selected_image_feature = selected_image_feature
227
- else:
228
- raise ValueError(
229
- f"Unexpected select feature strategy: {self.config.vision_feature_select_strategy}"
230
- )
231
-
232
- image_features = self.mm_projector(selected_image_feature)
233
- inputs_embeds = inputs_embeds.to(image_features.dtype)
234
- inputs_embeds, attention_mask = self._merge_input_ids_with_image_features(
235
- image_features, inputs_embeds, input_ids, attention_mask
236
- )
237
- input_shape = inputs_embeds.size()[:-1]
238
- #################
239
-
240
- batch_size, seq_length = input_shape
241
-
242
- # required mask seq length can be calculated via length of past
243
- mask_seq_length = past_key_values[0][0].shape[2] + seq_length if past_key_values is not None else seq_length
244
-
245
- if use_cache is True:
246
- if not self.is_decoder:
247
- raise ValueError(f"`use_cache` can only be set to `True` if {self} is used as a decoder")
248
-
249
- # initialize past_key_values with `None` if past does not exist
250
- if past_key_values is None:
251
- past_key_values = [None] * len(self.block)
252
-
253
- if attention_mask is None:
254
- attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device)
255
-
256
- # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
257
- # ourselves in which case we just need to make it broadcastable to all heads.
258
- extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape)
259
-
260
- # If a 2D or 3D attention mask is provided for the cross-attention
261
- # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
262
- if self.is_decoder and encoder_hidden_states is not None:
263
- encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
264
- encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
265
- if encoder_attention_mask is None:
266
- encoder_attention_mask = torch.ones(
267
- encoder_hidden_shape, device=inputs_embeds.device, dtype=torch.long
268
- )
269
- encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
270
- else:
271
- encoder_extended_attention_mask = None
272
-
273
- if self.gradient_checkpointing and self.training:
274
- if use_cache:
275
- # logger.warning_once(
276
- # "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
277
- # )
278
- use_cache = False
279
-
280
- # Prepare head mask if needed
281
- head_mask = self.get_head_mask(head_mask, self.config.num_layers)
282
- cross_attn_head_mask = self.get_head_mask(cross_attn_head_mask, self.config.num_layers)
283
- present_key_value_states = () if use_cache else None
284
- all_hidden_states = () if output_hidden_states else None
285
- all_attentions = () if output_attentions else None
286
- all_cross_attentions = () if (output_attentions and self.is_decoder) else None
287
- position_bias = None
288
- encoder_decoder_position_bias = None
289
-
290
- hidden_states = self.dropout(inputs_embeds)
291
-
292
- for i, (layer_module, past_key_value) in enumerate(zip(self.block, past_key_values)):
293
- layer_head_mask = head_mask[i]
294
- cross_attn_layer_head_mask = cross_attn_head_mask[i]
295
- # Model parallel
296
- if self.model_parallel:
297
- torch.cuda.set_device(hidden_states.device)
298
- # Ensure that attention_mask is always on the same device as hidden_states
299
- if attention_mask is not None:
300
- attention_mask = attention_mask.to(hidden_states.device)
301
- if position_bias is not None:
302
- position_bias = position_bias.to(hidden_states.device)
303
- if encoder_hidden_states is not None:
304
- encoder_hidden_states = encoder_hidden_states.to(hidden_states.device)
305
- if encoder_extended_attention_mask is not None:
306
- encoder_extended_attention_mask = encoder_extended_attention_mask.to(hidden_states.device)
307
- if encoder_decoder_position_bias is not None:
308
- encoder_decoder_position_bias = encoder_decoder_position_bias.to(hidden_states.device)
309
- if layer_head_mask is not None:
310
- layer_head_mask = layer_head_mask.to(hidden_states.device)
311
- if cross_attn_layer_head_mask is not None:
312
- cross_attn_layer_head_mask = cross_attn_layer_head_mask.to(hidden_states.device)
313
- if output_hidden_states:
314
- all_hidden_states = all_hidden_states + (hidden_states,)
315
-
316
- if self.gradient_checkpointing and self.training:
317
- layer_outputs = self._gradient_checkpointing_func(
318
- layer_module.forward,
319
- hidden_states,
320
- extended_attention_mask,
321
- position_bias,
322
- encoder_hidden_states,
323
- encoder_extended_attention_mask,
324
- encoder_decoder_position_bias,
325
- layer_head_mask,
326
- cross_attn_layer_head_mask,
327
- None, # past_key_value is always None with gradient checkpointing
328
- use_cache,
329
- output_attentions,
330
- )
331
- else:
332
- layer_outputs = layer_module(
333
- hidden_states,
334
- attention_mask=extended_attention_mask,
335
- position_bias=position_bias,
336
- encoder_hidden_states=encoder_hidden_states,
337
- encoder_attention_mask=encoder_extended_attention_mask,
338
- encoder_decoder_position_bias=encoder_decoder_position_bias,
339
- layer_head_mask=layer_head_mask,
340
- cross_attn_layer_head_mask=cross_attn_layer_head_mask,
341
- past_key_value=past_key_value,
342
- use_cache=use_cache,
343
- output_attentions=output_attentions,
344
- )
345
-
346
- # layer_outputs is a tuple with:
347
- # hidden-states, key-value-states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights)
348
- if use_cache is False:
349
- layer_outputs = layer_outputs[:1] + (None,) + layer_outputs[1:]
350
-
351
- hidden_states, present_key_value_state = layer_outputs[:2]
352
-
353
- # We share the position biases between the layers - the first layer store them
354
- # layer_outputs = hidden-states, key-value-states (self-attention position bias), (self-attention weights),
355
- # (cross-attention position bias), (cross-attention weights)
356
- position_bias = layer_outputs[2]
357
- if self.is_decoder and encoder_hidden_states is not None:
358
- encoder_decoder_position_bias = layer_outputs[4 if output_attentions else 3]
359
- # append next layer key value states
360
- if use_cache:
361
- present_key_value_states = present_key_value_states + (present_key_value_state,)
362
-
363
- if output_attentions:
364
- all_attentions = all_attentions + (layer_outputs[3],)
365
- if self.is_decoder:
366
- all_cross_attentions = all_cross_attentions + (layer_outputs[5],)
367
-
368
- # Model Parallel: If it's the last layer for that device, put things on the next device
369
- if self.model_parallel:
370
- for k, v in self.device_map.items():
371
- if i == v[-1] and "cuda:" + str(k) != self.last_device:
372
- hidden_states = hidden_states.to("cuda:" + str(k + 1))
373
-
374
- hidden_states = self.final_layer_norm(hidden_states)
375
- hidden_states = self.dropout(hidden_states)
376
-
377
- # Add last layer
378
- if output_hidden_states:
379
- all_hidden_states = all_hidden_states + (hidden_states,)
380
-
381
- if not return_dict:
382
- return tuple(
383
- v
384
- for v in [
385
- hidden_states,
386
- present_key_value_states,
387
- all_hidden_states,
388
- all_attentions,
389
- all_cross_attentions,
390
- ]
391
- if v is not None
392
- )
393
- return BaseModelOutputWithPastAndCrossAttentionsWithAttentionMask(
394
- last_hidden_state=hidden_states,
395
- past_key_values=present_key_value_states,
396
- hidden_states=all_hidden_states,
397
- attentions=all_attentions,
398
- cross_attentions=all_cross_attentions,
399
- attention_mask=attention_mask,
400
- )
401
-
402
-
403
- class LlavaT5ForConditionalGeneration(T5ForConditionalGeneration):
404
- config_class = LlavaT5Config
405
-
406
- def __init__(self, config):
407
- super(T5ForConditionalGeneration, self).__init__(config)
408
-
409
- self.model_dim = config.d_model
410
-
411
- self.shared = nn.Embedding(config.vocab_size, config.d_model)
412
-
413
- encoder_config = copy.deepcopy(config)
414
- encoder_config.is_decoder = False
415
- encoder_config.use_cache = False
416
- encoder_config.is_encoder_decoder = False
417
- self.encoder = LlavaT5Stack(encoder_config, self.shared)
418
-
419
- decoder_config = copy.deepcopy(config)
420
- decoder_config.is_decoder = True
421
- decoder_config.is_encoder_decoder = False
422
- decoder_config.num_layers = config.num_decoder_layers
423
- self.decoder = T5Stack(decoder_config, self.shared)
424
-
425
- self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
426
-
427
- # Initialize weights and apply final processing
428
- self.post_init()
429
-
430
- # Model parallel
431
- self.model_parallel = False
432
- self.device_map = None
433
-
434
- def get_model(self):
435
- return self.encoder
436
- def get_encoder(self):
437
- return self.encoder
438
- def get_decoder(self):
439
- return self.decoder
440
-
441
- def forward(
442
- self,
443
- input_ids: torch.LongTensor = None,
444
- attention_mask: Optional[torch.Tensor] = None,
445
- past_key_values: Optional[List[torch.FloatTensor]] = None,
446
- inputs_embeds: Optional[torch.FloatTensor] = None,
447
- labels: Optional[torch.LongTensor] = None,
448
- use_cache: Optional[bool] = None,
449
- output_attentions: Optional[bool] = None,
450
- output_hidden_states: Optional[bool] = None,
451
- pixel_values: Optional[torch.FloatTensor] = None,
452
- return_dict: Optional[bool] = None,
453
-
454
- decoder_input_ids: Optional[torch.LongTensor] = None,
455
- decoder_attention_mask: Optional[torch.BoolTensor] = None,
456
- head_mask: Optional[torch.FloatTensor] = None,
457
- decoder_head_mask: Optional[torch.FloatTensor] = None,
458
- cross_attn_head_mask: Optional[torch.Tensor] = None,
459
- encoder_outputs: Optional[Tuple[Tuple[torch.Tensor]]] = None,
460
- decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
461
-
462
- ) -> Union[Tuple, Seq2SeqLMOutput]:
463
-
464
- use_cache = use_cache if use_cache is not None else self.config.use_cache
465
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
466
-
467
-
468
- # FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask
469
- if head_mask is not None and decoder_head_mask is None:
470
- if self.config.num_layers == self.config.num_decoder_layers:
471
- #warnings.warn(__HEAD_MASK_WARNING_MSG, FutureWarning)
472
- decoder_head_mask = head_mask
473
-
474
- if encoder_outputs is not None:
475
- attention_mask = encoder_outputs.attention_mask
476
-
477
- # Encode if needed (training, first prediction pass)
478
- if encoder_outputs is None:
479
- # Convert encoder inputs in embeddings if needed
480
- encoder_outputs = self.encoder(
481
- input_ids=input_ids,
482
- attention_mask=attention_mask,
483
- pixel_values=pixel_values,
484
- inputs_embeds=inputs_embeds,
485
- head_mask=head_mask,
486
- output_attentions=output_attentions,
487
- output_hidden_states=output_hidden_states,
488
- return_dict=return_dict,
489
- )
490
- elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
491
- encoder_outputs = BaseModelOutput(
492
- last_hidden_state=encoder_outputs[0],
493
- hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
494
- attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
495
- )
496
-
497
-
498
- hidden_states = encoder_outputs[0]
499
-
500
- if self.model_parallel:
501
- torch.cuda.set_device(self.decoder.first_device)
502
-
503
- if labels is not None and decoder_input_ids is None and decoder_inputs_embeds is None:
504
- # get decoder inputs from shifting lm labels to the right
505
- decoder_input_ids = self._shift_right(labels)
506
-
507
- # Set device for model parallelism
508
- if self.model_parallel:
509
- torch.cuda.set_device(self.decoder.first_device)
510
- hidden_states = hidden_states.to(self.decoder.first_device)
511
- if decoder_input_ids is not None:
512
- decoder_input_ids = decoder_input_ids.to(self.decoder.first_device)
513
- if attention_mask is not None:
514
- attention_mask = attention_mask.to(self.decoder.first_device)
515
- if decoder_attention_mask is not None:
516
- decoder_attention_mask = decoder_attention_mask.to(self.decoder.first_device)
517
-
518
-
519
- decoder_outputs = self.decoder(
520
- input_ids=decoder_input_ids,
521
- attention_mask=decoder_attention_mask,
522
- inputs_embeds=decoder_inputs_embeds,
523
- past_key_values=past_key_values,
524
- encoder_hidden_states=hidden_states,
525
- encoder_attention_mask=attention_mask,
526
- head_mask=decoder_head_mask,
527
- cross_attn_head_mask=cross_attn_head_mask,
528
- use_cache=use_cache,
529
- output_attentions=output_attentions,
530
- output_hidden_states=output_hidden_states,
531
- return_dict=return_dict,
532
- )
533
- sequence_output = decoder_outputs[0]
534
-
535
- # Set device for model parallelism
536
- if self.model_parallel:
537
- torch.cuda.set_device(self.encoder.first_device)
538
- self.lm_head = self.lm_head.to(self.encoder.first_device)
539
- sequence_output = sequence_output.to(self.lm_head.weight.device)
540
-
541
- if self.config.tie_word_embeddings:
542
- # Rescale output before projecting on vocab
543
- # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586
544
- sequence_output = sequence_output * (self.model_dim**-0.5)
545
-
546
- lm_logits = self.lm_head(sequence_output)
547
-
548
- loss = None
549
- if labels is not None:
550
- loss_fct = CrossEntropyLoss(ignore_index=-100)
551
- # move labels to correct device to enable PP
552
- labels = labels.to(lm_logits.device)
553
- loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1))
554
- # TODO(thom): Add z_loss https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L666
555
-
556
- if not return_dict:
557
- output = (lm_logits,) + decoder_outputs[1:] + encoder_outputs
558
- return ((loss,) + output) if loss is not None else output
559
-
560
- return Seq2SeqLMOutput(
561
- loss=loss,
562
- logits=lm_logits,
563
- past_key_values=decoder_outputs.past_key_values,
564
- decoder_hidden_states=decoder_outputs.hidden_states,
565
- decoder_attentions=decoder_outputs.attentions,
566
- cross_attentions=decoder_outputs.cross_attentions,
567
- encoder_last_hidden_state=encoder_outputs.last_hidden_state,
568
- encoder_hidden_states=encoder_outputs.hidden_states,
569
- encoder_attentions=encoder_outputs.attentions,
570
- )
571
-
572
- def prepare_inputs_for_generation(
573
- self,
574
- input_ids,
575
- past_key_values=None,
576
- attention_mask=None,
577
- head_mask=None,
578
- decoder_head_mask=None,
579
- decoder_attention_mask=None,
580
- cross_attn_head_mask=None,
581
- use_cache=None,
582
- encoder_outputs=None,
583
- **kwargs,
584
- ):
585
- # cut decoder_input_ids if past_key_values is used
586
- if past_key_values is not None:
587
- past_length = past_key_values[0][0].shape[2]
588
-
589
- # Some generation methods already pass only the last input ID
590
- if input_ids.shape[1] > past_length:
591
- remove_prefix_length = past_length
592
- else:
593
- # Default to old behavior: keep only final ID
594
- remove_prefix_length = input_ids.shape[1] - 1
595
-
596
- input_ids = input_ids[:, remove_prefix_length:]
597
-
598
- return {
599
- "decoder_input_ids": input_ids,
600
- "past_key_values": past_key_values,
601
- "encoder_outputs": encoder_outputs,
602
- "attention_mask": attention_mask,
603
- "head_mask": head_mask,
604
- "decoder_head_mask": decoder_head_mask,
605
- "decoder_attention_mask": decoder_attention_mask,
606
- "cross_attn_head_mask": cross_attn_head_mask,
607
- "use_cache": use_cache,
608
- "pixel_values": kwargs.get("pixel_values", None),
609
- }
 
 
1
+ from typing import List, Optional, Tuple, Union
2
+ from dataclasses import dataclass
3
+ import copy, os
4
+ import torch
5
+ import torch.nn as nn
6
+ from torch.nn import CrossEntropyLoss
7
+ from transformers import AutoConfig, AutoModelForSeq2SeqLM, \
8
+ T5Config, T5Model, T5ForConditionalGeneration
9
+
10
+ from transformers.models.t5.modeling_t5 import T5Stack
11
+ from transformers.modeling_outputs import CausalLMOutputWithPast, Seq2SeqLMOutput, BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions
12
+ from transformers.utils import ModelOutput
13
+ from transformers import DonutSwinModel, DonutImageProcessor, DonutSwinConfig
14
+ from abc import ABC, abstractmethod
15
+ import re
16
+
17
+ from transformers import T5PreTrainedModel
18
+ from transformers.models.t5.modeling_t5 import T5Block, T5LayerNorm
19
+
20
+
21
+ @dataclass
22
+ class BaseModelOutputWithPastAndCrossAttentionsWithAttentionMask(ModelOutput):
23
+ """
24
+ Base class for model's outputs that may also contain a past key/values (to speed up sequential decoding).
25
+
26
+ Args:
27
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
28
+ Sequence of hidden-states at the output of the last layer of the model.
29
+
30
+ If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1,
31
+ hidden_size)` is output.
32
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
33
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
34
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if
35
+ `config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads,
36
+ encoder_sequence_length, embed_size_per_head)`.
37
+
38
+ Contains pre-computed hidden-states (key and values in the self-attention blocks and optionally if
39
+ `config.is_encoder_decoder=True` in the cross-attention blocks) that can be used (see `past_key_values`
40
+ input) to speed up sequential decoding.
41
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
42
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
43
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
44
+
45
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
46
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
47
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
48
+ sequence_length)`.
49
+
50
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
51
+ heads.
52
+ cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` and `config.add_cross_attention=True` is passed or when `config.output_attentions=True`):
53
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
54
+ sequence_length)`.
55
+
56
+ Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the
57
+ weighted average in the cross-attention heads.
58
+ """
59
+
60
+ last_hidden_state: torch.FloatTensor = None
61
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
62
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
63
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
64
+ cross_attentions: Optional[Tuple[torch.FloatTensor]] = None
65
+ attention_mask: Optional[torch.LongTensor] = None
66
+
67
+ class LlavaT5Config(T5Config):
68
+ model_type = "llava_t5"
69
+
70
+
71
+
72
+ class LlavaT5Stack(T5PreTrainedModel):
73
+ config_class = LlavaT5Config
74
+
75
+ def __init__(self, config, embed_tokens=None):
76
+ super().__init__(config)
77
+
78
+ self.embed_tokens = embed_tokens
79
+ self.is_decoder = config.is_decoder
80
+
81
+ self.block = nn.ModuleList(
82
+ [T5Block(config, has_relative_attention_bias=bool(i == 0)) for i in range(config.num_layers)]
83
+ )
84
+ self.final_layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
85
+ self.dropout = nn.Dropout(config.dropout_rate)
86
+
87
+ ## Vision
88
+ vision_config = DonutSwinConfig(**config.vision_config)
89
+ self.vision_tower = DonutSwinModel(config=vision_config)
90
+ self.mm_projector = nn.Linear(config.mm_hidden_size, config.hidden_size)
91
+ self.pad_token_id = 0
92
+ self.image_token_index = 32100
93
+ ##
94
+
95
+ # Initialize weights and apply final processing
96
+ self.post_init()
97
+ # Model parallel
98
+ self.model_parallel = False
99
+ self.device_map = None
100
+ self.gradient_checkpointing = False
101
+
102
+ def _merge_input_ids_with_image_features(self, image_features, inputs_embeds, input_ids, attention_mask):
103
+ num_images, num_image_patches, embed_dim = image_features.shape
104
+ batch_size, sequence_length = input_ids.shape
105
+ left_padding = not torch.sum(input_ids[:, -1] == torch.tensor(self.pad_token_id))
106
+ # 1. Create a mask to know where special image tokens are
107
+ special_image_token_mask = input_ids == self.image_token_index
108
+ num_special_image_tokens = torch.sum(special_image_token_mask, dim=-1)
109
+ # Compute the maximum embed dimension
110
+ max_embed_dim = (num_special_image_tokens.max() * (num_image_patches - 1)) + sequence_length
111
+ batch_indices, non_image_indices = torch.where(input_ids != self.image_token_index)
112
+
113
+ # 2. Compute the positions where text should be written
114
+ # Calculate new positions for text tokens in merged image-text sequence.
115
+ # `special_image_token_mask` identifies image tokens. Each image token will be replaced by `nb_text_tokens_per_images - 1` text tokens.
116
+ # `torch.cumsum` computes how each image token shifts subsequent text token positions.
117
+ # - 1 to adjust for zero-based indexing, as `cumsum` inherently increases indices by one.
118
+ new_token_positions = torch.cumsum((special_image_token_mask * (num_image_patches - 1) + 1), -1) - 1
119
+ nb_image_pad = max_embed_dim - 1 - new_token_positions[:, -1]
120
+ if left_padding:
121
+ new_token_positions += nb_image_pad[:, None] # offset for left padding
122
+ text_to_overwrite = new_token_positions[batch_indices, non_image_indices]
123
+
124
+ # 3. Create the full embedding, already padded to the maximum position
125
+ final_embedding = torch.zeros(
126
+ batch_size, max_embed_dim, embed_dim, dtype=inputs_embeds.dtype, device=inputs_embeds.device
127
+ )
128
+ final_attention_mask = torch.zeros(
129
+ batch_size, max_embed_dim, dtype=attention_mask.dtype, device=inputs_embeds.device
130
+ )
131
+
132
+ # In case the Vision model or the Language model has been offloaded to CPU, we need to manually
133
+ # set the corresponding tensors into their correct target device.
134
+ target_device = inputs_embeds.device
135
+ batch_indices, non_image_indices, text_to_overwrite = (
136
+ batch_indices.to(target_device),
137
+ non_image_indices.to(target_device),
138
+ text_to_overwrite.to(target_device),
139
+ )
140
+ attention_mask = attention_mask.to(target_device)
141
+
142
+ # 4. Fill the embeddings based on the mask. If we have ["hey" "<image>", "how", "are"]
143
+ # we need to index copy on [0, 577, 578, 579] for the text and [1:576] for the image features
144
+ final_embedding[batch_indices, text_to_overwrite] = inputs_embeds[batch_indices, non_image_indices]
145
+ final_attention_mask[batch_indices, text_to_overwrite] = attention_mask[batch_indices, non_image_indices]
146
+
147
+ # 5. Fill the embeddings corresponding to the images. Anything that is not `text_positions` needs filling (#29835)
148
+ image_to_overwrite = torch.full(
149
+ (batch_size, max_embed_dim), True, dtype=torch.bool, device=inputs_embeds.device
150
+ )
151
+ image_to_overwrite[batch_indices, text_to_overwrite] = False
152
+ image_to_overwrite &= image_to_overwrite.cumsum(-1) - 1 >= nb_image_pad[:, None].to(target_device)
153
+
154
+ if image_to_overwrite.sum() != image_features.shape[:-1].numel():
155
+ raise ValueError(
156
+ f"The input provided to the model are wrong. The number of image tokens is {torch.sum(special_image_token_mask)} while"
157
+ f" the number of image given to the model is {num_images}. This prevents correct indexing and breaks batch generation."
158
+ )
159
+
160
+ final_embedding[image_to_overwrite] = image_features.contiguous().reshape(-1, embed_dim).to(target_device)
161
+ final_attention_mask |= image_to_overwrite
162
+
163
+ # 6. Mask out the embedding at padding positions, as we later use the past_key_value value to determine the non-attended tokens.
164
+ batch_indices, pad_indices = torch.where(input_ids == self.pad_token_id)
165
+ indices_to_mask = new_token_positions[batch_indices, pad_indices]
166
+
167
+ final_embedding[batch_indices, indices_to_mask] = 0
168
+
169
+ return final_embedding, final_attention_mask
170
+
171
+ def forward(
172
+ self,
173
+ input_ids=None,
174
+ attention_mask=None,
175
+ pixel_values=None,
176
+ encoder_hidden_states=None,
177
+ encoder_attention_mask=None,
178
+ inputs_embeds=None,
179
+ head_mask=None,
180
+ cross_attn_head_mask=None,
181
+ past_key_values=None,
182
+ use_cache=None,
183
+ output_attentions=None,
184
+ output_hidden_states=None,
185
+ return_dict=None,
186
+ ):
187
+ # Model parallel
188
+ if self.model_parallel:
189
+ torch.cuda.set_device(self.first_device)
190
+ self.embed_tokens = self.embed_tokens.to(self.first_device)
191
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
192
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
193
+ output_hidden_states = (
194
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
195
+ )
196
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
197
+
198
+ if input_ids is not None and inputs_embeds is not None:
199
+ err_msg_prefix = "decoder_" if self.is_decoder else ""
200
+ raise ValueError(
201
+ f"You cannot specify both {err_msg_prefix}input_ids and {err_msg_prefix}inputs_embeds at the same time"
202
+ )
203
+ elif input_ids is not None:
204
+ input_shape = input_ids.size()
205
+ input_ids = input_ids.view(-1, input_shape[-1])
206
+ elif inputs_embeds is not None:
207
+ input_shape = inputs_embeds.size()[:-1]
208
+ else:
209
+ err_msg_prefix = "decoder_" if self.is_decoder else ""
210
+ raise ValueError(f"You have to specify either {err_msg_prefix}input_ids or {err_msg_prefix}inputs_embeds")
211
+
212
+ if inputs_embeds is None:
213
+ if self.embed_tokens is None:
214
+ raise ValueError("You have to initialize the model with valid token embeddings")
215
+ inputs_embeds = self.embed_tokens(input_ids)
216
+
217
+ ### Multimodal
218
+ vision_feature_layer = -1
219
+ vision_feature_select_strategy = "default"
220
+ image_outputs = self.vision_tower(pixel_values, output_hidden_states=True)
221
+ # this is not memory efficient at all (output_hidden_states=True) will save all the hidden stated.
222
+ selected_image_feature = image_outputs.hidden_states[vision_feature_layer]
223
+
224
+ if vision_feature_select_strategy == "default":
225
+ selected_image_feature = selected_image_feature[:, 1:]
226
+ elif vision_feature_select_strategy == "full":
227
+ selected_image_feature = selected_image_feature
228
+ else:
229
+ raise ValueError(
230
+ f"Unexpected select feature strategy: {self.config.vision_feature_select_strategy}"
231
+ )
232
+
233
+ image_features = self.mm_projector(selected_image_feature)
234
+ inputs_embeds = inputs_embeds.to(image_features.dtype)
235
+ inputs_embeds, attention_mask = self._merge_input_ids_with_image_features(
236
+ image_features, inputs_embeds, input_ids, attention_mask
237
+ )
238
+ input_shape = inputs_embeds.size()[:-1]
239
+ #################
240
+
241
+ batch_size, seq_length = input_shape
242
+
243
+ # required mask seq length can be calculated via length of past
244
+ mask_seq_length = past_key_values[0][0].shape[2] + seq_length if past_key_values is not None else seq_length
245
+
246
+ if use_cache is True:
247
+ if not self.is_decoder:
248
+ raise ValueError(f"`use_cache` can only be set to `True` if {self} is used as a decoder")
249
+
250
+ # initialize past_key_values with `None` if past does not exist
251
+ if past_key_values is None:
252
+ past_key_values = [None] * len(self.block)
253
+
254
+ if attention_mask is None:
255
+ attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device)
256
+
257
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
258
+ # ourselves in which case we just need to make it broadcastable to all heads.
259
+ extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape)
260
+
261
+ # If a 2D or 3D attention mask is provided for the cross-attention
262
+ # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
263
+ if self.is_decoder and encoder_hidden_states is not None:
264
+ encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
265
+ encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
266
+ if encoder_attention_mask is None:
267
+ encoder_attention_mask = torch.ones(
268
+ encoder_hidden_shape, device=inputs_embeds.device, dtype=torch.long
269
+ )
270
+ encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
271
+ else:
272
+ encoder_extended_attention_mask = None
273
+
274
+ if self.gradient_checkpointing and self.training:
275
+ if use_cache:
276
+ # logger.warning_once(
277
+ # "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
278
+ # )
279
+ use_cache = False
280
+
281
+ # Prepare head mask if needed
282
+ head_mask = self.get_head_mask(head_mask, self.config.num_layers)
283
+ cross_attn_head_mask = self.get_head_mask(cross_attn_head_mask, self.config.num_layers)
284
+ present_key_value_states = () if use_cache else None
285
+ all_hidden_states = () if output_hidden_states else None
286
+ all_attentions = () if output_attentions else None
287
+ all_cross_attentions = () if (output_attentions and self.is_decoder) else None
288
+ position_bias = None
289
+ encoder_decoder_position_bias = None
290
+
291
+ hidden_states = self.dropout(inputs_embeds)
292
+
293
+ for i, (layer_module, past_key_value) in enumerate(zip(self.block, past_key_values)):
294
+ layer_head_mask = head_mask[i]
295
+ cross_attn_layer_head_mask = cross_attn_head_mask[i]
296
+ # Model parallel
297
+ if self.model_parallel:
298
+ torch.cuda.set_device(hidden_states.device)
299
+ # Ensure that attention_mask is always on the same device as hidden_states
300
+ if attention_mask is not None:
301
+ attention_mask = attention_mask.to(hidden_states.device)
302
+ if position_bias is not None:
303
+ position_bias = position_bias.to(hidden_states.device)
304
+ if encoder_hidden_states is not None:
305
+ encoder_hidden_states = encoder_hidden_states.to(hidden_states.device)
306
+ if encoder_extended_attention_mask is not None:
307
+ encoder_extended_attention_mask = encoder_extended_attention_mask.to(hidden_states.device)
308
+ if encoder_decoder_position_bias is not None:
309
+ encoder_decoder_position_bias = encoder_decoder_position_bias.to(hidden_states.device)
310
+ if layer_head_mask is not None:
311
+ layer_head_mask = layer_head_mask.to(hidden_states.device)
312
+ if cross_attn_layer_head_mask is not None:
313
+ cross_attn_layer_head_mask = cross_attn_layer_head_mask.to(hidden_states.device)
314
+ if output_hidden_states:
315
+ all_hidden_states = all_hidden_states + (hidden_states,)
316
+
317
+ if self.gradient_checkpointing and self.training:
318
+ layer_outputs = self._gradient_checkpointing_func(
319
+ layer_module.forward,
320
+ hidden_states,
321
+ extended_attention_mask,
322
+ position_bias,
323
+ encoder_hidden_states,
324
+ encoder_extended_attention_mask,
325
+ encoder_decoder_position_bias,
326
+ layer_head_mask,
327
+ cross_attn_layer_head_mask,
328
+ None, # past_key_value is always None with gradient checkpointing
329
+ use_cache,
330
+ output_attentions,
331
+ )
332
+ else:
333
+ layer_outputs = layer_module(
334
+ hidden_states,
335
+ attention_mask=extended_attention_mask,
336
+ position_bias=position_bias,
337
+ encoder_hidden_states=encoder_hidden_states,
338
+ encoder_attention_mask=encoder_extended_attention_mask,
339
+ encoder_decoder_position_bias=encoder_decoder_position_bias,
340
+ layer_head_mask=layer_head_mask,
341
+ cross_attn_layer_head_mask=cross_attn_layer_head_mask,
342
+ past_key_value=past_key_value,
343
+ use_cache=use_cache,
344
+ output_attentions=output_attentions,
345
+ )
346
+
347
+ # layer_outputs is a tuple with:
348
+ # hidden-states, key-value-states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights)
349
+ if use_cache is False:
350
+ layer_outputs = layer_outputs[:1] + (None,) + layer_outputs[1:]
351
+
352
+ hidden_states, present_key_value_state = layer_outputs[:2]
353
+
354
+ # We share the position biases between the layers - the first layer store them
355
+ # layer_outputs = hidden-states, key-value-states (self-attention position bias), (self-attention weights),
356
+ # (cross-attention position bias), (cross-attention weights)
357
+ position_bias = layer_outputs[2]
358
+ if self.is_decoder and encoder_hidden_states is not None:
359
+ encoder_decoder_position_bias = layer_outputs[4 if output_attentions else 3]
360
+ # append next layer key value states
361
+ if use_cache:
362
+ present_key_value_states = present_key_value_states + (present_key_value_state,)
363
+
364
+ if output_attentions:
365
+ all_attentions = all_attentions + (layer_outputs[3],)
366
+ if self.is_decoder:
367
+ all_cross_attentions = all_cross_attentions + (layer_outputs[5],)
368
+
369
+ # Model Parallel: If it's the last layer for that device, put things on the next device
370
+ if self.model_parallel:
371
+ for k, v in self.device_map.items():
372
+ if i == v[-1] and "cuda:" + str(k) != self.last_device:
373
+ hidden_states = hidden_states.to("cuda:" + str(k + 1))
374
+
375
+ hidden_states = self.final_layer_norm(hidden_states)
376
+ hidden_states = self.dropout(hidden_states)
377
+
378
+ # Add last layer
379
+ if output_hidden_states:
380
+ all_hidden_states = all_hidden_states + (hidden_states,)
381
+
382
+ if not return_dict:
383
+ return tuple(
384
+ v
385
+ for v in [
386
+ hidden_states,
387
+ present_key_value_states,
388
+ all_hidden_states,
389
+ all_attentions,
390
+ all_cross_attentions,
391
+ ]
392
+ if v is not None
393
+ )
394
+ return BaseModelOutputWithPastAndCrossAttentionsWithAttentionMask(
395
+ last_hidden_state=hidden_states,
396
+ past_key_values=present_key_value_states,
397
+ hidden_states=all_hidden_states,
398
+ attentions=all_attentions,
399
+ cross_attentions=all_cross_attentions,
400
+ attention_mask=attention_mask,
401
+ )
402
+
403
+
404
+ class LlavaT5ForConditionalGeneration(T5ForConditionalGeneration):
405
+ config_class = LlavaT5Config
406
+
407
+ def __init__(self, config):
408
+ super(T5ForConditionalGeneration, self).__init__(config)
409
+
410
+ self.model_dim = config.d_model
411
+
412
+ self.shared = nn.Embedding(config.vocab_size, config.d_model)
413
+
414
+ encoder_config = copy.deepcopy(config)
415
+ encoder_config.is_decoder = False
416
+ encoder_config.use_cache = False
417
+ encoder_config.is_encoder_decoder = False
418
+ self.encoder = LlavaT5Stack(encoder_config, self.shared)
419
+
420
+ decoder_config = copy.deepcopy(config)
421
+ decoder_config.is_decoder = True
422
+ decoder_config.is_encoder_decoder = False
423
+ decoder_config.num_layers = config.num_decoder_layers
424
+ self.decoder = T5Stack(decoder_config, self.shared)
425
+
426
+ self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
427
+
428
+ # Initialize weights and apply final processing
429
+ self.post_init()
430
+
431
+ # Model parallel
432
+ self.model_parallel = False
433
+ self.device_map = None
434
+
435
+ def get_model(self):
436
+ return self.encoder
437
+ def get_encoder(self):
438
+ return self.encoder
439
+ def get_decoder(self):
440
+ return self.decoder
441
+
442
+ def forward(
443
+ self,
444
+ input_ids: torch.LongTensor = None,
445
+ attention_mask: Optional[torch.Tensor] = None,
446
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
447
+ inputs_embeds: Optional[torch.FloatTensor] = None,
448
+ labels: Optional[torch.LongTensor] = None,
449
+ use_cache: Optional[bool] = None,
450
+ output_attentions: Optional[bool] = None,
451
+ output_hidden_states: Optional[bool] = None,
452
+ pixel_values: Optional[torch.FloatTensor] = None,
453
+ return_dict: Optional[bool] = None,
454
+
455
+ decoder_input_ids: Optional[torch.LongTensor] = None,
456
+ decoder_attention_mask: Optional[torch.BoolTensor] = None,
457
+ head_mask: Optional[torch.FloatTensor] = None,
458
+ decoder_head_mask: Optional[torch.FloatTensor] = None,
459
+ cross_attn_head_mask: Optional[torch.Tensor] = None,
460
+ encoder_outputs: Optional[Tuple[Tuple[torch.Tensor]]] = None,
461
+ decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
462
+
463
+ ) -> Union[Tuple, Seq2SeqLMOutput]:
464
+
465
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
466
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
467
+
468
+
469
+ # FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask
470
+ if head_mask is not None and decoder_head_mask is None:
471
+ if self.config.num_layers == self.config.num_decoder_layers:
472
+ #warnings.warn(__HEAD_MASK_WARNING_MSG, FutureWarning)
473
+ decoder_head_mask = head_mask
474
+
475
+ if encoder_outputs is not None:
476
+ attention_mask = encoder_outputs.attention_mask
477
+
478
+ # Encode if needed (training, first prediction pass)
479
+ if encoder_outputs is None:
480
+ # Convert encoder inputs in embeddings if needed
481
+ encoder_outputs = self.encoder(
482
+ input_ids=input_ids,
483
+ attention_mask=attention_mask,
484
+ pixel_values=pixel_values,
485
+ inputs_embeds=inputs_embeds,
486
+ head_mask=head_mask,
487
+ output_attentions=output_attentions,
488
+ output_hidden_states=output_hidden_states,
489
+ return_dict=return_dict,
490
+ )
491
+ elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
492
+ encoder_outputs = BaseModelOutput(
493
+ last_hidden_state=encoder_outputs[0],
494
+ hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
495
+ attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
496
+ )
497
+
498
+
499
+ hidden_states = encoder_outputs[0]
500
+
501
+ if self.model_parallel:
502
+ torch.cuda.set_device(self.decoder.first_device)
503
+
504
+ if labels is not None and decoder_input_ids is None and decoder_inputs_embeds is None:
505
+ # get decoder inputs from shifting lm labels to the right
506
+ decoder_input_ids = self._shift_right(labels)
507
+
508
+ # Set device for model parallelism
509
+ if self.model_parallel:
510
+ torch.cuda.set_device(self.decoder.first_device)
511
+ hidden_states = hidden_states.to(self.decoder.first_device)
512
+ if decoder_input_ids is not None:
513
+ decoder_input_ids = decoder_input_ids.to(self.decoder.first_device)
514
+ if attention_mask is not None:
515
+ attention_mask = attention_mask.to(self.decoder.first_device)
516
+ if decoder_attention_mask is not None:
517
+ decoder_attention_mask = decoder_attention_mask.to(self.decoder.first_device)
518
+
519
+
520
+ decoder_outputs = self.decoder(
521
+ input_ids=decoder_input_ids,
522
+ attention_mask=decoder_attention_mask,
523
+ inputs_embeds=decoder_inputs_embeds,
524
+ past_key_values=past_key_values,
525
+ encoder_hidden_states=hidden_states,
526
+ encoder_attention_mask=attention_mask,
527
+ head_mask=decoder_head_mask,
528
+ cross_attn_head_mask=cross_attn_head_mask,
529
+ use_cache=use_cache,
530
+ output_attentions=output_attentions,
531
+ output_hidden_states=output_hidden_states,
532
+ return_dict=return_dict,
533
+ )
534
+ sequence_output = decoder_outputs[0]
535
+
536
+ # Set device for model parallelism
537
+ if self.model_parallel:
538
+ torch.cuda.set_device(self.encoder.first_device)
539
+ self.lm_head = self.lm_head.to(self.encoder.first_device)
540
+ sequence_output = sequence_output.to(self.lm_head.weight.device)
541
+
542
+ if self.config.tie_word_embeddings:
543
+ # Rescale output before projecting on vocab
544
+ # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586
545
+ sequence_output = sequence_output * (self.model_dim**-0.5)
546
+
547
+ lm_logits = self.lm_head(sequence_output)
548
+
549
+ loss = None
550
+ if labels is not None:
551
+ loss_fct = CrossEntropyLoss(ignore_index=-100)
552
+ # move labels to correct device to enable PP
553
+ labels = labels.to(lm_logits.device)
554
+ loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1))
555
+ # TODO(thom): Add z_loss https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L666
556
+
557
+ if not return_dict:
558
+ output = (lm_logits,) + decoder_outputs[1:] + encoder_outputs
559
+ return ((loss,) + output) if loss is not None else output
560
+
561
+ return Seq2SeqLMOutput(
562
+ loss=loss,
563
+ logits=lm_logits,
564
+ past_key_values=decoder_outputs.past_key_values,
565
+ decoder_hidden_states=decoder_outputs.hidden_states,
566
+ decoder_attentions=decoder_outputs.attentions,
567
+ cross_attentions=decoder_outputs.cross_attentions,
568
+ encoder_last_hidden_state=encoder_outputs.last_hidden_state,
569
+ encoder_hidden_states=encoder_outputs.hidden_states,
570
+ encoder_attentions=encoder_outputs.attentions,
571
+ )
572
+
573
+ def prepare_inputs_for_generation(
574
+ self,
575
+ input_ids,
576
+ past_key_values=None,
577
+ attention_mask=None,
578
+ head_mask=None,
579
+ decoder_head_mask=None,
580
+ decoder_attention_mask=None,
581
+ cross_attn_head_mask=None,
582
+ use_cache=None,
583
+ encoder_outputs=None,
584
+ **kwargs,
585
+ ):
586
+ # cut decoder_input_ids if past_key_values is used
587
+ if past_key_values is not None:
588
+ past_length = past_key_values[0][0].shape[2]
589
+
590
+ # Some generation methods already pass only the last input ID
591
+ if input_ids.shape[1] > past_length:
592
+ remove_prefix_length = past_length
593
+ else:
594
+ # Default to old behavior: keep only final ID
595
+ remove_prefix_length = input_ids.shape[1] - 1
596
+
597
+ input_ids = input_ids[:, remove_prefix_length:]
598
+
599
+ return {
600
+ "decoder_input_ids": input_ids,
601
+ "past_key_values": past_key_values,
602
+ "encoder_outputs": encoder_outputs,
603
+ "attention_mask": attention_mask,
604
+ "head_mask": head_mask,
605
+ "decoder_head_mask": decoder_head_mask,
606
+ "decoder_attention_mask": decoder_attention_mask,
607
+ "cross_attn_head_mask": cross_attn_head_mask,
608
+ "use_cache": use_cache,
609
+ "pixel_values": kwargs.get("pixel_values", None),
610
+ }