import logging import math import os import re from typing import List, Optional, Union import torch import torch.nn as nn import torch.nn.functional as F from torch import nn from torchvision.ops import roi_align from transformers import ( AutoConfig, AutoModel, AutoModelForCausalLM, Qwen2Config, Qwen2ForCausalLM, StoppingCriteria, StoppingCriteriaList, ) from transformers.generation.utils import GenerateOutput from transformers.utils import logging, strtobool from .clip import CLIPVisionTower from .convnext import ConvNextVisionEncoder logger = logging.get_logger(__name__) XLA_USE_BF16 = os.environ.get("XLA_USE_BF16", "0").upper() XLA_DOWNCAST_BF16 = os.environ.get("XLA_DOWNCAST_BF16", "0").upper() IGNORE_INDEX = -100 DEFAULT_PAD_TOKEN_INDEX = 0 IMAGE_TOKEN_INDEX = -200 DEFAULT_IMAGE_TOKEN = "" # For Objects DEFAULT_OBJECT_TOKEN = ">" DEFAULT_OBJECT_FEATURE_TOKEN = "" DEFAULT_OBJECT_INDEX = -300 # For Grounding DEFAULT_GROUNDING_START = "" DEFAULT_GROUNDING_END = "" DEFAULT_GROUNDING_OBJECTS_START = "" DEFAULT_GROUNDING_OBJECTS_END = "" def is_fsdp_enabled(): return ( torch.distributed.is_available() and torch.distributed.is_initialized() and strtobool(os.environ.get("ACCELERATE_USE_FSDP", "False")) == 1 and strtobool(os.environ.get("FSDP_CPU_RAM_EFFICIENT_LOADING", "False")) == 1 ) class IdentityMap(nn.Module): def __init__(self): super().__init__() def forward(self, x, *args, **kwargs): return x @property def config(self): return {"mm_projector_type": "identity"} class SimpleResBlock(nn.Module): def __init__(self, channels): super().__init__() self.pre_norm = nn.LayerNorm(channels) self.proj = nn.Sequential( nn.Linear(channels, channels), nn.GELU(), nn.Linear(channels, channels) ) def forward(self, x): x = self.pre_norm(x) return x + self.proj(x) def build_vision_projector(config, start_hidden_size, delay_load=False, **kwargs): projector_type = "mlp2x_gelu" mlp_gelu_match = re.match(r"^mlp(\d+)x_gelu$", projector_type) if mlp_gelu_match: mlp_depth = int(mlp_gelu_match.group(1)) modules = [nn.Linear(start_hidden_size, config.hidden_size)] for _ in range(1, mlp_depth): modules.append(nn.GELU()) modules.append(nn.Linear(config.hidden_size, config.hidden_size)) return nn.Sequential(*modules) if projector_type == "identity": return IdentityMap() raise ValueError(f"Unknown projector type: {projector_type}") def get_token_slices(input_ids: torch.Tensor): """ Get slices of tokens based on special markers in the input tensor. Args: input_ids (torch.Tensor): A tensor of token IDs where IMAGE_TOKEN_INDEX represents an image token, DEFAULT_OBJECT_INDEX represents an object token, and all other values represent text tokens. Returns: List[Dict[str, Any]]: A list of dictionaries where each dictionary contains the type of the token slice ('text', 'image', 'object') and the span as a list of start and end indices. """ # define type markers and corresponding types type_map = {IMAGE_TOKEN_INDEX: "image", DEFAULT_OBJECT_INDEX: "object"} # find the positions of special markers image_indices = torch.where(input_ids == IMAGE_TOKEN_INDEX)[0] object_indices = torch.where(input_ids == DEFAULT_OBJECT_INDEX)[0] if len(object_indices) > 0: has_object = True else: has_object = False # merge all the positions of special markers special_indices = torch.cat((image_indices, object_indices)) special_indices, _ = torch.sort(special_indices) special_tokens = input_ids[special_indices] slices = [] start_idx = 0 for i, idx in enumerate(special_indices): if start_idx < idx: slices.append({"type": "text", "span": [start_idx, idx.item()]}) token_type = type_map[special_tokens[i].item()] slices.append({"type": token_type, "span": [idx.item(), idx.item() + 1]}) start_idx = idx.item() + 1 if start_idx < len(input_ids): slices.append({"type": "text", "span": [start_idx, len(input_ids)]}) return slices, has_object class StopWordStoppingCriteria(StoppingCriteria): """StopWord stopping criteria.""" def __init__(self, tokenizer, stop_word): self.tokenizer = tokenizer self.stop_word = stop_word self.length = len(self.stop_word) def __call__(self, input_ids, *args, **kwargs) -> bool: cur_text = self.tokenizer.decode(input_ids[0]) cur_text = cur_text.replace("\r", "").replace("\n", "") return cur_text[-self.length :] == self.stop_word def get_stop_criteria( tokenizer, stop_words=[], ): stop_criteria = StoppingCriteriaList() for word in stop_words: stop_criteria.append(StopWordStoppingCriteria(tokenizer, word)) return stop_criteria def gen_sineembed_for_position(pos_tensor, dim_of_pos_feats): """Generate sine position embedding from a position tensor. Args: pos_tensor (torch.Tensor): shape: [batch_size, N, 4]. the last dimension is [cx, cy, w, h] in normalized coordinates in range [0, 1]. out_dim (int): the output dimension of the position embedding. Returns: pos (torch.Tensor): shape: [batch_size, N, out_dim]. """ scale = 2 * math.pi dim_t = torch.arange( dim_of_pos_feats, dtype=torch.float32, device=pos_tensor.device ) dim_t = 10000 ** (2 * (dim_t // 2) / dim_of_pos_feats) x_embed = pos_tensor[:, :, 0] * scale y_embed = pos_tensor[:, :, 1] * scale pos_x = x_embed[:, :, None] / dim_t pos_y = y_embed[:, :, None] / dim_t pos_x = torch.stack( (pos_x[:, :, 0::2].sin(), pos_x[:, :, 1::2].cos()), dim=3 ).flatten(2) pos_y = torch.stack( (pos_y[:, :, 0::2].sin(), pos_y[:, :, 1::2].cos()), dim=3 ).flatten(2) if pos_tensor.size(-1) == 2: pos = torch.cat((pos_y, pos_x), dim=2) elif pos_tensor.size(-1) == 4: w_embed = pos_tensor[:, :, 2] * scale pos_w = w_embed[:, :, None] / dim_t pos_w = torch.stack( (pos_w[:, :, 0::2].sin(), pos_w[:, :, 1::2].cos()), dim=3 ).flatten(2) h_embed = pos_tensor[:, :, 3] * scale pos_h = h_embed[:, :, None] / dim_t pos_h = torch.stack( (pos_h[:, :, 0::2].sin(), pos_h[:, :, 1::2].cos()), dim=3 ).flatten(2) pos = torch.cat((pos_y, pos_x, pos_w, pos_h), dim=2) else: raise ValueError("Unknown pos_tensor shape(-1):{}".format(pos_tensor.size(-1))) return pos class MultiLevelROIVisualPrompt(nn.Module): """Initialize the MultiLevelROIVisualPrompt. Args: output_size (Optional[int]): The size of the output. Default is None. channel_per_level (List[int]): List of channels per level. Default is [192, 384, 768, 1536]. spatial_scale (Optional[float]): The spatial scale factor. Default is None. with_additional_projection (bool): Whether to use additional projection. Default is False. visual_prompt_hidden_size (int): The hidden size of the visual prompt. Default is 1024. add_pos_embedding (bool): Whether to add position embedding. Default is False. pos_embedding_dim (int): The dimension of the position embedding. Default is 1024. """ def __init__( self, output_size: int = None, channel_per_level: List[int] = [192, 384, 768, 1536], spatail_scale: float = None, add_pos_embedding: bool = False, pos_embedding_dim: int = 1024, ): super(MultiLevelROIVisualPrompt, self).__init__() self.output_size = output_size self.channel_per_level = channel_per_level self.spatail_scale = spatail_scale self.add_pos_embedding = add_pos_embedding self.pos_embedding_dim = pos_embedding_dim def __call__( self, multi_level_features: List[torch.Tensor], boxes: Union[torch.Tensor, List[torch.Tensor]], ) -> torch.Tensor: """Performs Region of Interest (RoI) Align operator on multi-level features. The RoI feature on each scale will go through a different linear layer for projection. Different RoI features will be summed up and then average pooled. Args: multi_level_features (Listp[Tensor[N, C, H, W]]): Feature maps from different levels boxes (Tensor[K, 5] or List[Tensor[L, 4]]): the box coordinates in (x1, y1, x2, y2) format where the regions will be taken from. Returns: Tensor[1, K, C]: The output tensor that has the shape KxC, where K is the number of RoIs """ boxes[0] = boxes[0].float() concat_multi_level_feature = [] max_height = max([feature.shape[2] for feature in multi_level_features]) max_width = max([feature.shape[3] for feature in multi_level_features]) # interpolate to the same size for level, feature in enumerate(multi_level_features): if level != 0: concat_multi_level_feature.append( F.interpolate( feature.float(), size=(max_height, max_width), mode="bilinear", align_corners=False, ) ) else: concat_multi_level_feature.append(feature.float()) concat_multi_level_feature = torch.cat(concat_multi_level_feature, dim=1) out_box_feat = roi_align( concat_multi_level_feature, boxes, output_size=self.output_size, spatial_scale=self.spatail_scale, ) # Average Pooling -> n,c -> 1,n,c out_box_feat = out_box_feat.mean(dim=(2, 3)).reshape( 1, out_box_feat.shape[0], out_box_feat.shape[1] ) if self.add_pos_embedding: # note that this boxes is in xyxy, unormalized format, so we need to normalize it first boxes = boxes[0] # (N, 4) boxes = boxes.to(out_box_feat.dtype) original_img_width = max_width / self.spatail_scale original_img_height = max_height / self.spatail_scale boxes[:, [0, 2]] = boxes[:, [0, 2]] / original_img_width boxes[:, [1, 3]] = boxes[:, [1, 3]] / original_img_height # convert from xyxy to cx, cy, w, h boxes[:, 2] = boxes[:, 2] - boxes[:, 0] boxes[:, 3] = boxes[:, 3] - boxes[:, 1] boxes[:, 0] = boxes[:, 0] + boxes[:, 2] / 2 boxes[:, 1] = boxes[:, 1] + boxes[:, 3] / 2 pos_embed = gen_sineembed_for_position( boxes.unsqueeze(0), self.pos_embedding_dim // 4 ) out_box_feat = out_box_feat + pos_embed return out_box_feat class RexSeekQwenConfig(Qwen2Config): model_type = "rexseek_qwen" class RexSeekQwenForCausalLM(Qwen2ForCausalLM): config_class = RexSeekQwenConfig def __init__(self, config): super().__init__(config) # low resolusion vision encoder vision_tower = getattr( config, "mm_vision_tower", getattr(config, "vision_tower", None), ) self.vision_tower = CLIPVisionTower( vision_tower, args=config, ) # high resolusion vision encoder self.vision_tower_aux = ConvNextVisionEncoder() # vision projector self.mm_projector = build_vision_projector( config, start_hidden_size=2560 ) # projector for vision_tower # projector for object token self.mm_object_projector = build_vision_projector( config, start_hidden_size=2880 ) # visual prompt encoder self.vocab_size = config.vocab_size self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) # Initialize weights and apply final processing self.box_encoder = MultiLevelROIVisualPrompt( output_size=7, channel_per_level=[192, 384, 768, 1536], # ConvNeXt Large spatail_scale=192 / 768, add_pos_embedding=True, pos_embedding_dim=2880, ) self.post_init() print("model initialized") def get_vision_tower(self): vision_tower = getattr(self, "vision_tower", None) if type(vision_tower) is list: vision_tower = vision_tower[0] return vision_tower def get_vision_tower_aux(self): vision_tower_aux = getattr(self, "vision_tower_aux", None) if type(vision_tower_aux) is list: vision_tower_aux = vision_tower_aux[0] return vision_tower_aux def get_model(self): return self.model def encode_images(self, images, images_aux): low_res_feat = self.get_vision_tower()(images) aux_output = self.get_vision_tower_aux()(images_aux) visual_outputs_aux = aux_output["image_features"] high_res_feat = aux_output["last_feat"] # (B, 1536, 24, 24) # concat the low res features with the high res features b, c, h, w = high_res_feat.shape # (2, 1536, 24, 24) _, _, d = low_res_feat.shape # (2, 576, 1024) high_res_feat = high_res_feat.view(b, c, h * w).transpose(1, 2) image_features = torch.cat((low_res_feat, high_res_feat), dim=-1) image_features = self.mm_projector(image_features) return image_features, visual_outputs_aux def encode_objects( self, bboxes, visual_outputs_aux, dtype, num_gt_boxes_per_image=None ): """Encode object features from bounding boxes. Args: bboxes (torch.Tensor): bounding boxes in the shape of (N, 4) image_features_before_proj (torch.Tensor): image features in the shape of (N, hidden_size) Returns: torch.Tensor: object features in the shape of (N, hidden_size) """ bbox_visual_outputs = [] for batch_idx, boxes in enumerate(bboxes): num_box = ( num_gt_boxes_per_image[batch_idx] if num_gt_boxes_per_image is not None else len(boxes) ) boxes = boxes[:num_box] if len(boxes) == 0: bbox_visual_outputs.append(None) continue multi_level_aux_features = [ visual_output_aux[batch_idx].unsqueeze(0) for visual_output_aux in visual_outputs_aux ] out_vp_feat = self.box_encoder( multi_level_aux_features, [boxes], ).squeeze(0) out_vp_feat = out_vp_feat.to(dtype) out_vp_feat = self.mm_object_projector(out_vp_feat) bbox_visual_outputs.append(out_vp_feat) # b,n,c return bbox_visual_outputs def prepare_inputs_labels_for_multimodal( self, input_ids, position_ids, attention_mask, past_key_values, labels, pixel_values=None, pixel_values_aux=None, gt_boxes=None, num_gt_boxes_per_image=None, ): if pixel_values is None: return ( input_ids, position_ids, attention_mask, past_key_values, None, labels, ) pixel_values, visual_outputs_aux = self.encode_images( pixel_values, pixel_values_aux ) # (B, 576, 2048) if gt_boxes is not None: bbox_feats = self.encode_objects( gt_boxes, visual_outputs_aux, pixel_values.dtype, num_gt_boxes_per_image ) _labels = labels _position_ids = position_ids _attention_mask = attention_mask if attention_mask is None: attention_mask = torch.ones_like(input_ids, dtype=torch.bool) else: attention_mask = attention_mask.bool() # padding mask in shaoe (B, L) if position_ids is None: position_ids = torch.arange( 0, input_ids.shape[1], dtype=torch.long, device=input_ids.device ) if labels is None: labels = torch.full_like(input_ids, IGNORE_INDEX) input_ids = [ cur_input_ids[cur_attention_mask] for cur_input_ids, cur_attention_mask in zip(input_ids, attention_mask) ] labels = [ cur_labels[cur_attention_mask] for cur_labels, cur_attention_mask in zip(labels, attention_mask) ] new_input_embeds = [] new_labels = [] cur_image_idx = 0 cur_object_idx = 0 for batch_idx, cur_input_ids in enumerate(input_ids): num_images = (cur_input_ids == IMAGE_TOKEN_INDEX).sum() if num_images == 0: cur_image_features = pixel_values[cur_image_idx] cur_input_embeds_1 = self.get_model().embed_tokens(cur_input_ids) cur_input_embeds = torch.cat( [cur_input_embeds_1, cur_image_features[0:0]], dim=0 ) new_input_embeds.append(cur_input_embeds) new_labels.append(labels[batch_idx]) cur_image_idx += 1 cur_object_idx += 1 continue cur_labels = labels[batch_idx] token_slices, has_object = get_token_slices(cur_input_ids) result_input_embeddings = [] result_output_labels = [] cur_gt_bnox_indice = 0 cur_object_features = None for slice in token_slices: slice_type = slice["type"] slice_span = slice["span"] if slice_type == "text": cur_input_ids_noim = cur_input_ids[slice_span[0] : slice_span[1]] cur_labels_noim = cur_labels[slice_span[0] : slice_span[1]] cur_input_embeds = self.get_model().embed_tokens(cur_input_ids_noim) result_input_embeddings.append(cur_input_embeds) result_output_labels.append(cur_labels_noim) elif slice_type == "image": cur_input_embeds = pixel_values[cur_image_idx] result_input_embeddings.append(cur_input_embeds) result_output_labels.append( torch.full( (cur_input_embeds.shape[0],), IGNORE_INDEX, device=cur_labels.device, dtype=cur_labels.dtype, ) ) cur_image_idx += 1 elif slice_type == "object": try: result_input_embeddings.append( bbox_feats[cur_object_idx][cur_gt_bnox_indice].unsqueeze(0) ) except: raise ValueError( f"current boxe_feats.shape: {bbox_feats[cur_object_idx].shape}, " ) cur_gt_bnox_indice += 1 result_output_labels.append( torch.full( (1,), IGNORE_INDEX, device=cur_labels.device, dtype=cur_labels.dtype, ) ) cur_object_idx += 1 result_input_embeddings = torch.cat(result_input_embeddings) result_output_labels = torch.cat(result_output_labels) assert len(result_output_labels) == len(result_input_embeddings) new_input_embeds.append(result_input_embeddings) new_labels.append(result_output_labels) # Truncate sequences to max length as image embeddings can make the sequence longer tokenizer_model_max_length = getattr( self.config, "tokenizer_model_max_length", None ) if tokenizer_model_max_length is not None: new_input_embeds = [ x[:tokenizer_model_max_length] for x in new_input_embeds ] new_labels = [x[:tokenizer_model_max_length] for x in new_labels] # Combine them max_len = max(x.shape[0] for x in new_input_embeds) batch_size = len(new_input_embeds) new_input_embeds_padded = [] new_labels_padded = torch.full( (batch_size, max_len), IGNORE_INDEX, dtype=new_labels[0].dtype, device=new_labels[0].device, ) attention_mask = torch.zeros( (batch_size, max_len), dtype=attention_mask.dtype, device=attention_mask.device, ) position_ids = torch.zeros( (batch_size, max_len), dtype=position_ids.dtype, device=position_ids.device ) for i, (cur_new_embed, cur_new_labels) in enumerate( zip(new_input_embeds, new_labels) ): cur_len = cur_new_embed.shape[0] new_input_embeds_padded.append( torch.cat( ( cur_new_embed, torch.zeros( (max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device, ), ), dim=0, ) ) if cur_len > 0: new_labels_padded[i, :cur_len] = cur_new_labels attention_mask[i, :cur_len] = True position_ids[i, :cur_len] = torch.arange( 0, cur_len, dtype=position_ids.dtype, device=position_ids.device ) new_input_embeds = torch.stack(new_input_embeds_padded, dim=0) if _labels is None: new_labels = None else: new_labels = new_labels_padded if _attention_mask is None: attention_mask = None else: attention_mask = attention_mask.to(dtype=_attention_mask.dtype) if _position_ids is None: position_ids = None return ( None, position_ids, attention_mask, past_key_values, new_input_embeds, new_labels, ) @torch.no_grad() def generate( self, inputs: Optional[torch.Tensor], pixel_values: Optional[torch.Tensor], pixel_values_aux: Optional[torch.Tensor], position_ids: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None, **kwargs, ) -> Union[GenerateOutput, torch.LongTensor]: if inputs_embeds is None: position_ids = kwargs.pop("position_ids", None) attention_mask = kwargs.pop("attention_mask", None) gt_boxes = kwargs.pop("gt_boxes", None) num_gt_boxes_per_image = kwargs.pop("num_gt_boxes_per_image", None) if pixel_values is not None: (inputs, position_ids, attention_mask, _, inputs_embeds, _) = ( self.prepare_inputs_labels_for_multimodal( inputs, position_ids, attention_mask, past_key_values=None, labels=None, pixel_values=pixel_values, pixel_values_aux=pixel_values_aux, gt_boxes=gt_boxes, num_gt_boxes_per_image=num_gt_boxes_per_image, ) ) else: inputs_embeds = self.get_model().embed_tokens(inputs) return super().generate( position_ids=position_ids, attention_mask=attention_mask, inputs_embeds=inputs_embeds, **kwargs, ) AutoConfig.register("rexseek_qwen", RexSeekQwenConfig) AutoModelForCausalLM.register(RexSeekQwenConfig, RexSeekQwenForCausalLM)