Spaces:
Paused
Paused
import torch.nn as nn | |
import torch.utils.checkpoint as cp | |
from annotator.uniformer.mmcv.cnn import (UPSAMPLE_LAYERS, ConvModule, build_activation_layer, | |
build_norm_layer, constant_init, kaiming_init) | |
from annotator.uniformer.mmcv.runner import load_checkpoint | |
from annotator.uniformer.mmcv.utils.parrots_wrapper import _BatchNorm | |
from annotator.uniformer.mmseg.utils import get_root_logger | |
from ..builder import BACKBONES | |
from ..utils import UpConvBlock | |
class BasicConvBlock(nn.Module): | |
"""Basic convolutional block for UNet. | |
This module consists of several plain convolutional layers. | |
Args: | |
in_channels (int): Number of input channels. | |
out_channels (int): Number of output channels. | |
num_convs (int): Number of convolutional layers. Default: 2. | |
stride (int): Whether use stride convolution to downsample | |
the input feature map. If stride=2, it only uses stride convolution | |
in the first convolutional layer to downsample the input feature | |
map. Options are 1 or 2. Default: 1. | |
dilation (int): Whether use dilated convolution to expand the | |
receptive field. Set dilation rate of each convolutional layer and | |
the dilation rate of the first convolutional layer is always 1. | |
Default: 1. | |
with_cp (bool): Use checkpoint or not. Using checkpoint will save some | |
memory while slowing down the training speed. Default: False. | |
conv_cfg (dict | None): Config dict for convolution layer. | |
Default: None. | |
norm_cfg (dict | None): Config dict for normalization layer. | |
Default: dict(type='BN'). | |
act_cfg (dict | None): Config dict for activation layer in ConvModule. | |
Default: dict(type='ReLU'). | |
dcn (bool): Use deformable convolution in convolutional layer or not. | |
Default: None. | |
plugins (dict): plugins for convolutional layers. Default: None. | |
""" | |
def __init__(self, | |
in_channels, | |
out_channels, | |
num_convs=2, | |
stride=1, | |
dilation=1, | |
with_cp=False, | |
conv_cfg=None, | |
norm_cfg=dict(type='BN'), | |
act_cfg=dict(type='ReLU'), | |
dcn=None, | |
plugins=None): | |
super(BasicConvBlock, self).__init__() | |
assert dcn is None, 'Not implemented yet.' | |
assert plugins is None, 'Not implemented yet.' | |
self.with_cp = with_cp | |
convs = [] | |
for i in range(num_convs): | |
convs.append( | |
ConvModule( | |
in_channels=in_channels if i == 0 else out_channels, | |
out_channels=out_channels, | |
kernel_size=3, | |
stride=stride if i == 0 else 1, | |
dilation=1 if i == 0 else dilation, | |
padding=1 if i == 0 else dilation, | |
conv_cfg=conv_cfg, | |
norm_cfg=norm_cfg, | |
act_cfg=act_cfg)) | |
self.convs = nn.Sequential(*convs) | |
def forward(self, x): | |
"""Forward function.""" | |
if self.with_cp and x.requires_grad: | |
out = cp.checkpoint(self.convs, x) | |
else: | |
out = self.convs(x) | |
return out | |
class DeconvModule(nn.Module): | |
"""Deconvolution upsample module in decoder for UNet (2X upsample). | |
This module uses deconvolution to upsample feature map in the decoder | |
of UNet. | |
Args: | |
in_channels (int): Number of input channels. | |
out_channels (int): Number of output channels. | |
with_cp (bool): Use checkpoint or not. Using checkpoint will save some | |
memory while slowing down the training speed. Default: False. | |
norm_cfg (dict | None): Config dict for normalization layer. | |
Default: dict(type='BN'). | |
act_cfg (dict | None): Config dict for activation layer in ConvModule. | |
Default: dict(type='ReLU'). | |
kernel_size (int): Kernel size of the convolutional layer. Default: 4. | |
""" | |
def __init__(self, | |
in_channels, | |
out_channels, | |
with_cp=False, | |
norm_cfg=dict(type='BN'), | |
act_cfg=dict(type='ReLU'), | |
*, | |
kernel_size=4, | |
scale_factor=2): | |
super(DeconvModule, self).__init__() | |
assert (kernel_size - scale_factor >= 0) and\ | |
(kernel_size - scale_factor) % 2 == 0,\ | |
f'kernel_size should be greater than or equal to scale_factor '\ | |
f'and (kernel_size - scale_factor) should be even numbers, '\ | |
f'while the kernel size is {kernel_size} and scale_factor is '\ | |
f'{scale_factor}.' | |
stride = scale_factor | |
padding = (kernel_size - scale_factor) // 2 | |
self.with_cp = with_cp | |
deconv = nn.ConvTranspose2d( | |
in_channels, | |
out_channels, | |
kernel_size=kernel_size, | |
stride=stride, | |
padding=padding) | |
norm_name, norm = build_norm_layer(norm_cfg, out_channels) | |
activate = build_activation_layer(act_cfg) | |
self.deconv_upsamping = nn.Sequential(deconv, norm, activate) | |
def forward(self, x): | |
"""Forward function.""" | |
if self.with_cp and x.requires_grad: | |
out = cp.checkpoint(self.deconv_upsamping, x) | |
else: | |
out = self.deconv_upsamping(x) | |
return out | |
class InterpConv(nn.Module): | |
"""Interpolation upsample module in decoder for UNet. | |
This module uses interpolation to upsample feature map in the decoder | |
of UNet. It consists of one interpolation upsample layer and one | |
convolutional layer. It can be one interpolation upsample layer followed | |
by one convolutional layer (conv_first=False) or one convolutional layer | |
followed by one interpolation upsample layer (conv_first=True). | |
Args: | |
in_channels (int): Number of input channels. | |
out_channels (int): Number of output channels. | |
with_cp (bool): Use checkpoint or not. Using checkpoint will save some | |
memory while slowing down the training speed. Default: False. | |
norm_cfg (dict | None): Config dict for normalization layer. | |
Default: dict(type='BN'). | |
act_cfg (dict | None): Config dict for activation layer in ConvModule. | |
Default: dict(type='ReLU'). | |
conv_cfg (dict | None): Config dict for convolution layer. | |
Default: None. | |
conv_first (bool): Whether convolutional layer or interpolation | |
upsample layer first. Default: False. It means interpolation | |
upsample layer followed by one convolutional layer. | |
kernel_size (int): Kernel size of the convolutional layer. Default: 1. | |
stride (int): Stride of the convolutional layer. Default: 1. | |
padding (int): Padding of the convolutional layer. Default: 1. | |
upsample_cfg (dict): Interpolation config of the upsample layer. | |
Default: dict( | |
scale_factor=2, mode='bilinear', align_corners=False). | |
""" | |
def __init__(self, | |
in_channels, | |
out_channels, | |
with_cp=False, | |
norm_cfg=dict(type='BN'), | |
act_cfg=dict(type='ReLU'), | |
*, | |
conv_cfg=None, | |
conv_first=False, | |
kernel_size=1, | |
stride=1, | |
padding=0, | |
upsample_cfg=dict( | |
scale_factor=2, mode='bilinear', align_corners=False)): | |
super(InterpConv, self).__init__() | |
self.with_cp = with_cp | |
conv = ConvModule( | |
in_channels, | |
out_channels, | |
kernel_size=kernel_size, | |
stride=stride, | |
padding=padding, | |
conv_cfg=conv_cfg, | |
norm_cfg=norm_cfg, | |
act_cfg=act_cfg) | |
upsample = nn.Upsample(**upsample_cfg) | |
if conv_first: | |
self.interp_upsample = nn.Sequential(conv, upsample) | |
else: | |
self.interp_upsample = nn.Sequential(upsample, conv) | |
def forward(self, x): | |
"""Forward function.""" | |
if self.with_cp and x.requires_grad: | |
out = cp.checkpoint(self.interp_upsample, x) | |
else: | |
out = self.interp_upsample(x) | |
return out | |
class UNet(nn.Module): | |
"""UNet backbone. | |
U-Net: Convolutional Networks for Biomedical Image Segmentation. | |
https://arxiv.org/pdf/1505.04597.pdf | |
Args: | |
in_channels (int): Number of input image channels. Default" 3. | |
base_channels (int): Number of base channels of each stage. | |
The output channels of the first stage. Default: 64. | |
num_stages (int): Number of stages in encoder, normally 5. Default: 5. | |
strides (Sequence[int 1 | 2]): Strides of each stage in encoder. | |
len(strides) is equal to num_stages. Normally the stride of the | |
first stage in encoder is 1. If strides[i]=2, it uses stride | |
convolution to downsample in the correspondence encoder stage. | |
Default: (1, 1, 1, 1, 1). | |
enc_num_convs (Sequence[int]): Number of convolutional layers in the | |
convolution block of the correspondence encoder stage. | |
Default: (2, 2, 2, 2, 2). | |
dec_num_convs (Sequence[int]): Number of convolutional layers in the | |
convolution block of the correspondence decoder stage. | |
Default: (2, 2, 2, 2). | |
downsamples (Sequence[int]): Whether use MaxPool to downsample the | |
feature map after the first stage of encoder | |
(stages: [1, num_stages)). If the correspondence encoder stage use | |
stride convolution (strides[i]=2), it will never use MaxPool to | |
downsample, even downsamples[i-1]=True. | |
Default: (True, True, True, True). | |
enc_dilations (Sequence[int]): Dilation rate of each stage in encoder. | |
Default: (1, 1, 1, 1, 1). | |
dec_dilations (Sequence[int]): Dilation rate of each stage in decoder. | |
Default: (1, 1, 1, 1). | |
with_cp (bool): Use checkpoint or not. Using checkpoint will save some | |
memory while slowing down the training speed. Default: False. | |
conv_cfg (dict | None): Config dict for convolution layer. | |
Default: None. | |
norm_cfg (dict | None): Config dict for normalization layer. | |
Default: dict(type='BN'). | |
act_cfg (dict | None): Config dict for activation layer in ConvModule. | |
Default: dict(type='ReLU'). | |
upsample_cfg (dict): The upsample config of the upsample module in | |
decoder. Default: dict(type='InterpConv'). | |
norm_eval (bool): Whether to set norm layers to eval mode, namely, | |
freeze running stats (mean and var). Note: Effect on Batch Norm | |
and its variants only. Default: False. | |
dcn (bool): Use deformable convolution in convolutional layer or not. | |
Default: None. | |
plugins (dict): plugins for convolutional layers. Default: None. | |
Notice: | |
The input image size should be divisible by the whole downsample rate | |
of the encoder. More detail of the whole downsample rate can be found | |
in UNet._check_input_divisible. | |
""" | |
def __init__(self, | |
in_channels=3, | |
base_channels=64, | |
num_stages=5, | |
strides=(1, 1, 1, 1, 1), | |
enc_num_convs=(2, 2, 2, 2, 2), | |
dec_num_convs=(2, 2, 2, 2), | |
downsamples=(True, True, True, True), | |
enc_dilations=(1, 1, 1, 1, 1), | |
dec_dilations=(1, 1, 1, 1), | |
with_cp=False, | |
conv_cfg=None, | |
norm_cfg=dict(type='BN'), | |
act_cfg=dict(type='ReLU'), | |
upsample_cfg=dict(type='InterpConv'), | |
norm_eval=False, | |
dcn=None, | |
plugins=None): | |
super(UNet, self).__init__() | |
assert dcn is None, 'Not implemented yet.' | |
assert plugins is None, 'Not implemented yet.' | |
assert len(strides) == num_stages, \ | |
'The length of strides should be equal to num_stages, '\ | |
f'while the strides is {strides}, the length of '\ | |
f'strides is {len(strides)}, and the num_stages is '\ | |
f'{num_stages}.' | |
assert len(enc_num_convs) == num_stages, \ | |
'The length of enc_num_convs should be equal to num_stages, '\ | |
f'while the enc_num_convs is {enc_num_convs}, the length of '\ | |
f'enc_num_convs is {len(enc_num_convs)}, and the num_stages is '\ | |
f'{num_stages}.' | |
assert len(dec_num_convs) == (num_stages-1), \ | |
'The length of dec_num_convs should be equal to (num_stages-1), '\ | |
f'while the dec_num_convs is {dec_num_convs}, the length of '\ | |
f'dec_num_convs is {len(dec_num_convs)}, and the num_stages is '\ | |
f'{num_stages}.' | |
assert len(downsamples) == (num_stages-1), \ | |
'The length of downsamples should be equal to (num_stages-1), '\ | |
f'while the downsamples is {downsamples}, the length of '\ | |
f'downsamples is {len(downsamples)}, and the num_stages is '\ | |
f'{num_stages}.' | |
assert len(enc_dilations) == num_stages, \ | |
'The length of enc_dilations should be equal to num_stages, '\ | |
f'while the enc_dilations is {enc_dilations}, the length of '\ | |
f'enc_dilations is {len(enc_dilations)}, and the num_stages is '\ | |
f'{num_stages}.' | |
assert len(dec_dilations) == (num_stages-1), \ | |
'The length of dec_dilations should be equal to (num_stages-1), '\ | |
f'while the dec_dilations is {dec_dilations}, the length of '\ | |
f'dec_dilations is {len(dec_dilations)}, and the num_stages is '\ | |
f'{num_stages}.' | |
self.num_stages = num_stages | |
self.strides = strides | |
self.downsamples = downsamples | |
self.norm_eval = norm_eval | |
self.base_channels = base_channels | |
self.encoder = nn.ModuleList() | |
self.decoder = nn.ModuleList() | |
for i in range(num_stages): | |
enc_conv_block = [] | |
if i != 0: | |
if strides[i] == 1 and downsamples[i - 1]: | |
enc_conv_block.append(nn.MaxPool2d(kernel_size=2)) | |
upsample = (strides[i] != 1 or downsamples[i - 1]) | |
self.decoder.append( | |
UpConvBlock( | |
conv_block=BasicConvBlock, | |
in_channels=base_channels * 2**i, | |
skip_channels=base_channels * 2**(i - 1), | |
out_channels=base_channels * 2**(i - 1), | |
num_convs=dec_num_convs[i - 1], | |
stride=1, | |
dilation=dec_dilations[i - 1], | |
with_cp=with_cp, | |
conv_cfg=conv_cfg, | |
norm_cfg=norm_cfg, | |
act_cfg=act_cfg, | |
upsample_cfg=upsample_cfg if upsample else None, | |
dcn=None, | |
plugins=None)) | |
enc_conv_block.append( | |
BasicConvBlock( | |
in_channels=in_channels, | |
out_channels=base_channels * 2**i, | |
num_convs=enc_num_convs[i], | |
stride=strides[i], | |
dilation=enc_dilations[i], | |
with_cp=with_cp, | |
conv_cfg=conv_cfg, | |
norm_cfg=norm_cfg, | |
act_cfg=act_cfg, | |
dcn=None, | |
plugins=None)) | |
self.encoder.append((nn.Sequential(*enc_conv_block))) | |
in_channels = base_channels * 2**i | |
def forward(self, x): | |
self._check_input_divisible(x) | |
enc_outs = [] | |
for enc in self.encoder: | |
x = enc(x) | |
enc_outs.append(x) | |
dec_outs = [x] | |
for i in reversed(range(len(self.decoder))): | |
x = self.decoder[i](enc_outs[i], x) | |
dec_outs.append(x) | |
return dec_outs | |
def train(self, mode=True): | |
"""Convert the model into training mode while keep normalization layer | |
freezed.""" | |
super(UNet, self).train(mode) | |
if mode and self.norm_eval: | |
for m in self.modules(): | |
# trick: eval have effect on BatchNorm only | |
if isinstance(m, _BatchNorm): | |
m.eval() | |
def _check_input_divisible(self, x): | |
h, w = x.shape[-2:] | |
whole_downsample_rate = 1 | |
for i in range(1, self.num_stages): | |
if self.strides[i] == 2 or self.downsamples[i - 1]: | |
whole_downsample_rate *= 2 | |
assert (h % whole_downsample_rate == 0) \ | |
and (w % whole_downsample_rate == 0),\ | |
f'The input image size {(h, w)} should be divisible by the whole '\ | |
f'downsample rate {whole_downsample_rate}, when num_stages is '\ | |
f'{self.num_stages}, strides is {self.strides}, and downsamples '\ | |
f'is {self.downsamples}.' | |
def init_weights(self, pretrained=None): | |
"""Initialize the weights in backbone. | |
Args: | |
pretrained (str, optional): Path to pre-trained weights. | |
Defaults to None. | |
""" | |
if isinstance(pretrained, str): | |
logger = get_root_logger() | |
load_checkpoint(self, pretrained, strict=False, logger=logger) | |
elif pretrained is None: | |
for m in self.modules(): | |
if isinstance(m, nn.Conv2d): | |
kaiming_init(m) | |
elif isinstance(m, (_BatchNorm, nn.GroupNorm)): | |
constant_init(m, 1) | |
else: | |
raise TypeError('pretrained must be a str or None') | |