dblasko's picture
Upload 11 files
9b9b1dc
raw
history blame
1.67 kB
import torch
import torch.nn as nn
from model.MIRNet.ChannelAttention import ChannelAttention
from model.MIRNet.SpatialAttention import SpatialAttention
class DualAttentionUnit(nn.Module):
"""
Combines the ChannelAttention and SpatialAttention modules.
(conv, PReLU, conv -> concat. SA & CA output -> conv -> skip connection from input)
In: HxWxC
Out: HxWxC (original channels are restored by multiplying the output with the original input)
"""
def __init__(self, in_channels, kernel_size=3, reduction_ratio=8, bias=False):
super().__init__()
self.initial_convs = nn.Sequential(
nn.Conv2d(in_channels, in_channels, kernel_size, padding=1, bias=bias),
nn.PReLU(),
nn.Conv2d(in_channels, in_channels, kernel_size, padding=1, bias=bias),
)
self.channel_attention = ChannelAttention(in_channels, reduction_ratio, bias)
self.spatial_attention = SpatialAttention()
self.final_conv = nn.Conv2d(
in_channels * 2, in_channels, kernel_size=1, bias=bias
)
self.in_channels = in_channels
def forward(self, x):
initial_convs = self.initial_convs(x) # HxWxC
channel_attention = self.channel_attention(initial_convs) # HxWxC
spatial_attention = self.spatial_attention(initial_convs) # HxWxC
attention = torch.cat((spatial_attention, channel_attention), dim=1) # HxWx2C
block_output = self.final_conv(
attention
) # HxWxC - the 1x1 conv. restores the C channels for the skip connection
return x + block_output # the addition is the skip connection from input