|
import json |
|
import logging |
|
import math |
|
import os |
|
from pathlib import Path |
|
from typing import Any, Dict, List, Optional, Tuple, Union |
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from open_clip.factory import get_model_config, load_state_dict |
|
from open_clip.model import (CLIPTextCfg, CLIPVisionCfg, _build_text_tower, |
|
_build_vision_tower, |
|
convert_to_custom_text_state_dict) |
|
from open_clip.transformer import text_global_pool |
|
from torch import nn |
|
from torchvision.ops import roi_align |
|
from transformers import (CONFIG_MAPPING, AutoConfig, AutoModel, |
|
AutoModelForCausalLM, GenerationConfig, |
|
PretrainedConfig, PreTrainedModel, StoppingCriteria, |
|
StoppingCriteriaList) |
|
from transformers.activations import ACT2FN |
|
from transformers.configuration_utils import PretrainedConfig |
|
from transformers.generation import GenerationConfig |
|
from transformers.modeling_utils import load_state_dict |
|
from transformers.utils import logging, strtobool |
|
|
|
from .convnext import ConvNextVisionEncoder |
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
XLA_USE_BF16 = os.environ.get("XLA_USE_BF16", "0").upper() |
|
XLA_DOWNCAST_BF16 = os.environ.get("XLA_DOWNCAST_BF16", "0").upper() |
|
|
|
IGNORE_INDEX = -100 |
|
DEFAULT_PAD_TOKEN_INDEX = 0 |
|
IMAGE_TOKEN_INDEX = -200 |
|
DEFAULT_IMAGE_TOKEN = "<image>" |
|
|
|
|
|
DEFAULT_OBJECT_TOKEN = "<obj<i>>" |
|
DEFAULT_OBJECT_FEATURE_TOKEN = "<objfeat>" |
|
DEFAULT_OBJECT_INDEX = -300 |
|
|
|
|
|
DEFAULT_GROUNDING_START = "<ground>" |
|
DEFAULT_GROUNDING_END = "</ground>" |
|
DEFAULT_GROUNDING_OBJECTS_START = "<objects>" |
|
DEFAULT_GROUNDING_OBJECTS_END = "</objects>" |
|
|
|
def is_fsdp_enabled(): |
|
return ( |
|
torch.distributed.is_available() |
|
and torch.distributed.is_initialized() |
|
and strtobool(os.environ.get("ACCELERATE_USE_FSDP", "False")) == 1 |
|
and strtobool(os.environ.get("FSDP_CPU_RAM_EFFICIENT_LOADING", "False")) == 1 |
|
) |
|
|
|
|
|
|
|
|
|
def get_token_slices(input_ids: torch.Tensor): |
|
""" |
|
Get slices of tokens based on special markers in the input tensor. |
|
|
|
Args: |
|
input_ids (torch.Tensor): A tensor of token IDs where IMAGE_TOKEN_INDEX represents an image token, |
|
DEFAULT_OBJECT_INDEX represents an object token, and all other values represent text tokens. |
|
|
|
Returns: |
|
List[Dict[str, Any]]: A list of dictionaries where each dictionary contains the type of the |
|
token slice ('text', 'image', 'object') and the span as a list of start and end indices. |
|
""" |
|
|
|
type_map = {IMAGE_TOKEN_INDEX: "image", DEFAULT_OBJECT_INDEX: "object"} |
|
|
|
|
|
image_indices = torch.where(input_ids == IMAGE_TOKEN_INDEX)[0] |
|
object_indices = torch.where(input_ids == DEFAULT_OBJECT_INDEX)[0] |
|
if len(object_indices) > 0: |
|
has_object = True |
|
else: |
|
has_object = False |
|
|
|
|
|
special_indices = torch.cat((image_indices, object_indices)) |
|
special_indices, _ = torch.sort(special_indices) |
|
special_tokens = input_ids[special_indices] |
|
|
|
slices = [] |
|
start_idx = 0 |
|
|
|
for i, idx in enumerate(special_indices): |
|
if start_idx < idx: |
|
slices.append({"type": "text", "span": [start_idx, idx.item()]}) |
|
token_type = type_map[special_tokens[i].item()] |
|
slices.append({"type": token_type, "span": [idx.item(), idx.item() + 1]}) |
|
start_idx = idx.item() + 1 |
|
|
|
if start_idx < len(input_ids): |
|
slices.append({"type": "text", "span": [start_idx, len(input_ids)]}) |
|
|
|
return slices, has_object |
|
|
|
|
|
def prepare_inputs_labels_for_multimodal( |
|
llm, |
|
input_ids: torch.LongTensor = None, |
|
position_ids: Optional[torch.LongTensor] = None, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
past_key_values: Optional[List[torch.FloatTensor]] = None, |
|
labels: Optional[torch.LongTensor] = None, |
|
pixel_values: Optional[torch.FloatTensor] = None, |
|
bbox_feats=None, |
|
extra_llm_input_embed: nn.Embedding = None, |
|
**kwargs, |
|
): |
|
if pixel_values is None: |
|
return { |
|
"input_ids": input_ids, |
|
"position_ids": position_ids, |
|
"attention_mask": attention_mask, |
|
"past_key_values": past_key_values, |
|
"inputs_embeds": None, |
|
"labels": labels, |
|
} |
|
|
|
_labels = labels |
|
_position_ids = position_ids |
|
_attention_mask = attention_mask |
|
if attention_mask is None: |
|
attention_mask = torch.ones_like(input_ids, dtype=torch.bool) |
|
else: |
|
attention_mask = attention_mask.bool() |
|
if position_ids is None: |
|
position_ids = torch.arange( |
|
0, input_ids.shape[1], dtype=torch.long, device=input_ids.device |
|
) |
|
if labels is None: |
|
labels = torch.full_like(input_ids, IGNORE_INDEX) |
|
|
|
|
|
input_ids = [ |
|
cur_input_ids[cur_attention_mask] |
|
for cur_input_ids, cur_attention_mask in zip(input_ids, attention_mask) |
|
] |
|
labels = [ |
|
cur_labels[cur_attention_mask] |
|
for cur_labels, cur_attention_mask in zip(labels, attention_mask) |
|
] |
|
|
|
new_inputs_embeds = [] |
|
new_labels = [] |
|
cur_image_idx = 0 |
|
cur_object_idx = 0 |
|
for batch_idx, cur_input_ids in enumerate(input_ids): |
|
num_images = (cur_input_ids == IMAGE_TOKEN_INDEX).sum() |
|
if num_images == 0: |
|
cur_pixel_values = pixel_values[cur_image_idx] |
|
cur_inputs_embeds_1 = llm.get_input_embeddings()(cur_input_ids) |
|
cur_inputs_embeds = torch.cat( |
|
[cur_inputs_embeds_1, cur_pixel_values[0:0]], dim=0 |
|
) |
|
new_inputs_embeds.append(cur_inputs_embeds) |
|
new_labels.append(labels[batch_idx]) |
|
cur_image_idx += 1 |
|
cur_object_idx += 1 |
|
continue |
|
|
|
cur_labels = labels[batch_idx] |
|
token_slices, has_object = get_token_slices(cur_input_ids) |
|
result_input_embeddings = [] |
|
result_output_labels = [] |
|
cur_gt_bnox_indice = 0 |
|
for slice in token_slices: |
|
slice_type = slice["type"] |
|
slice_span = slice["span"] |
|
if slice_type == "text": |
|
cur_input_ids_noim = cur_input_ids[slice_span[0] : slice_span[1]] |
|
cur_labels_noim = cur_labels[slice_span[0] : slice_span[1]] |
|
cur_input_embeds = llm.get_input_embeddings()(cur_input_ids_noim) |
|
result_input_embeddings.append(cur_input_embeds) |
|
result_output_labels.append(cur_labels_noim) |
|
elif slice_type == "image": |
|
cur_input_embeds = pixel_values[cur_image_idx] |
|
result_input_embeddings.append(cur_input_embeds) |
|
result_output_labels.append( |
|
torch.full( |
|
(cur_input_embeds.shape[0],), |
|
IGNORE_INDEX, |
|
device=cur_labels.device, |
|
dtype=cur_labels.dtype, |
|
) |
|
) |
|
cur_image_idx += 1 |
|
elif slice_type == "object": |
|
try: |
|
result_input_embeddings.append( |
|
bbox_feats[cur_object_idx][cur_gt_bnox_indice].unsqueeze(0) |
|
) |
|
except: |
|
raise ValueError( |
|
f"current boxe_feats.shape: {bbox_feats[cur_object_idx].shape}, " |
|
) |
|
cur_gt_bnox_indice += 1 |
|
result_output_labels.append( |
|
torch.full( |
|
(1,), |
|
IGNORE_INDEX, |
|
device=cur_labels.device, |
|
dtype=cur_labels.dtype, |
|
) |
|
) |
|
cur_object_idx += 1 |
|
result_input_embeddings = torch.cat(result_input_embeddings) |
|
result_output_labels = torch.cat(result_output_labels) |
|
assert len(result_output_labels) == len(result_input_embeddings) |
|
new_inputs_embeds.append(result_input_embeddings) |
|
new_labels.append(result_output_labels) |
|
|
|
|
|
max_len = max(x.shape[0] for x in new_inputs_embeds) |
|
batch_size = len(new_inputs_embeds) |
|
|
|
new_inputs_embeds_padded = [] |
|
new_labels_padded = torch.full( |
|
(batch_size, max_len), |
|
IGNORE_INDEX, |
|
dtype=new_labels[0].dtype, |
|
device=new_labels[0].device, |
|
) |
|
attention_mask = torch.zeros( |
|
(batch_size, max_len), dtype=attention_mask.dtype, device=attention_mask.device |
|
) |
|
position_ids = torch.zeros( |
|
(batch_size, max_len), dtype=position_ids.dtype, device=position_ids.device |
|
) |
|
|
|
for i, (cur_new_embed, cur_new_labels) in enumerate( |
|
zip(new_inputs_embeds, new_labels) |
|
): |
|
cur_len = cur_new_embed.shape[0] |
|
new_inputs_embeds_padded.append( |
|
torch.cat( |
|
( |
|
cur_new_embed, |
|
torch.zeros( |
|
(max_len - cur_len, cur_new_embed.shape[1]), |
|
dtype=cur_new_embed.dtype, |
|
device=cur_new_embed.device, |
|
), |
|
), |
|
dim=0, |
|
) |
|
) |
|
if cur_len > 0: |
|
new_labels_padded[i, :cur_len] = cur_new_labels |
|
attention_mask[i, :cur_len] = True |
|
position_ids[i, :cur_len] = torch.arange( |
|
0, cur_len, dtype=position_ids.dtype, device=position_ids.device |
|
) |
|
|
|
new_inputs_embeds = torch.stack(new_inputs_embeds_padded, dim=0) |
|
|
|
if _labels is None: |
|
new_labels = None |
|
else: |
|
new_labels = new_labels_padded |
|
|
|
if _attention_mask is None: |
|
attention_mask = None |
|
else: |
|
attention_mask = attention_mask.to(dtype=_attention_mask.dtype) |
|
|
|
if _position_ids is None: |
|
position_ids = None |
|
|
|
return { |
|
"input_ids": None, |
|
"position_ids": position_ids, |
|
"attention_mask": attention_mask, |
|
"past_key_values": past_key_values, |
|
"inputs_embeds": new_inputs_embeds, |
|
"labels": new_labels, |
|
} |
|
|
|
class StopWordStoppingCriteria(StoppingCriteria): |
|
"""StopWord stopping criteria.""" |
|
|
|
def __init__(self, tokenizer, stop_word): |
|
self.tokenizer = tokenizer |
|
self.stop_word = stop_word |
|
self.length = len(self.stop_word) |
|
|
|
def __call__(self, input_ids, *args, **kwargs) -> bool: |
|
cur_text = self.tokenizer.decode(input_ids[0]) |
|
cur_text = cur_text.replace('\r', '').replace('\n', '') |
|
return cur_text[-self.length:] == self.stop_word |
|
|
|
def get_stop_criteria( |
|
tokenizer, |
|
stop_words=[], |
|
): |
|
stop_criteria = StoppingCriteriaList() |
|
for word in stop_words: |
|
stop_criteria.append(StopWordStoppingCriteria(tokenizer, word)) |
|
return stop_criteria |
|
|
|
class DualPathFuseModule(nn.Module): |
|
|
|
def __init__(self, low_res_dim, high_res_dim, zero_init=True): |
|
super().__init__() |
|
|
|
self.slow_conv = nn.Conv2d(high_res_dim, high_res_dim, 1) |
|
self.slow_proj = nn.Conv2d(high_res_dim, low_res_dim, 1) |
|
|
|
self.fast_conv = nn.Conv2d( |
|
low_res_dim, low_res_dim, 7, padding=3, groups=low_res_dim |
|
) |
|
self.fast_proj = nn.Conv2d(low_res_dim, low_res_dim, 1) |
|
|
|
self.gate = nn.Sequential( |
|
nn.Linear(low_res_dim * 2, low_res_dim // 2), |
|
nn.GELU(), |
|
nn.Linear(low_res_dim // 2, 1), |
|
) |
|
|
|
nn.init.xavier_uniform_(self.slow_conv.weight) |
|
nn.init.xavier_uniform_(self.fast_conv.weight) |
|
nn.init.zeros_(self.slow_conv.bias) |
|
nn.init.zeros_(self.fast_conv.bias) |
|
if zero_init: |
|
nn.init.zeros_(self.slow_proj.weight) |
|
nn.init.zeros_(self.fast_proj.weight) |
|
else: |
|
nn.init.xavier_uniform_(self.slow_proj.weight) |
|
nn.init.xavier_uniform_(self.fast_proj.weight) |
|
nn.init.zeros_(self.slow_proj.bias) |
|
nn.init.zeros_(self.fast_proj.bias) |
|
|
|
def forward(self, low_res_feat, high_res_feat, sampler=None): |
|
b, c, h, w = high_res_feat.shape |
|
_, _, d = low_res_feat.shape |
|
high_res_feat = self.slow_proj( |
|
F.gelu(self.slow_conv(high_res_feat)) |
|
) |
|
high_res_feat = high_res_feat.view(b, d, -1).transpose(1, 2) |
|
dst_size = int(math.sqrt(low_res_feat.shape[1])) |
|
low_res_feat = low_res_feat.transpose(1, 2).view( |
|
b, d, dst_size, dst_size |
|
) |
|
low_res_feat = low_res_feat + self.fast_proj( |
|
F.gelu(self.fast_conv(low_res_feat)) |
|
) |
|
low_res_feat = low_res_feat.view(b, d, dst_size * dst_size).transpose( |
|
1, 2 |
|
) |
|
gate = self.gate( |
|
torch.cat([low_res_feat, high_res_feat], -1).mean(1) |
|
).unsqueeze( |
|
1 |
|
) |
|
low_res_feat = low_res_feat + high_res_feat * gate.tanh() |
|
return low_res_feat |
|
|
|
class ProjectorConfig(PretrainedConfig): |
|
model_type = "projector" |
|
_auto_class = "AutoConfig" |
|
|
|
def __init__( |
|
self, |
|
visual_hidden_size=4096, |
|
llm_hidden_size=4096, |
|
depth=2, |
|
hidden_act="gelu", |
|
bias=True, |
|
**kwargs, |
|
): |
|
self.visual_hidden_size = visual_hidden_size |
|
self.llm_hidden_size = llm_hidden_size |
|
self.depth = depth |
|
self.hidden_act = hidden_act |
|
self.bias = bias |
|
super().__init__(**kwargs) |
|
|
|
class ProjectorModel(PreTrainedModel): |
|
_auto_class = "AutoModel" |
|
config_class = ProjectorConfig |
|
base_model_prefix = "model" |
|
supports_gradient_checkpointing = True |
|
_no_split_modules = [] |
|
|
|
def __init__(self, config: ProjectorConfig) -> None: |
|
super().__init__(config) |
|
self.gradient_checkpointing = False |
|
|
|
modules = [ |
|
nn.Linear( |
|
config.visual_hidden_size, config.llm_hidden_size, bias=config.bias |
|
) |
|
] |
|
for _ in range(1, config.depth): |
|
modules.append(ACT2FN[config.hidden_act]) |
|
modules.append( |
|
nn.Linear( |
|
config.llm_hidden_size, config.llm_hidden_size, bias=config.bias |
|
) |
|
) |
|
self.model = nn.Sequential(*modules) |
|
|
|
def enable_input_require_grads(self): |
|
|
|
def make_inputs_require_grad(module, input, output): |
|
output.requires_grad_(True) |
|
|
|
self.model.register_forward_hook(make_inputs_require_grad) |
|
|
|
def _set_gradient_checkpointing(self, module, value=False): |
|
if isinstance(module, ProjectorModel): |
|
module.gradient_checkpointing = value |
|
|
|
def forward(self, x): |
|
layer_outputs = self.model(x) |
|
return layer_outputs |
|
|
|
|
|
def gen_sineembed_for_position(pos_tensor, dim_of_pos_feats): |
|
"""Generate sine position embedding from a position tensor. |
|
|
|
Args: |
|
pos_tensor (torch.Tensor): shape: [batch_size, N, 4]. the last dimension is [cx, cy, w, h] in |
|
normalized coordinates in range [0, 1]. |
|
out_dim (int): the output dimension of the position embedding. |
|
|
|
Returns: |
|
pos (torch.Tensor): shape: [batch_size, N, out_dim]. |
|
""" |
|
scale = 2 * math.pi |
|
dim_t = torch.arange( |
|
dim_of_pos_feats, dtype=torch.float32, device=pos_tensor.device |
|
) |
|
dim_t = 10000 ** (2 * (dim_t // 2) / dim_of_pos_feats) |
|
x_embed = pos_tensor[:, :, 0] * scale |
|
y_embed = pos_tensor[:, :, 1] * scale |
|
pos_x = x_embed[:, :, None] / dim_t |
|
pos_y = y_embed[:, :, None] / dim_t |
|
pos_x = torch.stack( |
|
(pos_x[:, :, 0::2].sin(), pos_x[:, :, 1::2].cos()), dim=3 |
|
).flatten(2) |
|
pos_y = torch.stack( |
|
(pos_y[:, :, 0::2].sin(), pos_y[:, :, 1::2].cos()), dim=3 |
|
).flatten(2) |
|
if pos_tensor.size(-1) == 2: |
|
pos = torch.cat((pos_y, pos_x), dim=2) |
|
elif pos_tensor.size(-1) == 4: |
|
w_embed = pos_tensor[:, :, 2] * scale |
|
pos_w = w_embed[:, :, None] / dim_t |
|
pos_w = torch.stack( |
|
(pos_w[:, :, 0::2].sin(), pos_w[:, :, 1::2].cos()), dim=3 |
|
).flatten(2) |
|
|
|
h_embed = pos_tensor[:, :, 3] * scale |
|
pos_h = h_embed[:, :, None] / dim_t |
|
pos_h = torch.stack( |
|
(pos_h[:, :, 0::2].sin(), pos_h[:, :, 1::2].cos()), dim=3 |
|
).flatten(2) |
|
|
|
pos = torch.cat((pos_y, pos_x, pos_w, pos_h), dim=2) |
|
else: |
|
raise ValueError("Unknown pos_tensor shape(-1):{}".format(pos_tensor.size(-1))) |
|
return pos |
|
|
|
|
|
class MultiLevelROIVisualPrompt(nn.Module): |
|
"""Initialize the MultiLevelROIVisualPrompt. |
|
|
|
Args: |
|
output_size (Optional[int]): The size of the output. Default is None. |
|
channel_per_level (List[int]): List of channels per level. Default is [192, 384, 768, 1536]. |
|
spatial_scale (Optional[float]): The spatial scale factor. Default is None. |
|
with_additional_projection (bool): Whether to use additional projection. Default is False. |
|
visual_prompt_hidden_size (int): The hidden size of the visual prompt. Default is 1024. |
|
add_pos_embedding (bool): Whether to add position embedding. Default is False. |
|
pos_embedding_dim (int): The dimension of the position embedding. Default is 1024. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
output_size: int = None, |
|
channel_per_level: List[int] = [192, 384, 768, 1536], |
|
spatail_scale: float = None, |
|
visual_prompt_hidden_size: bool = 1024, |
|
add_pos_embedding: bool = False, |
|
pos_embedding_dim: int = 1024, |
|
): |
|
super(MultiLevelROIVisualPrompt, self).__init__() |
|
self.output_size = output_size |
|
self.channel_per_level = channel_per_level |
|
self.spatail_scale = spatail_scale |
|
self.add_pos_embedding = add_pos_embedding |
|
self.pos_embedding_dim = pos_embedding_dim |
|
|
|
def __call__( |
|
self, |
|
multi_level_features: List[torch.Tensor], |
|
boxes: Union[torch.Tensor, List[torch.Tensor]], |
|
) -> torch.Tensor: |
|
"""Performs Region of Interest (RoI) Align operator on multi-level features. The RoI |
|
feature on each scale will go through a different linear layer for projection. Different |
|
RoI features will be summed up and then average pooled. |
|
|
|
Args: |
|
multi_level_features (Listp[Tensor[N, C, H, W]]): Feature maps from different levels |
|
boxes (Tensor[K, 5] or List[Tensor[L, 4]]): the box coordinates in (x1, y1, x2, y2) |
|
format where the regions will be taken from. |
|
Returns: |
|
Tensor[1, K, C]: The output tensor that has the shape KxC, where K is the number of RoIs |
|
""" |
|
boxes[0] = boxes[0].float() |
|
concat_multi_level_feature = [] |
|
max_height = max([feature.shape[2] for feature in multi_level_features]) |
|
max_width = max([feature.shape[3] for feature in multi_level_features]) |
|
|
|
for level, feature in enumerate(multi_level_features): |
|
if level != 0: |
|
concat_multi_level_feature.append( |
|
F.interpolate( |
|
feature.float(), |
|
size=(max_height, max_width), |
|
mode="bilinear", |
|
align_corners=False, |
|
) |
|
) |
|
else: |
|
concat_multi_level_feature.append(feature.float()) |
|
concat_multi_level_feature = torch.cat(concat_multi_level_feature, dim=1) |
|
|
|
|
|
out_box_feat = roi_align( |
|
concat_multi_level_feature, |
|
boxes, |
|
output_size=self.output_size, |
|
spatial_scale=self.spatail_scale, |
|
) |
|
|
|
|
|
out_box_feat = out_box_feat.mean(dim=(2, 3)).reshape( |
|
1, out_box_feat.shape[0], out_box_feat.shape[1] |
|
) |
|
if self.add_pos_embedding: |
|
|
|
boxes = boxes[0] |
|
boxes = boxes.to(out_box_feat.dtype) |
|
original_img_width = max_width / self.spatail_scale |
|
original_img_height = max_height / self.spatail_scale |
|
boxes[:, [0, 2]] = boxes[:, [0, 2]] / original_img_width |
|
boxes[:, [1, 3]] = boxes[:, [1, 3]] / original_img_height |
|
|
|
boxes[:, 2] = boxes[:, 2] - boxes[:, 0] |
|
boxes[:, 3] = boxes[:, 3] - boxes[:, 1] |
|
boxes[:, 0] = boxes[:, 0] + boxes[:, 2] / 2 |
|
boxes[:, 1] = boxes[:, 1] + boxes[:, 3] / 2 |
|
pos_embed = gen_sineembed_for_position( |
|
boxes.unsqueeze(0), self.pos_embedding_dim // 4 |
|
) |
|
out_box_feat = out_box_feat + pos_embed |
|
|
|
return out_box_feat |
|
|
|
|
|
|
|
class ChatRexAuxConfig(PretrainedConfig): |
|
r""" |
|
This is the configuration class to store the configuration of ChatRexAux model. |
|
|
|
|
|
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the |
|
documentation from [`PretrainedConfig`] for more information. |
|
|
|
Args: |
|
vision_config (`Union[AutoConfig, dict]`, *optional*, defaults to `CLIPVisionConfig`): |
|
The config object or dictionary of the vision backbone. |
|
vision_aux_config (`Union[AutoConfig, dict]`, *optional*, defaults to `OpenCLIPVisionTower`): |
|
visual_prompt_encoder (`Union[AutoConfig, dict]`, *optional*, defaults to `MultiLevelROIVisualPrompt`): |
|
text_config (`Union[AutoConfig, dict]`, *optional*, defaults to `LlamaConfig`): |
|
The config object or dictionary of the text backbone. |
|
ignore_index (`int`, *optional*, defaults to -100): |
|
The ignore index for the loss function. |
|
image_token_index (`int`, *optional*, defaults to 32000): |
|
The image token index to encode the image prompt. |
|
projector_hidden_act (`str`, *optional*, defaults to `"gelu"`): |
|
The activation function used by the multimodal projector. |
|
vision_feature_select_strategy (`str`, *optional*, defaults to `"default"`): |
|
The feature selection strategy used to select the vision feature from the vision backbone. |
|
Can be one of `"default"` or `"full"`. |
|
vision_feature_layer (`int`, *optional*, defaults to -2): |
|
The index of the layer to select the vision feature. |
|
|
|
Example: |
|
|
|
```python |
|
>>> from transformers import LlavaForConditionalGeneration, LlavaConfig, CLIPVisionConfig, LlamaConfig |
|
|
|
>>> # Initializing a CLIP-vision config |
|
>>> vision_config = CLIPVisionConfig() |
|
|
|
>>> # Initializing a Llama config |
|
>>> text_config = LlamaConfig() |
|
|
|
>>> # Initializing a Llava llava-1.5-7b style configuration |
|
>>> configuration = LlavaConfig(vision_config, text_config) |
|
|
|
>>> # Initializing a model from the llava-1.5-7b style configuration |
|
>>> model = LlavaForConditionalGeneration(configuration) |
|
|
|
>>> # Accessing the model configuration |
|
>>> configuration = model.config |
|
```""" |
|
|
|
model_type = "chatrex" |
|
is_composition = False |
|
|
|
def __init__( |
|
self, |
|
vision_config=None, |
|
vision_aux_config=None, |
|
visual_prompt_encoder_config=None, |
|
text_config=None, |
|
ignore_index=-100, |
|
image_token_index=32000, |
|
projector_hidden_act="gelu", |
|
vision_feature_select_strategy="default", |
|
vision_feature_layer=-2, |
|
projector_depth=2, |
|
visual_prompt_hidden_size=2880, |
|
**kwargs, |
|
): |
|
self.ignore_index = ignore_index |
|
self.image_token_index = image_token_index |
|
self.projector_hidden_act = projector_hidden_act |
|
self.projector_depth = projector_depth |
|
self.visual_prompt_hidden_size = visual_prompt_hidden_size |
|
self.visual_prompt_encoder_config = visual_prompt_encoder_config |
|
|
|
if vision_feature_select_strategy not in ["default", "full"]: |
|
raise ValueError( |
|
"vision_feature_select_strategy should be one of 'default', 'full'." |
|
f"Got: {vision_feature_select_strategy}" |
|
) |
|
|
|
self.vision_feature_select_strategy = vision_feature_select_strategy |
|
self.vision_feature_layer = vision_feature_layer |
|
|
|
if isinstance(vision_config, dict): |
|
vision_config["model_type"] = ( |
|
vision_config["model_type"] |
|
if "model_type" in vision_config |
|
else "clip_vision_model" |
|
) |
|
vision_config = CONFIG_MAPPING[vision_config["model_type"]](**vision_config) |
|
elif vision_config is None: |
|
vision_config = CONFIG_MAPPING["clip_vision_model"]( |
|
intermediate_size=4096, |
|
hidden_size=1024, |
|
patch_size=14, |
|
image_size=336, |
|
num_hidden_layers=24, |
|
num_attention_heads=16, |
|
vocab_size=32000, |
|
projection_dim=768, |
|
) |
|
|
|
self.vision_config = vision_config |
|
self.vision_aux_config = vision_aux_config |
|
|
|
if isinstance(text_config, dict): |
|
text_config["model_type"] = ( |
|
text_config["model_type"] if "model_type" in text_config else "llama" |
|
) |
|
text_config = CONFIG_MAPPING[text_config["model_type"]](**text_config) |
|
elif text_config is None: |
|
text_config = CONFIG_MAPPING["llama"]() |
|
|
|
self.text_config = text_config |
|
|
|
super().__init__(**kwargs) |
|
|
|
|
|
class ChatRexAuxPreTrainedModel(PreTrainedModel): |
|
config_class = ChatRexAuxConfig |
|
base_model_prefix = "model" |
|
supports_gradient_checkpointing = True |
|
_no_split_modules = ["LlavaVisionAttention"] |
|
_skip_keys_device_placement = "past_key_values" |
|
_supports_flash_attn_2 = True |
|
_supports_cache_class = True |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@property |
|
def _supports_sdpa(self): |
|
""" |
|
Retrieve language_model's attribute to check whether the model supports |
|
SDPA or not. |
|
""" |
|
return self.language_model._supports_sdpa |
|
|
|
|
|
class ChatRexAuxForConditionalGeneration(ChatRexAuxPreTrainedModel): |
|
|
|
def __init__(self, config: ChatRexAuxConfig): |
|
super().__init__(config) |
|
|
|
self.vision_encoder = AutoModel.from_config(config.vision_config) |
|
|
|
self.vision_encoder_aux = ConvNextVisionEncoder() |
|
|
|
|
|
projector_config = ProjectorConfig( |
|
visual_hidden_size=config.vision_config.hidden_size, |
|
llm_hidden_size=config.text_config.hidden_size, |
|
depth=config.projector_depth, |
|
) |
|
self.projector = ProjectorModel(projector_config) |
|
|
|
|
|
vp_projector_config = ProjectorConfig( |
|
visual_hidden_size=config.visual_prompt_hidden_size, |
|
llm_hidden_size=config.text_config.hidden_size, |
|
depth=config.projector_depth, |
|
) |
|
self.vp_projector = ProjectorModel(vp_projector_config) |
|
|
|
|
|
self.fuser = DualPathFuseModule( |
|
low_res_dim=config.vision_config.hidden_size, |
|
high_res_dim=1536, |
|
) |
|
|
|
|
|
self.vp_encoder = MultiLevelROIVisualPrompt( |
|
output_size=7, |
|
channel_per_level=[192, 384, 768, 1536], |
|
spatail_scale=192 / 768, |
|
add_pos_embedding=True, |
|
pos_embedding_dim=2880, |
|
) |
|
|
|
|
|
self.gen_config = None |
|
|
|
self.vocab_size = config.text_config.vocab_size |
|
self.llm = AutoModelForCausalLM.from_config( |
|
config.text_config, attn_implementation=config._attn_implementation |
|
) |
|
self.pad_token_id = ( |
|
self.config.pad_token_id if self.config.pad_token_id is not None else -1 |
|
) |
|
self.post_init() |
|
|
|
|
|
def _prepare_data_for_llm(self, data): |
|
if "pixel_values" in data: |
|
visual_outputs = self.vision_encoder( |
|
data["pixel_values"].to(self.vision_encoder.dtype), |
|
output_hidden_states=True, |
|
) |
|
if type(self.vision_encoder).__name__ in [ |
|
"CLIPVisionModel", |
|
"CLIPVisionModelAnyRes", |
|
]: |
|
visual_outputs = visual_outputs.hidden_states[-2][ |
|
:, 1: |
|
] |
|
elif type(self.vision_encoder).__name__ == "SiglipVisionModel": |
|
visual_outputs = visual_outputs.hidden_states[-2] |
|
else: |
|
raise NotImplementedError |
|
|
|
|
|
if self.vision_encoder_aux is not None: |
|
pixels_aux = [] |
|
for pixels in data["pixel_values_aux"]: |
|
if pixels.dim() == 3: |
|
pixels = pixels.unsqueeze(0) |
|
elif pixels.dim() == 4: |
|
pixels = pixels.permute(1, 0, 2, 3) |
|
pixels_aux.append(pixels) |
|
visual_outputs_aux = torch.cat( |
|
pixels_aux, dim=0 |
|
) |
|
aux_output = self.vision_encoder_aux( |
|
visual_outputs_aux |
|
) |
|
visual_outputs_aux = aux_output["image_features"] |
|
last_feat = aux_output["last_feat"] |
|
|
|
fuse_features = self.fuser( |
|
low_res_feat=visual_outputs, high_res_feat=last_feat |
|
) |
|
pixel_values = self.projector(fuse_features) |
|
data["pixel_values"] = pixel_values |
|
|
|
|
|
bbox_visual_outputs = [] |
|
if "gt_boxes" in data: |
|
for batch_idx, boxes in enumerate(data["gt_boxes"]): |
|
if len(boxes) == 0: |
|
bbox_visual_outputs.append(None) |
|
continue |
|
multi_level_aux_features = [ |
|
visual_output_aux[batch_idx].unsqueeze(0) |
|
for visual_output_aux in visual_outputs_aux |
|
] |
|
boxes = boxes.to(torch.float32) |
|
out_vp_feat = self.vp_encoder( |
|
multi_level_aux_features, |
|
[boxes], |
|
).squeeze(0) |
|
out_vp_feat = out_vp_feat.to(pixel_values.dtype) |
|
out_vp_feat = self.vp_projector(out_vp_feat) |
|
bbox_visual_outputs.append(out_vp_feat) |
|
|
|
data["bbox_feats"] = bbox_visual_outputs |
|
|
|
data = prepare_inputs_labels_for_multimodal(llm=self.llm, **data) |
|
return data |
|
|
|
|
|
def generate(self, data_dict: Dict[str, Any], gen_config=None, tokenizer=None): |
|
"""Perform inference on the given data. |
|
|
|
Args: |
|
data_dict (Dict[str, Any]): The data to perform inference on. |
|
|
|
Returns: |
|
str: The answer to the question. |
|
""" |
|
data_dict = self._prepare_data_for_llm(data_dict) |
|
data_dict["inputs_embeds"] = data_dict["inputs_embeds"].to(self.llm.dtype) |
|
stop_criteria = get_stop_criteria( |
|
tokenizer=tokenizer, stop_words=[] |
|
) |
|
generate_output = self.llm.generate( |
|
**data_dict, |
|
generation_config=self.gen_config if gen_config is None else gen_config, |
|
streamer=None, |
|
bos_token_id=tokenizer.bos_token_id, |
|
stopping_criteria=stop_criteria, |
|
) |
|
print(f'generate_output:', generate_output) |
|
prediction = tokenizer.decode( |
|
generate_output[0], skip_special_tokens=False |
|
).strip() |
|
prediction = prediction.replace("<s>", "").replace("</s>", "").strip() |
|
return prediction |
|
|
|
|
|
AutoConfig.register("chatrex", ChatRexAuxConfig) |
|
AutoModelForCausalLM.register(ChatRexAuxConfig, ChatRexAuxForConditionalGeneration) |
|
|