import torch import torch.nn as nn class ChannelAttention(nn.Module): """ Squeezes down the input to 1x1xC, applies the excitation operation and restores the C channels through a 1x1 convolution. In: HxWxC Out: HxWxC (original channels are restored by multiplying the output with the original input) """ def __init__(self, in_channels, reduction_ratio=8, bias=True): super().__init__() self.squeezing = nn.AdaptiveAvgPool2d(1) self.excitation = nn.Sequential( nn.Conv2d( in_channels, in_channels // reduction_ratio, kernel_size=1, padding=0, bias=bias, ), nn.PReLU(), nn.Conv2d( in_channels // reduction_ratio, in_channels, kernel_size=1, padding=0, bias=bias, ), nn.Sigmoid(), ) def forward(self, x): squeezed_x = self.squeezing(x) # 1x1xC excitation = self.excitation(squeezed_x) # 1x1x(C/r) return ( excitation * x ) # HxWxC restored through the mult. with the original input