# coding=utf-8 from dataclasses import dataclass from typing import List, Optional, Tuple, Union import torch import torch.utils.checkpoint from torch import nn from transformers import PreTrainedModel from transformers.modeling_outputs import ModelOutput from modeling_phi import PhiForCausalLM, InferenceParams from processing_llava import OpenCLIPImageProcessor from configuration_llava import LlavaConfig from open_clip import create_model @dataclass class LlavaCausalLMOutputWithPast(ModelOutput): loss: Optional[torch.FloatTensor] = None logits: torch.FloatTensor = None past_key_values: Optional[List[torch.FloatTensor]] = None hidden_states: Optional[Tuple[torch.FloatTensor]] = None attentions: Optional[Tuple[torch.FloatTensor]] = None image_hidden_states: Optional[Tuple[torch.FloatTensor]] = None class LlavaMultiModalProjector(nn.Module): def __init__(self, config: LlavaConfig): super().__init__() self.linear_1 = nn.Linear( config.vision_embed_dim, config.text_config.n_embd * config.projector_tokens_num, bias=True, ) self.act = nn.GELU() self.linear_2 = nn.Linear( config.text_config.n_embd * config.projector_tokens_num, config.text_config.n_embd * config.projector_tokens_num, bias=True, ) self.projector_tokens_num = config.projector_tokens_num def forward(self, image_features): hidden_states = self.linear_1(image_features) hidden_states = self.act(hidden_states) hidden_states = self.linear_2(hidden_states) hidden_states = hidden_states.reshape( hidden_states.shape[0], self.projector_tokens_num, int(hidden_states.shape[1] / self.projector_tokens_num), ) return hidden_states class LlavaPreTrainedModel(PreTrainedModel): config_class = LlavaConfig base_model_prefix = "model" supports_gradient_checkpointing = True _no_split_modules = ["LlavaVisionAttention"] _skip_keys_device_placement = "past_key_values" _supports_flash_attn_2 = True def __init__(self, config): super().__init__(config) def _init_weights(self, module): return @property def _supports_sdpa(self): """ Retrieve language_model's attribute to check whether the model supports SDPA or not. """ return self.language_model._supports_sdpa class LlavaForConditionalGeneration(LlavaPreTrainedModel): def __init__(self, config: LlavaConfig): super().__init__(config) clip_model = create_model(config.vision_tower_name) self.vision_model = clip_model.visual self.multi_modal_projector = LlavaMultiModalProjector(config) self.vocab_size = config.vocab_size self.language_model = PhiForCausalLM(config.text_config) self.pad_token_id = ( self.config.pad_token_id if self.config.pad_token_id is not None else -1 ) self.post_init() def get_input_embeddings(self): return self.language_model.get_input_embeddings() def set_input_embeddings(self, value): self.language_model.set_input_embeddings(value) def get_output_embeddings(self): return self.language_model.get_output_embeddings() def set_output_embeddings(self, new_embeddings): self.language_model.set_output_embeddings(new_embeddings) def set_decoder(self, decoder): self.language_model.transformer = decoder def get_decoder(self): return self.language_model.transformer def tie_weights(self): return self.language_model.tie_weights() def resize_token_embeddings( self, new_num_tokens: Optional[int] = None, pad_to_multiple_of=None ) -> nn.Embedding: model_embeds = self.language_model.resize_token_embeddings( new_num_tokens, pad_to_multiple_of ) # update vocab size self.config.text_config.vocab_size = model_embeds.num_embeddings self.config.vocab_size = model_embeds.num_embeddings self.vocab_size = model_embeds.num_embeddings return model_embeds def _merge_input_ids_with_image_features( self, image_features, inputs_embeds, input_ids, attention_mask, position_ids ): num_images, num_image_patches, embed_dim = image_features.shape batch_size, sequence_length = input_ids.shape left_padding = not torch.sum( input_ids[:, -1] == torch.tensor(self.pad_token_id) ) # 1. Create a mask to know where special image tokens are special_image_token_mask = input_ids == self.config.image_token_index num_special_image_tokens = torch.sum(special_image_token_mask, dim=-1) # Compute the maximum embed dimension max_embed_dim = ( num_special_image_tokens.max() * (num_image_patches - 1) ) + sequence_length batch_indices, non_image_indices = torch.where( input_ids != self.config.image_token_index ) # 2. Compute the positions where text should be written # Calculate new positions for text tokens in merged image-text sequence. # `special_image_token_mask` identifies image tokens. Each image token will be replaced by `nb_text_tokens_per_images - 1` text tokens. # `torch.cumsum` computes how each image token shifts subsequent text token positions. # - 1 to adjust for zero-based indexing, as `cumsum` inherently increases indices by one. new_token_positions = ( torch.cumsum((special_image_token_mask * (num_image_patches - 1) + 1), -1) - 1 ) nb_image_pad = max_embed_dim - 1 - new_token_positions[:, -1] if left_padding: new_token_positions += nb_image_pad[:, None] # offset for left padding text_to_overwrite = new_token_positions[batch_indices, non_image_indices] # 3. Create the full embedding, already padded to the maximum position final_embedding = torch.zeros( batch_size, max_embed_dim, embed_dim, dtype=inputs_embeds.dtype, device=inputs_embeds.device, ) final_attention_mask = torch.zeros( batch_size, max_embed_dim, dtype=attention_mask.dtype, device=inputs_embeds.device, ) # In case the Vision model or the Language model has been offloaded to CPU, we need to manually # set the corresponding tensors into their correct target device. target_device = inputs_embeds.device batch_indices, non_image_indices, text_to_overwrite = ( batch_indices.to(target_device), non_image_indices.to(target_device), text_to_overwrite.to(target_device), ) attention_mask = attention_mask.to(target_device) # 4. Fill the embeddings based on the mask. If we have ["hey" "", "how", "are"] # we need to index copy on [0, 577, 578, 579] for the text and [1:576] for the image features final_embedding[batch_indices, text_to_overwrite] = inputs_embeds[ batch_indices, non_image_indices ] final_attention_mask[batch_indices, text_to_overwrite] = attention_mask[ batch_indices, non_image_indices ] # 5. Fill the embeddings corresponding to the images. Anything that is still zeros needs filling image_to_overwrite = torch.all(final_embedding == 0, dim=-1) image_to_overwrite &= image_to_overwrite.cumsum(-1) - 1 >= nb_image_pad[ :, None ].to(target_device) if image_to_overwrite.sum() != image_features.shape[:-1].numel(): raise ValueError( f"The input provided to the model are wrong. The number of image tokens is {torch.sum(special_image_token_mask)} while" f" the number of image given to the model is {num_images}. This prevents correct indexing and breaks batch generation." ) final_embedding[image_to_overwrite] = ( image_features.contiguous().reshape(-1, embed_dim).to(target_device) ) final_attention_mask |= image_to_overwrite position_ids = (final_attention_mask.cumsum(-1) - 1).masked_fill_( (final_attention_mask == 0), 1 ) return final_embedding, final_attention_mask, position_ids def forward( self, input_ids: torch.LongTensor = None, pixel_values: torch.FloatTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, vision_feature_layer: Optional[int] = None, vision_feature_select_strategy: Optional[str] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple, LlavaCausalLMOutputWithPast]: output_attentions = ( output_attentions if output_attentions is not None else self.config.output_attentions ) output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = ( return_dict if return_dict is not None else self.config.use_return_dict ) if inputs_embeds is None: # 1. Extra the input embeddings inputs_embeds = self.get_input_embeddings()(input_ids) # 2. Merge text and images if pixel_values is not None and input_ids.shape[1] != 1: image_outputs = self.vision_model(pixel_values) image_features = self.multi_modal_projector(image_outputs) ( inputs_embeds, attention_mask, position_ids, ) = self._merge_input_ids_with_image_features( image_features, inputs_embeds, input_ids, attention_mask, position_ids, ) # if labels is None: # labels = torch.full_like( # attention_mask, self.config.ignore_index # ).to(torch.long) else: # In case input_ids.shape[1] == 1 & pixel_values==None & past_key_values != None, we are in the case of # generation with cache if ( past_key_values is not None and pixel_values is not None and input_ids.shape[1] == 1 ): # Retrieve the first layer to inspect the logits and mask out the hidden states # that are set to 0 first_layer_past_key_value = past_key_values[0][0][:, :, :, 0] # Sum all dimensions of head_dim (-2) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941 batch_index, non_attended_tokens = torch.where( first_layer_past_key_value.float().sum(-2) == 0 ) # Get the target length target_seqlen = first_layer_past_key_value.shape[-1] + 1 extended_attention_mask = torch.ones( ( attention_mask.shape[0], target_seqlen - attention_mask.shape[1], ), dtype=attention_mask.dtype, device=attention_mask.device, ) # Zero-out the places where we don't need to attend extended_attention_mask[batch_index, non_attended_tokens] = 0 attention_mask = torch.cat( (attention_mask, extended_attention_mask), dim=1 ) position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1 outputs = self.language_model( input_ids=None, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) logits = outputs[0] loss = None if labels is not None: # Shift so that tokens < n predict n if attention_mask is not None: shift_attention_mask = attention_mask[..., 1:] shift_logits = logits[..., :-1, :][ shift_attention_mask.to(logits.device) != 0 ].contiguous() shift_labels = labels[..., 1:][ shift_attention_mask.to(labels.device) != 0 ].contiguous() else: shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() # Flatten the tokens loss_fct = nn.CrossEntropyLoss() loss = loss_fct( shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1).to(shift_logits.device), ) if not return_dict: output = (logits,) + outputs[1:] return (loss,) + output if loss is not None else output return LlavaCausalLMOutputWithPast( loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) def prepare_inputs_for_generation( self, input_ids, past_key_values=None, inputs_embeds=None, pixel_values=None, attention_mask=None, **kwargs, ): if past_key_values is not None: if isinstance(past_key_values, InferenceParams): cache_length = past_key_values.max_seqlen past_length = past_key_values.seqlen_offset else: cache_length = past_length = past_key_values[0][0].shape[2] # Keep only the unprocessed tokens: # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where # some of the inputs are exclusivelly passed as part of the cache (e.g. when passing input_embeds as # input) if ( attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1] ): input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard # input_ids based on the past_length. elif past_length < input_ids.shape[1]: input_ids = input_ids[:, past_length:] # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens. elif self.config.image_token_index in input_ids: input_ids = input_ids[:, input_ids.shape[1] - 1 :] # If the cache has seen more tokens than it can hold, then the cache has a size limit. Let's discard the # older attention values, as their corresponding values are not part of the input. if cache_length < past_length and attention_mask is not None: attention_mask = attention_mask[ :, -(cache_length + input_ids.shape[1]) : ] position_ids = kwargs.get("position_ids", None) if attention_mask is not None and position_ids is None: # create position_ids on the fly for batch generation position_ids = attention_mask.long().cumsum(-1) - 1 position_ids.masked_fill_(attention_mask == 0, 1) if past_key_values: position_ids = position_ids[:, -input_ids.shape[1] :] # if `inputs_embeds` are passed, we only want to use them in the 1st generation step if inputs_embeds is not None and past_key_values is None: model_inputs = {"inputs_embeds": inputs_embeds} else: model_inputs = {"input_ids": input_ids} model_inputs.update( { "position_ids": position_ids, "past_key_values": past_key_values, "use_cache": kwargs.get("use_cache"), "attention_mask": attention_mask, "pixel_values": pixel_values, } ) return model_inputs def _reorder_cache(self, *args, **kwargs): return self.language_model._reorder_cache(*args, **kwargs)