File size: 3,897 Bytes
9b9b1dc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import torch
import numpy as np
import torch.nn as nn
from model.MIRNet.Downsampling import DownsamplingModule

from model.MIRNet.DualAttentionUnit import DualAttentionUnit
from model.MIRNet.SelectiveKernelFeatureFusion import SelectiveKernelFeatureFusion
from model.MIRNet.Upsampling import UpsamplingModule


class MultiScaleResidualBlock(nn.Module):
    """
    Three parallel convolutional streams at different resolutions. Information is exchanged through residual connexions.
    """

    def __init__(self, num_features, height, width, stride, bias):
        super().__init__()
        self.num_features = num_features
        self.height = height
        self.width = width
        features = [int((stride**i) * num_features) for i in range(height)]
        scale = [2**i for i in range(1, height)]

        self.dual_attention_units = nn.ModuleList(
            [
                nn.ModuleList(
                    [DualAttentionUnit(int(num_features * stride**i))] * width
                )
                for i in range(height)
            ]
        )
        self.last_up = nn.ModuleDict()
        for i in range(1, height):
            self.last_up.update(
                {
                    f"{i}": UpsamplingModule(
                        in_channels=int(num_features * stride**i),
                        scaling_factor=2**i,
                        stride=stride,
                    )
                }
            )

        self.down = nn.ModuleDict()
        i = 0
        scale.reverse()
        for f in features:
            for s in scale[i:]:
                self.down.update({f"{f}_{s}": DownsamplingModule(f, s, stride)})
            i += 1

        self.up = nn.ModuleDict()
        i = 0
        features.reverse()
        for f in features:
            for s in scale[i:]:
                self.up.update({f"{f}_{s}": UpsamplingModule(f, s, stride)})
            i += 1

        self.out_conv = nn.Conv2d(
            num_features, num_features, kernel_size=3, padding=1, bias=bias
        )
        self.skff_blocks = nn.ModuleList(
            [
                SelectiveKernelFeatureFusion(num_features * stride**i, height)
                for i in range(height)
            ]
        )

    def forward(self, x):
        inp = x.clone()
        out = []

        for j in range(self.height):
            if j == 0:
                inp = self.dual_attention_units[j][0](inp)
            else:
                inp = self.dual_attention_units[j][0](
                    self.down[f"{inp.size(1)}_{2}"](inp)
                )
            out.append(inp)

        for i in range(1, self.width):
            if True:
                temp = []
                for j in range(self.height):
                    TENSOR = []
                    nfeats = (2**j) * self.num_features
                    for k in range(self.height):
                        TENSOR.append(self.select_up_down(out[k], j, k))

                    skff = self.skff_blocks[j](TENSOR)
                    temp.append(skff)

            else:
                temp = out

            for j in range(self.height):
                out[j] = self.dual_attention_units[j][i](temp[j])

        output = []
        for k in range(self.height):
            output.append(self.select_last_up(out[k], k))

        output = self.skff_blocks[0](output)
        output = self.out_conv(output)
        output = output + x
        return output

    def select_up_down(self, tensor, j, k):
        if j == k:
            return tensor
        else:
            diff = 2 ** np.abs(j - k)
            if j < k:
                return self.up[f"{tensor.size(1)}_{diff}"](tensor)
            else:
                return self.down[f"{tensor.size(1)}_{diff}"](tensor)

    def select_last_up(self, tensor, k):
        if k == 0:
            return tensor
        else:
            return self.last_up[f"{k}"](tensor)