visheratin
commited on
Commit
•
ed8f61a
1
Parent(s):
304a0a4
Update model files
Browse files- modeling_llava.py +13 -118
modeling_llava.py
CHANGED
@@ -9,8 +9,7 @@ from torch import nn
|
|
9 |
from transformers import PreTrainedModel
|
10 |
from transformers.modeling_outputs import ModelOutput
|
11 |
|
12 |
-
from modeling_phi import PhiForCausalLM
|
13 |
-
from processing_llava import OpenCLIPImageProcessor
|
14 |
from configuration_llava import LlavaConfig
|
15 |
from open_clip import create_model
|
16 |
|
@@ -22,7 +21,7 @@ class LlavaCausalLMOutputWithPast(ModelOutput):
|
|
22 |
past_key_values: Optional[List[torch.FloatTensor]] = None
|
23 |
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
24 |
attentions: Optional[Tuple[torch.FloatTensor]] = None
|
25 |
-
|
26 |
|
27 |
|
28 |
class LlavaMultiModalProjector(nn.Module):
|
@@ -214,14 +213,11 @@ class LlavaForConditionalGeneration(LlavaPreTrainedModel):
|
|
214 |
def forward(
|
215 |
self,
|
216 |
input_ids: torch.LongTensor = None,
|
217 |
-
|
218 |
attention_mask: Optional[torch.Tensor] = None,
|
219 |
position_ids: Optional[torch.LongTensor] = None,
|
220 |
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
221 |
inputs_embeds: Optional[torch.FloatTensor] = None,
|
222 |
-
vision_feature_layer: Optional[int] = None,
|
223 |
-
vision_feature_select_strategy: Optional[str] = None,
|
224 |
-
labels: Optional[torch.LongTensor] = None,
|
225 |
use_cache: Optional[bool] = None,
|
226 |
output_attentions: Optional[bool] = None,
|
227 |
output_hidden_states: Optional[bool] = None,
|
@@ -242,14 +238,8 @@ class LlavaForConditionalGeneration(LlavaPreTrainedModel):
|
|
242 |
)
|
243 |
|
244 |
if inputs_embeds is None:
|
245 |
-
# 1. Extra the input embeddings
|
246 |
inputs_embeds = self.get_input_embeddings()(input_ids)
|
247 |
-
|
248 |
-
# 2. Merge text and images
|
249 |
-
if pixel_values is not None and input_ids.shape[1] != 1:
|
250 |
-
image_outputs = self.vision_model(pixel_values)
|
251 |
-
|
252 |
-
image_features = self.multi_modal_projector(image_outputs)
|
253 |
(
|
254 |
inputs_embeds,
|
255 |
attention_mask,
|
@@ -261,46 +251,6 @@ class LlavaForConditionalGeneration(LlavaPreTrainedModel):
|
|
261 |
attention_mask,
|
262 |
position_ids,
|
263 |
)
|
264 |
-
# if labels is None:
|
265 |
-
# labels = torch.full_like(
|
266 |
-
# attention_mask, self.config.ignore_index
|
267 |
-
# ).to(torch.long)
|
268 |
-
else:
|
269 |
-
# In case input_ids.shape[1] == 1 & pixel_values==None & past_key_values != None, we are in the case of
|
270 |
-
# generation with cache
|
271 |
-
if (
|
272 |
-
past_key_values is not None
|
273 |
-
and pixel_values is not None
|
274 |
-
and input_ids.shape[1] == 1
|
275 |
-
):
|
276 |
-
# Retrieve the first layer to inspect the logits and mask out the hidden states
|
277 |
-
# that are set to 0
|
278 |
-
first_layer_past_key_value = past_key_values[0][0][:, :, :, 0]
|
279 |
-
|
280 |
-
# Sum all dimensions of head_dim (-2) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941
|
281 |
-
batch_index, non_attended_tokens = torch.where(
|
282 |
-
first_layer_past_key_value.float().sum(-2) == 0
|
283 |
-
)
|
284 |
-
|
285 |
-
# Get the target length
|
286 |
-
target_seqlen = first_layer_past_key_value.shape[-1] + 1
|
287 |
-
|
288 |
-
extended_attention_mask = torch.ones(
|
289 |
-
(
|
290 |
-
attention_mask.shape[0],
|
291 |
-
target_seqlen - attention_mask.shape[1],
|
292 |
-
),
|
293 |
-
dtype=attention_mask.dtype,
|
294 |
-
device=attention_mask.device,
|
295 |
-
)
|
296 |
-
|
297 |
-
# Zero-out the places where we don't need to attend
|
298 |
-
extended_attention_mask[batch_index, non_attended_tokens] = 0
|
299 |
-
|
300 |
-
attention_mask = torch.cat(
|
301 |
-
(attention_mask, extended_attention_mask), dim=1
|
302 |
-
)
|
303 |
-
position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1
|
304 |
|
305 |
outputs = self.language_model(
|
306 |
input_ids=None,
|
@@ -316,37 +266,17 @@ class LlavaForConditionalGeneration(LlavaPreTrainedModel):
|
|
316 |
|
317 |
logits = outputs[0]
|
318 |
|
319 |
-
loss = None
|
320 |
-
if labels is not None:
|
321 |
-
# Shift so that tokens < n predict n
|
322 |
-
if attention_mask is not None:
|
323 |
-
shift_attention_mask = attention_mask[..., 1:]
|
324 |
-
shift_logits = logits[..., :-1, :][
|
325 |
-
shift_attention_mask.to(logits.device) != 0
|
326 |
-
].contiguous()
|
327 |
-
shift_labels = labels[..., 1:][
|
328 |
-
shift_attention_mask.to(labels.device) != 0
|
329 |
-
].contiguous()
|
330 |
-
else:
|
331 |
-
shift_logits = logits[..., :-1, :].contiguous()
|
332 |
-
shift_labels = labels[..., 1:].contiguous()
|
333 |
-
# Flatten the tokens
|
334 |
-
loss_fct = nn.CrossEntropyLoss()
|
335 |
-
loss = loss_fct(
|
336 |
-
shift_logits.view(-1, shift_logits.size(-1)),
|
337 |
-
shift_labels.view(-1).to(shift_logits.device),
|
338 |
-
)
|
339 |
|
340 |
if not return_dict:
|
341 |
output = (logits,) + outputs[1:]
|
342 |
-
return
|
343 |
|
344 |
return LlavaCausalLMOutputWithPast(
|
345 |
-
loss=loss,
|
346 |
logits=logits,
|
347 |
past_key_values=outputs.past_key_values,
|
348 |
hidden_states=outputs.hidden_states,
|
349 |
attentions=outputs.attentions,
|
|
|
350 |
)
|
351 |
|
352 |
def prepare_inputs_for_generation(
|
@@ -354,49 +284,15 @@ class LlavaForConditionalGeneration(LlavaPreTrainedModel):
|
|
354 |
input_ids,
|
355 |
past_key_values=None,
|
356 |
inputs_embeds=None,
|
357 |
-
pixel_values=None,
|
358 |
attention_mask=None,
|
|
|
359 |
**kwargs,
|
360 |
):
|
361 |
-
|
362 |
-
|
363 |
-
|
364 |
-
|
365 |
-
|
366 |
-
cache_length = past_length = past_key_values[0][0].shape[2]
|
367 |
-
|
368 |
-
# Keep only the unprocessed tokens:
|
369 |
-
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
|
370 |
-
# some of the inputs are exclusivelly passed as part of the cache (e.g. when passing input_embeds as
|
371 |
-
# input)
|
372 |
-
if (
|
373 |
-
attention_mask is not None
|
374 |
-
and attention_mask.shape[1] > input_ids.shape[1]
|
375 |
-
):
|
376 |
-
input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
|
377 |
-
# 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
|
378 |
-
# input_ids based on the past_length.
|
379 |
-
elif past_length < input_ids.shape[1]:
|
380 |
-
input_ids = input_ids[:, past_length:]
|
381 |
-
# 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
|
382 |
-
elif self.config.image_token_index in input_ids:
|
383 |
-
input_ids = input_ids[:, input_ids.shape[1] - 1 :]
|
384 |
-
# If the cache has seen more tokens than it can hold, then the cache has a size limit. Let's discard the
|
385 |
-
# older attention values, as their corresponding values are not part of the input.
|
386 |
-
if cache_length < past_length and attention_mask is not None:
|
387 |
-
attention_mask = attention_mask[
|
388 |
-
:, -(cache_length + input_ids.shape[1]) :
|
389 |
-
]
|
390 |
-
|
391 |
-
position_ids = kwargs.get("position_ids", None)
|
392 |
-
if attention_mask is not None and position_ids is None:
|
393 |
-
# create position_ids on the fly for batch generation
|
394 |
-
position_ids = attention_mask.long().cumsum(-1) - 1
|
395 |
-
position_ids.masked_fill_(attention_mask == 0, 1)
|
396 |
-
if past_key_values:
|
397 |
-
position_ids = position_ids[:, -input_ids.shape[1] :]
|
398 |
-
|
399 |
-
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
400 |
if inputs_embeds is not None and past_key_values is None:
|
401 |
model_inputs = {"inputs_embeds": inputs_embeds}
|
402 |
else:
|
@@ -404,11 +300,10 @@ class LlavaForConditionalGeneration(LlavaPreTrainedModel):
|
|
404 |
|
405 |
model_inputs.update(
|
406 |
{
|
407 |
-
"position_ids": position_ids,
|
408 |
"past_key_values": past_key_values,
|
409 |
"use_cache": kwargs.get("use_cache"),
|
410 |
"attention_mask": attention_mask,
|
411 |
-
"
|
412 |
}
|
413 |
)
|
414 |
return model_inputs
|
|
|
9 |
from transformers import PreTrainedModel
|
10 |
from transformers.modeling_outputs import ModelOutput
|
11 |
|
12 |
+
from modeling_phi import PhiForCausalLM
|
|
|
13 |
from configuration_llava import LlavaConfig
|
14 |
from open_clip import create_model
|
15 |
|
|
|
21 |
past_key_values: Optional[List[torch.FloatTensor]] = None
|
22 |
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
23 |
attentions: Optional[Tuple[torch.FloatTensor]] = None
|
24 |
+
image_features: Optional[torch.FloatTensor] = None
|
25 |
|
26 |
|
27 |
class LlavaMultiModalProjector(nn.Module):
|
|
|
213 |
def forward(
|
214 |
self,
|
215 |
input_ids: torch.LongTensor = None,
|
216 |
+
image_features: torch.FloatTensor = None,
|
217 |
attention_mask: Optional[torch.Tensor] = None,
|
218 |
position_ids: Optional[torch.LongTensor] = None,
|
219 |
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
220 |
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
|
|
|
|
|
221 |
use_cache: Optional[bool] = None,
|
222 |
output_attentions: Optional[bool] = None,
|
223 |
output_hidden_states: Optional[bool] = None,
|
|
|
238 |
)
|
239 |
|
240 |
if inputs_embeds is None:
|
|
|
241 |
inputs_embeds = self.get_input_embeddings()(input_ids)
|
242 |
+
if image_features is not None and input_ids.shape[1] != 1:
|
|
|
|
|
|
|
|
|
|
|
243 |
(
|
244 |
inputs_embeds,
|
245 |
attention_mask,
|
|
|
251 |
attention_mask,
|
252 |
position_ids,
|
253 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
254 |
|
255 |
outputs = self.language_model(
|
256 |
input_ids=None,
|
|
|
266 |
|
267 |
logits = outputs[0]
|
268 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
269 |
|
270 |
if not return_dict:
|
271 |
output = (logits,) + outputs[1:]
|
272 |
+
return output
|
273 |
|
274 |
return LlavaCausalLMOutputWithPast(
|
|
|
275 |
logits=logits,
|
276 |
past_key_values=outputs.past_key_values,
|
277 |
hidden_states=outputs.hidden_states,
|
278 |
attentions=outputs.attentions,
|
279 |
+
image_features=image_features,
|
280 |
)
|
281 |
|
282 |
def prepare_inputs_for_generation(
|
|
|
284 |
input_ids,
|
285 |
past_key_values=None,
|
286 |
inputs_embeds=None,
|
|
|
287 |
attention_mask=None,
|
288 |
+
image_features=None,
|
289 |
**kwargs,
|
290 |
):
|
291 |
+
res = self.language_model.prepare_inputs_for_generation(input_ids, past_key_values, attention_mask, **kwargs)
|
292 |
+
input_ids = res["input_ids"]
|
293 |
+
past_key_values = res["past_key_values"]
|
294 |
+
attention_mask = res["attention_mask"]
|
295 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
296 |
if inputs_embeds is not None and past_key_values is None:
|
297 |
model_inputs = {"inputs_embeds": inputs_embeds}
|
298 |
else:
|
|
|
300 |
|
301 |
model_inputs.update(
|
302 |
{
|
|
|
303 |
"past_key_values": past_key_values,
|
304 |
"use_cache": kwargs.get("use_cache"),
|
305 |
"attention_mask": attention_mask,
|
306 |
+
"image_features": image_features,
|
307 |
}
|
308 |
)
|
309 |
return model_inputs
|