dblasko commited on
Commit
9b9b1dc
1 Parent(s): 559af64

Upload 11 files

Browse files
model/MIRNet/ChannelAttention.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+
5
+ class ChannelAttention(nn.Module):
6
+ """
7
+ Squeezes down the input to 1x1xC, applies the excitation operation and restores the C channels through a 1x1 convolution.
8
+
9
+ In: HxWxC
10
+ Out: HxWxC (original channels are restored by multiplying the output with the original input)
11
+ """
12
+
13
+ def __init__(self, in_channels, reduction_ratio=8, bias=True):
14
+ super().__init__()
15
+ self.squeezing = nn.AdaptiveAvgPool2d(1)
16
+ self.excitation = nn.Sequential(
17
+ nn.Conv2d(
18
+ in_channels,
19
+ in_channels // reduction_ratio,
20
+ kernel_size=1,
21
+ padding=0,
22
+ bias=bias,
23
+ ),
24
+ nn.PReLU(),
25
+ nn.Conv2d(
26
+ in_channels // reduction_ratio,
27
+ in_channels,
28
+ kernel_size=1,
29
+ padding=0,
30
+ bias=bias,
31
+ ),
32
+ nn.Sigmoid(),
33
+ )
34
+
35
+ def forward(self, x):
36
+ squeezed_x = self.squeezing(x) # 1x1xC
37
+ excitation = self.excitation(squeezed_x) # 1x1x(C/r)
38
+ return (
39
+ excitation * x
40
+ ) # HxWxC restored through the mult. with the original input
model/MIRNet/ChannelCompression.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+
5
+ class ChannelCompression(nn.Module):
6
+ """
7
+ Reduces the input to 2 channels by concatenating the global average pooling and global max pooling outputs.
8
+
9
+ In: HxWxC
10
+ Out: HxWx2
11
+ """
12
+
13
+ def forward(self, x):
14
+ return torch.cat(
15
+ (torch.max(x, 1)[0].unsqueeze(1), torch.mean(x, 1).unsqueeze(1)), dim=1
16
+ )
model/MIRNet/Downsampling.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as fun
4
+ import numpy as np
5
+
6
+
7
+ class DownsamplingBlock(nn.Module):
8
+ """
9
+ Downsamples the input to halve the dimensions while doubling the channels through two parallel conv + antialiased downsampling branches.
10
+
11
+ In: HxWxC
12
+ Out: H/2xW/2x2C
13
+ """
14
+
15
+ def __init__(self, in_channels, bias=False):
16
+ super().__init__()
17
+ self.branch1 = (
18
+ nn.Sequential( # 1x1 conv + PReLU -> 3x3 conv + PReLU -> AD -> 1x1 conv
19
+ nn.Conv2d(
20
+ in_channels, in_channels, kernel_size=1, padding=0, bias=bias
21
+ ),
22
+ nn.PReLU(),
23
+ nn.Conv2d(
24
+ in_channels, in_channels, kernel_size=3, padding=1, bias=bias
25
+ ),
26
+ nn.PReLU(),
27
+ DownSample(channels=in_channels, filter_size=3, stride=2),
28
+ nn.Conv2d(
29
+ in_channels, in_channels * 2, kernel_size=1, padding=0, bias=bias
30
+ ),
31
+ )
32
+ )
33
+ self.branch2 = nn.Sequential(
34
+ DownSample(channels=in_channels, filter_size=3, stride=2),
35
+ nn.Conv2d(
36
+ in_channels, in_channels * 2, kernel_size=1, padding=0, bias=bias
37
+ ),
38
+ )
39
+
40
+ def forward(self, x):
41
+ return self.branch1(x) + self.branch2(x) # H/2xW/2x2C
42
+
43
+
44
+ class DownsamplingModule(nn.Module):
45
+ """
46
+ Downsampling module of the network composed of (scaling factor) DownsamplingBlocks.
47
+
48
+ In: HxWxC
49
+ Out: H/2^(scaling factor) x W/2^(scaling factor) x C^2(scaling factor)
50
+ """
51
+
52
+ def __init__(self, in_channels, scaling_factor, stride=2):
53
+ super().__init__()
54
+ self.scaling_factor = int(np.log2(scaling_factor))
55
+
56
+ blocks = []
57
+ for i in range(self.scaling_factor):
58
+ blocks.append(DownsamplingBlock(in_channels))
59
+ in_channels = int(in_channels * stride)
60
+ self.blocks = nn.Sequential(*blocks)
61
+
62
+ def forward(self, x):
63
+ x = self.blocks(x)
64
+ return x # H/2^(scaling factor) x W/2^(scaling factor) x C^2(scaling factor)
65
+
66
+
67
+ class DownSample(nn.Module):
68
+ """
69
+ Antialiased downsampling module using the blur-pooling method.
70
+
71
+ From Adobe's implementation available here: https://github.com/yilundu/improved_contrastive_divergence/blob/master/downsample.py
72
+ """
73
+
74
+ def __init__(
75
+ self, pad_type="reflect", filter_size=3, stride=2, channels=None, pad_off=0
76
+ ):
77
+ super().__init__()
78
+ self.filter_size = filter_size
79
+ self.stride = stride
80
+ self.pad_off = pad_off
81
+ self.channels = channels
82
+ self.pad_sizes = [
83
+ int(1.0 * (filter_size - 1) / 2),
84
+ int(np.ceil(1.0 * (filter_size - 1) / 2)),
85
+ int(1.0 * (filter_size - 1) / 2),
86
+ int(np.ceil(1.0 * (filter_size - 1) / 2)),
87
+ ]
88
+
89
+ self.pad_sizes = [pad_size + pad_off for pad_size in self.pad_sizes]
90
+ self.off = int((self.stride - 1) / 2.0)
91
+
92
+ if self.filter_size == 1:
93
+ a = np.array([1.0])
94
+ elif self.filter_size == 2:
95
+ a = np.array([1.0, 1.0])
96
+ elif self.filter_size == 3:
97
+ a = np.array([1.0, 2.0, 1.0])
98
+ elif self.filter_size == 4:
99
+ a = np.array([1.0, 3.0, 3.0, 1.0])
100
+ elif self.filter_size == 5:
101
+ a = np.array([1.0, 4.0, 6.0, 4.0, 1.0])
102
+ elif self.filter_size == 6:
103
+ a = np.array([1.0, 5.0, 10.0, 10.0, 5.0, 1.0])
104
+ elif self.filter_size == 7:
105
+ a = np.array([1.0, 6.0, 15.0, 20.0, 15.0, 6.0, 1.0])
106
+
107
+ filt = torch.Tensor(a[:, None] * a[None, :])
108
+ filt = filt / torch.sum(filt)
109
+ self.register_buffer(
110
+ "filt", filt[None, None, :, :].repeat((self.channels, 1, 1, 1))
111
+ )
112
+ self.pad = get_pad_layer(pad_type)(self.pad_sizes)
113
+
114
+ def forward(self, x):
115
+ if self.filter_size == 1:
116
+ if self.pad_off == 0:
117
+ return x[:, :, :: self.stride, :: self.stride]
118
+ else:
119
+ return self.pad(x)[:, :, :: self.stride, :: self.stride]
120
+
121
+ else:
122
+ return fun.conv2d(
123
+ self.pad(x), self.filt, stride=self.stride, groups=x.shape[1]
124
+ )
125
+
126
+
127
+ def get_pad_layer(pad_type):
128
+ if pad_type == "reflect":
129
+ pad_layer = nn.ReflectionPad2d
130
+ elif pad_type == "replication":
131
+ pad_layer = nn.ReplicationPad2d
132
+ else:
133
+ print("Pad Type [%s] not recognized" % pad_type)
134
+
135
+ return pad_layer
model/MIRNet/DualAttentionUnit.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from model.MIRNet.ChannelAttention import ChannelAttention
4
+
5
+ from model.MIRNet.SpatialAttention import SpatialAttention
6
+
7
+
8
+ class DualAttentionUnit(nn.Module):
9
+ """
10
+ Combines the ChannelAttention and SpatialAttention modules.
11
+ (conv, PReLU, conv -> concat. SA & CA output -> conv -> skip connection from input)
12
+
13
+ In: HxWxC
14
+ Out: HxWxC (original channels are restored by multiplying the output with the original input)
15
+ """
16
+
17
+ def __init__(self, in_channels, kernel_size=3, reduction_ratio=8, bias=False):
18
+ super().__init__()
19
+ self.initial_convs = nn.Sequential(
20
+ nn.Conv2d(in_channels, in_channels, kernel_size, padding=1, bias=bias),
21
+ nn.PReLU(),
22
+ nn.Conv2d(in_channels, in_channels, kernel_size, padding=1, bias=bias),
23
+ )
24
+ self.channel_attention = ChannelAttention(in_channels, reduction_ratio, bias)
25
+ self.spatial_attention = SpatialAttention()
26
+ self.final_conv = nn.Conv2d(
27
+ in_channels * 2, in_channels, kernel_size=1, bias=bias
28
+ )
29
+ self.in_channels = in_channels
30
+
31
+ def forward(self, x):
32
+ initial_convs = self.initial_convs(x) # HxWxC
33
+ channel_attention = self.channel_attention(initial_convs) # HxWxC
34
+ spatial_attention = self.spatial_attention(initial_convs) # HxWxC
35
+ attention = torch.cat((spatial_attention, channel_attention), dim=1) # HxWx2C
36
+ block_output = self.final_conv(
37
+ attention
38
+ ) # HxWxC - the 1x1 conv. restores the C channels for the skip connection
39
+ return x + block_output # the addition is the skip connection from input
model/MIRNet/MultiScaleResidualBlock.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import torch.nn as nn
4
+ from model.MIRNet.Downsampling import DownsamplingModule
5
+
6
+ from model.MIRNet.DualAttentionUnit import DualAttentionUnit
7
+ from model.MIRNet.SelectiveKernelFeatureFusion import SelectiveKernelFeatureFusion
8
+ from model.MIRNet.Upsampling import UpsamplingModule
9
+
10
+
11
+ class MultiScaleResidualBlock(nn.Module):
12
+ """
13
+ Three parallel convolutional streams at different resolutions. Information is exchanged through residual connexions.
14
+ """
15
+
16
+ def __init__(self, num_features, height, width, stride, bias):
17
+ super().__init__()
18
+ self.num_features = num_features
19
+ self.height = height
20
+ self.width = width
21
+ features = [int((stride**i) * num_features) for i in range(height)]
22
+ scale = [2**i for i in range(1, height)]
23
+
24
+ self.dual_attention_units = nn.ModuleList(
25
+ [
26
+ nn.ModuleList(
27
+ [DualAttentionUnit(int(num_features * stride**i))] * width
28
+ )
29
+ for i in range(height)
30
+ ]
31
+ )
32
+ self.last_up = nn.ModuleDict()
33
+ for i in range(1, height):
34
+ self.last_up.update(
35
+ {
36
+ f"{i}": UpsamplingModule(
37
+ in_channels=int(num_features * stride**i),
38
+ scaling_factor=2**i,
39
+ stride=stride,
40
+ )
41
+ }
42
+ )
43
+
44
+ self.down = nn.ModuleDict()
45
+ i = 0
46
+ scale.reverse()
47
+ for f in features:
48
+ for s in scale[i:]:
49
+ self.down.update({f"{f}_{s}": DownsamplingModule(f, s, stride)})
50
+ i += 1
51
+
52
+ self.up = nn.ModuleDict()
53
+ i = 0
54
+ features.reverse()
55
+ for f in features:
56
+ for s in scale[i:]:
57
+ self.up.update({f"{f}_{s}": UpsamplingModule(f, s, stride)})
58
+ i += 1
59
+
60
+ self.out_conv = nn.Conv2d(
61
+ num_features, num_features, kernel_size=3, padding=1, bias=bias
62
+ )
63
+ self.skff_blocks = nn.ModuleList(
64
+ [
65
+ SelectiveKernelFeatureFusion(num_features * stride**i, height)
66
+ for i in range(height)
67
+ ]
68
+ )
69
+
70
+ def forward(self, x):
71
+ inp = x.clone()
72
+ out = []
73
+
74
+ for j in range(self.height):
75
+ if j == 0:
76
+ inp = self.dual_attention_units[j][0](inp)
77
+ else:
78
+ inp = self.dual_attention_units[j][0](
79
+ self.down[f"{inp.size(1)}_{2}"](inp)
80
+ )
81
+ out.append(inp)
82
+
83
+ for i in range(1, self.width):
84
+ if True:
85
+ temp = []
86
+ for j in range(self.height):
87
+ TENSOR = []
88
+ nfeats = (2**j) * self.num_features
89
+ for k in range(self.height):
90
+ TENSOR.append(self.select_up_down(out[k], j, k))
91
+
92
+ skff = self.skff_blocks[j](TENSOR)
93
+ temp.append(skff)
94
+
95
+ else:
96
+ temp = out
97
+
98
+ for j in range(self.height):
99
+ out[j] = self.dual_attention_units[j][i](temp[j])
100
+
101
+ output = []
102
+ for k in range(self.height):
103
+ output.append(self.select_last_up(out[k], k))
104
+
105
+ output = self.skff_blocks[0](output)
106
+ output = self.out_conv(output)
107
+ output = output + x
108
+ return output
109
+
110
+ def select_up_down(self, tensor, j, k):
111
+ if j == k:
112
+ return tensor
113
+ else:
114
+ diff = 2 ** np.abs(j - k)
115
+ if j < k:
116
+ return self.up[f"{tensor.size(1)}_{diff}"](tensor)
117
+ else:
118
+ return self.down[f"{tensor.size(1)}_{diff}"](tensor)
119
+
120
+ def select_last_up(self, tensor, k):
121
+ if k == 0:
122
+ return tensor
123
+ else:
124
+ return self.last_up[f"{k}"](tensor)
model/MIRNet/ResidualRecurrentGroup.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from model.MIRNet.MultiScaleResidualBlock import MultiScaleResidualBlock
5
+
6
+
7
+ class ResidualRecurrentGroup(nn.Module):
8
+ """
9
+ Group of multi-scale residual blocks followed by a convolutional layer. The output is what is added to the input image for restoration.
10
+ """
11
+
12
+ def __init__(
13
+ self, num_features, number_msrb_blocks, height, width, stride, bias=False
14
+ ):
15
+ super().__init__()
16
+ blocks = [
17
+ MultiScaleResidualBlock(num_features, height, width, stride, bias)
18
+ for _ in range(number_msrb_blocks)
19
+ ]
20
+ blocks.append(
21
+ nn.Conv2d(
22
+ num_features,
23
+ num_features,
24
+ kernel_size=3,
25
+ padding=1,
26
+ stride=1,
27
+ bias=bias,
28
+ )
29
+ )
30
+ self.blocks = nn.Sequential(*blocks)
31
+
32
+ def forward(self, x):
33
+ output = self.blocks(x)
34
+ return x + output # restored image, HxWxC
model/MIRNet/SelectiveKernelFeatureFusion.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+
5
+ class SelectiveKernelFeatureFusion(nn.Module):
6
+ """
7
+ Merges outputs of the three different resolutions through self-attention.
8
+
9
+ All three inputs are summed -> global average pooling -> downscaling -> the signal is passed through 3 different convs to have three descriptors,
10
+ softmax is applied to each descriptor to get 3 attention activations used to recalibrate the three input feature maps.
11
+ """
12
+
13
+ def __init__(self, in_channels, reduction_ratio, bias=False):
14
+ super().__init__()
15
+ self.avg_pool = nn.AdaptiveAvgPool2d(1)
16
+ conv_out_channels = max(int(in_channels / reduction_ratio), 4)
17
+ self.convolution = nn.Sequential(
18
+ nn.Conv2d(
19
+ in_channels, conv_out_channels, kernel_size=1, padding=0, bias=bias
20
+ ),
21
+ nn.PReLU(),
22
+ )
23
+
24
+ self.attention_convs = nn.ModuleList([])
25
+ for i in range(3):
26
+ self.attention_convs.append(
27
+ nn.Conv2d(
28
+ conv_out_channels, in_channels, kernel_size=1, stride=1, bias=bias
29
+ )
30
+ )
31
+
32
+ self.softmax = nn.Softmax(dim=1)
33
+
34
+ def forward(self, x):
35
+ batch_size = x[0].shape[0]
36
+ n_features = x[0].shape[1]
37
+
38
+ x = torch.cat(
39
+ x, dim=1
40
+ ) # the three outputs of diff. res. are concatenated along the channel dimension
41
+ x = x.view(
42
+ batch_size, 3, n_features, x.shape[2], x.shape[3]
43
+ ) # batch_size x 3 x n_features x H x W
44
+
45
+ z = torch.sum(x, dim=1) # batch_size x n_features x H x W
46
+ z = self.avg_pool(z) # batch_size x n_features x 1 x 1
47
+ z = self.convolution(z) # batch_size x n_features/8 x 1 x 1
48
+
49
+ attention_activations = [
50
+ atn(z) for atn in self.attention_convs
51
+ ] # 3 x (batch_size x n_features x 1 x 1)
52
+ attention_activations = torch.cat(
53
+ attention_activations, dim=1
54
+ ) # batch_size x 3*n_features x 1 x 1
55
+ attention_activations = attention_activations.view(
56
+ batch_size, 3, n_features, 1, 1
57
+ ) # batch_size x 3 x n_features x 1 x 1
58
+
59
+ attention_activations = self.softmax(
60
+ attention_activations
61
+ ) # batch_size x 3 x n_features x 1 x 1
62
+
63
+ return torch.sum(
64
+ x * attention_activations, dim=1
65
+ ) # batch_size x n_features x H x W (the three feature maps are recalibrated and summed
model/MIRNet/SpatialAttention.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from model.MIRNet.ChannelCompression import ChannelCompression
5
+
6
+
7
+ class SpatialAttention(nn.Module):
8
+ """
9
+ Reduces the input to 2 channel with the ChannelCompression module and applies a 2D convolution with 1 output channel.
10
+
11
+ In: HxWxC
12
+ Out: HxWxC (original channels are restored by multiplying the output with the original input)
13
+ """
14
+
15
+ def __init__(self):
16
+ super().__init__()
17
+ self.channel_compression = ChannelCompression()
18
+ self.conv = nn.Conv2d(2, 1, kernel_size=5, stride=1, padding=2)
19
+
20
+ def forward(self, x):
21
+ x_compressed = self.channel_compression(x) # HxWx2
22
+ x_conv = self.conv(x_compressed) # HxWx1
23
+ scaling_factor = torch.sigmoid(x_conv)
24
+ return x * scaling_factor # HxWxC
model/MIRNet/Upsampling.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import numpy as np
4
+
5
+
6
+ class UpsamplingBlock(nn.Module):
7
+ """
8
+ Upsamples the input to double the dimensions while halving the channels through two parallel conv + bilinear upsampling branches.
9
+
10
+ In: HxWxC
11
+ Out: 2Hx2WxC/2
12
+ """
13
+
14
+ def __init__(self, in_channels, bias=False):
15
+ super().__init__()
16
+ self.branch1 = nn.Sequential( # 1x1 conv + PReLU -> 3x3 conv + PReLU -> BU -> 1x1 conv
17
+ nn.Conv2d(in_channels, in_channels, kernel_size=1, padding=0, bias=bias),
18
+ nn.PReLU(),
19
+ nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1, bias=bias),
20
+ nn.PReLU(),
21
+ nn.Upsample(scale_factor=2, mode='bilinear', align_corners=bias),
22
+ nn.Conv2d(in_channels, in_channels // 2, kernel_size=1, padding=0, bias=bias)
23
+ )
24
+ self.branch2 = nn.Sequential(
25
+ nn.Upsample(scale_factor=2, mode='bilinear', align_corners=bias),
26
+ nn.Conv2d(in_channels, in_channels // 2, kernel_size=1, padding=0, bias=bias)
27
+ )
28
+
29
+ def forward(self, x):
30
+ return self.branch1(x) + self.branch2(x) # 2Hx2WxC/2
31
+
32
+
33
+
34
+ class UpsamplingModule(nn.Module):
35
+ """
36
+ Upsampling module of the network composed of (scaling factor) UpsamplingBlocks.
37
+
38
+ In: HxWxC
39
+ Out: 2^(scaling factor)H x 2^(scaling factor)W x C/2^(scaling factor)
40
+ """
41
+
42
+ def __init__(self, in_channels, scaling_factor, stride=2):
43
+ super().__init__()
44
+ self.scaling_factor = int(np.log2(scaling_factor))
45
+
46
+ blocks = []
47
+ for i in range(self.scaling_factor):
48
+ blocks.append(UpsamplingBlock(in_channels))
49
+ in_channels = int(in_channels // 2)
50
+ self.blocks = nn.Sequential(*blocks)
51
+
52
+
53
+ def forward(self, x):
54
+ return self.blocks(x) # 2^(scaling factor)H x 2^(scaling factor)W x C/2^(scaling factor)
55
+
56
+
model/MIRNet/__init__.py ADDED
File without changes
model/MIRNet/model.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from model.MIRNet.ResidualRecurrentGroup import ResidualRecurrentGroup
5
+
6
+
7
+ class MIRNet(nn.Module):
8
+ """
9
+ Low-level features are extracted through convolution and passed to n residual recurrent groups that operate at different resolutions.
10
+ Their output is added to the input image for restoration.
11
+
12
+ Please refer to the documentation of the different blocks of the model in this folder for detailed explanations.
13
+ """
14
+
15
+ def __init__(
16
+ self,
17
+ in_channels=3,
18
+ out_channels=3,
19
+ num_features=64,
20
+ kernel_size=3,
21
+ stride=2,
22
+ number_msrb=2,
23
+ number_rrg=3,
24
+ height=3,
25
+ width=2,
26
+ bias=False,
27
+ ):
28
+ super().__init__()
29
+ self.conv_start = nn.Conv2d(
30
+ in_channels, num_features, kernel_size, padding=1, bias=bias
31
+ )
32
+ msrb_blocks = [
33
+ ResidualRecurrentGroup(
34
+ num_features, number_msrb, height, width, stride, bias
35
+ )
36
+ for _ in range(number_rrg)
37
+ ]
38
+ self.msrb_blocks = nn.Sequential(*msrb_blocks)
39
+ self.conv_end = nn.Conv2d(
40
+ num_features, out_channels, kernel_size, padding=1, bias=bias
41
+ )
42
+
43
+ def forward(self, x):
44
+ output = self.conv_start(x)
45
+ output = self.msrb_blocks(output)
46
+ output = self.conv_end(output)
47
+ return x + output # restored image, HxWxC