# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import torch.nn as nn import torch.utils.checkpoint as cp from mmcv.cnn import ConvModule from mmcv.cnn.bricks import DropPath from mmengine.model import BaseModule from .se_layer import SELayer class InvertedResidual(BaseModule): """Inverted Residual Block. Args: in_channels (int): The input channels of this Module. out_channels (int): The output channels of this Module. mid_channels (int): The input channels of the depthwise convolution. kernel_size (int): The kernel size of the depthwise convolution. Default: 3. stride (int): The stride of the depthwise convolution. Default: 1. se_cfg (dict): Config dict for se layer. Default: None, which means no se layer. with_expand_conv (bool): Use expand conv or not. If set False, mid_channels must be the same with in_channels. Default: True. conv_cfg (dict): Config dict for convolution layer. Default: None, which means using conv2d. norm_cfg (dict): Config dict for normalization layer. Default: dict(type='BN'). act_cfg (dict): Config dict for activation layer. Default: dict(type='ReLU'). drop_path_rate (float): stochastic depth rate. Defaults to 0. with_cp (bool): Use checkpoint or not. Using checkpoint will save some memory while slowing down the training speed. Default: False. init_cfg (dict or list[dict], optional): Initialization config dict. Default: None Returns: Tensor: The output tensor. """ def __init__(self, in_channels, out_channels, mid_channels, kernel_size=3, stride=1, se_cfg=None, with_expand_conv=True, conv_cfg=None, norm_cfg=dict(type='BN'), act_cfg=dict(type='ReLU'), drop_path_rate=0., with_cp=False, init_cfg=None): super(InvertedResidual, self).__init__(init_cfg) self.with_res_shortcut = (stride == 1 and in_channels == out_channels) assert stride in [1, 2], f'stride must in [1, 2]. ' \ f'But received {stride}.' self.with_cp = with_cp self.drop_path = DropPath( drop_path_rate) if drop_path_rate > 0 else nn.Identity() self.with_se = se_cfg is not None self.with_expand_conv = with_expand_conv if self.with_se: assert isinstance(se_cfg, dict) if not self.with_expand_conv: assert mid_channels == in_channels if self.with_expand_conv: self.expand_conv = ConvModule( in_channels=in_channels, out_channels=mid_channels, kernel_size=1, stride=1, padding=0, conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=act_cfg) self.depthwise_conv = ConvModule( in_channels=mid_channels, out_channels=mid_channels, kernel_size=kernel_size, stride=stride, padding=kernel_size // 2, groups=mid_channels, conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=act_cfg) if self.with_se: self.se = SELayer(**se_cfg) self.linear_conv = ConvModule( in_channels=mid_channels, out_channels=out_channels, kernel_size=1, stride=1, padding=0, conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=None) def forward(self, x): def _inner_forward(x): out = x if self.with_expand_conv: out = self.expand_conv(out) out = self.depthwise_conv(out) if self.with_se: out = self.se(out) out = self.linear_conv(out) if self.with_res_shortcut: return x + self.drop_path(out) else: return out if self.with_cp and x.requires_grad: out = cp.checkpoint(_inner_forward, x) else: out = _inner_forward(x) return out