|
import torch |
|
import torch.nn as nn |
|
|
|
from model.MIRNet.ChannelCompression import ChannelCompression |
|
|
|
|
|
class SpatialAttention(nn.Module): |
|
""" |
|
Reduces the input to 2 channel with the ChannelCompression module and applies a 2D convolution with 1 output channel. |
|
|
|
In: HxWxC |
|
Out: HxWxC (original channels are restored by multiplying the output with the original input) |
|
""" |
|
|
|
def __init__(self): |
|
super().__init__() |
|
self.channel_compression = ChannelCompression() |
|
self.conv = nn.Conv2d(2, 1, kernel_size=5, stride=1, padding=2) |
|
|
|
def forward(self, x): |
|
x_compressed = self.channel_compression(x) |
|
x_conv = self.conv(x_compressed) |
|
scaling_factor = torch.sigmoid(x_conv) |
|
return x * scaling_factor |
|
|