|
from dataclasses import dataclass |
|
from typing import List, Optional, Tuple, Union |
|
import torch |
|
from torch.nn import CrossEntropyLoss |
|
|
|
from transformers import Qwen2Config, Qwen2Model, Qwen2ForCausalLM, PreTrainedModel |
|
from transformers.modeling_outputs import CausalLMOutputWithPast, ModelOutput |
|
from torch import nn |
|
import torch.nn.functional as F |
|
from .configuration_aimv2 import MonoConfig |
|
from .modeling_aimv2 import AIMv2Model, PixelShuffleConnector |
|
from transformers.generation import GenerationMixin |
|
|
|
""" |
|
|
|
Simple arch of Mono, used for pretrain vision encoder. |
|
|
|
""" |
|
|
|
|
|
@dataclass |
|
class MonoCausalLMOutputWithPast(ModelOutput): |
|
|
|
loss: Optional[torch.FloatTensor] = None |
|
logits: torch.FloatTensor = None |
|
past_key_values: Optional[List[torch.FloatTensor]] = None |
|
hidden_states: Optional[Tuple[torch.FloatTensor]] = None |
|
attentions: Optional[Tuple[torch.FloatTensor]] = None |
|
image_hidden_states: Optional[Tuple[torch.FloatTensor]] = None |
|
|
|
|
|
class MonoPretrainedModel(PreTrainedModel): |
|
config_class = MonoConfig |
|
base_model_prefix = "mono" |
|
|
|
_supports_sdpa = True |
|
_supports_flash_attn_2 = True |
|
_supports_cache_class = True |
|
supports_gradient_checkpointing = True |
|
|
|
|
|
|
|
class MonoForConditionalGeneration(MonoPretrainedModel, GenerationMixin): |
|
_tied_weights_keys = ["lm_head.weight"] |
|
|
|
def __init__(self, config: MonoConfig): |
|
|
|
MonoPretrainedModel.__init__(self, config) |
|
|
|
|
|
self.vision_tower = AIMv2Model(config=config.vision_config) |
|
self._attn_implementation = config._attn_implementation |
|
|
|
self._build_image_projection_layers(config) |
|
|
|
self.model = Qwen2Model(config) |
|
self.vocab_size = config.vocab_size |
|
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) |
|
|
|
self.pad_token_id = config.pad_token_id |
|
print(f"==> pad_token_id: {self.pad_token_id}") |
|
self.post_init() |
|
|
|
def _build_image_projection_layers(self, config): |
|
image_dim_out = config.vision_config.hidden_size |
|
dim_projection = config.hidden_size |
|
|
|
self.mm_projector = PixelShuffleConnector(image_dim_out, dim_projection) |
|
print(f"==> build mm_projector: {image_dim_out} -> {dim_projection}") |
|
|
|
def get_vision_tower(self): |
|
return self.vision_tower |
|
|
|
def get_input_embeddings(self): |
|
return self.model.get_input_embeddings() |
|
|
|
def resize_token_embeddings( |
|
self, new_num_tokens: Optional[int] = None, pad_to_multiple_of=None |
|
) -> nn.Embedding: |
|
model_embeds = self.model.resize_token_embeddings( |
|
new_num_tokens, pad_to_multiple_of |
|
) |
|
|
|
self.config.text_config.vocab_size = model_embeds.num_embeddings |
|
self.config.vocab_size = model_embeds.num_embeddings |
|
self.vocab_size = model_embeds.num_embeddings |
|
return model_embeds |
|
|
|
def _encode_image(self, pixel_values): |
|
|
|
batch_size, C, H, W = pixel_values.shape |
|
x = self.vision_tower(pixel_values, output_hidden_states=True) |
|
x = x.hidden_states[-2] |
|
|
|
x = self.mm_projector(x) |
|
|
|
return x |
|
|
|
def forward( |
|
self, |
|
input_ids: Optional[torch.LongTensor] = None, |
|
pixel_values: torch.FloatTensor = None, |
|
attention_mask: Optional[torch.FloatTensor] = None, |
|
position_ids: Optional[torch.LongTensor] = None, |
|
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, |
|
inputs_embeds: Optional[torch.FloatTensor] = None, |
|
labels: Optional[torch.LongTensor] = None, |
|
use_cache: Optional[bool] = None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
return_dict: Optional[bool] = None, |
|
cache_position=None, |
|
) -> Union[Tuple, CausalLMOutputWithPast]: |
|
output_attentions = ( |
|
output_attentions |
|
if output_attentions is not None |
|
else self.config.output_attentions |
|
) |
|
output_hidden_states = ( |
|
output_hidden_states |
|
if output_hidden_states is not None |
|
else self.config.output_hidden_states |
|
) |
|
return_dict = ( |
|
return_dict if return_dict is not None else self.config.use_return_dict |
|
) |
|
|
|
image_features = None |
|
if inputs_embeds is None: |
|
if pixel_values is not None: |
|
|
|
image_features = self._encode_image(pixel_values) |
|
|
|
if input_ids is not None: |
|
inputs_embeds, attention_mask, labels = ( |
|
self._get_input_embeds_with_image(input_ids, image_features, labels) |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
outputs = self.model( |
|
input_ids=None, |
|
attention_mask=attention_mask, |
|
position_ids=position_ids, |
|
past_key_values=past_key_values, |
|
inputs_embeds=inputs_embeds, |
|
use_cache=use_cache, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict, |
|
) |
|
|
|
hidden_states = outputs[0] |
|
logits = self.lm_head(hidden_states) |
|
|
|
loss = None |
|
if labels is not None: |
|
|
|
logits = logits.float() |
|
labels = labels.to(logits.device) |
|
|
|
if attention_mask is not None: |
|
|
|
|
|
shift_attention_mask = attention_mask[:, -(logits.shape[1] - 1) :].to( |
|
logits.device |
|
) |
|
shift_logits = logits[..., :-1, :][ |
|
shift_attention_mask != 0 |
|
].contiguous() |
|
|
|
shift_labels = labels[..., 1:][shift_attention_mask != 0].contiguous() |
|
|
|
else: |
|
shift_logits = logits[..., :-1, :].contiguous() |
|
shift_labels = labels[..., 1:].contiguous() |
|
|
|
loss_fct = CrossEntropyLoss() |
|
loss = loss_fct( |
|
shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1) |
|
) |
|
|
|
if not return_dict: |
|
output = (logits,) + outputs[1:] |
|
return (loss,) + output if loss is not None else output |
|
|
|
return MonoCausalLMOutputWithPast( |
|
loss=loss, |
|
logits=logits, |
|
past_key_values=outputs.past_key_values, |
|
hidden_states=outputs.hidden_states, |
|
attentions=outputs.attentions, |
|
) |
|
|
|
def _get_input_embeds_with_image(self, input_ids, image_features, labels=None): |
|
|
|
|
|
|
|
batch_size = input_ids.size(0) |
|
processed_embeds = [] |
|
processed_masks = [] |
|
labels_ignored_im = [] |
|
|
|
max_seq_len = 0 |
|
for idx in range(batch_size): |
|
seq = input_ids[idx] |
|
im_pos = (seq == -200).nonzero(as_tuple=True)[0] |
|
|
|
if im_pos.numel() > 0: |
|
im_pos = im_pos.item() |
|
before = seq[:im_pos] |
|
after = seq[im_pos + 1 :] |
|
|
|
before = before[before != -100] |
|
after = after[after != -100] |
|
|
|
before_embed = self.get_input_embeddings()(before) |
|
after_embed = self.get_input_embeddings()(after) |
|
|
|
seq_embed = torch.cat( |
|
[before_embed, image_features[idx], after_embed], dim=0 |
|
) |
|
new_seq_len = seq_embed.size(0) |
|
|
|
|
|
if labels is not None: |
|
image_token_ignore = torch.full( |
|
(image_features[idx].shape[0],), |
|
-100, |
|
dtype=torch.long, |
|
device=labels.device, |
|
) |
|
labels_ignored_im.append( |
|
torch.cat( |
|
( |
|
labels[idx][:im_pos], |
|
image_token_ignore, |
|
labels[idx][im_pos + 1 :], |
|
), |
|
dim=0, |
|
) |
|
) |
|
|
|
else: |
|
|
|
valid_tokens = seq[seq != -100] |
|
seq_embed = self.get_input_embeddings()(valid_tokens) |
|
new_seq_len = seq_embed.size(0) |
|
|
|
|
|
if new_seq_len > max_seq_len: |
|
max_seq_len = new_seq_len |
|
|
|
processed_embeds.append(seq_embed) |
|
attn_mask = torch.ones(new_seq_len, dtype=torch.bool, device=seq.device) |
|
processed_masks.append(attn_mask) |
|
|
|
|
|
inputs_embeds = torch.nn.utils.rnn.pad_sequence( |
|
processed_embeds, batch_first=True, padding_value=0.0 |
|
) |
|
attn_masks = torch.nn.utils.rnn.pad_sequence( |
|
processed_masks, batch_first=True, padding_value=0 |
|
) |
|
if labels is not None: |
|
labels_ignored_im = torch.stack(labels_ignored_im, dim=0) |
|
return inputs_embeds, attn_masks, labels_ignored_im |
|
return inputs_embeds, attn_masks, None |
|
|
|
@torch.no_grad() |
|
def generate(self, input_ids, pixel_values=None, **kwargs): |
|
|
|
|
|
if pixel_values is not None: |
|
image_features = self._encode_image(pixel_values) |
|
|
|
inputs_embeds, attention_mask, _ = self._get_input_embeds_with_image( |
|
input_ids, image_features |
|
) |
|
else: |
|
if input_ids is not None: |
|
inputs_embeds = self.get_input_embeddings()(input_ids) |
|
attention_mask = torch.ones( |
|
inputs_embeds.size(0), |
|
inputs_embeds.size(1), |
|
dtype=torch.bool, |
|
device=inputs_embeds.device, |
|
) |
|
|
|
|
|
return super().generate( |
|
input_ids=None, |
|
inputs_embeds=inputs_embeds, |
|
attention_mask=attention_mask, |
|
**kwargs, |
|
) |
|
|
|
def prepare_inputs_for_generation( |
|
self, |
|
input_ids, |
|
past_key_values=None, |
|
inputs_embeds=None, |
|
attention_mask=None, |
|
**kwargs, |
|
): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model_inputs = super().prepare_inputs_for_generation( |
|
input_ids, |
|
past_key_values=past_key_values, |
|
inputs_embeds=inputs_embeds, |
|
**kwargs, |
|
) |
|
return model_inputs |
|
|
|
def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): |
|
return self.model.shift_tokens_right(labels) |
|
|
|
def _reorder_cache(self, *args, **kwargs): |
|
return self.model._reorder_cache(*args, **kwargs) |
|
|