|
import torch |
|
import torch.nn as nn |
|
from mmcv.cnn import ConvModule, DepthwiseSeparableConvModule |
|
|
|
|
|
class DarknetBottleneck(nn.Module): |
|
"""The basic bottleneck block used in Darknet. |
|
|
|
Each ResBlock consists of two ConvModules and the input is added to the |
|
final output. Each ConvModule is composed of Conv, BN, and LeakyReLU. |
|
The first convLayer has filter size of 1x1 and the second one has the |
|
filter size of 3x3. |
|
|
|
Args: |
|
in_channels (int): The input channels of this Module. |
|
out_channels (int): The output channels of this Module. |
|
expansion (int): The kernel size of the convolution. Default: 0.5 |
|
add_identity (bool): Whether to add identity to the out. |
|
Default: True |
|
use_depthwise (bool): Whether to use depthwise separable convolution. |
|
Default: False |
|
conv_cfg (dict): Config dict for convolution layer. Default: None, |
|
which means using conv2d. |
|
norm_cfg (dict): Config dict for normalization layer. |
|
Default: dict(type='BN'). |
|
act_cfg (dict): Config dict for activation layer. |
|
Default: dict(type='Swish'). |
|
""" |
|
|
|
def __init__(self, |
|
in_channels, |
|
out_channels, |
|
expansion=0.5, |
|
add_identity=True, |
|
use_depthwise=False, |
|
conv_cfg=None, |
|
norm_cfg=dict(type='BN', momentum=0.03, eps=0.001), |
|
act_cfg=dict(type='Swish'), |
|
init_cfg=None): |
|
super().__init__() |
|
hidden_channels = int(out_channels * expansion) |
|
conv = DepthwiseSeparableConvModule if use_depthwise else ConvModule |
|
self.conv1 = ConvModule( |
|
in_channels, |
|
hidden_channels, |
|
1, |
|
conv_cfg=conv_cfg, |
|
norm_cfg=norm_cfg, |
|
act_cfg=act_cfg) |
|
self.conv2 = conv( |
|
hidden_channels, |
|
out_channels, |
|
3, |
|
stride=1, |
|
padding=1, |
|
conv_cfg=conv_cfg, |
|
norm_cfg=norm_cfg, |
|
act_cfg=act_cfg) |
|
self.add_identity = \ |
|
add_identity and in_channels == out_channels |
|
|
|
def forward(self, x): |
|
identity = x |
|
out = self.conv1(x) |
|
out = self.conv2(out) |
|
|
|
if self.add_identity: |
|
return out + identity |
|
else: |
|
return out |
|
|
|
|
|
class CSPLayer(nn.Module): |
|
"""Cross Stage Partial Layer. |
|
|
|
Args: |
|
in_channels (int): The input channels of the CSP layer. |
|
out_channels (int): The output channels of the CSP layer. |
|
expand_ratio (float): Ratio to adjust the number of channels of the |
|
hidden layer. Default: 0.5 |
|
num_blocks (int): Number of blocks. Default: 1 |
|
add_identity (bool): Whether to add identity in blocks. |
|
Default: True |
|
use_depthwise (bool): Whether to depthwise separable convolution in |
|
blocks. Default: False |
|
conv_cfg (dict, optional): Config dict for convolution layer. |
|
Default: None, which means using conv2d. |
|
norm_cfg (dict): Config dict for normalization layer. |
|
Default: dict(type='BN') |
|
act_cfg (dict): Config dict for activation layer. |
|
Default: dict(type='Swish') |
|
""" |
|
|
|
def __init__(self, |
|
in_channels, |
|
out_channels, |
|
expand_ratio=0.5, |
|
num_blocks=1, |
|
add_identity=True, |
|
use_depthwise=False, |
|
conv_cfg=None, |
|
norm_cfg=dict(type='BN', momentum=0.03, eps=0.001), |
|
act_cfg=dict(type='Swish'), |
|
init_cfg=None): |
|
super().__init__() |
|
mid_channels = int(out_channels * expand_ratio) |
|
self.main_conv = ConvModule( |
|
in_channels, |
|
mid_channels, |
|
1, |
|
conv_cfg=conv_cfg, |
|
norm_cfg=norm_cfg, |
|
act_cfg=act_cfg) |
|
self.short_conv = ConvModule( |
|
in_channels, |
|
mid_channels, |
|
1, |
|
conv_cfg=conv_cfg, |
|
norm_cfg=norm_cfg, |
|
act_cfg=act_cfg) |
|
self.final_conv = ConvModule( |
|
2 * mid_channels, |
|
out_channels, |
|
1, |
|
conv_cfg=conv_cfg, |
|
norm_cfg=norm_cfg, |
|
act_cfg=act_cfg) |
|
|
|
self.blocks = nn.Sequential(*[ |
|
DarknetBottleneck( |
|
mid_channels, |
|
mid_channels, |
|
1.0, |
|
add_identity, |
|
use_depthwise, |
|
conv_cfg=conv_cfg, |
|
norm_cfg=norm_cfg, |
|
act_cfg=act_cfg) for _ in range(num_blocks) |
|
]) |
|
|
|
def forward(self, x): |
|
x_short = self.short_conv(x) |
|
|
|
x_main = self.main_conv(x) |
|
x_main = self.blocks(x_main) |
|
|
|
x_final = torch.cat((x_main, x_short), dim=1) |
|
return self.final_conv(x_final) |