|
|
|
from typing import Optional, Sequence, Tuple |
|
|
|
import torch |
|
from mmcv.cnn import (ConvModule, build_activation_layer, build_conv_layer, |
|
build_norm_layer) |
|
from mmengine.model import BaseModule |
|
from torch import Tensor, nn |
|
from torch.nn import functional as F |
|
|
|
from mmdet3d.registry import MODELS |
|
from mmdet3d.utils import ConfigType, OptConfigType, OptMultiConfig |
|
|
|
|
|
class BasicBlock(BaseModule): |
|
|
|
def __init__(self, |
|
inplanes: int, |
|
planes: int, |
|
stride: int = 1, |
|
dilation: int = 1, |
|
downsample: Optional[nn.Module] = None, |
|
conv_cfg: OptConfigType = None, |
|
norm_cfg: ConfigType = dict(type='BN'), |
|
act_cfg: ConfigType = dict(type='LeakyReLU'), |
|
init_cfg: OptMultiConfig = None) -> None: |
|
super(BasicBlock, self).__init__(init_cfg) |
|
|
|
self.norm1_name, norm1 = build_norm_layer(norm_cfg, planes, postfix=1) |
|
self.norm2_name, norm2 = build_norm_layer(norm_cfg, planes, postfix=2) |
|
|
|
self.conv1 = build_conv_layer( |
|
conv_cfg, |
|
inplanes, |
|
planes, |
|
3, |
|
stride=stride, |
|
padding=dilation, |
|
dilation=dilation, |
|
bias=False) |
|
self.add_module(self.norm1_name, norm1) |
|
self.conv2 = build_conv_layer( |
|
conv_cfg, planes, planes, 3, padding=1, bias=False) |
|
self.add_module(self.norm2_name, norm2) |
|
self.relu = build_activation_layer(act_cfg) |
|
self.downsample = downsample |
|
|
|
@property |
|
def norm1(self) -> nn.Module: |
|
"""nn.Module: normalization layer after the first convolution layer.""" |
|
return getattr(self, self.norm1_name) |
|
|
|
@property |
|
def norm2(self) -> nn.Module: |
|
"""nn.Module: normalization layer after the second convolution layer. |
|
""" |
|
return getattr(self, self.norm2_name) |
|
|
|
def forward(self, x: Tensor) -> Tensor: |
|
identity = x |
|
|
|
out = self.conv1(x) |
|
out = self.norm1(out) |
|
out = self.relu(out) |
|
|
|
out = self.conv2(out) |
|
out = self.norm2(out) |
|
|
|
if self.downsample is not None: |
|
identity = self.downsample(x) |
|
|
|
out += identity |
|
out = self.relu(out) |
|
return out |
|
|
|
|
|
@MODELS.register_module() |
|
class CENet(BaseModule): |
|
|
|
def __init__(self, |
|
in_channels: int = 5, |
|
stem_channels: int = 128, |
|
num_stages: int = 4, |
|
stage_blocks: Sequence[int] = (3, 4, 6, 3), |
|
out_channels: Sequence[int] = (128, 128, 128, 128), |
|
strides: Sequence[int] = (1, 2, 2, 2), |
|
dilations: Sequence[int] = (1, 1, 1, 1), |
|
fuse_channels: Sequence[int] = (256, 128), |
|
conv_cfg: OptConfigType = None, |
|
norm_cfg: ConfigType = dict(type='BN'), |
|
act_cfg: ConfigType = dict(type='LeakyReLU'), |
|
init_cfg=None) -> None: |
|
super(CENet, self).__init__(init_cfg) |
|
|
|
assert len(stage_blocks) == len(out_channels) == len(strides) == len( |
|
dilations) == num_stages, \ |
|
'The length of stage_blocks, out_channels, strides and ' \ |
|
'dilations should be equal to num_stages' |
|
self.conv_cfg = conv_cfg |
|
self.norm_cfg = norm_cfg |
|
self.act_cfg = act_cfg |
|
self._make_stem_layer(in_channels, stem_channels) |
|
|
|
inplanes = stem_channels |
|
self.res_layers = [] |
|
for i, num_blocks in enumerate(stage_blocks): |
|
stride = strides[i] |
|
dilation = dilations[i] |
|
planes = out_channels[i] |
|
res_layer = self.make_res_layer( |
|
inplanes=inplanes, |
|
planes=planes, |
|
num_blocks=num_blocks, |
|
stride=stride, |
|
dilation=dilation, |
|
conv_cfg=conv_cfg, |
|
norm_cfg=norm_cfg, |
|
act_cfg=act_cfg) |
|
inplanes = planes |
|
layer_name = f'layer{i + 1}' |
|
self.add_module(layer_name, res_layer) |
|
self.res_layers.append(layer_name) |
|
|
|
in_channels = stem_channels + sum(out_channels) |
|
self.fuse_layers = [] |
|
for i, fuse_channel in enumerate(fuse_channels): |
|
fuse_layer = ConvModule( |
|
in_channels, |
|
fuse_channel, |
|
kernel_size=3, |
|
padding=1, |
|
conv_cfg=conv_cfg, |
|
norm_cfg=norm_cfg, |
|
act_cfg=act_cfg) |
|
in_channels = fuse_channel |
|
layer_name = f'fuse_layer{i + 1}' |
|
self.add_module(layer_name, fuse_layer) |
|
self.fuse_layers.append(layer_name) |
|
|
|
def _make_stem_layer(self, in_channels: int, out_channels: int) -> None: |
|
self.stem = nn.Sequential( |
|
build_conv_layer( |
|
self.conv_cfg, |
|
in_channels, |
|
out_channels // 2, |
|
kernel_size=3, |
|
padding=1, |
|
bias=False), |
|
build_norm_layer(self.norm_cfg, out_channels // 2)[1], |
|
build_activation_layer(self.act_cfg), |
|
build_conv_layer( |
|
self.conv_cfg, |
|
out_channels // 2, |
|
out_channels, |
|
kernel_size=3, |
|
padding=1, |
|
bias=False), |
|
build_norm_layer(self.norm_cfg, out_channels)[1], |
|
build_activation_layer(self.act_cfg), |
|
build_conv_layer( |
|
self.conv_cfg, |
|
out_channels, |
|
out_channels, |
|
kernel_size=3, |
|
padding=1, |
|
bias=False), |
|
build_norm_layer(self.norm_cfg, out_channels)[1], |
|
build_activation_layer(self.act_cfg)) |
|
|
|
def make_res_layer( |
|
self, |
|
inplanes: int, |
|
planes: int, |
|
num_blocks: int, |
|
stride: int, |
|
dilation: int, |
|
conv_cfg: OptConfigType = None, |
|
norm_cfg: ConfigType = dict(type='BN'), |
|
act_cfg: ConfigType = dict(type='LeakyReLU') |
|
) -> nn.Sequential: |
|
downsample = None |
|
if stride != 1 or inplanes != planes: |
|
downsample = nn.Sequential( |
|
build_conv_layer( |
|
conv_cfg, |
|
inplanes, |
|
planes, |
|
kernel_size=1, |
|
stride=stride, |
|
bias=False), |
|
build_norm_layer(norm_cfg, planes)[1]) |
|
|
|
layers = [] |
|
layers.append( |
|
BasicBlock( |
|
inplanes=inplanes, |
|
planes=planes, |
|
stride=stride, |
|
dilation=dilation, |
|
downsample=downsample, |
|
conv_cfg=conv_cfg, |
|
norm_cfg=norm_cfg, |
|
act_cfg=act_cfg)) |
|
inplanes = planes |
|
for _ in range(1, num_blocks): |
|
layers.append( |
|
BasicBlock( |
|
inplanes=inplanes, |
|
planes=planes, |
|
stride=1, |
|
dilation=dilation, |
|
conv_cfg=conv_cfg, |
|
norm_cfg=norm_cfg, |
|
act_cfg=act_cfg)) |
|
return nn.Sequential(*layers) |
|
|
|
def forward(self, x: Tensor) -> Tuple[Tensor]: |
|
x = self.stem(x) |
|
outs = [x] |
|
for layer_name in self.res_layers: |
|
res_layer = getattr(self, layer_name) |
|
x = res_layer(x) |
|
outs.append(x) |
|
|
|
|
|
for i in range(len(outs)): |
|
if outs[i].shape != outs[0].shape: |
|
outs[i] = F.interpolate( |
|
outs[i], |
|
size=outs[0].size()[2:], |
|
mode='bilinear', |
|
align_corners=True) |
|
|
|
outs[0] = torch.cat(outs, dim=1) |
|
|
|
for layer_name in self.fuse_layers: |
|
fuse_layer = getattr(self, layer_name) |
|
outs[0] = fuse_layer(outs[0]) |
|
return tuple(outs) |
|
|