mirnet-low-light-img-enhancement / model /MIRNet /MultiScaleResidualBlock.py
dblasko's picture
Upload 11 files
9b9b1dc
raw
history blame contribute delete
No virus
3.9 kB
import torch
import numpy as np
import torch.nn as nn
from model.MIRNet.Downsampling import DownsamplingModule
from model.MIRNet.DualAttentionUnit import DualAttentionUnit
from model.MIRNet.SelectiveKernelFeatureFusion import SelectiveKernelFeatureFusion
from model.MIRNet.Upsampling import UpsamplingModule
class MultiScaleResidualBlock(nn.Module):
"""
Three parallel convolutional streams at different resolutions. Information is exchanged through residual connexions.
"""
def __init__(self, num_features, height, width, stride, bias):
super().__init__()
self.num_features = num_features
self.height = height
self.width = width
features = [int((stride**i) * num_features) for i in range(height)]
scale = [2**i for i in range(1, height)]
self.dual_attention_units = nn.ModuleList(
[
nn.ModuleList(
[DualAttentionUnit(int(num_features * stride**i))] * width
)
for i in range(height)
]
)
self.last_up = nn.ModuleDict()
for i in range(1, height):
self.last_up.update(
{
f"{i}": UpsamplingModule(
in_channels=int(num_features * stride**i),
scaling_factor=2**i,
stride=stride,
)
}
)
self.down = nn.ModuleDict()
i = 0
scale.reverse()
for f in features:
for s in scale[i:]:
self.down.update({f"{f}_{s}": DownsamplingModule(f, s, stride)})
i += 1
self.up = nn.ModuleDict()
i = 0
features.reverse()
for f in features:
for s in scale[i:]:
self.up.update({f"{f}_{s}": UpsamplingModule(f, s, stride)})
i += 1
self.out_conv = nn.Conv2d(
num_features, num_features, kernel_size=3, padding=1, bias=bias
)
self.skff_blocks = nn.ModuleList(
[
SelectiveKernelFeatureFusion(num_features * stride**i, height)
for i in range(height)
]
)
def forward(self, x):
inp = x.clone()
out = []
for j in range(self.height):
if j == 0:
inp = self.dual_attention_units[j][0](inp)
else:
inp = self.dual_attention_units[j][0](
self.down[f"{inp.size(1)}_{2}"](inp)
)
out.append(inp)
for i in range(1, self.width):
if True:
temp = []
for j in range(self.height):
TENSOR = []
nfeats = (2**j) * self.num_features
for k in range(self.height):
TENSOR.append(self.select_up_down(out[k], j, k))
skff = self.skff_blocks[j](TENSOR)
temp.append(skff)
else:
temp = out
for j in range(self.height):
out[j] = self.dual_attention_units[j][i](temp[j])
output = []
for k in range(self.height):
output.append(self.select_last_up(out[k], k))
output = self.skff_blocks[0](output)
output = self.out_conv(output)
output = output + x
return output
def select_up_down(self, tensor, j, k):
if j == k:
return tensor
else:
diff = 2 ** np.abs(j - k)
if j < k:
return self.up[f"{tensor.size(1)}_{diff}"](tensor)
else:
return self.down[f"{tensor.size(1)}_{diff}"](tensor)
def select_last_up(self, tensor, k):
if k == 0:
return tensor
else:
return self.last_up[f"{k}"](tensor)