mirnet-low-light-img-enhancement / model /MIRNet /SelectiveKernelFeatureFusion.py
dblasko's picture
Upload 11 files
9b9b1dc
raw
history blame contribute delete
No virus
2.45 kB
import torch
import torch.nn as nn
class SelectiveKernelFeatureFusion(nn.Module):
"""
Merges outputs of the three different resolutions through self-attention.
All three inputs are summed -> global average pooling -> downscaling -> the signal is passed through 3 different convs to have three descriptors,
softmax is applied to each descriptor to get 3 attention activations used to recalibrate the three input feature maps.
"""
def __init__(self, in_channels, reduction_ratio, bias=False):
super().__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
conv_out_channels = max(int(in_channels / reduction_ratio), 4)
self.convolution = nn.Sequential(
nn.Conv2d(
in_channels, conv_out_channels, kernel_size=1, padding=0, bias=bias
),
nn.PReLU(),
)
self.attention_convs = nn.ModuleList([])
for i in range(3):
self.attention_convs.append(
nn.Conv2d(
conv_out_channels, in_channels, kernel_size=1, stride=1, bias=bias
)
)
self.softmax = nn.Softmax(dim=1)
def forward(self, x):
batch_size = x[0].shape[0]
n_features = x[0].shape[1]
x = torch.cat(
x, dim=1
) # the three outputs of diff. res. are concatenated along the channel dimension
x = x.view(
batch_size, 3, n_features, x.shape[2], x.shape[3]
) # batch_size x 3 x n_features x H x W
z = torch.sum(x, dim=1) # batch_size x n_features x H x W
z = self.avg_pool(z) # batch_size x n_features x 1 x 1
z = self.convolution(z) # batch_size x n_features/8 x 1 x 1
attention_activations = [
atn(z) for atn in self.attention_convs
] # 3 x (batch_size x n_features x 1 x 1)
attention_activations = torch.cat(
attention_activations, dim=1
) # batch_size x 3*n_features x 1 x 1
attention_activations = attention_activations.view(
batch_size, 3, n_features, 1, 1
) # batch_size x 3 x n_features x 1 x 1
attention_activations = self.softmax(
attention_activations
) # batch_size x 3 x n_features x 1 x 1
return torch.sum(
x * attention_activations, dim=1
) # batch_size x n_features x H x W (the three feature maps are recalibrated and summed