aimv2-4b-ve / modeling_mono.py
lucasjin's picture
Upload folder using huggingface_hub
f3588fe verified
from dataclasses import dataclass
from typing import List, Optional, Tuple, Union
import torch
from torch.nn import CrossEntropyLoss
from transformers import Qwen2Config, Qwen2Model, Qwen2ForCausalLM, PreTrainedModel
from transformers.modeling_outputs import CausalLMOutputWithPast, ModelOutput
from torch import nn
import torch.nn.functional as F
from .configuration_aimv2 import MonoConfig
from .modeling_aimv2 import AIMv2Model, PixelShuffleConnector
from transformers.generation import GenerationMixin
"""
Simple arch of Mono, used for pretrain vision encoder.
"""
@dataclass
class MonoCausalLMOutputWithPast(ModelOutput):
loss: Optional[torch.FloatTensor] = None
logits: torch.FloatTensor = None
past_key_values: Optional[List[torch.FloatTensor]] = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
image_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
class MonoPretrainedModel(PreTrainedModel):
config_class = MonoConfig
base_model_prefix = "mono"
# main_input_name = "pixel_values"
_supports_sdpa = True
_supports_flash_attn_2 = True
_supports_cache_class = True
supports_gradient_checkpointing = True
# class MonoForConditionalGeneration(MonoPretrainedModel, Qwen2ForCausalLM):
class MonoForConditionalGeneration(MonoPretrainedModel, GenerationMixin):
_tied_weights_keys = ["lm_head.weight"]
def __init__(self, config: MonoConfig):
# super().__init__(config)
MonoPretrainedModel.__init__(self, config)
# super(Qwen2ForCausalLM, self).__init__(config)
self.vision_tower = AIMv2Model(config=config.vision_config)
self._attn_implementation = config._attn_implementation
self._build_image_projection_layers(config)
self.model = Qwen2Model(config)
self.vocab_size = config.vocab_size
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
self.pad_token_id = config.pad_token_id
print(f"==> pad_token_id: {self.pad_token_id}")
self.post_init()
def _build_image_projection_layers(self, config):
image_dim_out = config.vision_config.hidden_size
dim_projection = config.hidden_size
# self.mm_projector = nn.Linear(image_dim_out, dim_projection)
self.mm_projector = PixelShuffleConnector(image_dim_out, dim_projection)
print(f"==> build mm_projector: {image_dim_out} -> {dim_projection}")
def get_vision_tower(self):
return self.vision_tower
def get_input_embeddings(self):
return self.model.get_input_embeddings()
def resize_token_embeddings(
self, new_num_tokens: Optional[int] = None, pad_to_multiple_of=None
) -> nn.Embedding:
model_embeds = self.model.resize_token_embeddings(
new_num_tokens, pad_to_multiple_of
)
# update vocab size
self.config.text_config.vocab_size = model_embeds.num_embeddings
self.config.vocab_size = model_embeds.num_embeddings
self.vocab_size = model_embeds.num_embeddings
return model_embeds
def _encode_image(self, pixel_values):
# print(f"pixel_values: {pixel_values}")
batch_size, C, H, W = pixel_values.shape
x = self.vision_tower(pixel_values, output_hidden_states=True)
x = x.hidden_states[-2]
# print(x)
x = self.mm_projector(x)
# print(f"image features: {x}")
return x
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
pixel_values: torch.FloatTensor = None,
attention_mask: Optional[torch.FloatTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position=None,
) -> Union[Tuple, CausalLMOutputWithPast]:
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
image_features = None
if inputs_embeds is None:
if pixel_values is not None:
# (batch_size, num_image_tokens, hidden_size)
image_features = self._encode_image(pixel_values)
if input_ids is not None:
inputs_embeds, attention_mask, labels = (
self._get_input_embeds_with_image(input_ids, image_features, labels)
)
# print(f'before inputs_embeds: {inputs_embeds.shape}')
# print(f'before labels: {labels.shape}')
# padding all to normal sequence length only train
# if labels is not None:
# input_length = inputs_embeds.shape[1]
# label_length = labels.shape[1]
# if labels is not None:
# labels = F.pad(labels, (input_length, 0), value=-100)
# if inputs_embeds is not None:
# # append embeds and attn_mask to labels length
# padding = torch.zeros(
# inputs_embeds.shape[0],
# label_length,
# inputs_embeds.shape[2],
# dtype=inputs_embeds.dtype,
# device=inputs_embeds.device,
# )
# inputs_embeds = torch.cat([inputs_embeds, padding], dim=1)
# attention_mask = attention_mask.to(inputs_embeds.dtype)
# attention_mask = F.pad(attention_mask, (0, label_length), value=0)
# if position_ids is None:
# position_ids = torch.arange(
# input_length + label_length, device=inputs_embeds.device
# )
# position_ids = position_ids.unsqueeze(0).expand(
# inputs_embeds.shape[0], -1
# )
# position_ids[input_length:] = 0
# print(f"position_ids {position_ids}")
# print(f"labels {labels.shape}")
# print(f"labels {labels}")
# print(f"inputs_embeds {inputs_embeds.shape}")
# print(f"inputs_embeds {inputs_embeds}")
# print(f"attention_mask {attention_mask.shape}")
# print(f"attention_mask {attention_mask}")
outputs = self.model(
input_ids=None,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = outputs[0]
logits = self.lm_head(hidden_states)
loss = None
if labels is not None:
# Upcast to float if we need to compute the loss to avoid potential precision issues
logits = logits.float()
labels = labels.to(logits.device)
# Shift so that tokens < n predict n
if attention_mask is not None:
# we use the input attention mask to shift the logits and labels, because it is 2D.
# we also crop attn mask in case it is longer, which happens in PrefixTuning with peft
shift_attention_mask = attention_mask[:, -(logits.shape[1] - 1) :].to(
logits.device
)
shift_logits = logits[..., :-1, :][
shift_attention_mask != 0
].contiguous()
# print(f"shift_logits: {shift_logits.shape}")
shift_labels = labels[..., 1:][shift_attention_mask != 0].contiguous()
# print(f"shift_labels: {shift_labels.shape}")
else:
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss()
loss = loss_fct(
shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)
)
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
return MonoCausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
def _get_input_embeds_with_image(self, input_ids, image_features, labels=None):
# 1. replace image token with features; 2. replace -100 in input_ids into zeroes
# 3. handling right attention_mask
# not complicated, you can understand.
batch_size = input_ids.size(0)
processed_embeds = []
processed_masks = []
labels_ignored_im = []
max_seq_len = 0
for idx in range(batch_size):
seq = input_ids[idx]
im_pos = (seq == -200).nonzero(as_tuple=True)[0]
if im_pos.numel() > 0:
im_pos = im_pos.item()
before = seq[:im_pos]
after = seq[im_pos + 1 :]
# Exclude -100 tokens (maybe, input_ids padding with -100 intentionly)
before = before[before != -100]
after = after[after != -100]
# Get embeddings for before and after
before_embed = self.get_input_embeddings()(before)
after_embed = self.get_input_embeddings()(after)
# Concatenate before, image features, and after
seq_embed = torch.cat(
[before_embed, image_features[idx], after_embed], dim=0
)
new_seq_len = seq_embed.size(0)
# if labels not None, change image token into -100, keep image tokens length
if labels is not None:
image_token_ignore = torch.full(
(image_features[idx].shape[0],),
-100,
dtype=torch.long,
device=labels.device,
)
labels_ignored_im.append(
torch.cat(
(
labels[idx][:im_pos],
image_token_ignore,
labels[idx][im_pos + 1 :],
),
dim=0,
)
)
else:
# Exclude -100 tokens
valid_tokens = seq[seq != -100]
seq_embed = self.get_input_embeddings()(valid_tokens)
new_seq_len = seq_embed.size(0)
# Update the maximum sequence length
if new_seq_len > max_seq_len:
max_seq_len = new_seq_len
processed_embeds.append(seq_embed)
attn_mask = torch.ones(new_seq_len, dtype=torch.bool, device=seq.device)
processed_masks.append(attn_mask)
# rest embedding is 0, rest mask is False, just padding it
inputs_embeds = torch.nn.utils.rnn.pad_sequence(
processed_embeds, batch_first=True, padding_value=0.0
)
attn_masks = torch.nn.utils.rnn.pad_sequence(
processed_masks, batch_first=True, padding_value=0
)
if labels is not None:
labels_ignored_im = torch.stack(labels_ignored_im, dim=0)
return inputs_embeds, attn_masks, labels_ignored_im
return inputs_embeds, attn_masks, None
@torch.no_grad()
def generate(self, input_ids, pixel_values=None, **kwargs):
# print(input_ids)
# print(f"pixel_values {pixel_values}")
if pixel_values is not None:
image_features = self._encode_image(pixel_values)
# print(f"image_features {image_features}")
inputs_embeds, attention_mask, _ = self._get_input_embeds_with_image(
input_ids, image_features
)
else:
if input_ids is not None:
inputs_embeds = self.get_input_embeddings()(input_ids)
attention_mask = torch.ones(
inputs_embeds.size(0),
inputs_embeds.size(1),
dtype=torch.bool,
device=inputs_embeds.device,
)
# print(f"inputs_embeds: {inputs_embeds}")
return super().generate(
input_ids=None,
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
**kwargs,
)
def prepare_inputs_for_generation(
self,
input_ids,
past_key_values=None,
inputs_embeds=None,
attention_mask=None,
**kwargs,
):
# cut input_ids if past_key_values is used
# if past_key_values is not None:
# past_length = past_key_values[0][0].shape[2]
# # Some generation methods already pass only the last input ID
# if input_ids.shape[1] > past_length:
# input_ids = input_ids[:, -1:]
# elif input_ids.shape[1] == 1:
# pass
# else:
# # Default to old behavior: keep only final ID
# input_ids = input_ids[:, -1:]
model_inputs = super().prepare_inputs_for_generation(
input_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
**kwargs,
)
return model_inputs
def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
return self.model.shift_tokens_right(labels)
def _reorder_cache(self, *args, **kwargs):
return self.model._reorder_cache(*args, **kwargs)