import torch import numpy as np import torch.nn as nn from model.MIRNet.Downsampling import DownsamplingModule from model.MIRNet.DualAttentionUnit import DualAttentionUnit from model.MIRNet.SelectiveKernelFeatureFusion import SelectiveKernelFeatureFusion from model.MIRNet.Upsampling import UpsamplingModule class MultiScaleResidualBlock(nn.Module): """ Three parallel convolutional streams at different resolutions. Information is exchanged through residual connexions. """ def __init__(self, num_features, height, width, stride, bias): super().__init__() self.num_features = num_features self.height = height self.width = width features = [int((stride**i) * num_features) for i in range(height)] scale = [2**i for i in range(1, height)] self.dual_attention_units = nn.ModuleList( [ nn.ModuleList( [DualAttentionUnit(int(num_features * stride**i))] * width ) for i in range(height) ] ) self.last_up = nn.ModuleDict() for i in range(1, height): self.last_up.update( { f"{i}": UpsamplingModule( in_channels=int(num_features * stride**i), scaling_factor=2**i, stride=stride, ) } ) self.down = nn.ModuleDict() i = 0 scale.reverse() for f in features: for s in scale[i:]: self.down.update({f"{f}_{s}": DownsamplingModule(f, s, stride)}) i += 1 self.up = nn.ModuleDict() i = 0 features.reverse() for f in features: for s in scale[i:]: self.up.update({f"{f}_{s}": UpsamplingModule(f, s, stride)}) i += 1 self.out_conv = nn.Conv2d( num_features, num_features, kernel_size=3, padding=1, bias=bias ) self.skff_blocks = nn.ModuleList( [ SelectiveKernelFeatureFusion(num_features * stride**i, height) for i in range(height) ] ) def forward(self, x): inp = x.clone() out = [] for j in range(self.height): if j == 0: inp = self.dual_attention_units[j][0](inp) else: inp = self.dual_attention_units[j][0]( self.down[f"{inp.size(1)}_{2}"](inp) ) out.append(inp) for i in range(1, self.width): if True: temp = [] for j in range(self.height): TENSOR = [] nfeats = (2**j) * self.num_features for k in range(self.height): TENSOR.append(self.select_up_down(out[k], j, k)) skff = self.skff_blocks[j](TENSOR) temp.append(skff) else: temp = out for j in range(self.height): out[j] = self.dual_attention_units[j][i](temp[j]) output = [] for k in range(self.height): output.append(self.select_last_up(out[k], k)) output = self.skff_blocks[0](output) output = self.out_conv(output) output = output + x return output def select_up_down(self, tensor, j, k): if j == k: return tensor else: diff = 2 ** np.abs(j - k) if j < k: return self.up[f"{tensor.size(1)}_{diff}"](tensor) else: return self.down[f"{tensor.size(1)}_{diff}"](tensor) def select_last_up(self, tensor, k): if k == 0: return tensor else: return self.last_up[f"{k}"](tensor)