from dataclasses import dataclass from typing import List, Optional, Tuple, Union import torch from torch.nn import CrossEntropyLoss from transformers import Qwen2Config, Qwen2Model, Qwen2ForCausalLM, PreTrainedModel from transformers.modeling_outputs import CausalLMOutputWithPast, ModelOutput from torch import nn import torch.nn.functional as F from .configuration_aimv2 import MonoConfig from .modeling_aimv2 import AIMv2Model, PixelShuffleConnector from transformers.generation import GenerationMixin """ Simple arch of Mono, used for pretrain vision encoder. """ @dataclass class MonoCausalLMOutputWithPast(ModelOutput): loss: Optional[torch.FloatTensor] = None logits: torch.FloatTensor = None past_key_values: Optional[List[torch.FloatTensor]] = None hidden_states: Optional[Tuple[torch.FloatTensor]] = None attentions: Optional[Tuple[torch.FloatTensor]] = None image_hidden_states: Optional[Tuple[torch.FloatTensor]] = None class MonoPretrainedModel(PreTrainedModel): config_class = MonoConfig base_model_prefix = "mono" # main_input_name = "pixel_values" _supports_sdpa = True _supports_flash_attn_2 = True _supports_cache_class = True supports_gradient_checkpointing = True # class MonoForConditionalGeneration(MonoPretrainedModel, Qwen2ForCausalLM): class MonoForConditionalGeneration(MonoPretrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] def __init__(self, config: MonoConfig): # super().__init__(config) MonoPretrainedModel.__init__(self, config) # super(Qwen2ForCausalLM, self).__init__(config) self.vision_tower = AIMv2Model(config=config.vision_config) self._attn_implementation = config._attn_implementation self._build_image_projection_layers(config) self.model = Qwen2Model(config) self.vocab_size = config.vocab_size self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) self.pad_token_id = config.pad_token_id print(f"==> pad_token_id: {self.pad_token_id}") self.post_init() def _build_image_projection_layers(self, config): image_dim_out = config.vision_config.hidden_size dim_projection = config.hidden_size # self.mm_projector = nn.Linear(image_dim_out, dim_projection) self.mm_projector = PixelShuffleConnector(image_dim_out, dim_projection) print(f"==> build mm_projector: {image_dim_out} -> {dim_projection}") def get_vision_tower(self): return self.vision_tower def get_input_embeddings(self): return self.model.get_input_embeddings() def resize_token_embeddings( self, new_num_tokens: Optional[int] = None, pad_to_multiple_of=None ) -> nn.Embedding: model_embeds = self.model.resize_token_embeddings( new_num_tokens, pad_to_multiple_of ) # update vocab size self.config.text_config.vocab_size = model_embeds.num_embeddings self.config.vocab_size = model_embeds.num_embeddings self.vocab_size = model_embeds.num_embeddings return model_embeds def _encode_image(self, pixel_values): # print(f"pixel_values: {pixel_values}") batch_size, C, H, W = pixel_values.shape x = self.vision_tower(pixel_values, output_hidden_states=True) x = x.hidden_states[-2] # print(x) x = self.mm_projector(x) # print(f"image features: {x}") return x def forward( self, input_ids: Optional[torch.LongTensor] = None, pixel_values: torch.FloatTensor = None, attention_mask: Optional[torch.FloatTensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position=None, ) -> Union[Tuple, CausalLMOutputWithPast]: output_attentions = ( output_attentions if output_attentions is not None else self.config.output_attentions ) output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = ( return_dict if return_dict is not None else self.config.use_return_dict ) image_features = None if inputs_embeds is None: if pixel_values is not None: # (batch_size, num_image_tokens, hidden_size) image_features = self._encode_image(pixel_values) if input_ids is not None: inputs_embeds, attention_mask, labels = ( self._get_input_embeds_with_image(input_ids, image_features, labels) ) # print(f'before inputs_embeds: {inputs_embeds.shape}') # print(f'before labels: {labels.shape}') # padding all to normal sequence length only train # if labels is not None: # input_length = inputs_embeds.shape[1] # label_length = labels.shape[1] # if labels is not None: # labels = F.pad(labels, (input_length, 0), value=-100) # if inputs_embeds is not None: # # append embeds and attn_mask to labels length # padding = torch.zeros( # inputs_embeds.shape[0], # label_length, # inputs_embeds.shape[2], # dtype=inputs_embeds.dtype, # device=inputs_embeds.device, # ) # inputs_embeds = torch.cat([inputs_embeds, padding], dim=1) # attention_mask = attention_mask.to(inputs_embeds.dtype) # attention_mask = F.pad(attention_mask, (0, label_length), value=0) # if position_ids is None: # position_ids = torch.arange( # input_length + label_length, device=inputs_embeds.device # ) # position_ids = position_ids.unsqueeze(0).expand( # inputs_embeds.shape[0], -1 # ) # position_ids[input_length:] = 0 # print(f"position_ids {position_ids}") # print(f"labels {labels.shape}") # print(f"labels {labels}") # print(f"inputs_embeds {inputs_embeds.shape}") # print(f"inputs_embeds {inputs_embeds}") # print(f"attention_mask {attention_mask.shape}") # print(f"attention_mask {attention_mask}") outputs = self.model( input_ids=None, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) hidden_states = outputs[0] logits = self.lm_head(hidden_states) loss = None if labels is not None: # Upcast to float if we need to compute the loss to avoid potential precision issues logits = logits.float() labels = labels.to(logits.device) # Shift so that tokens < n predict n if attention_mask is not None: # we use the input attention mask to shift the logits and labels, because it is 2D. # we also crop attn mask in case it is longer, which happens in PrefixTuning with peft shift_attention_mask = attention_mask[:, -(logits.shape[1] - 1) :].to( logits.device ) shift_logits = logits[..., :-1, :][ shift_attention_mask != 0 ].contiguous() # print(f"shift_logits: {shift_logits.shape}") shift_labels = labels[..., 1:][shift_attention_mask != 0].contiguous() # print(f"shift_labels: {shift_labels.shape}") else: shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() # Flatten the tokens loss_fct = CrossEntropyLoss() loss = loss_fct( shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1) ) if not return_dict: output = (logits,) + outputs[1:] return (loss,) + output if loss is not None else output return MonoCausalLMOutputWithPast( loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) def _get_input_embeds_with_image(self, input_ids, image_features, labels=None): # 1. replace image token with features; 2. replace -100 in input_ids into zeroes # 3. handling right attention_mask # not complicated, you can understand. batch_size = input_ids.size(0) processed_embeds = [] processed_masks = [] labels_ignored_im = [] max_seq_len = 0 for idx in range(batch_size): seq = input_ids[idx] im_pos = (seq == -200).nonzero(as_tuple=True)[0] if im_pos.numel() > 0: im_pos = im_pos.item() before = seq[:im_pos] after = seq[im_pos + 1 :] # Exclude -100 tokens (maybe, input_ids padding with -100 intentionly) before = before[before != -100] after = after[after != -100] # Get embeddings for before and after before_embed = self.get_input_embeddings()(before) after_embed = self.get_input_embeddings()(after) # Concatenate before, image features, and after seq_embed = torch.cat( [before_embed, image_features[idx], after_embed], dim=0 ) new_seq_len = seq_embed.size(0) # if labels not None, change image token into -100, keep image tokens length if labels is not None: image_token_ignore = torch.full( (image_features[idx].shape[0],), -100, dtype=torch.long, device=labels.device, ) labels_ignored_im.append( torch.cat( ( labels[idx][:im_pos], image_token_ignore, labels[idx][im_pos + 1 :], ), dim=0, ) ) else: # Exclude -100 tokens valid_tokens = seq[seq != -100] seq_embed = self.get_input_embeddings()(valid_tokens) new_seq_len = seq_embed.size(0) # Update the maximum sequence length if new_seq_len > max_seq_len: max_seq_len = new_seq_len processed_embeds.append(seq_embed) attn_mask = torch.ones(new_seq_len, dtype=torch.bool, device=seq.device) processed_masks.append(attn_mask) # rest embedding is 0, rest mask is False, just padding it inputs_embeds = torch.nn.utils.rnn.pad_sequence( processed_embeds, batch_first=True, padding_value=0.0 ) attn_masks = torch.nn.utils.rnn.pad_sequence( processed_masks, batch_first=True, padding_value=0 ) if labels is not None: labels_ignored_im = torch.stack(labels_ignored_im, dim=0) return inputs_embeds, attn_masks, labels_ignored_im return inputs_embeds, attn_masks, None @torch.no_grad() def generate(self, input_ids, pixel_values=None, **kwargs): # print(input_ids) # print(f"pixel_values {pixel_values}") if pixel_values is not None: image_features = self._encode_image(pixel_values) # print(f"image_features {image_features}") inputs_embeds, attention_mask, _ = self._get_input_embeds_with_image( input_ids, image_features ) else: if input_ids is not None: inputs_embeds = self.get_input_embeddings()(input_ids) attention_mask = torch.ones( inputs_embeds.size(0), inputs_embeds.size(1), dtype=torch.bool, device=inputs_embeds.device, ) # print(f"inputs_embeds: {inputs_embeds}") return super().generate( input_ids=None, inputs_embeds=inputs_embeds, attention_mask=attention_mask, **kwargs, ) def prepare_inputs_for_generation( self, input_ids, past_key_values=None, inputs_embeds=None, attention_mask=None, **kwargs, ): # cut input_ids if past_key_values is used # if past_key_values is not None: # past_length = past_key_values[0][0].shape[2] # # Some generation methods already pass only the last input ID # if input_ids.shape[1] > past_length: # input_ids = input_ids[:, -1:] # elif input_ids.shape[1] == 1: # pass # else: # # Default to old behavior: keep only final ID # input_ids = input_ids[:, -1:] model_inputs = super().prepare_inputs_for_generation( input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs, ) return model_inputs def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): return self.model.shift_tokens_right(labels) def _reorder_cache(self, *args, **kwargs): return self.model._reorder_cache(*args, **kwargs)