52Hz commited on
Commit
0ad3230
1 Parent(s): 30fb6c4

Create CMFNet.py

Browse files
Files changed (1) hide show
  1. model/CMFNet.py +191 -0
model/CMFNet.py ADDED
@@ -0,0 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from model.block import SAB, CAB, PAB, conv, SAM, conv3x3, conv_down
4
+
5
+ ##########################################################################
6
+ ## U-Net
7
+ bn = 2 # block number-1
8
+
9
+ class Encoder(nn.Module):
10
+ def __init__(self, n_feat, kernel_size, reduction, act, bias, scale_unetfeats, block):
11
+ super(Encoder, self).__init__()
12
+ if block == 'CAB':
13
+ self.encoder_level1 = [CAB(n_feat, kernel_size, reduction, bias=bias, act=act) for _ in range(bn)]
14
+ self.encoder_level2 = [CAB(n_feat + scale_unetfeats, kernel_size, reduction, bias=bias, act=act) for _ in range(bn)]
15
+ self.encoder_level3 = [CAB(n_feat + (scale_unetfeats * 2), kernel_size, reduction, bias=bias, act=act) for _ in range(bn)]
16
+ elif block == 'PAB':
17
+ self.encoder_level1 = [PAB(n_feat, kernel_size, reduction, bias=bias, act=act) for _ in range(bn)]
18
+ self.encoder_level2 = [PAB(n_feat + scale_unetfeats, kernel_size, reduction, bias=bias, act=act) for _ in range(bn)]
19
+ self.encoder_level3 = [PAB(n_feat + (scale_unetfeats * 2), kernel_size, reduction, bias=bias, act=act) for _ in range(bn)]
20
+ elif block == 'SAB':
21
+ self.encoder_level1 = [SAB(n_feat, kernel_size, reduction, bias=bias, act=act) for _ in range(bn)]
22
+ self.encoder_level2 = [SAB(n_feat + scale_unetfeats, kernel_size, reduction, bias=bias, act=act) for _ in range(bn)]
23
+ self.encoder_level3 = [SAB(n_feat + (scale_unetfeats * 2), kernel_size, reduction, bias=bias, act=act) for _ in range(bn)]
24
+ self.encoder_level1 = nn.Sequential(*self.encoder_level1)
25
+ self.encoder_level2 = nn.Sequential(*self.encoder_level2)
26
+ self.encoder_level3 = nn.Sequential(*self.encoder_level3)
27
+ self.down12 = DownSample(n_feat, scale_unetfeats)
28
+ self.down23 = DownSample(n_feat + scale_unetfeats, scale_unetfeats)
29
+
30
+ def forward(self, x):
31
+ enc1 = self.encoder_level1(x)
32
+ x = self.down12(enc1)
33
+ enc2 = self.encoder_level2(x)
34
+ x = self.down23(enc2)
35
+ enc3 = self.encoder_level3(x)
36
+ return [enc1, enc2, enc3]
37
+
38
+ class Decoder(nn.Module):
39
+ def __init__(self, n_feat, kernel_size, reduction, act, bias, scale_unetfeats, block):
40
+ super(Decoder, self).__init__()
41
+ if block == 'CAB':
42
+ self.decoder_level1 = [CAB(n_feat, kernel_size, reduction, bias=bias, act=act) for _ in range(bn)]
43
+ self.decoder_level2 = [CAB(n_feat + scale_unetfeats, kernel_size, reduction, bias=bias, act=act) for _ in range(bn)]
44
+ self.decoder_level3 = [CAB(n_feat + (scale_unetfeats * 2), kernel_size, reduction, bias=bias, act=act) for _ in range(bn)]
45
+ elif block == 'PAB':
46
+ self.decoder_level1 = [PAB(n_feat, kernel_size, reduction, bias=bias, act=act) for _ in range(bn)]
47
+ self.decoder_level2 = [PAB(n_feat + scale_unetfeats, kernel_size, reduction, bias=bias, act=act) for _ in range(bn)]
48
+ self.decoder_level3 = [PAB(n_feat + (scale_unetfeats * 2), kernel_size, reduction, bias=bias, act=act) for _ in range(bn)]
49
+ elif block == 'SAB':
50
+ self.decoder_level1 = [SAB(n_feat, kernel_size, reduction, bias=bias, act=act) for _ in range(bn)]
51
+ self.decoder_level2 = [SAB(n_feat + scale_unetfeats, kernel_size, reduction, bias=bias, act=act) for _ in range(bn)]
52
+ self.decoder_level3 = [SAB(n_feat + (scale_unetfeats * 2), kernel_size, reduction, bias=bias, act=act) for _ in range(bn)]
53
+ self.decoder_level1 = nn.Sequential(*self.decoder_level1)
54
+ self.decoder_level2 = nn.Sequential(*self.decoder_level2)
55
+ self.decoder_level3 = nn.Sequential(*self.decoder_level3)
56
+ if block == 'CAB':
57
+ self.skip_attn1 = CAB(n_feat, kernel_size, reduction, bias=bias, act=act)
58
+ self.skip_attn2 = CAB(n_feat + scale_unetfeats, kernel_size, reduction, bias=bias, act=act)
59
+ if block == 'PAB':
60
+ self.skip_attn1 = PAB(n_feat, kernel_size, reduction, bias=bias, act=act)
61
+ self.skip_attn2 = PAB(n_feat + scale_unetfeats, kernel_size, reduction, bias=bias, act=act)
62
+ if block == 'SAB':
63
+ self.skip_attn1 = SAB(n_feat, kernel_size, reduction, bias=bias, act=act)
64
+ self.skip_attn2 = SAB(n_feat + scale_unetfeats, kernel_size, reduction, bias=bias, act=act)
65
+ self.up21 = SkipUpSample(n_feat, scale_unetfeats)
66
+ self.up32 = SkipUpSample(n_feat + scale_unetfeats, scale_unetfeats)
67
+
68
+ def forward(self, outs):
69
+ enc1, enc2, enc3 = outs
70
+ dec3 = self.decoder_level3(enc3)
71
+ x = self.up32(dec3, self.skip_attn2(enc2))
72
+ dec2 = self.decoder_level2(x)
73
+ x = self.up21(dec2, self.skip_attn1(enc1))
74
+ dec1 = self.decoder_level1(x)
75
+ return [dec1, dec2, dec3]
76
+
77
+ ##########################################################################
78
+ ##---------- Resizing Modules ----------
79
+ class DownSample(nn.Module):
80
+ def __init__(self, in_channels, s_factor):
81
+ super(DownSample, self).__init__()
82
+ self.down = nn.Sequential(nn.Upsample(scale_factor=0.5, mode='bilinear', align_corners=False),
83
+ nn.Conv2d(in_channels, in_channels + s_factor, 1, stride=1, padding=0, bias=False))
84
+
85
+ def forward(self, x):
86
+ x = self.down(x)
87
+ return x
88
+
89
+ class UpSample(nn.Module):
90
+ def __init__(self, in_channels, s_factor):
91
+ super(UpSample, self).__init__()
92
+ self.up = nn.Sequential(nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
93
+ nn.Conv2d(in_channels + s_factor, in_channels, 1, stride=1, padding=0, bias=False))
94
+
95
+ def forward(self, x):
96
+ x = self.up(x)
97
+ return x
98
+
99
+ class SkipUpSample(nn.Module):
100
+ def __init__(self, in_channels, s_factor):
101
+ super(SkipUpSample, self).__init__()
102
+ self.up = nn.Sequential(nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
103
+ nn.Conv2d(in_channels + s_factor, in_channels, 1, stride=1, padding=0, bias=False))
104
+
105
+ def forward(self, x, y):
106
+ x = self.up(x)
107
+ x = x + y
108
+ return x
109
+
110
+ ##########################################################################
111
+ # Mixed Residual Module
112
+ class Mix(nn.Module):
113
+ def __init__(self, m=1):
114
+ super(Mix, self).__init__()
115
+ w = nn.Parameter(torch.FloatTensor([m]), requires_grad=True)
116
+ w = nn.Parameter(w, requires_grad=True)
117
+ self.w = w
118
+ self.mix_block = nn.Sigmoid()
119
+
120
+ def forward(self, fea1, fea2, feat3):
121
+ factor = self.mix_block(self.w)
122
+ other = (1 - factor)/2
123
+ output = fea1 * other.expand_as(fea1) + fea2 * factor.expand_as(fea2) + feat3 * other.expand_as(feat3)
124
+ return output, factor
125
+
126
+ ##########################################################################
127
+ # Architecture
128
+ class CMFNet(nn.Module):
129
+ def __init__(self, in_c=3, out_c=3, n_feat=96, scale_unetfeats=48, kernel_size=3, reduction=4, bias=False):
130
+ super(CMFNet, self).__init__()
131
+
132
+ p_act = nn.PReLU()
133
+ self.shallow_feat1 = nn.Sequential(conv(in_c, n_feat // 2, kernel_size, bias=bias), p_act,
134
+ conv(n_feat // 2, n_feat, kernel_size, bias=bias))
135
+ self.shallow_feat2 = nn.Sequential(conv(in_c, n_feat // 2, kernel_size, bias=bias), p_act,
136
+ conv(n_feat // 2, n_feat, kernel_size, bias=bias))
137
+ self.shallow_feat3 = nn.Sequential(conv(in_c, n_feat // 2, kernel_size, bias=bias), p_act,
138
+ conv(n_feat // 2, n_feat, kernel_size, bias=bias))
139
+
140
+ self.stage1_encoder = Encoder(n_feat, kernel_size, reduction, p_act, bias, scale_unetfeats, 'CAB')
141
+ self.stage1_decoder = Decoder(n_feat, kernel_size, reduction, p_act, bias, scale_unetfeats, 'CAB')
142
+
143
+ self.stage2_encoder = Encoder(n_feat, kernel_size, reduction, p_act, bias, scale_unetfeats, 'PAB')
144
+ self.stage2_decoder = Decoder(n_feat, kernel_size, reduction, p_act, bias, scale_unetfeats, 'PAB')
145
+
146
+ self.stage3_encoder = Encoder(n_feat, kernel_size, reduction, p_act, bias, scale_unetfeats, 'SAB')
147
+ self.stage3_decoder = Decoder(n_feat, kernel_size, reduction, p_act, bias, scale_unetfeats, 'SAB')
148
+
149
+ self.sam1o = SAM(n_feat, kernel_size=3, bias=bias)
150
+ self.sam2o = SAM(n_feat, kernel_size=3, bias=bias)
151
+ self.sam3o = SAM(n_feat, kernel_size=3, bias=bias)
152
+
153
+ self.mix = Mix(1)
154
+ self.add123 = conv(out_c, out_c, kernel_size, bias=bias)
155
+ self.concat123 = conv(n_feat*3, n_feat, kernel_size, bias=bias)
156
+ self.tail = conv(n_feat, out_c, kernel_size, bias=bias)
157
+
158
+
159
+ def forward(self, x):
160
+ ## Compute Shallow Features
161
+ shallow1 = self.shallow_feat1(x)
162
+ shallow2 = self.shallow_feat2(x)
163
+ shallow3 = self.shallow_feat3(x)
164
+
165
+ ## Enter the UNet-CAB
166
+ x1 = self.stage1_encoder(shallow1)
167
+ x1_D = self.stage1_decoder(x1)
168
+ ## Apply SAM
169
+ x1_out, x1_img = self.sam1o(x1_D[0], x)
170
+
171
+ ## Enter the UNet-PAB
172
+ x2 = self.stage2_encoder(shallow2)
173
+ x2_D = self.stage2_decoder(x2)
174
+ ## Apply SAM
175
+ x2_out, x2_img = self.sam2o(x2_D[0], x)
176
+
177
+ ## Enter the UNet-SAB
178
+ x3 = self.stage3_encoder(shallow3)
179
+ x3_D = self.stage3_decoder(x3)
180
+ ## Apply SAM
181
+ x3_out, x3_img = self.sam3o(x3_D[0], x)
182
+
183
+ ## Aggregate SAM features of Stage 1, Stage 2 and Stage 3
184
+ mix_r = self.mix(x1_img, x2_img, x3_img)
185
+ mixed_img = self.add123(mix_r[0])
186
+
187
+ ## Concat SAM features of Stage 1, Stage 2 and Stage 3
188
+ concat_feat = self.concat123(torch.cat([x1_out, x2_out, x3_out], 1))
189
+ x_final = self.tail(concat_feat)
190
+
191
+ return x_final + mixed_img