|
|
|
import torch.nn as nn |
|
|
|
from .conv_module import ConvModule |
|
|
|
|
|
class DepthwiseSeparableConvModule(nn.Module): |
|
"""Depthwise separable convolution module. |
|
|
|
See https://arxiv.org/pdf/1704.04861.pdf for details. |
|
|
|
This module can replace a ConvModule with the conv block replaced by two |
|
conv block: depthwise conv block and pointwise conv block. The depthwise |
|
conv block contains depthwise-conv/norm/activation layers. The pointwise |
|
conv block contains pointwise-conv/norm/activation layers. It should be |
|
noted that there will be norm/activation layer in the depthwise conv block |
|
if `norm_cfg` and `act_cfg` are specified. |
|
|
|
Args: |
|
in_channels (int): Number of channels in the input feature map. |
|
Same as that in ``nn._ConvNd``. |
|
out_channels (int): Number of channels produced by the convolution. |
|
Same as that in ``nn._ConvNd``. |
|
kernel_size (int | tuple[int]): Size of the convolving kernel. |
|
Same as that in ``nn._ConvNd``. |
|
stride (int | tuple[int]): Stride of the convolution. |
|
Same as that in ``nn._ConvNd``. Default: 1. |
|
padding (int | tuple[int]): Zero-padding added to both sides of |
|
the input. Same as that in ``nn._ConvNd``. Default: 0. |
|
dilation (int | tuple[int]): Spacing between kernel elements. |
|
Same as that in ``nn._ConvNd``. Default: 1. |
|
norm_cfg (dict): Default norm config for both depthwise ConvModule and |
|
pointwise ConvModule. Default: None. |
|
act_cfg (dict): Default activation config for both depthwise ConvModule |
|
and pointwise ConvModule. Default: dict(type='ReLU'). |
|
dw_norm_cfg (dict): Norm config of depthwise ConvModule. If it is |
|
'default', it will be the same as `norm_cfg`. Default: 'default'. |
|
dw_act_cfg (dict): Activation config of depthwise ConvModule. If it is |
|
'default', it will be the same as `act_cfg`. Default: 'default'. |
|
pw_norm_cfg (dict): Norm config of pointwise ConvModule. If it is |
|
'default', it will be the same as `norm_cfg`. Default: 'default'. |
|
pw_act_cfg (dict): Activation config of pointwise ConvModule. If it is |
|
'default', it will be the same as `act_cfg`. Default: 'default'. |
|
kwargs (optional): Other shared arguments for depthwise and pointwise |
|
ConvModule. See ConvModule for ref. |
|
""" |
|
|
|
def __init__(self, |
|
in_channels, |
|
out_channels, |
|
kernel_size, |
|
stride=1, |
|
padding=0, |
|
dilation=1, |
|
norm_cfg=None, |
|
act_cfg=dict(type='ReLU'), |
|
dw_norm_cfg='default', |
|
dw_act_cfg='default', |
|
pw_norm_cfg='default', |
|
pw_act_cfg='default', |
|
**kwargs): |
|
super(DepthwiseSeparableConvModule, self).__init__() |
|
assert 'groups' not in kwargs, 'groups should not be specified' |
|
|
|
|
|
|
|
dw_norm_cfg = dw_norm_cfg if dw_norm_cfg != 'default' else norm_cfg |
|
dw_act_cfg = dw_act_cfg if dw_act_cfg != 'default' else act_cfg |
|
pw_norm_cfg = pw_norm_cfg if pw_norm_cfg != 'default' else norm_cfg |
|
pw_act_cfg = pw_act_cfg if pw_act_cfg != 'default' else act_cfg |
|
|
|
|
|
self.depthwise_conv = ConvModule( |
|
in_channels, |
|
in_channels, |
|
kernel_size, |
|
stride=stride, |
|
padding=padding, |
|
dilation=dilation, |
|
groups=in_channels, |
|
norm_cfg=dw_norm_cfg, |
|
act_cfg=dw_act_cfg, |
|
**kwargs) |
|
|
|
self.pointwise_conv = ConvModule( |
|
in_channels, |
|
out_channels, |
|
1, |
|
norm_cfg=pw_norm_cfg, |
|
act_cfg=pw_act_cfg, |
|
**kwargs) |
|
|
|
def forward(self, x): |
|
x = self.depthwise_conv(x) |
|
x = self.pointwise_conv(x) |
|
return x |
|
|