# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import torch.nn.functional as F from mmcv.cnn import ConvModule from mmengine.model import BaseModule, ModuleList class ConvUpsample(BaseModule): """ConvUpsample performs 2x upsampling after Conv. There are several `ConvModule` layers. In the first few layers, upsampling will be applied after each layer of convolution. The number of upsampling must be no more than the number of ConvModule layers. Args: in_channels (int): Number of channels in the input feature map. inner_channels (int): Number of channels produced by the convolution. num_layers (int): Number of convolution layers. num_upsample (int | optional): Number of upsampling layer. Must be no more than num_layers. Upsampling will be applied after the first ``num_upsample`` layers of convolution. Default: ``num_layers``. conv_cfg (dict): Config dict for convolution layer. Default: None, which means using conv2d. norm_cfg (dict): Config dict for normalization layer. Default: None. init_cfg (dict): Config dict for initialization. Default: None. kwargs (key word augments): Other augments used in ConvModule. """ def __init__(self, in_channels, inner_channels, num_layers=1, num_upsample=None, conv_cfg=None, norm_cfg=None, init_cfg=None, **kwargs): super(ConvUpsample, self).__init__(init_cfg) if num_upsample is None: num_upsample = num_layers assert num_upsample <= num_layers, \ f'num_upsample({num_upsample})must be no more than ' \ f'num_layers({num_layers})' self.num_layers = num_layers self.num_upsample = num_upsample self.conv = ModuleList() for i in range(num_layers): self.conv.append( ConvModule( in_channels, inner_channels, 3, padding=1, stride=1, conv_cfg=conv_cfg, norm_cfg=norm_cfg, **kwargs)) in_channels = inner_channels def forward(self, x): num_upsample = self.num_upsample for i in range(self.num_layers): x = self.conv[i](x) if num_upsample > 0: num_upsample -= 1 x = F.interpolate( x, scale_factor=2, mode='bilinear', align_corners=False) return x