Spaces:
Running
on
Zero
Running
on
Zero
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
# All rights reserved. | |
# | |
# This source code is licensed under the license found in the | |
# LICENSE file in the root directory of this source tree. | |
from typing import Dict, Optional, Tuple, Union | |
import torch | |
import torch.nn as nn | |
from .conv_module import ConvModule | |
class DepthwiseSeparableConvModule(nn.Module): | |
"""Depthwise separable convolution module. | |
See https://arxiv.org/pdf/1704.04861.pdf for details. | |
This module can replace a ConvModule with the conv block replaced by two | |
conv block: depthwise conv block and pointwise conv block. The depthwise | |
conv block contains depthwise-conv/norm/activation layers. The pointwise | |
conv block contains pointwise-conv/norm/activation layers. It should be | |
noted that there will be norm/activation layer in the depthwise conv block | |
if `norm_cfg` and `act_cfg` are specified. | |
Args: | |
in_channels (int): Number of channels in the input feature map. | |
Same as that in ``nn._ConvNd``. | |
out_channels (int): Number of channels produced by the convolution. | |
Same as that in ``nn._ConvNd``. | |
kernel_size (int | tuple[int]): Size of the convolving kernel. | |
Same as that in ``nn._ConvNd``. | |
stride (int | tuple[int]): Stride of the convolution. | |
Same as that in ``nn._ConvNd``. Default: 1. | |
padding (int | tuple[int]): Zero-padding added to both sides of | |
the input. Same as that in ``nn._ConvNd``. Default: 0. | |
dilation (int | tuple[int]): Spacing between kernel elements. | |
Same as that in ``nn._ConvNd``. Default: 1. | |
norm_cfg (dict): Default norm config for both depthwise ConvModule and | |
pointwise ConvModule. Default: None. | |
act_cfg (dict): Default activation config for both depthwise ConvModule | |
and pointwise ConvModule. Default: dict(type='ReLU'). | |
dw_norm_cfg (dict): Norm config of depthwise ConvModule. If it is | |
'default', it will be the same as `norm_cfg`. Default: 'default'. | |
dw_act_cfg (dict): Activation config of depthwise ConvModule. If it is | |
'default', it will be the same as `act_cfg`. Default: 'default'. | |
pw_norm_cfg (dict): Norm config of pointwise ConvModule. If it is | |
'default', it will be the same as `norm_cfg`. Default: 'default'. | |
pw_act_cfg (dict): Activation config of pointwise ConvModule. If it is | |
'default', it will be the same as `act_cfg`. Default: 'default'. | |
kwargs (optional): Other shared arguments for depthwise and pointwise | |
ConvModule. See ConvModule for ref. | |
""" | |
def __init__(self, | |
in_channels: int, | |
out_channels: int, | |
kernel_size: Union[int, Tuple[int, int]], | |
stride: Union[int, Tuple[int, int]] = 1, | |
padding: Union[int, Tuple[int, int]] = 0, | |
dilation: Union[int, Tuple[int, int]] = 1, | |
norm_cfg: Optional[Dict] = None, | |
act_cfg: Dict = dict(type='ReLU'), | |
dw_norm_cfg: Union[Dict, str] = 'default', | |
dw_act_cfg: Union[Dict, str] = 'default', | |
pw_norm_cfg: Union[Dict, str] = 'default', | |
pw_act_cfg: Union[Dict, str] = 'default', | |
**kwargs): | |
super().__init__() | |
assert 'groups' not in kwargs, 'groups should not be specified' | |
# if norm/activation config of depthwise/pointwise ConvModule is not | |
# specified, use default config. | |
dw_norm_cfg = dw_norm_cfg if dw_norm_cfg != 'default' else norm_cfg # type: ignore # noqa E501 | |
dw_act_cfg = dw_act_cfg if dw_act_cfg != 'default' else act_cfg | |
pw_norm_cfg = pw_norm_cfg if pw_norm_cfg != 'default' else norm_cfg # type: ignore # noqa E501 | |
pw_act_cfg = pw_act_cfg if pw_act_cfg != 'default' else act_cfg | |
# depthwise convolution | |
self.depthwise_conv = ConvModule( | |
in_channels, | |
in_channels, | |
kernel_size, | |
stride=stride, | |
padding=padding, | |
dilation=dilation, | |
groups=in_channels, | |
norm_cfg=dw_norm_cfg, # type: ignore | |
act_cfg=dw_act_cfg, # type: ignore | |
**kwargs) | |
self.pointwise_conv = ConvModule( | |
in_channels, | |
out_channels, | |
1, | |
norm_cfg=pw_norm_cfg, # type: ignore | |
act_cfg=pw_act_cfg, # type: ignore | |
**kwargs) | |
def forward(self, x: torch.Tensor) -> torch.Tensor: | |
x = self.depthwise_conv(x) | |
x = self.pointwise_conv(x) | |
return x | |