""" 3D Squeeze and Excitation Modules ***************************** 3D Extensions of the following 2D squeeze and excitation blocks: 1. `Channel Squeeze and Excitation `_ 2. `Spatial Squeeze and Excitation `_ 3. `Channel and Spatial Squeeze and Excitation `_ New Project & Excite block, designed specifically for 3D inputs 'quote' Coded by -- Anne-Marie Rickmann (https://github.com/arickm) """ import torch from torch import nn as nn from torch.nn import functional as F class ChannelSELayer3D(nn.Module): """ 3D extension of Squeeze-and-Excitation (SE) block described in: *Hu et al., Squeeze-and-Excitation Networks, arXiv:1709.01507* *Zhu et al., AnatomyNet, arXiv:arXiv:1808.05238* """ def __init__(self, num_channels, reduction_ratio=2): """ Args: num_channels (int): No of input channels reduction_ratio (int): By how much should the num_channels should be reduced """ super(ChannelSELayer3D, self).__init__() self.avg_pool = nn.AdaptiveAvgPool3d(1) num_channels_reduced = num_channels // reduction_ratio self.reduction_ratio = reduction_ratio self.fc1 = nn.Linear(num_channels, num_channels_reduced, bias=True) self.fc2 = nn.Linear(num_channels_reduced, num_channels, bias=True) self.relu = nn.ReLU() self.sigmoid = nn.Sigmoid() def forward(self, x): batch_size, num_channels, D, H, W = x.size() # Average along each channel squeeze_tensor = self.avg_pool(x) # channel excitation fc_out_1 = self.relu(self.fc1(squeeze_tensor.view(batch_size, num_channels))) fc_out_2 = self.sigmoid(self.fc2(fc_out_1)) output_tensor = torch.mul(x, fc_out_2.view(batch_size, num_channels, 1, 1, 1)) return output_tensor class SpatialSELayer3D(nn.Module): """ 3D extension of SE block -- squeezing spatially and exciting channel-wise described in: *Roy et al., Concurrent Spatial and Channel Squeeze & Excitation in Fully Convolutional Networks, MICCAI 2018* """ def __init__(self, num_channels): """ Args: num_channels (int): No of input channels """ super(SpatialSELayer3D, self).__init__() self.conv = nn.Conv3d(num_channels, 1, 1) self.sigmoid = nn.Sigmoid() def forward(self, x, weights=None): """ Args: weights (torch.Tensor): weights for few shot learning x: X, shape = (batch_size, num_channels, D, H, W) Returns: (torch.Tensor): output_tensor """ # channel squeeze batch_size, channel, D, H, W = x.size() if weights: weights = weights.view(1, channel, 1, 1) out = F.conv2d(x, weights) else: out = self.conv(x) squeeze_tensor = self.sigmoid(out) # spatial excitation output_tensor = torch.mul(x, squeeze_tensor.view(batch_size, 1, D, H, W)) return output_tensor class ChannelSpatialSELayer3D(nn.Module): """ 3D extension of concurrent spatial and channel squeeze & excitation: *Roy et al., Concurrent Spatial and Channel Squeeze & Excitation in Fully Convolutional Networks, arXiv:1803.02579* """ def __init__(self, num_channels, reduction_ratio=2): """ Args: num_channels (int): No of input channels reduction_ratio (int): By how much should the num_channels should be reduced """ super(ChannelSpatialSELayer3D, self).__init__() self.cSE = ChannelSELayer3D(num_channels, reduction_ratio) self.sSE = SpatialSELayer3D(num_channels) def forward(self, input_tensor): output_tensor = torch.max(self.cSE(input_tensor), self.sSE(input_tensor)) return output_tensor