# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import torch import torch.nn as nn import torch.nn.functional as F from mmengine.model import BaseModule, ModuleList, constant_init, xavier_init from mmdet.registry import MODELS from .fpn import FPN class ASPP(BaseModule): """ASPP (Atrous Spatial Pyramid Pooling) This is an implementation of the ASPP module used in DetectoRS (https://arxiv.org/pdf/2006.02334.pdf) Args: in_channels (int): Number of input channels. out_channels (int): Number of channels produced by this module dilations (tuple[int]): Dilations of the four branches. Default: (1, 3, 6, 1) init_cfg (dict or list[dict], optional): Initialization config dict. """ def __init__(self, in_channels, out_channels, dilations=(1, 3, 6, 1), init_cfg=dict(type='Kaiming', layer='Conv2d')): super().__init__(init_cfg) assert dilations[-1] == 1 self.aspp = nn.ModuleList() for dilation in dilations: kernel_size = 3 if dilation > 1 else 1 padding = dilation if dilation > 1 else 0 conv = nn.Conv2d( in_channels, out_channels, kernel_size=kernel_size, stride=1, dilation=dilation, padding=padding, bias=True) self.aspp.append(conv) self.gap = nn.AdaptiveAvgPool2d(1) def forward(self, x): avg_x = self.gap(x) out = [] for aspp_idx in range(len(self.aspp)): inp = avg_x if (aspp_idx == len(self.aspp) - 1) else x out.append(F.relu_(self.aspp[aspp_idx](inp))) out[-1] = out[-1].expand_as(out[-2]) out = torch.cat(out, dim=1) return out @MODELS.register_module() class RFP(FPN): """RFP (Recursive Feature Pyramid) This is an implementation of RFP in `DetectoRS `_. Different from standard FPN, the input of RFP should be multi level features along with origin input image of backbone. Args: rfp_steps (int): Number of unrolled steps of RFP. rfp_backbone (dict): Configuration of the backbone for RFP. aspp_out_channels (int): Number of output channels of ASPP module. aspp_dilations (tuple[int]): Dilation rates of four branches. Default: (1, 3, 6, 1) init_cfg (dict or list[dict], optional): Initialization config dict. Default: None """ def __init__(self, rfp_steps, rfp_backbone, aspp_out_channels, aspp_dilations=(1, 3, 6, 1), init_cfg=None, **kwargs): assert init_cfg is None, 'To prevent abnormal initialization ' \ 'behavior, init_cfg is not allowed to be set' super().__init__(init_cfg=init_cfg, **kwargs) self.rfp_steps = rfp_steps # Be careful! Pretrained weights cannot be loaded when use # nn.ModuleList self.rfp_modules = ModuleList() for rfp_idx in range(1, rfp_steps): rfp_module = MODELS.build(rfp_backbone) self.rfp_modules.append(rfp_module) self.rfp_aspp = ASPP(self.out_channels, aspp_out_channels, aspp_dilations) self.rfp_weight = nn.Conv2d( self.out_channels, 1, kernel_size=1, stride=1, padding=0, bias=True) def init_weights(self): # Avoid using super().init_weights(), which may alter the default # initialization of the modules in self.rfp_modules that have missing # keys in the pretrained checkpoint. for convs in [self.lateral_convs, self.fpn_convs]: for m in convs.modules(): if isinstance(m, nn.Conv2d): xavier_init(m, distribution='uniform') for rfp_idx in range(self.rfp_steps - 1): self.rfp_modules[rfp_idx].init_weights() constant_init(self.rfp_weight, 0) def forward(self, inputs): inputs = list(inputs) assert len(inputs) == len(self.in_channels) + 1 # +1 for input image img = inputs.pop(0) # FPN forward x = super().forward(tuple(inputs)) for rfp_idx in range(self.rfp_steps - 1): rfp_feats = [x[0]] + list( self.rfp_aspp(x[i]) for i in range(1, len(x))) x_idx = self.rfp_modules[rfp_idx].rfp_forward(img, rfp_feats) # FPN forward x_idx = super().forward(x_idx) x_new = [] for ft_idx in range(len(x_idx)): add_weight = torch.sigmoid(self.rfp_weight(x_idx[ft_idx])) x_new.append(add_weight * x_idx[ft_idx] + (1 - add_weight) * x[ft_idx]) x = x_new return x