File size: 794 Bytes
9b9b1dc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 |
import torch
import torch.nn as nn
from model.MIRNet.ChannelCompression import ChannelCompression
class SpatialAttention(nn.Module):
"""
Reduces the input to 2 channel with the ChannelCompression module and applies a 2D convolution with 1 output channel.
In: HxWxC
Out: HxWxC (original channels are restored by multiplying the output with the original input)
"""
def __init__(self):
super().__init__()
self.channel_compression = ChannelCompression()
self.conv = nn.Conv2d(2, 1, kernel_size=5, stride=1, padding=2)
def forward(self, x):
x_compressed = self.channel_compression(x) # HxWx2
x_conv = self.conv(x_compressed) # HxWx1
scaling_factor = torch.sigmoid(x_conv)
return x * scaling_factor # HxWxC
|