Spaces:
Running
on
Zero
Running
on
Zero
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
# All rights reserved. | |
# | |
# This source code is licensed under the license found in the | |
# LICENSE file in the root directory of this source tree. | |
from typing import Dict, Tuple | |
from mmengine.model import uniform_init | |
from torch import Tensor, nn | |
from mmdet.registry import MODELS | |
from ..layers import SinePositionalEncoding | |
from ..layers.transformer import (DABDetrTransformerDecoder, | |
DABDetrTransformerEncoder, inverse_sigmoid) | |
from .detr import DETR | |
class DABDETR(DETR): | |
r"""Implementation of `DAB-DETR: | |
Dynamic Anchor Boxes are Better Queries for DETR. | |
<https://arxiv.org/abs/2201.12329>`_. | |
Code is modified from the `official github repo | |
<https://github.com/IDEA-Research/DAB-DETR>`_. | |
Args: | |
with_random_refpoints (bool): Whether to randomly initialize query | |
embeddings and not update them during training. | |
Defaults to False. | |
num_patterns (int): Inspired by Anchor-DETR. Defaults to 0. | |
""" | |
def __init__(self, | |
*args, | |
with_random_refpoints: bool = False, | |
num_patterns: int = 0, | |
**kwargs) -> None: | |
self.with_random_refpoints = with_random_refpoints | |
assert isinstance(num_patterns, int), \ | |
f'num_patterns should be int but {num_patterns}.' | |
self.num_patterns = num_patterns | |
super().__init__(*args, **kwargs) | |
def _init_layers(self) -> None: | |
"""Initialize layers except for backbone, neck and bbox_head.""" | |
self.positional_encoding = SinePositionalEncoding( | |
**self.positional_encoding) | |
self.encoder = DABDetrTransformerEncoder(**self.encoder) | |
self.decoder = DABDetrTransformerDecoder(**self.decoder) | |
self.embed_dims = self.encoder.embed_dims | |
self.query_dim = self.decoder.query_dim | |
self.query_embedding = nn.Embedding(self.num_queries, self.query_dim) | |
if self.num_patterns > 0: | |
self.patterns = nn.Embedding(self.num_patterns, self.embed_dims) | |
num_feats = self.positional_encoding.num_feats | |
assert num_feats * 2 == self.embed_dims, \ | |
f'embed_dims should be exactly 2 times of num_feats. ' \ | |
f'Found {self.embed_dims} and {num_feats}.' | |
def init_weights(self) -> None: | |
"""Initialize weights for Transformer and other components.""" | |
super(DABDETR, self).init_weights() | |
if self.with_random_refpoints: | |
uniform_init(self.query_embedding) | |
self.query_embedding.weight.data[:, :2] = \ | |
inverse_sigmoid(self.query_embedding.weight.data[:, :2]) | |
self.query_embedding.weight.data[:, :2].requires_grad = False | |
def pre_decoder(self, memory: Tensor) -> Tuple[Dict, Dict]: | |
"""Prepare intermediate variables before entering Transformer decoder, | |
such as `query`, `query_pos`. | |
Args: | |
memory (Tensor): The output embeddings of the Transformer encoder, | |
has shape (bs, num_feat_points, dim). | |
Returns: | |
tuple[dict, dict]: The first dict contains the inputs of decoder | |
and the second dict contains the inputs of the bbox_head function. | |
- decoder_inputs_dict (dict): The keyword args dictionary of | |
`self.forward_decoder()`, which includes 'query', 'query_pos', | |
'memory' and 'reg_branches'. | |
- head_inputs_dict (dict): The keyword args dictionary of the | |
bbox_head functions, which is usually empty, or includes | |
`enc_outputs_class` and `enc_outputs_class` when the detector | |
support 'two stage' or 'query selection' strategies. | |
""" | |
batch_size = memory.size(0) | |
query_pos = self.query_embedding.weight | |
query_pos = query_pos.unsqueeze(0).repeat(batch_size, 1, 1) | |
if self.num_patterns == 0: | |
query = query_pos.new_zeros(batch_size, self.num_queries, | |
self.embed_dims) | |
else: | |
query = self.patterns.weight[:, None, None, :]\ | |
.repeat(1, self.num_queries, batch_size, 1)\ | |
.view(-1, batch_size, self.embed_dims)\ | |
.permute(1, 0, 2) | |
query_pos = query_pos.repeat(1, self.num_patterns, 1) | |
decoder_inputs_dict = dict( | |
query_pos=query_pos, query=query, memory=memory) | |
head_inputs_dict = dict() | |
return decoder_inputs_dict, head_inputs_dict | |
def forward_decoder(self, query: Tensor, query_pos: Tensor, memory: Tensor, | |
memory_mask: Tensor, memory_pos: Tensor) -> Dict: | |
"""Forward with Transformer decoder. | |
Args: | |
query (Tensor): The queries of decoder inputs, has shape | |
(bs, num_queries, dim). | |
query_pos (Tensor): The positional queries of decoder inputs, | |
has shape (bs, num_queries, dim). | |
memory (Tensor): The output embeddings of the Transformer encoder, | |
has shape (bs, num_feat_points, dim). | |
memory_mask (Tensor): ByteTensor, the padding mask of the memory, | |
has shape (bs, num_feat_points). | |
memory_pos (Tensor): The positional embeddings of memory, has | |
shape (bs, num_feat_points, dim). | |
Returns: | |
dict: The dictionary of decoder outputs, which includes the | |
`hidden_states` and `references` of the decoder output. | |
""" | |
hidden_states, references = self.decoder( | |
query=query, | |
key=memory, | |
query_pos=query_pos, | |
key_pos=memory_pos, | |
key_padding_mask=memory_mask, | |
reg_branches=self.bbox_head. | |
fc_reg # iterative refinement for anchor boxes | |
) | |
head_inputs_dict = dict( | |
hidden_states=hidden_states, references=references) | |
return head_inputs_dict | |