import importlib import math from typing import TYPE_CHECKING, Optional, Tuple, Union, Callable, List, Any, Generator import torch import torch.nn.functional as F import torch.utils.checkpoint from torch.cuda.amp import autocast from torch.nn import CrossEntropyLoss from transformers import PreTrainedTokenizer, GenerationConfig, StoppingCriteriaList from transformers.generation.logits_process import LogitsProcessorList if TYPE_CHECKING: from transformers.generation.streamers import BaseStreamer from transformers.generation.utils import GenerateOutput from transformers.modeling_outputs import ( BaseModelOutputWithPast, CausalLMOutputWithPast, ) from transformers.modeling_utils import PreTrainedModel from transformers.utils import logging from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training try: from einops import rearrange except ImportError: rearrange = None from torch import nn from monkey_model.modeling_qwen import QWenModel,QWenPreTrainedModel,QWenLMHeadModel SUPPORT_CUDA = torch.cuda.is_available() SUPPORT_BF16 = SUPPORT_CUDA and torch.cuda.is_bf16_supported() SUPPORT_FP16 = SUPPORT_CUDA and torch.cuda.get_device_capability(0)[0] >= 7 logger = logging.get_logger(__name__) class MonkeyModel(QWenModel): def __init__(self, config): super().__init__(config) def forward( self, input_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, attention_mask: Optional[torch.FloatTensor] = None, token_type_ids: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None, head_mask: Optional[torch.FloatTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.Tensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ): if past_key_values is None and torch.any(input_ids == self.config.visual['image_start_id']): bos_pos = torch.where(input_ids == self.config.visual['image_start_id']) eos_pos = torch.where(input_ids == self.config.visual['image_start_id'] + 1) assert (bos_pos[0] == eos_pos[0]).all() img_pos = torch.stack((bos_pos[0], bos_pos[1], eos_pos[1]), dim=1) images = [] for i, a, b in img_pos: image = input_ids[i][a + 1 : b - 1].tolist() image = image[ : image.index(self.config.visual['image_start_id'] + 2)] images.append(bytes(image).decode('utf-8')) windows,images_448 = self.visual.encode(images) patch_list = [] lora_idx = 0 for col in windows: for image_patch in col: patch_list.append(self.visual(image_patch,idx=lora_idx)) lora_idx += 1 global_feat = self.visual(images_448) local_feat = torch.cat(patch_list,dim=1) images = torch.cat([local_feat,global_feat],dim=1) assert images.shape[0] == len(images) else: images = None return super().forward(input_ids, past_key_values, attention_mask, token_type_ids, position_ids, head_mask,inputs_embeds, encoder_hidden_states, encoder_attention_mask, use_cache, output_attentions, output_hidden_states, return_dict, images) class MonkeyLMHeadModel(QWenLMHeadModel): _keys_to_ignore_on_load_missing = [r"h\.\d+\.attn\.rotary_emb\.inv_freq"] _keys_to_ignore_on_load_unexpected = [r"h\.\d+\.attn\.masked_bias"] def __init__(self, config): super().__init__(config) assert ( config.bf16 + config.fp16 + config.fp32 <= 1 ), "Only one of \"bf16\", \"fp16\", \"fp32\" can be true" autoset_precision = config.bf16 + config.fp16 + config.fp32 == 0 if autoset_precision: if SUPPORT_BF16: logger.warn( "The model is automatically converting to bf16 for faster inference. " "If you want to disable the automatic precision, please manually add bf16/fp16/fp32=True to \"AutoModelForCausalLM.from_pretrained\"." ) config.bf16 = True elif SUPPORT_FP16: logger.warn( "The model is automatically converting to fp16 for faster inference. " "If you want to disable the automatic precision, please manually add bf16/fp16/fp32=True to \"AutoModelForCausalLM.from_pretrained\"." ) config.fp16 = True else: config.fp32 = True if config.bf16 and SUPPORT_CUDA and not SUPPORT_BF16: logger.warn("Your device does NOT seem to support bf16, you can switch to fp16 or fp32 by by passing fp16/fp32=True in \"AutoModelForCausalLM.from_pretrained\".") if config.fp16 and SUPPORT_CUDA and not SUPPORT_FP16: logger.warn("Your device does NOT support faster inference with fp16, please switch to fp32 which is likely to be faster") if config.fp32: if SUPPORT_BF16: logger.warn("Your device support faster inference by passing bf16=True in \"AutoModelForCausalLM.from_pretrained\".") elif SUPPORT_FP16: logger.warn("Your device support faster inference by passing fp16=True in \"AutoModelForCausalLM.from_pretrained\".") self.transformer = MonkeyModel(config) self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) if config.bf16: self.transformer.bfloat16() self.lm_head.bfloat16() if config.fp16: self.transformer.half() self.lm_head.half() self.post_init()