import math import os import warnings from dataclasses import dataclass from functools import lru_cache, partial from pathlib import Path from typing import Dict, List, Optional, Tuple, Union import torch import torch.nn.functional as F from torch import Tensor, nn from torch.autograd import Function from torch.autograd.function import once_differentiable from transformers.activations import ACT2CLS, ACT2FN from transformers.image_transforms import center_to_corners_format, corners_to_center_format from transformers.modeling_outputs import BaseModelOutput from transformers.modeling_utils import PreTrainedModel from transformers.utils import ( ModelOutput, add_start_docstrings, add_start_docstrings_to_model_forward, is_ninja_available, is_scipy_available, is_torch_cuda_available, logging, replace_return_docstrings, requires_backends, ) from transformers.models.rt_detr.configuration_rt_detr_resnet import RTDetrResNetConfig from transformers.models.rt_detr.modeling_rt_detr import ( RTDetrConfig, RTDetrDecoderOutput, RTDetrModelOutput, RTDetrObjectDetectionOutput, RTDetrFrozenBatchNorm2d, RTDetrConvEncoder, RTDetrConvNormLayer, RTDetrEncoderLayer, RTDetrRepVggBlock, RTDetrCSPRepLayer, RTDetrMultiscaleDeformableAttention, RTDetrMultiheadAttention, RTDetrDecoderLayer, RTDetrPreTrainedModel, RTDetrEncoder, RTDetrHybridEncoder, RTDetrDecoder, RTDetrModel, RTDetrMLPPredictionHead, RTDetrForObjectDetection ) from transformers.loss.loss_rt_detr import (RTDetrLoss, RTDetrHungarianMatcher) from transformers.utils.backbone_utils import load_backbone # from .configuration_rt_detr_v2 import RTDetrV2Config TODO define the config class RTDetrV2Config(RTDetrConfig): model_type = "rt_detr_v2" # Update the model type def __init__( self, decoder_n_levels=3, decoder_offset_scale=0.5, **kwargs ): super().__init__(**kwargs) self.decoder_n_levels = decoder_n_levels self.decoder_offset_scale = decoder_offset_scale class RTDetrV2ResNetConfig(RTDetrResNetConfig): model_type = "rt_detr_v2_resnet" logger = logging.get_logger(__name__) class RTDetrV2DecoderOutput(RTDetrDecoderOutput): pass class RTDetrV2ModelOutput(RTDetrModelOutput): pass class RTDetrV2ObjectDetectionOutput(RTDetrObjectDetectionOutput): pass class RTDetrV2FrozenBatchNorm2d(RTDetrFrozenBatchNorm2d): pass class RTDetrV2ConvEncoder(RTDetrConvEncoder): pass class RTDetrV2ConvNormLayer(RTDetrConvNormLayer): pass class RTDetrV2EncoderLayer(RTDetrEncoderLayer): pass class RTDetrV2RepVggBlock(RTDetrRepVggBlock): pass class RTDetrV2CSPRepLayer(RTDetrCSPRepLayer): pass # new implementaiton of the multiscale deformable attention (v2) def multi_scale_deformable_attention_v2( value: Tensor, value_spatial_shapes: Tensor, sampling_locations: Tensor, attention_weights: Tensor, num_points_list: List[int], method="default", ) -> Tensor: batch_size, _, num_heads, hidden_dim = value.shape _, num_queries, num_heads, num_levels, num_points = sampling_locations.shape value_list = ( value.permute(0, 2, 3, 1) .flatten(0, 1) .split([height.item() * width.item() for height, width in value_spatial_shapes], dim=-1) ) # sampling_offsets [8, 480, 8, 12, 2] if method == "default": sampling_grids = 2 * sampling_locations - 1 elif method == "discrete": sampling_grids = sampling_locations sampling_grids = sampling_grids.permute(0, 2, 1, 3, 4).flatten(0, 1) sampling_grids = sampling_grids.split(num_points_list, dim=-2) sampling_value_list = [] for level_id, (height, width) in enumerate(value_spatial_shapes): # batch_size, height*width, num_heads, hidden_dim # -> batch_size, height*width, num_heads*hidden_dim # -> batch_size, num_heads*hidden_dim, height*width # -> batch_size*num_heads, hidden_dim, height, width value_l_ = value_list[level_id].reshape(batch_size * num_heads, hidden_dim, height, width) # batch_size, num_queries, num_heads, num_points, 2 # -> batch_size, num_heads, num_queries, num_points, 2 # -> batch_size*num_heads, num_queries, num_points, 2 sampling_grid_l_ = sampling_grids[level_id] # batch_size*num_heads, hidden_dim, num_queries, num_points if method == "default": sampling_value_l_ = nn.functional.grid_sample( value_l_, sampling_grid_l_, mode="bilinear", padding_mode="zeros", align_corners=False ) elif method == "discrete": sampling_coord = (sampling_grid_l_ * torch.tensor([[width, height]], device=value.device) + 0.5).to( torch.int64 ) # Separate clamping for x and y coordinates sampling_coord_x = sampling_coord[..., 0].clamp(0, width - 1) sampling_coord_y = sampling_coord[..., 1].clamp(0, height - 1) # Combine the clamped coordinates sampling_coord = torch.stack([sampling_coord_x, sampling_coord_y], dim=-1) sampling_coord = sampling_coord.reshape(batch_size * num_heads, num_queries * num_points_list[level_id], 2) sampling_idx = ( torch.arange(sampling_coord.shape[0], device=value.device) .unsqueeze(-1) .repeat(1, sampling_coord.shape[1]) ) sampling_value_l_ = value_l_[sampling_idx, :, sampling_coord[..., 1], sampling_coord[..., 0]] sampling_value_l_ = sampling_value_l_.permute(0, 2, 1).reshape( batch_size * num_heads, hidden_dim, num_queries, num_points_list[level_id] ) sampling_value_list.append(sampling_value_l_) # (batch_size, num_queries, num_heads, num_levels, num_points) # -> (batch_size, num_heads, num_queries, num_levels, num_points) # -> (batch_size, num_heads, 1, num_queries, num_levels*num_points) attention_weights = attention_weights.permute(0, 2, 1, 3).reshape( batch_size * num_heads, 1, num_queries, sum(num_points_list) ) output = ( (torch.concat(sampling_value_list, dim=-1) * attention_weights) .sum(-1) .view(batch_size, num_heads * hidden_dim, num_queries) ) return output.transpose(1, 2).contiguous() def __init__(self, config: RTDetrV2Config): super().__init__(config, config.decoder_attention_heads, config.decoder_n_points) self.n_levels = config.decoder_n_levels self.offset_scale = config.decoder_offset_scale class RTDetrV2MultiscaleDeformableAttention(RTDetrMultiscaleDeformableAttention): def __init__(self, config: RTDetrV2Config): super().__init__(config, config.decoder_attention_heads, config.decoder_n_points) self.n_levels = config.decoder_n_levels self.offset_scale = config.decoder_offset_scale n_points_list = [self.n_points for _ in range(self.n_levels)] self.n_points_list = n_points_list n_points_scale = [1 / n for n in n_points_list for _ in range(n)] self.register_buffer("n_points_scale", torch.tensor(n_points_scale, dtype=torch.float32)) self._reset_parameters() def _reset_parameters(self): nn.init.constant_(self.sampling_offsets.weight.data, 0.0) default_dtype = torch.get_default_dtype() thetas = torch.arange(self.n_heads, dtype=torch.int64).to(default_dtype) * (2.0 * math.pi / self.n_heads) grid_init = torch.stack([thetas.cos(), thetas.sin()], -1) grid_init = ( (grid_init / grid_init.abs().max(-1, keepdim=True)[0]) .view(self.n_heads, 1, 1, 2) .repeat(1, self.n_levels, self.n_points, 1) ) for i in range(self.n_points): grid_init[:, :, i, :] *= i + 1 with torch.no_grad(): self.sampling_offsets.bias = nn.Parameter(grid_init.view(-1)) nn.init.constant_(self.attention_weights.weight.data, 0.0) nn.init.constant_(self.attention_weights.bias.data, 0.0) nn.init.xavier_uniform_(self.value_proj.weight.data) nn.init.constant_(self.value_proj.bias.data, 0.0) nn.init.xavier_uniform_(self.output_proj.weight.data) nn.init.constant_(self.output_proj.bias.data, 0.0) def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, encoder_hidden_states=None, encoder_attention_mask=None, position_embeddings: Optional[torch.Tensor] = None, reference_points=None, spatial_shapes=None, level_start_index=None, output_attentions: bool = False, ): # add position embeddings to the hidden states before projecting to queries and keys if position_embeddings is not None: hidden_states = self.with_pos_embed(hidden_states, position_embeddings) batch_size, num_queries, _ = hidden_states.shape batch_size, sequence_length, _ = encoder_hidden_states.shape if (spatial_shapes[:, 0] * spatial_shapes[:, 1]).sum() != sequence_length: raise ValueError( "Make sure to align the spatial shapes with the sequence length of the encoder hidden states" ) value = self.value_proj(encoder_hidden_states) if attention_mask is not None: # we invert the attention_mask value = value.masked_fill(~attention_mask[..., None], float(0)) value = value.view(batch_size, sequence_length, self.n_heads, self.d_model // self.n_heads) sampling_offsets = self.sampling_offsets(hidden_states).view( batch_size, num_queries, self.n_heads, self.n_levels * self.n_points, 2 ) attention_weights = self.attention_weights(hidden_states).view( batch_size, num_queries, self.n_heads, self.n_levels * self.n_points ) attention_weights = F.softmax(attention_weights, -1).view( batch_size, num_queries, self.n_heads, self.n_levels * self.n_points ) # batch_size, num_queries, n_heads, n_levels, n_points, 2 num_coordinates = reference_points.shape[-1] if num_coordinates == 2: offset_normalizer = torch.stack([spatial_shapes[..., 1], spatial_shapes[..., 0]], -1) sampling_locations = ( reference_points[:, :, None, :, None, :] + sampling_offsets / offset_normalizer[None, None, None, :, None, :] ) elif num_coordinates == 4: n_points_scale = self.n_points_scale.to(dtype=hidden_states.dtype).unsqueeze(-1) offset = sampling_offsets * n_points_scale * reference_points[:, :, None, :, 2:] * self.offset_scale sampling_locations = reference_points[:, :, None, :, :2] + offset else: raise ValueError(f"Last dim of reference_points must be 2 or 4, but got {reference_points.shape[-1]}") if self.disable_custom_kernels: # PyTorch implementation output = multi_scale_deformable_attention_v2( value, spatial_shapes, sampling_locations, attention_weights, self.n_points_list ) else: try: # custom kernel output = MultiScaleDeformableAttentionFunction.apply( value, spatial_shapes, level_start_index, sampling_locations, attention_weights, self.im2col_step, ) except Exception: # PyTorch implementation output = multi_scale_deformable_attention_v2( value, spatial_shapes, sampling_locations, attention_weights, self.n_points_list ) output = self.output_proj(output) return output, attention_weights class RTDetrV2MultiheadAttention(RTDetrMultiheadAttention): pass class RTDetrV2DecoderLayer(RTDetrDecoderLayer): pass class RTDetrV2PreTrainedModel(RTDetrPreTrainedModel): config_class = RTDetrV2Config base_model_prefix = "rt_detr_v2" main_input_name = "pixel_values" _no_split_modules = [r"RTDetrV2ConvEncoder", r"RTDetrV2EncoderLayer", r"RTDetrV2DecoderLayer"] class RTDetrV2Encoder(RTDetrEncoder): pass class RTDetrV2HybridEncoder(RTDetrHybridEncoder): pass class RTDetrV2Decoder(RTDetrDecoder): pass class RTDetrV2Model(RTDetrModel): pass class RTDetrV2Loss(RTDetrLoss): pass class RTDetrV2MLPPredictionHead(RTDetrMLPPredictionHead): pass class RTDetrV2HungarianMatcher(RTDetrHungarianMatcher): pass # must inherit the new classes! class RTDetrV2ForObjectDetection(RTDetrForObjectDetection): pass