from functools import partial import torch from torch import nn as nn from torch.nn import functional as F from medomni.models.unet3d.se import ChannelSELayer3D, ChannelSpatialSELayer3D, SpatialSELayer3D import ipdb def create_conv(in_channels, out_channels, kernel_size, order, num_groups, padding, is3d): """ Create a list of modules with together constitute a single conv layer with non-linearity and optional batchnorm/groupnorm. Args: in_channels (int): number of input channels out_channels (int): number of output channels kernel_size(int or tuple): size of the convolving kernel order (string): order of things, e.g. 'cr' -> conv + ReLU 'gcr' -> groupnorm + conv + ReLU 'cl' -> conv + LeakyReLU 'ce' -> conv + ELU 'bcr' -> batchnorm + conv + ReLU num_groups (int): number of groups for the GroupNorm padding (int or tuple): add zero-padding added to all three sides of the input is3d (bool): is3d (bool): if True use Conv3d, otherwise use Conv2d Return: list of tuple (name, module) """ assert 'c' in order, "Conv layer MUST be present" assert order[0] not in 'rle', 'Non-linearity cannot be the first operation in the layer' modules = [] for i, char in enumerate(order): if char == 'r': modules.append(('ReLU', nn.ReLU(inplace=True))) elif char == 'l': modules.append(('LeakyReLU', nn.LeakyReLU(inplace=True))) elif char == 'e': modules.append(('ELU', nn.ELU(inplace=True))) elif char == 'c': # add learnable bias only in the absence of batchnorm/groupnorm bias = not ('g' in order or 'b' in order) if is3d: conv = nn.Conv3d(in_channels, out_channels, kernel_size, padding=padding, bias=bias) else: conv = nn.Conv2d(in_channels, out_channels, kernel_size, padding=padding, bias=bias) modules.append(('conv', conv)) elif char == 'g': is_before_conv = i < order.index('c') if is_before_conv: num_channels = in_channels else: num_channels = out_channels # use only one group if the given number of groups is greater than the number of channels if num_channels < num_groups: num_groups = 1 assert num_channels % num_groups == 0, f'Expected number of channels in input to be divisible by num_groups. num_channels={num_channels}, num_groups={num_groups}' modules.append(('groupnorm', nn.GroupNorm(num_groups=num_groups, num_channels=num_channels))) elif char == 'b': is_before_conv = i < order.index('c') if is3d: bn = nn.BatchNorm3d else: bn = nn.BatchNorm2d if is_before_conv: modules.append(('batchnorm', bn(in_channels))) else: modules.append(('batchnorm', bn(out_channels))) else: raise ValueError(f"Unsupported layer type '{char}'. MUST be one of ['b', 'g', 'r', 'l', 'e', 'c']") return modules class SingleConv(nn.Sequential): """ Basic convolutional module consisting of a Conv3d, non-linearity and optional batchnorm/groupnorm. The order of operations can be specified via the `order` parameter Args: in_channels (int): number of input channels out_channels (int): number of output channels kernel_size (int or tuple): size of the convolving kernel order (string): determines the order of layers, e.g. 'cr' -> conv + ReLU 'crg' -> conv + ReLU + groupnorm 'cl' -> conv + LeakyReLU 'ce' -> conv + ELU num_groups (int): number of groups for the GroupNorm padding (int or tuple): add zero-padding is3d (bool): if True use Conv3d, otherwise use Conv2d """ def __init__(self, in_channels, out_channels, kernel_size=3, order='gcr', num_groups=8, padding=1, is3d=True): super(SingleConv, self).__init__() for name, module in create_conv(in_channels, out_channels, kernel_size, order, num_groups, padding, is3d): self.add_module(name, module) class DoubleConv(nn.Sequential): """ A module consisting of two consecutive convolution layers (e.g. BatchNorm3d+ReLU+Conv3d). We use (Conv3d+ReLU+GroupNorm3d) by default. This can be changed however by providing the 'order' argument, e.g. in order to change to Conv3d+BatchNorm3d+ELU use order='cbe'. Use padded convolutions to make sure that the output (H_out, W_out) is the same as (H_in, W_in), so that you don't have to crop in the decoder path. Args: in_channels (int): number of input channels out_channels (int): number of output channels encoder (bool): if True we're in the encoder path, otherwise we're in the decoder kernel_size (int or tuple): size of the convolving kernel order (string): determines the order of layers, e.g. 'cr' -> conv + ReLU 'crg' -> conv + ReLU + groupnorm 'cl' -> conv + LeakyReLU 'ce' -> conv + ELU num_groups (int): number of groups for the GroupNorm padding (int or tuple): add zero-padding added to all three sides of the input is3d (bool): if True use Conv3d instead of Conv2d layers """ def __init__(self, in_channels, out_channels, encoder, kernel_size=3, order='gcr', num_groups=8, padding=1, is3d=True): super(DoubleConv, self).__init__() if encoder: # we're in the encoder path conv1_in_channels = in_channels conv1_out_channels = out_channels // 2 if conv1_out_channels < in_channels: conv1_out_channels = in_channels conv2_in_channels, conv2_out_channels = conv1_out_channels, out_channels else: # we're in the decoder path, decrease the number of channels in the 1st convolution conv1_in_channels, conv1_out_channels = in_channels, out_channels conv2_in_channels, conv2_out_channels = out_channels, out_channels # conv1 self.add_module('SingleConv1', SingleConv(conv1_in_channels, conv1_out_channels, kernel_size, order, num_groups, padding=padding, is3d=is3d)) # conv2 self.add_module('SingleConv2', SingleConv(conv2_in_channels, conv2_out_channels, kernel_size, order, num_groups, padding=padding, is3d=is3d)) class ResNetBlock(nn.Module): """ Residual block that can be used instead of standard DoubleConv in the Encoder module. Motivated by: https://arxiv.org/pdf/1706.00120.pdf Notice we use ELU instead of ReLU (order='cge') and put non-linearity after the groupnorm. """ def __init__(self, in_channels, out_channels, kernel_size=3, order='cge', num_groups=8, is3d=True, **kwargs): super(ResNetBlock, self).__init__() if in_channels != out_channels: # conv1x1 for increasing the number of channels if is3d: self.conv1 = nn.Conv3d(in_channels, out_channels, 1) else: self.conv1 = nn.Conv2d(in_channels, out_channels, 1) else: self.conv1 = nn.Identity() # residual block self.conv2 = SingleConv(out_channels, out_channels, kernel_size=kernel_size, order=order, num_groups=num_groups, is3d=is3d) # remove non-linearity from the 3rd convolution since it's going to be applied after adding the residual n_order = order for c in 'rel': n_order = n_order.replace(c, '') self.conv3 = SingleConv(out_channels, out_channels, kernel_size=kernel_size, order=n_order, num_groups=num_groups, is3d=is3d) # create non-linearity separately if 'l' in order: self.non_linearity = nn.LeakyReLU(negative_slope=0.1, inplace=True) elif 'e' in order: self.non_linearity = nn.ELU(inplace=True) else: self.non_linearity = nn.ReLU(inplace=True) def forward(self, x): # apply first convolution to bring the number of channels to out_channels residual = self.conv1(x) # residual block out = self.conv2(residual) out = self.conv3(out) out += residual out = self.non_linearity(out) return out class ResNetBlockSE(ResNetBlock): def __init__(self, in_channels, out_channels, kernel_size=3, order='cge', num_groups=8, se_module='scse', **kwargs): super(ResNetBlockSE, self).__init__( in_channels, out_channels, kernel_size=kernel_size, order=order, num_groups=num_groups, **kwargs) assert se_module in ['scse', 'cse', 'sse'] if se_module == 'scse': self.se_module = ChannelSpatialSELayer3D(num_channels=out_channels, reduction_ratio=1) elif se_module == 'cse': self.se_module = ChannelSELayer3D(num_channels=out_channels, reduction_ratio=1) elif se_module == 'sse': self.se_module = SpatialSELayer3D(num_channels=out_channels) def forward(self, x): out = super().forward(x) out = self.se_module(out) return out class Encoder(nn.Module): """ A single module from the encoder path consisting of the optional max pooling layer (one may specify the MaxPool kernel_size to be different from the standard (2,2,2), e.g. if the volumetric data is anisotropic (make sure to use complementary scale_factor in the decoder path) followed by a basic module (DoubleConv or ResNetBlock). Args: in_channels (int): number of input channels out_channels (int): number of output channels conv_kernel_size (int or tuple): size of the convolving kernel apply_pooling (bool): if True use MaxPool3d before DoubleConv pool_kernel_size (int or tuple): the size of the window pool_type (str): pooling layer: 'max' or 'avg' basic_module(nn.Module): either ResNetBlock or DoubleConv conv_layer_order (string): determines the order of layers in `DoubleConv` module. See `DoubleConv` for more info. num_groups (int): number of groups for the GroupNorm padding (int or tuple): add zero-padding added to all three sides of the input is3d (bool): use 3d or 2d convolutions/pooling operation """ def __init__(self, in_channels, out_channels, conv_kernel_size=3, apply_pooling=True, pool_kernel_size=2, pool_type='max', basic_module=DoubleConv, conv_layer_order='gcr', num_groups=8, padding=1, is3d=True): super(Encoder, self).__init__() assert pool_type in ['max', 'avg'] if apply_pooling: if pool_type == 'max': if is3d: self.pooling = nn.MaxPool3d(kernel_size=pool_kernel_size) else: self.pooling = nn.MaxPool2d(kernel_size=pool_kernel_size) else: if is3d: self.pooling = nn.AvgPool3d(kernel_size=pool_kernel_size) else: self.pooling = nn.AvgPool2d(kernel_size=pool_kernel_size) else: self.pooling = None self.basic_module = basic_module(in_channels, out_channels, encoder=True, kernel_size=conv_kernel_size, order=conv_layer_order, num_groups=num_groups, padding=padding, is3d=is3d) def forward(self, x): if self.pooling is not None: x = self.pooling(x) x = self.basic_module(x) return x class Decoder(nn.Module): """ A single module for decoder path consisting of the upsampling layer (either learned ConvTranspose3d or nearest neighbor interpolation) followed by a basic module (DoubleConv or ResNetBlock). Args: in_channels (int): number of input channels out_channels (int): number of output channels conv_kernel_size (int or tuple): size of the convolving kernel scale_factor (tuple): used as the multiplier for the image H/W/D in case of nn.Upsample or as stride in case of ConvTranspose3d, must reverse the MaxPool3d operation from the corresponding encoder basic_module(nn.Module): either ResNetBlock or DoubleConv conv_layer_order (string): determines the order of layers in `DoubleConv` module. See `DoubleConv` for more info. num_groups (int): number of groups for the GroupNorm padding (int or tuple): add zero-padding added to all three sides of the input upsample (bool): should the input be upsampled """ def __init__(self, in_channels, out_channels, conv_kernel_size=3, scale_factor=(2, 2, 2), basic_module=DoubleConv, conv_layer_order='gcr', num_groups=8, mode='nearest', padding=1, upsample=True, is3d=True): super(Decoder, self).__init__() if upsample: if basic_module == DoubleConv: # if DoubleConv is the basic_module use interpolation for upsampling and concatenation joining self.upsampling = InterpolateUpsampling(mode=mode) # concat joining self.joining = partial(self._joining, concat=True) else: # if basic_module=ResNetBlock use transposed convolution upsampling and summation joining self.upsampling = TransposeConvUpsampling(in_channels=in_channels, out_channels=out_channels, kernel_size=conv_kernel_size, scale_factor=scale_factor) # sum joining self.joining = partial(self._joining, concat=False) # adapt the number of in_channels for the ResNetBlock in_channels = out_channels else: # no upsampling self.upsampling = NoUpsampling() # concat joining self.joining = partial(self._joining, concat=True) self.basic_module = basic_module(in_channels, out_channels, encoder=False, kernel_size=conv_kernel_size, order=conv_layer_order, num_groups=num_groups, padding=padding, is3d=is3d) def forward(self, encoder_features, x): x = self.upsampling(encoder_features=encoder_features, x=x) x = self.joining(encoder_features, x) ipdb.set_trace() x = self.basic_module(x) return x @staticmethod def _joining(encoder_features, x, concat): if concat: return torch.cat((encoder_features, x), dim=1) else: return encoder_features + x def create_encoders(in_channels, f_maps, basic_module, conv_kernel_size, conv_padding, layer_order, num_groups, pool_kernel_size, is3d): # create encoder path consisting of Encoder modules. Depth of the encoder is equal to `len(f_maps)` encoders = [] for i, out_feature_num in enumerate(f_maps): if i == 0: # apply conv_coord only in the first encoder if any encoder = Encoder(in_channels, out_feature_num, apply_pooling=False, # skip pooling in the firs encoder basic_module=basic_module, conv_layer_order=layer_order, conv_kernel_size=conv_kernel_size, num_groups=num_groups, padding=conv_padding, is3d=is3d) else: encoder = Encoder(f_maps[i - 1], out_feature_num, basic_module=basic_module, conv_layer_order=layer_order, conv_kernel_size=conv_kernel_size, num_groups=num_groups, pool_kernel_size=pool_kernel_size, padding=conv_padding, is3d=is3d) encoders.append(encoder) return nn.ModuleList(encoders) def create_decoders(f_maps, basic_module, conv_kernel_size, conv_padding, layer_order, num_groups, is3d): # create decoder path consisting of the Decoder modules. The length of the decoder list is equal to `len(f_maps) - 1` decoders = [] reversed_f_maps = list(reversed(f_maps)) for i in range(len(reversed_f_maps) - 1): if basic_module == DoubleConv: in_feature_num = reversed_f_maps[i] + reversed_f_maps[i + 1] else: in_feature_num = reversed_f_maps[i] out_feature_num = reversed_f_maps[i + 1] decoder = Decoder(in_feature_num, out_feature_num, basic_module=basic_module, conv_layer_order=layer_order, conv_kernel_size=conv_kernel_size, num_groups=num_groups, padding=conv_padding, is3d=is3d) decoders.append(decoder) return nn.ModuleList(decoders) class AbstractUpsampling(nn.Module): """ Abstract class for upsampling. A given implementation should upsample a given 5D input tensor using either interpolation or learned transposed convolution. """ def __init__(self, upsample): super(AbstractUpsampling, self).__init__() self.upsample = upsample def forward(self, encoder_features, x): # get the spatial dimensions of the output given the encoder_features output_size = encoder_features.size()[2:] # upsample the input and return return self.upsample(x, output_size) class InterpolateUpsampling(AbstractUpsampling): """ Args: mode (str): algorithm used for upsampling: 'nearest' | 'linear' | 'bilinear' | 'trilinear' | 'area'. Default: 'nearest' used only if transposed_conv is False """ def __init__(self, mode='nearest'): upsample = partial(self._interpolate, mode=mode) super().__init__(upsample) @staticmethod def _interpolate(x, size, mode): return F.interpolate(x, size=size, mode=mode) class TransposeConvUpsampling(AbstractUpsampling): """ Args: in_channels (int): number of input channels for transposed conv used only if transposed_conv is True out_channels (int): number of output channels for transpose conv used only if transposed_conv is True kernel_size (int or tuple): size of the convolving kernel used only if transposed_conv is True scale_factor (int or tuple): stride of the convolution used only if transposed_conv is True """ def __init__(self, in_channels=None, out_channels=None, kernel_size=3, scale_factor=(2, 2, 2)): # make sure that the output size reverses the MaxPool3d from the corresponding encoder upsample = nn.ConvTranspose3d(in_channels, out_channels, kernel_size=kernel_size, stride=scale_factor, padding=1) super().__init__(upsample) class NoUpsampling(AbstractUpsampling): def __init__(self): super().__init__(self._no_upsampling) @staticmethod def _no_upsampling(x, size): return x