visheratin commited on
Commit
cde656c
1 Parent(s): d6d35ed

Update model files

Browse files
Files changed (1) hide show
  1. modeling_llava.py +417 -0
modeling_llava.py ADDED
@@ -0,0 +1,417 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ from dataclasses import dataclass
3
+ from typing import List, Optional, Tuple, Union
4
+
5
+ import torch
6
+ import torch.utils.checkpoint
7
+ from torch import nn
8
+
9
+ from transformers import PreTrainedModel
10
+ from transformers.modeling_outputs import ModelOutput
11
+
12
+ from modeling_phi import PhiForCausalLM, InferenceParams
13
+ from processing_llava import OpenCLIPImageProcessor
14
+ from configuration_llava import LlavaConfig
15
+ from open_clip import create_model
16
+
17
+
18
+ @dataclass
19
+ class LlavaCausalLMOutputWithPast(ModelOutput):
20
+ loss: Optional[torch.FloatTensor] = None
21
+ logits: torch.FloatTensor = None
22
+ past_key_values: Optional[List[torch.FloatTensor]] = None
23
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
24
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
25
+ image_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
26
+
27
+
28
+ class LlavaMultiModalProjector(nn.Module):
29
+ def __init__(self, config: LlavaConfig):
30
+ super().__init__()
31
+
32
+ self.linear_1 = nn.Linear(
33
+ config.vision_embed_dim,
34
+ config.text_config.n_embd * config.projector_tokens_num,
35
+ bias=True,
36
+ )
37
+ self.act = nn.GELU()
38
+ self.linear_2 = nn.Linear(
39
+ config.text_config.n_embd * config.projector_tokens_num,
40
+ config.text_config.n_embd * config.projector_tokens_num,
41
+ bias=True,
42
+ )
43
+ self.projector_tokens_num = config.projector_tokens_num
44
+
45
+ def forward(self, image_features):
46
+ hidden_states = self.linear_1(image_features)
47
+ hidden_states = self.act(hidden_states)
48
+ hidden_states = self.linear_2(hidden_states)
49
+ hidden_states = hidden_states.reshape(
50
+ hidden_states.shape[0],
51
+ self.projector_tokens_num,
52
+ int(hidden_states.shape[1] / self.projector_tokens_num),
53
+ )
54
+ return hidden_states
55
+
56
+
57
+ class LlavaPreTrainedModel(PreTrainedModel):
58
+ config_class = LlavaConfig
59
+ base_model_prefix = "model"
60
+ supports_gradient_checkpointing = True
61
+ _no_split_modules = ["LlavaVisionAttention"]
62
+ _skip_keys_device_placement = "past_key_values"
63
+ _supports_flash_attn_2 = True
64
+
65
+ def __init__(self, config):
66
+ super().__init__(config)
67
+
68
+ def _init_weights(self, module):
69
+ return
70
+
71
+ @property
72
+ def _supports_sdpa(self):
73
+ """
74
+ Retrieve language_model's attribute to check whether the model supports
75
+ SDPA or not.
76
+ """
77
+ return self.language_model._supports_sdpa
78
+
79
+
80
+ class LlavaForConditionalGeneration(LlavaPreTrainedModel):
81
+ def __init__(self, config: LlavaConfig):
82
+ super().__init__(config)
83
+ clip_model = create_model(config.vision_tower_name)
84
+ self.vision_model = clip_model.visual
85
+
86
+ self.multi_modal_projector = LlavaMultiModalProjector(config)
87
+ self.vocab_size = config.vocab_size
88
+ self.language_model = PhiForCausalLM(config.text_config)
89
+ self.pad_token_id = (
90
+ self.config.pad_token_id if self.config.pad_token_id is not None else -1
91
+ )
92
+ self.post_init()
93
+
94
+ def get_input_embeddings(self):
95
+ return self.language_model.get_input_embeddings()
96
+
97
+ def set_input_embeddings(self, value):
98
+ self.language_model.set_input_embeddings(value)
99
+
100
+ def get_output_embeddings(self):
101
+ return self.language_model.get_output_embeddings()
102
+
103
+ def set_output_embeddings(self, new_embeddings):
104
+ self.language_model.set_output_embeddings(new_embeddings)
105
+
106
+ def set_decoder(self, decoder):
107
+ self.language_model.transformer = decoder
108
+
109
+ def get_decoder(self):
110
+ return self.language_model.transformer
111
+
112
+ def tie_weights(self):
113
+ return self.language_model.tie_weights()
114
+
115
+ def resize_token_embeddings(
116
+ self, new_num_tokens: Optional[int] = None, pad_to_multiple_of=None
117
+ ) -> nn.Embedding:
118
+ model_embeds = self.language_model.resize_token_embeddings(
119
+ new_num_tokens, pad_to_multiple_of
120
+ )
121
+ # update vocab size
122
+ self.config.text_config.vocab_size = model_embeds.num_embeddings
123
+ self.config.vocab_size = model_embeds.num_embeddings
124
+ self.vocab_size = model_embeds.num_embeddings
125
+ return model_embeds
126
+
127
+ def _merge_input_ids_with_image_features(
128
+ self, image_features, inputs_embeds, input_ids, attention_mask, position_ids
129
+ ):
130
+ num_images, num_image_patches, embed_dim = image_features.shape
131
+ batch_size, sequence_length = input_ids.shape
132
+ left_padding = not torch.sum(
133
+ input_ids[:, -1] == torch.tensor(self.pad_token_id)
134
+ )
135
+ # 1. Create a mask to know where special image tokens are
136
+ special_image_token_mask = input_ids == self.config.image_token_index
137
+ num_special_image_tokens = torch.sum(special_image_token_mask, dim=-1)
138
+ # Compute the maximum embed dimension
139
+ max_embed_dim = (
140
+ num_special_image_tokens.max() * (num_image_patches - 1)
141
+ ) + sequence_length
142
+ batch_indices, non_image_indices = torch.where(
143
+ input_ids != self.config.image_token_index
144
+ )
145
+
146
+ # 2. Compute the positions where text should be written
147
+ # Calculate new positions for text tokens in merged image-text sequence.
148
+ # `special_image_token_mask` identifies image tokens. Each image token will be replaced by `nb_text_tokens_per_images - 1` text tokens.
149
+ # `torch.cumsum` computes how each image token shifts subsequent text token positions.
150
+ # - 1 to adjust for zero-based indexing, as `cumsum` inherently increases indices by one.
151
+ new_token_positions = (
152
+ torch.cumsum((special_image_token_mask * (num_image_patches - 1) + 1), -1)
153
+ - 1
154
+ )
155
+ nb_image_pad = max_embed_dim - 1 - new_token_positions[:, -1]
156
+ if left_padding:
157
+ new_token_positions += nb_image_pad[:, None] # offset for left padding
158
+ text_to_overwrite = new_token_positions[batch_indices, non_image_indices]
159
+
160
+ # 3. Create the full embedding, already padded to the maximum position
161
+ final_embedding = torch.zeros(
162
+ batch_size,
163
+ max_embed_dim,
164
+ embed_dim,
165
+ dtype=inputs_embeds.dtype,
166
+ device=inputs_embeds.device,
167
+ )
168
+ final_attention_mask = torch.zeros(
169
+ batch_size,
170
+ max_embed_dim,
171
+ dtype=attention_mask.dtype,
172
+ device=inputs_embeds.device,
173
+ )
174
+ # In case the Vision model or the Language model has been offloaded to CPU, we need to manually
175
+ # set the corresponding tensors into their correct target device.
176
+ target_device = inputs_embeds.device
177
+ batch_indices, non_image_indices, text_to_overwrite = (
178
+ batch_indices.to(target_device),
179
+ non_image_indices.to(target_device),
180
+ text_to_overwrite.to(target_device),
181
+ )
182
+ attention_mask = attention_mask.to(target_device)
183
+
184
+ # 4. Fill the embeddings based on the mask. If we have ["hey" "<image>", "how", "are"]
185
+ # we need to index copy on [0, 577, 578, 579] for the text and [1:576] for the image features
186
+ final_embedding[batch_indices, text_to_overwrite] = inputs_embeds[
187
+ batch_indices, non_image_indices
188
+ ]
189
+ final_attention_mask[batch_indices, text_to_overwrite] = attention_mask[
190
+ batch_indices, non_image_indices
191
+ ]
192
+
193
+ # 5. Fill the embeddings corresponding to the images. Anything that is still zeros needs filling
194
+ image_to_overwrite = torch.all(final_embedding == 0, dim=-1)
195
+ image_to_overwrite &= image_to_overwrite.cumsum(-1) - 1 >= nb_image_pad[
196
+ :, None
197
+ ].to(target_device)
198
+
199
+ if image_to_overwrite.sum() != image_features.shape[:-1].numel():
200
+ raise ValueError(
201
+ f"The input provided to the model are wrong. The number of image tokens is {torch.sum(special_image_token_mask)} while"
202
+ f" the number of image given to the model is {num_images}. This prevents correct indexing and breaks batch generation."
203
+ )
204
+
205
+ final_embedding[image_to_overwrite] = (
206
+ image_features.contiguous().reshape(-1, embed_dim).to(target_device)
207
+ )
208
+ final_attention_mask |= image_to_overwrite
209
+ position_ids = (final_attention_mask.cumsum(-1) - 1).masked_fill_(
210
+ (final_attention_mask == 0), 1
211
+ )
212
+ return final_embedding, final_attention_mask, position_ids
213
+
214
+ def forward(
215
+ self,
216
+ input_ids: torch.LongTensor = None,
217
+ pixel_values: torch.FloatTensor = None,
218
+ attention_mask: Optional[torch.Tensor] = None,
219
+ position_ids: Optional[torch.LongTensor] = None,
220
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
221
+ inputs_embeds: Optional[torch.FloatTensor] = None,
222
+ vision_feature_layer: Optional[int] = None,
223
+ vision_feature_select_strategy: Optional[str] = None,
224
+ labels: Optional[torch.LongTensor] = None,
225
+ use_cache: Optional[bool] = None,
226
+ output_attentions: Optional[bool] = None,
227
+ output_hidden_states: Optional[bool] = None,
228
+ return_dict: Optional[bool] = None,
229
+ ) -> Union[Tuple, LlavaCausalLMOutputWithPast]:
230
+ output_attentions = (
231
+ output_attentions
232
+ if output_attentions is not None
233
+ else self.config.output_attentions
234
+ )
235
+ output_hidden_states = (
236
+ output_hidden_states
237
+ if output_hidden_states is not None
238
+ else self.config.output_hidden_states
239
+ )
240
+ return_dict = (
241
+ return_dict if return_dict is not None else self.config.use_return_dict
242
+ )
243
+
244
+ if inputs_embeds is None:
245
+ # 1. Extra the input embeddings
246
+ inputs_embeds = self.get_input_embeddings()(input_ids)
247
+
248
+ # 2. Merge text and images
249
+ if pixel_values is not None and input_ids.shape[1] != 1:
250
+ image_outputs = self.vision_model(pixel_values)
251
+
252
+ image_features = self.multi_modal_projector(image_outputs)
253
+ (
254
+ inputs_embeds,
255
+ attention_mask,
256
+ position_ids,
257
+ ) = self._merge_input_ids_with_image_features(
258
+ image_features,
259
+ inputs_embeds,
260
+ input_ids,
261
+ attention_mask,
262
+ position_ids,
263
+ )
264
+ # if labels is None:
265
+ # labels = torch.full_like(
266
+ # attention_mask, self.config.ignore_index
267
+ # ).to(torch.long)
268
+ else:
269
+ # In case input_ids.shape[1] == 1 & pixel_values==None & past_key_values != None, we are in the case of
270
+ # generation with cache
271
+ if (
272
+ past_key_values is not None
273
+ and pixel_values is not None
274
+ and input_ids.shape[1] == 1
275
+ ):
276
+ # Retrieve the first layer to inspect the logits and mask out the hidden states
277
+ # that are set to 0
278
+ first_layer_past_key_value = past_key_values[0][0][:, :, :, 0]
279
+
280
+ # Sum all dimensions of head_dim (-2) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941
281
+ batch_index, non_attended_tokens = torch.where(
282
+ first_layer_past_key_value.float().sum(-2) == 0
283
+ )
284
+
285
+ # Get the target length
286
+ target_seqlen = first_layer_past_key_value.shape[-1] + 1
287
+
288
+ extended_attention_mask = torch.ones(
289
+ (
290
+ attention_mask.shape[0],
291
+ target_seqlen - attention_mask.shape[1],
292
+ ),
293
+ dtype=attention_mask.dtype,
294
+ device=attention_mask.device,
295
+ )
296
+
297
+ # Zero-out the places where we don't need to attend
298
+ extended_attention_mask[batch_index, non_attended_tokens] = 0
299
+
300
+ attention_mask = torch.cat(
301
+ (attention_mask, extended_attention_mask), dim=1
302
+ )
303
+ position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1
304
+
305
+ outputs = self.language_model(
306
+ input_ids=None,
307
+ attention_mask=attention_mask,
308
+ position_ids=position_ids,
309
+ past_key_values=past_key_values,
310
+ inputs_embeds=inputs_embeds,
311
+ use_cache=use_cache,
312
+ output_attentions=output_attentions,
313
+ output_hidden_states=output_hidden_states,
314
+ return_dict=return_dict,
315
+ )
316
+
317
+ logits = outputs[0]
318
+
319
+ loss = None
320
+ if labels is not None:
321
+ # Shift so that tokens < n predict n
322
+ if attention_mask is not None:
323
+ shift_attention_mask = attention_mask[..., 1:]
324
+ shift_logits = logits[..., :-1, :][
325
+ shift_attention_mask.to(logits.device) != 0
326
+ ].contiguous()
327
+ shift_labels = labels[..., 1:][
328
+ shift_attention_mask.to(labels.device) != 0
329
+ ].contiguous()
330
+ else:
331
+ shift_logits = logits[..., :-1, :].contiguous()
332
+ shift_labels = labels[..., 1:].contiguous()
333
+ # Flatten the tokens
334
+ loss_fct = nn.CrossEntropyLoss()
335
+ loss = loss_fct(
336
+ shift_logits.view(-1, shift_logits.size(-1)),
337
+ shift_labels.view(-1).to(shift_logits.device),
338
+ )
339
+
340
+ if not return_dict:
341
+ output = (logits,) + outputs[1:]
342
+ return (loss,) + output if loss is not None else output
343
+
344
+ return LlavaCausalLMOutputWithPast(
345
+ loss=loss,
346
+ logits=logits,
347
+ past_key_values=outputs.past_key_values,
348
+ hidden_states=outputs.hidden_states,
349
+ attentions=outputs.attentions,
350
+ )
351
+
352
+ def prepare_inputs_for_generation(
353
+ self,
354
+ input_ids,
355
+ past_key_values=None,
356
+ inputs_embeds=None,
357
+ pixel_values=None,
358
+ attention_mask=None,
359
+ **kwargs,
360
+ ):
361
+ if past_key_values is not None:
362
+ if isinstance(past_key_values, InferenceParams):
363
+ cache_length = past_key_values.max_seqlen
364
+ past_length = past_key_values.seqlen_offset
365
+ else:
366
+ cache_length = past_length = past_key_values[0][0].shape[2]
367
+
368
+ # Keep only the unprocessed tokens:
369
+ # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
370
+ # some of the inputs are exclusivelly passed as part of the cache (e.g. when passing input_embeds as
371
+ # input)
372
+ if (
373
+ attention_mask is not None
374
+ and attention_mask.shape[1] > input_ids.shape[1]
375
+ ):
376
+ input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
377
+ # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
378
+ # input_ids based on the past_length.
379
+ elif past_length < input_ids.shape[1]:
380
+ input_ids = input_ids[:, past_length:]
381
+ # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
382
+ elif self.config.image_token_index in input_ids:
383
+ input_ids = input_ids[:, input_ids.shape[1] - 1 :]
384
+ # If the cache has seen more tokens than it can hold, then the cache has a size limit. Let's discard the
385
+ # older attention values, as their corresponding values are not part of the input.
386
+ if cache_length < past_length and attention_mask is not None:
387
+ attention_mask = attention_mask[
388
+ :, -(cache_length + input_ids.shape[1]) :
389
+ ]
390
+
391
+ position_ids = kwargs.get("position_ids", None)
392
+ if attention_mask is not None and position_ids is None:
393
+ # create position_ids on the fly for batch generation
394
+ position_ids = attention_mask.long().cumsum(-1) - 1
395
+ position_ids.masked_fill_(attention_mask == 0, 1)
396
+ if past_key_values:
397
+ position_ids = position_ids[:, -input_ids.shape[1] :]
398
+
399
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
400
+ if inputs_embeds is not None and past_key_values is None:
401
+ model_inputs = {"inputs_embeds": inputs_embeds}
402
+ else:
403
+ model_inputs = {"input_ids": input_ids}
404
+
405
+ model_inputs.update(
406
+ {
407
+ "position_ids": position_ids,
408
+ "past_key_values": past_key_values,
409
+ "use_cache": kwargs.get("use_cache"),
410
+ "attention_mask": attention_mask,
411
+ "pixel_values": pixel_values,
412
+ }
413
+ )
414
+ return model_inputs
415
+
416
+ def _reorder_cache(self, *args, **kwargs):
417
+ return self.language_model._reorder_cache(*args, **kwargs)