rawalkhirodkar's picture
Add initial commit
28c256d
raw
history blame
3.18 kB
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from typing import Dict
import torch.nn as nn
from torch import Tensor
from mmdet.registry import MODELS
from ..layers import (ConditionalDetrTransformerDecoder,
DetrTransformerEncoder, SinePositionalEncoding)
from .detr import DETR
@MODELS.register_module()
class ConditionalDETR(DETR):
r"""Implementation of `Conditional DETR for Fast Training Convergence.
<https://arxiv.org/abs/2108.06152>`_.
Code is modified from the `official github repo
<https://github.com/Atten4Vis/ConditionalDETR>`_.
"""
def _init_layers(self) -> None:
"""Initialize layers except for backbone, neck and bbox_head."""
self.positional_encoding = SinePositionalEncoding(
**self.positional_encoding)
self.encoder = DetrTransformerEncoder(**self.encoder)
self.decoder = ConditionalDetrTransformerDecoder(**self.decoder)
self.embed_dims = self.encoder.embed_dims
# NOTE The embed_dims is typically passed from the inside out.
# For example in DETR, The embed_dims is passed as
# self_attn -> the first encoder layer -> encoder -> detector.
self.query_embedding = nn.Embedding(self.num_queries, self.embed_dims)
num_feats = self.positional_encoding.num_feats
assert num_feats * 2 == self.embed_dims, \
f'embed_dims should be exactly 2 times of num_feats. ' \
f'Found {self.embed_dims} and {num_feats}.'
def forward_decoder(self, query: Tensor, query_pos: Tensor, memory: Tensor,
memory_mask: Tensor, memory_pos: Tensor) -> Dict:
"""Forward with Transformer decoder.
Args:
query (Tensor): The queries of decoder inputs, has shape
(bs, num_queries, dim).
query_pos (Tensor): The positional queries of decoder inputs,
has shape (bs, num_queries, dim).
memory (Tensor): The output embeddings of the Transformer encoder,
has shape (bs, num_feat_points, dim).
memory_mask (Tensor): ByteTensor, the padding mask of the memory,
has shape (bs, num_feat_points).
memory_pos (Tensor): The positional embeddings of memory, has
shape (bs, num_feat_points, dim).
Returns:
dict: The dictionary of decoder outputs, which includes the
`hidden_states` and `references` of the decoder output.
- hidden_states (Tensor): Has shape
(num_decoder_layers, bs, num_queries, dim)
- references (Tensor): Has shape
(bs, num_queries, 2)
"""
hidden_states, references = self.decoder(
query=query,
key=memory,
query_pos=query_pos,
key_pos=memory_pos,
key_padding_mask=memory_mask)
head_inputs_dict = dict(
hidden_states=hidden_states, references=references)
return head_inputs_dict