dblasko's picture
Upload 11 files
9b9b1dc
raw
history blame contribute delete
No virus
1.22 kB
import torch
import torch.nn as nn
class ChannelAttention(nn.Module):
"""
Squeezes down the input to 1x1xC, applies the excitation operation and restores the C channels through a 1x1 convolution.
In: HxWxC
Out: HxWxC (original channels are restored by multiplying the output with the original input)
"""
def __init__(self, in_channels, reduction_ratio=8, bias=True):
super().__init__()
self.squeezing = nn.AdaptiveAvgPool2d(1)
self.excitation = nn.Sequential(
nn.Conv2d(
in_channels,
in_channels // reduction_ratio,
kernel_size=1,
padding=0,
bias=bias,
),
nn.PReLU(),
nn.Conv2d(
in_channels // reduction_ratio,
in_channels,
kernel_size=1,
padding=0,
bias=bias,
),
nn.Sigmoid(),
)
def forward(self, x):
squeezed_x = self.squeezing(x) # 1x1xC
excitation = self.excitation(squeezed_x) # 1x1x(C/r)
return (
excitation * x
) # HxWxC restored through the mult. with the original input