import json import logging import math import os from pathlib import Path from typing import Any, Dict, List, Optional, Tuple, Union import torch import torch.nn as nn import torch.nn.functional as F from open_clip.factory import get_model_config, load_state_dict from open_clip.model import (CLIPTextCfg, CLIPVisionCfg, _build_text_tower, _build_vision_tower, convert_to_custom_text_state_dict) from open_clip.transformer import text_global_pool from torch import nn from torchvision.ops import roi_align from transformers import (CONFIG_MAPPING, AutoConfig, AutoModel, AutoModelForCausalLM, GenerationConfig, PretrainedConfig, PreTrainedModel, StoppingCriteria, StoppingCriteriaList) from transformers.activations import ACT2FN from transformers.configuration_utils import PretrainedConfig from transformers.generation import GenerationConfig from transformers.modeling_utils import load_state_dict from transformers.utils import logging, strtobool from .convnext import ConvNextVisionEncoder logger = logging.get_logger(__name__) XLA_USE_BF16 = os.environ.get("XLA_USE_BF16", "0").upper() XLA_DOWNCAST_BF16 = os.environ.get("XLA_DOWNCAST_BF16", "0").upper() IGNORE_INDEX = -100 DEFAULT_PAD_TOKEN_INDEX = 0 IMAGE_TOKEN_INDEX = -200 DEFAULT_IMAGE_TOKEN = "" # For Objects DEFAULT_OBJECT_TOKEN = ">" DEFAULT_OBJECT_FEATURE_TOKEN = "" DEFAULT_OBJECT_INDEX = -300 # For Grounding DEFAULT_GROUNDING_START = "" DEFAULT_GROUNDING_END = "" DEFAULT_GROUNDING_OBJECTS_START = "" DEFAULT_GROUNDING_OBJECTS_END = "" def is_fsdp_enabled(): return ( torch.distributed.is_available() and torch.distributed.is_initialized() and strtobool(os.environ.get("ACCELERATE_USE_FSDP", "False")) == 1 and strtobool(os.environ.get("FSDP_CPU_RAM_EFFICIENT_LOADING", "False")) == 1 ) def get_token_slices(input_ids: torch.Tensor): """ Get slices of tokens based on special markers in the input tensor. Args: input_ids (torch.Tensor): A tensor of token IDs where IMAGE_TOKEN_INDEX represents an image token, DEFAULT_OBJECT_INDEX represents an object token, and all other values represent text tokens. Returns: List[Dict[str, Any]]: A list of dictionaries where each dictionary contains the type of the token slice ('text', 'image', 'object') and the span as a list of start and end indices. """ # define type markers and corresponding types type_map = {IMAGE_TOKEN_INDEX: "image", DEFAULT_OBJECT_INDEX: "object"} # find the positions of special markers image_indices = torch.where(input_ids == IMAGE_TOKEN_INDEX)[0] object_indices = torch.where(input_ids == DEFAULT_OBJECT_INDEX)[0] if len(object_indices) > 0: has_object = True else: has_object = False # merge all the positions of special markers special_indices = torch.cat((image_indices, object_indices)) special_indices, _ = torch.sort(special_indices) special_tokens = input_ids[special_indices] slices = [] start_idx = 0 for i, idx in enumerate(special_indices): if start_idx < idx: slices.append({"type": "text", "span": [start_idx, idx.item()]}) token_type = type_map[special_tokens[i].item()] slices.append({"type": token_type, "span": [idx.item(), idx.item() + 1]}) start_idx = idx.item() + 1 if start_idx < len(input_ids): slices.append({"type": "text", "span": [start_idx, len(input_ids)]}) return slices, has_object def prepare_inputs_labels_for_multimodal( llm, input_ids: torch.LongTensor = None, position_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, labels: Optional[torch.LongTensor] = None, pixel_values: Optional[torch.FloatTensor] = None, bbox_feats=None, extra_llm_input_embed: nn.Embedding = None, **kwargs, ): if pixel_values is None: return { "input_ids": input_ids, "position_ids": position_ids, "attention_mask": attention_mask, "past_key_values": past_key_values, "inputs_embeds": None, "labels": labels, } _labels = labels _position_ids = position_ids _attention_mask = attention_mask if attention_mask is None: attention_mask = torch.ones_like(input_ids, dtype=torch.bool) else: attention_mask = attention_mask.bool() if position_ids is None: position_ids = torch.arange( 0, input_ids.shape[1], dtype=torch.long, device=input_ids.device ) if labels is None: labels = torch.full_like(input_ids, IGNORE_INDEX) # remove the padding using attention_mask -- TODO: double check input_ids = [ cur_input_ids[cur_attention_mask] for cur_input_ids, cur_attention_mask in zip(input_ids, attention_mask) ] labels = [ cur_labels[cur_attention_mask] for cur_labels, cur_attention_mask in zip(labels, attention_mask) ] new_inputs_embeds = [] new_labels = [] cur_image_idx = 0 cur_object_idx = 0 for batch_idx, cur_input_ids in enumerate(input_ids): num_images = (cur_input_ids == IMAGE_TOKEN_INDEX).sum() if num_images == 0: cur_pixel_values = pixel_values[cur_image_idx] cur_inputs_embeds_1 = llm.get_input_embeddings()(cur_input_ids) cur_inputs_embeds = torch.cat( [cur_inputs_embeds_1, cur_pixel_values[0:0]], dim=0 ) new_inputs_embeds.append(cur_inputs_embeds) new_labels.append(labels[batch_idx]) cur_image_idx += 1 cur_object_idx += 1 continue cur_labels = labels[batch_idx] token_slices, has_object = get_token_slices(cur_input_ids) result_input_embeddings = [] result_output_labels = [] cur_gt_bnox_indice = 0 for slice in token_slices: slice_type = slice["type"] slice_span = slice["span"] if slice_type == "text": cur_input_ids_noim = cur_input_ids[slice_span[0] : slice_span[1]] cur_labels_noim = cur_labels[slice_span[0] : slice_span[1]] cur_input_embeds = llm.get_input_embeddings()(cur_input_ids_noim) result_input_embeddings.append(cur_input_embeds) result_output_labels.append(cur_labels_noim) elif slice_type == "image": cur_input_embeds = pixel_values[cur_image_idx] result_input_embeddings.append(cur_input_embeds) result_output_labels.append( torch.full( (cur_input_embeds.shape[0],), IGNORE_INDEX, device=cur_labels.device, dtype=cur_labels.dtype, ) ) cur_image_idx += 1 elif slice_type == "object": try: result_input_embeddings.append( bbox_feats[cur_object_idx][cur_gt_bnox_indice].unsqueeze(0) ) except: raise ValueError( f"current boxe_feats.shape: {bbox_feats[cur_object_idx].shape}, " ) cur_gt_bnox_indice += 1 result_output_labels.append( torch.full( (1,), IGNORE_INDEX, device=cur_labels.device, dtype=cur_labels.dtype, ) ) cur_object_idx += 1 result_input_embeddings = torch.cat(result_input_embeddings) result_output_labels = torch.cat(result_output_labels) assert len(result_output_labels) == len(result_input_embeddings) new_inputs_embeds.append(result_input_embeddings) new_labels.append(result_output_labels) # Combine them max_len = max(x.shape[0] for x in new_inputs_embeds) batch_size = len(new_inputs_embeds) new_inputs_embeds_padded = [] new_labels_padded = torch.full( (batch_size, max_len), IGNORE_INDEX, dtype=new_labels[0].dtype, device=new_labels[0].device, ) attention_mask = torch.zeros( (batch_size, max_len), dtype=attention_mask.dtype, device=attention_mask.device ) position_ids = torch.zeros( (batch_size, max_len), dtype=position_ids.dtype, device=position_ids.device ) for i, (cur_new_embed, cur_new_labels) in enumerate( zip(new_inputs_embeds, new_labels) ): cur_len = cur_new_embed.shape[0] new_inputs_embeds_padded.append( torch.cat( ( cur_new_embed, torch.zeros( (max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device, ), ), dim=0, ) ) if cur_len > 0: new_labels_padded[i, :cur_len] = cur_new_labels attention_mask[i, :cur_len] = True position_ids[i, :cur_len] = torch.arange( 0, cur_len, dtype=position_ids.dtype, device=position_ids.device ) new_inputs_embeds = torch.stack(new_inputs_embeds_padded, dim=0) if _labels is None: new_labels = None else: new_labels = new_labels_padded if _attention_mask is None: attention_mask = None else: attention_mask = attention_mask.to(dtype=_attention_mask.dtype) if _position_ids is None: position_ids = None return { "input_ids": None, "position_ids": position_ids, "attention_mask": attention_mask, "past_key_values": past_key_values, "inputs_embeds": new_inputs_embeds, "labels": new_labels, } class StopWordStoppingCriteria(StoppingCriteria): """StopWord stopping criteria.""" def __init__(self, tokenizer, stop_word): self.tokenizer = tokenizer self.stop_word = stop_word self.length = len(self.stop_word) def __call__(self, input_ids, *args, **kwargs) -> bool: cur_text = self.tokenizer.decode(input_ids[0]) cur_text = cur_text.replace('\r', '').replace('\n', '') return cur_text[-self.length:] == self.stop_word def get_stop_criteria( tokenizer, stop_words=[], ): stop_criteria = StoppingCriteriaList() for word in stop_words: stop_criteria.append(StopWordStoppingCriteria(tokenizer, word)) return stop_criteria class DualPathFuseModule(nn.Module): # change channel+gate+sum def __init__(self, low_res_dim, high_res_dim, zero_init=True): super().__init__() self.slow_conv = nn.Conv2d(high_res_dim, high_res_dim, 1) self.slow_proj = nn.Conv2d(high_res_dim, low_res_dim, 1) self.fast_conv = nn.Conv2d( low_res_dim, low_res_dim, 7, padding=3, groups=low_res_dim ) self.fast_proj = nn.Conv2d(low_res_dim, low_res_dim, 1) self.gate = nn.Sequential( nn.Linear(low_res_dim * 2, low_res_dim // 2), nn.GELU(), nn.Linear(low_res_dim // 2, 1), ) nn.init.xavier_uniform_(self.slow_conv.weight) nn.init.xavier_uniform_(self.fast_conv.weight) nn.init.zeros_(self.slow_conv.bias) nn.init.zeros_(self.fast_conv.bias) if zero_init: nn.init.zeros_(self.slow_proj.weight) nn.init.zeros_(self.fast_proj.weight) else: nn.init.xavier_uniform_(self.slow_proj.weight) nn.init.xavier_uniform_(self.fast_proj.weight) nn.init.zeros_(self.slow_proj.bias) nn.init.zeros_(self.fast_proj.bias) def forward(self, low_res_feat, high_res_feat, sampler=None): b, c, h, w = high_res_feat.shape # (2, 1536, 24, 24) _, _, d = low_res_feat.shape # (2, 576, 1024) high_res_feat = self.slow_proj( F.gelu(self.slow_conv(high_res_feat)) ) # (2, 1024, 24, 24) high_res_feat = high_res_feat.view(b, d, -1).transpose(1, 2) # (2, 576, 1024) dst_size = int(math.sqrt(low_res_feat.shape[1])) # 24 low_res_feat = low_res_feat.transpose(1, 2).view( b, d, dst_size, dst_size ) # (2, 1024, 24, 24) low_res_feat = low_res_feat + self.fast_proj( F.gelu(self.fast_conv(low_res_feat)) ) low_res_feat = low_res_feat.view(b, d, dst_size * dst_size).transpose( 1, 2 ) # (2, 576, 1024) gate = self.gate( torch.cat([low_res_feat, high_res_feat], -1).mean(1) ).unsqueeze( 1 ) # (2, 1, 1) low_res_feat = low_res_feat + high_res_feat * gate.tanh() return low_res_feat class ProjectorConfig(PretrainedConfig): model_type = "projector" _auto_class = "AutoConfig" def __init__( self, visual_hidden_size=4096, llm_hidden_size=4096, depth=2, hidden_act="gelu", bias=True, **kwargs, ): self.visual_hidden_size = visual_hidden_size self.llm_hidden_size = llm_hidden_size self.depth = depth self.hidden_act = hidden_act self.bias = bias super().__init__(**kwargs) class ProjectorModel(PreTrainedModel): _auto_class = "AutoModel" config_class = ProjectorConfig base_model_prefix = "model" supports_gradient_checkpointing = True _no_split_modules = [] def __init__(self, config: ProjectorConfig) -> None: super().__init__(config) self.gradient_checkpointing = False modules = [ nn.Linear( config.visual_hidden_size, config.llm_hidden_size, bias=config.bias ) ] for _ in range(1, config.depth): modules.append(ACT2FN[config.hidden_act]) modules.append( nn.Linear( config.llm_hidden_size, config.llm_hidden_size, bias=config.bias ) ) self.model = nn.Sequential(*modules) def enable_input_require_grads(self): def make_inputs_require_grad(module, input, output): output.requires_grad_(True) self.model.register_forward_hook(make_inputs_require_grad) def _set_gradient_checkpointing(self, module, value=False): if isinstance(module, ProjectorModel): module.gradient_checkpointing = value def forward(self, x): layer_outputs = self.model(x) return layer_outputs def gen_sineembed_for_position(pos_tensor, dim_of_pos_feats): """Generate sine position embedding from a position tensor. Args: pos_tensor (torch.Tensor): shape: [batch_size, N, 4]. the last dimension is [cx, cy, w, h] in normalized coordinates in range [0, 1]. out_dim (int): the output dimension of the position embedding. Returns: pos (torch.Tensor): shape: [batch_size, N, out_dim]. """ scale = 2 * math.pi dim_t = torch.arange( dim_of_pos_feats, dtype=torch.float32, device=pos_tensor.device ) dim_t = 10000 ** (2 * (dim_t // 2) / dim_of_pos_feats) x_embed = pos_tensor[:, :, 0] * scale y_embed = pos_tensor[:, :, 1] * scale pos_x = x_embed[:, :, None] / dim_t pos_y = y_embed[:, :, None] / dim_t pos_x = torch.stack( (pos_x[:, :, 0::2].sin(), pos_x[:, :, 1::2].cos()), dim=3 ).flatten(2) pos_y = torch.stack( (pos_y[:, :, 0::2].sin(), pos_y[:, :, 1::2].cos()), dim=3 ).flatten(2) if pos_tensor.size(-1) == 2: pos = torch.cat((pos_y, pos_x), dim=2) elif pos_tensor.size(-1) == 4: w_embed = pos_tensor[:, :, 2] * scale pos_w = w_embed[:, :, None] / dim_t pos_w = torch.stack( (pos_w[:, :, 0::2].sin(), pos_w[:, :, 1::2].cos()), dim=3 ).flatten(2) h_embed = pos_tensor[:, :, 3] * scale pos_h = h_embed[:, :, None] / dim_t pos_h = torch.stack( (pos_h[:, :, 0::2].sin(), pos_h[:, :, 1::2].cos()), dim=3 ).flatten(2) pos = torch.cat((pos_y, pos_x, pos_w, pos_h), dim=2) else: raise ValueError("Unknown pos_tensor shape(-1):{}".format(pos_tensor.size(-1))) return pos class MultiLevelROIVisualPrompt(nn.Module): """Initialize the MultiLevelROIVisualPrompt. Args: output_size (Optional[int]): The size of the output. Default is None. channel_per_level (List[int]): List of channels per level. Default is [192, 384, 768, 1536]. spatial_scale (Optional[float]): The spatial scale factor. Default is None. with_additional_projection (bool): Whether to use additional projection. Default is False. visual_prompt_hidden_size (int): The hidden size of the visual prompt. Default is 1024. add_pos_embedding (bool): Whether to add position embedding. Default is False. pos_embedding_dim (int): The dimension of the position embedding. Default is 1024. """ def __init__( self, output_size: int = None, channel_per_level: List[int] = [192, 384, 768, 1536], spatail_scale: float = None, visual_prompt_hidden_size: bool = 1024, add_pos_embedding: bool = False, pos_embedding_dim: int = 1024, ): super(MultiLevelROIVisualPrompt, self).__init__() self.output_size = output_size self.channel_per_level = channel_per_level self.spatail_scale = spatail_scale self.add_pos_embedding = add_pos_embedding self.pos_embedding_dim = pos_embedding_dim def __call__( self, multi_level_features: List[torch.Tensor], boxes: Union[torch.Tensor, List[torch.Tensor]], ) -> torch.Tensor: """Performs Region of Interest (RoI) Align operator on multi-level features. The RoI feature on each scale will go through a different linear layer for projection. Different RoI features will be summed up and then average pooled. Args: multi_level_features (Listp[Tensor[N, C, H, W]]): Feature maps from different levels boxes (Tensor[K, 5] or List[Tensor[L, 4]]): the box coordinates in (x1, y1, x2, y2) format where the regions will be taken from. Returns: Tensor[1, K, C]: The output tensor that has the shape KxC, where K is the number of RoIs """ boxes[0] = boxes[0].float() concat_multi_level_feature = [] max_height = max([feature.shape[2] for feature in multi_level_features]) max_width = max([feature.shape[3] for feature in multi_level_features]) # interpolate to the same size for level, feature in enumerate(multi_level_features): if level != 0: concat_multi_level_feature.append( F.interpolate( feature.float(), size=(max_height, max_width), mode="bilinear", align_corners=False, ) ) else: concat_multi_level_feature.append(feature.float()) concat_multi_level_feature = torch.cat(concat_multi_level_feature, dim=1) out_box_feat = roi_align( concat_multi_level_feature, boxes, output_size=self.output_size, spatial_scale=self.spatail_scale, ) # Average Pooling -> n,c -> 1,n,c out_box_feat = out_box_feat.mean(dim=(2, 3)).reshape( 1, out_box_feat.shape[0], out_box_feat.shape[1] ) if self.add_pos_embedding: # note that this boxes is in xyxy, unormalized format, so we need to normalize it first boxes = boxes[0] # (N, 4) boxes = boxes.to(out_box_feat.dtype) original_img_width = max_width / self.spatail_scale original_img_height = max_height / self.spatail_scale boxes[:, [0, 2]] = boxes[:, [0, 2]] / original_img_width boxes[:, [1, 3]] = boxes[:, [1, 3]] / original_img_height # convert from xyxy to cx, cy, w, h boxes[:, 2] = boxes[:, 2] - boxes[:, 0] boxes[:, 3] = boxes[:, 3] - boxes[:, 1] boxes[:, 0] = boxes[:, 0] + boxes[:, 2] / 2 boxes[:, 1] = boxes[:, 1] + boxes[:, 3] / 2 pos_embed = gen_sineembed_for_position( boxes.unsqueeze(0), self.pos_embedding_dim // 4 ) out_box_feat = out_box_feat + pos_embed return out_box_feat class ChatRexAuxConfig(PretrainedConfig): r""" This is the configuration class to store the configuration of ChatRexAux model. Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the documentation from [`PretrainedConfig`] for more information. Args: vision_config (`Union[AutoConfig, dict]`, *optional*, defaults to `CLIPVisionConfig`): The config object or dictionary of the vision backbone. vision_aux_config (`Union[AutoConfig, dict]`, *optional*, defaults to `OpenCLIPVisionTower`): visual_prompt_encoder (`Union[AutoConfig, dict]`, *optional*, defaults to `MultiLevelROIVisualPrompt`): text_config (`Union[AutoConfig, dict]`, *optional*, defaults to `LlamaConfig`): The config object or dictionary of the text backbone. ignore_index (`int`, *optional*, defaults to -100): The ignore index for the loss function. image_token_index (`int`, *optional*, defaults to 32000): The image token index to encode the image prompt. projector_hidden_act (`str`, *optional*, defaults to `"gelu"`): The activation function used by the multimodal projector. vision_feature_select_strategy (`str`, *optional*, defaults to `"default"`): The feature selection strategy used to select the vision feature from the vision backbone. Can be one of `"default"` or `"full"`. vision_feature_layer (`int`, *optional*, defaults to -2): The index of the layer to select the vision feature. Example: ```python >>> from transformers import LlavaForConditionalGeneration, LlavaConfig, CLIPVisionConfig, LlamaConfig >>> # Initializing a CLIP-vision config >>> vision_config = CLIPVisionConfig() >>> # Initializing a Llama config >>> text_config = LlamaConfig() >>> # Initializing a Llava llava-1.5-7b style configuration >>> configuration = LlavaConfig(vision_config, text_config) >>> # Initializing a model from the llava-1.5-7b style configuration >>> model = LlavaForConditionalGeneration(configuration) >>> # Accessing the model configuration >>> configuration = model.config ```""" model_type = "chatrex" is_composition = False def __init__( self, vision_config=None, vision_aux_config=None, visual_prompt_encoder_config=None, text_config=None, ignore_index=-100, image_token_index=32000, projector_hidden_act="gelu", vision_feature_select_strategy="default", vision_feature_layer=-2, projector_depth=2, visual_prompt_hidden_size=2880, **kwargs, ): self.ignore_index = ignore_index self.image_token_index = image_token_index self.projector_hidden_act = projector_hidden_act self.projector_depth = projector_depth self.visual_prompt_hidden_size = visual_prompt_hidden_size self.visual_prompt_encoder_config = visual_prompt_encoder_config if vision_feature_select_strategy not in ["default", "full"]: raise ValueError( "vision_feature_select_strategy should be one of 'default', 'full'." f"Got: {vision_feature_select_strategy}" ) self.vision_feature_select_strategy = vision_feature_select_strategy self.vision_feature_layer = vision_feature_layer if isinstance(vision_config, dict): vision_config["model_type"] = ( vision_config["model_type"] if "model_type" in vision_config else "clip_vision_model" ) vision_config = CONFIG_MAPPING[vision_config["model_type"]](**vision_config) elif vision_config is None: vision_config = CONFIG_MAPPING["clip_vision_model"]( intermediate_size=4096, hidden_size=1024, patch_size=14, image_size=336, num_hidden_layers=24, num_attention_heads=16, vocab_size=32000, projection_dim=768, ) self.vision_config = vision_config self.vision_aux_config = vision_aux_config if isinstance(text_config, dict): text_config["model_type"] = ( text_config["model_type"] if "model_type" in text_config else "llama" ) text_config = CONFIG_MAPPING[text_config["model_type"]](**text_config) elif text_config is None: text_config = CONFIG_MAPPING["llama"]() self.text_config = text_config super().__init__(**kwargs) class ChatRexAuxPreTrainedModel(PreTrainedModel): config_class = ChatRexAuxConfig base_model_prefix = "model" supports_gradient_checkpointing = True _no_split_modules = ["LlavaVisionAttention"] _skip_keys_device_placement = "past_key_values" _supports_flash_attn_2 = True _supports_cache_class = True # def _init_weights(self, module): # # important: this ported version of Llava isn't meant for training from scratch - only # # inference and fine-tuning - so the proper init weights code has been removed - the original codebase # # https://github.com/haotian-liu/LLaVA/tree/main/llava should serve for that purpose # std = ( # self.config.initializer_range # if hasattr(self.config, "initializer_range") # else self.config.text_config.initializer_range # ) # if hasattr(module, "class_embedding"): # module.class_embedding.data.normal_(mean=0.0, std=std) # if isinstance(module, (nn.Linear, nn.Conv2d)): # module.weight.data.normal_(mean=0.0, std=std) # if module.bias is not None: # module.bias.data.zero_() # elif isinstance(module, nn.Embedding): # module.weight.data.normal_(mean=0.0, std=std) # if module.padding_idx is not None: # module.weight.data[module.padding_idx].zero_() @property def _supports_sdpa(self): """ Retrieve language_model's attribute to check whether the model supports SDPA or not. """ return self.language_model._supports_sdpa class ChatRexAuxForConditionalGeneration(ChatRexAuxPreTrainedModel): def __init__(self, config: ChatRexAuxConfig): super().__init__(config) # low resolusion vision encoder self.vision_encoder = AutoModel.from_config(config.vision_config) # high resolusion vision encoder self.vision_encoder_aux = ConvNextVisionEncoder() # vision projector projector_config = ProjectorConfig( visual_hidden_size=config.vision_config.hidden_size, llm_hidden_size=config.text_config.hidden_size, depth=config.projector_depth, ) self.projector = ProjectorModel(projector_config) # visual prompt encoder vp_projector_config = ProjectorConfig( visual_hidden_size=config.visual_prompt_hidden_size, llm_hidden_size=config.text_config.hidden_size, depth=config.projector_depth, ) self.vp_projector = ProjectorModel(vp_projector_config) # fuser self.fuser = DualPathFuseModule( low_res_dim=config.vision_config.hidden_size, high_res_dim=1536, ) # visual prompt encoder self.vp_encoder = MultiLevelROIVisualPrompt( output_size=7, channel_per_level=[192, 384, 768, 1536], spatail_scale=192 / 768, add_pos_embedding=True, pos_embedding_dim=2880, ) # genconfig self.gen_config = None self.vocab_size = config.text_config.vocab_size self.llm = AutoModelForCausalLM.from_config( config.text_config, attn_implementation=config._attn_implementation ) self.pad_token_id = ( self.config.pad_token_id if self.config.pad_token_id is not None else -1 ) self.post_init() def _prepare_data_for_llm(self, data): if "pixel_values" in data: visual_outputs = self.vision_encoder( data["pixel_values"].to(self.vision_encoder.dtype), output_hidden_states=True, ) if type(self.vision_encoder).__name__ in [ "CLIPVisionModel", "CLIPVisionModelAnyRes", ]: visual_outputs = visual_outputs.hidden_states[-2][ :, 1: ] elif type(self.vision_encoder).__name__ == "SiglipVisionModel": visual_outputs = visual_outputs.hidden_states[-2] else: raise NotImplementedError # aux encoder if self.vision_encoder_aux is not None: pixels_aux = [] for pixels in data["pixel_values_aux"]: if pixels.dim() == 3: pixels = pixels.unsqueeze(0) elif pixels.dim() == 4: pixels = pixels.permute(1, 0, 2, 3) pixels_aux.append(pixels) visual_outputs_aux = torch.cat( pixels_aux, dim=0 ) # shape (2, 3, 768, 768) aux_output = self.vision_encoder_aux( visual_outputs_aux ) visual_outputs_aux = aux_output["image_features"] last_feat = aux_output["last_feat"] # (B, 1536, 24, 24) # fuser fuse_features = self.fuser( low_res_feat=visual_outputs, high_res_feat=last_feat ) # (2, 576, 1024) pixel_values = self.projector(fuse_features) data["pixel_values"] = pixel_values # extract visual prompt features bbox_visual_outputs = [] if "gt_boxes" in data: for batch_idx, boxes in enumerate(data["gt_boxes"]): if len(boxes) == 0: bbox_visual_outputs.append(None) continue multi_level_aux_features = [ visual_output_aux[batch_idx].unsqueeze(0) for visual_output_aux in visual_outputs_aux ] boxes = boxes.to(torch.float32) out_vp_feat = self.vp_encoder( multi_level_aux_features, [boxes], ).squeeze(0) out_vp_feat = out_vp_feat.to(pixel_values.dtype) out_vp_feat = self.vp_projector(out_vp_feat) bbox_visual_outputs.append(out_vp_feat) # b,n,c data["bbox_feats"] = bbox_visual_outputs data = prepare_inputs_labels_for_multimodal(llm=self.llm, **data) return data def generate(self, data_dict: Dict[str, Any], gen_config=None, tokenizer=None): """Perform inference on the given data. Args: data_dict (Dict[str, Any]): The data to perform inference on. Returns: str: The answer to the question. """ data_dict = self._prepare_data_for_llm(data_dict) data_dict["inputs_embeds"] = data_dict["inputs_embeds"].to(self.llm.dtype) stop_criteria = get_stop_criteria( tokenizer=tokenizer, stop_words=[] ) generate_output = self.llm.generate( **data_dict, generation_config=self.gen_config if gen_config is None else gen_config, streamer=None, bos_token_id=tokenizer.bos_token_id, stopping_criteria=stop_criteria, ) print(f'generate_output:', generate_output) prediction = tokenizer.decode( generate_output[0], skip_special_tokens=False ).strip() prediction = prediction.replace("", "").replace("", "").strip() return prediction AutoConfig.register("chatrex", ChatRexAuxConfig) AutoModelForCausalLM.register(ChatRexAuxConfig, ChatRexAuxForConditionalGeneration)