not-lain commited on
Commit
51185ce
1 Parent(s): bce4a86

Create isnet.py

Browse files
Files changed (1) hide show
  1. isnet.py +619 -0
isnet.py ADDED
@@ -0,0 +1,619 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Codes are borrowed from
2
+ # https://github.com/xuebinqin/DIS/blob/main/IS-Net/models/isnet.py
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ from torchvision import models
8
+
9
+ bce_loss = nn.BCEWithLogitsLoss(reduction="mean")
10
+
11
+
12
+ def muti_loss_fusion(preds, target):
13
+ loss0 = 0.0
14
+ loss = 0.0
15
+
16
+ for i in range(0, len(preds)):
17
+ if preds[i].shape[2] != target.shape[2] or preds[i].shape[3] != target.shape[3]:
18
+ tmp_target = F.interpolate(
19
+ target, size=preds[i].size()[2:], mode="bilinear", align_corners=True
20
+ )
21
+ loss = loss + bce_loss(preds[i], tmp_target)
22
+ else:
23
+ loss = loss + bce_loss(preds[i], target)
24
+ if i == 0:
25
+ loss0 = loss
26
+ return loss0, loss
27
+
28
+
29
+ fea_loss = nn.MSELoss(reduction="mean")
30
+ kl_loss = nn.KLDivLoss(reduction="mean")
31
+ l1_loss = nn.L1Loss(reduction="mean")
32
+ smooth_l1_loss = nn.SmoothL1Loss(reduction="mean")
33
+
34
+
35
+ def muti_loss_fusion_kl(preds, target, dfs, fs, mode="MSE"):
36
+ loss0 = 0.0
37
+ loss = 0.0
38
+
39
+ for i in range(0, len(preds)):
40
+ if preds[i].shape[2] != target.shape[2] or preds[i].shape[3] != target.shape[3]:
41
+ tmp_target = F.interpolate(
42
+ target, size=preds[i].size()[2:], mode="bilinear", align_corners=True
43
+ )
44
+ loss = loss + bce_loss(preds[i], tmp_target)
45
+ else:
46
+ loss = loss + bce_loss(preds[i], target)
47
+ if i == 0:
48
+ loss0 = loss
49
+
50
+ for i in range(0, len(dfs)):
51
+ df = dfs[i]
52
+ fs_i = fs[i]
53
+ if mode == "MSE":
54
+ loss = loss + fea_loss(
55
+ df, fs_i
56
+ ) ### add the mse loss of features as additional constraints
57
+ elif mode == "KL":
58
+ loss = loss + kl_loss(F.log_softmax(df, dim=1), F.softmax(fs_i, dim=1))
59
+ elif mode == "MAE":
60
+ loss = loss + l1_loss(df, fs_i)
61
+ elif mode == "SmoothL1":
62
+ loss = loss + smooth_l1_loss(df, fs_i)
63
+
64
+ return loss0, loss
65
+
66
+
67
+ class REBNCONV(nn.Module):
68
+ def __init__(self, in_ch=3, out_ch=3, dirate=1, stride=1):
69
+ super(REBNCONV, self).__init__()
70
+
71
+ self.conv_s1 = nn.Conv2d(
72
+ in_ch, out_ch, 3, padding=1 * dirate, dilation=1 * dirate, stride=stride
73
+ )
74
+ self.bn_s1 = nn.BatchNorm2d(out_ch)
75
+ self.relu_s1 = nn.ReLU(inplace=True)
76
+
77
+ def forward(self, x):
78
+ hx = x
79
+ xout = self.relu_s1(self.bn_s1(self.conv_s1(hx)))
80
+
81
+ return xout
82
+
83
+
84
+ ## upsample tensor 'src' to have the same spatial size with tensor 'tar'
85
+ def _upsample_like(src, tar):
86
+ src = F.interpolate(src, size=tar.shape[2:], mode="bilinear", align_corners=False)
87
+
88
+ return src
89
+
90
+
91
+ ### RSU-7 ###
92
+ class RSU7(nn.Module):
93
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3, img_size=512):
94
+ super(RSU7, self).__init__()
95
+
96
+ self.in_ch = in_ch
97
+ self.mid_ch = mid_ch
98
+ self.out_ch = out_ch
99
+
100
+ self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1) ## 1 -> 1/2
101
+
102
+ self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
103
+ self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
104
+
105
+ self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
106
+ self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
107
+
108
+ self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
109
+ self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
110
+
111
+ self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
112
+ self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
113
+
114
+ self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=1)
115
+ self.pool5 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
116
+
117
+ self.rebnconv6 = REBNCONV(mid_ch, mid_ch, dirate=1)
118
+
119
+ self.rebnconv7 = REBNCONV(mid_ch, mid_ch, dirate=2)
120
+
121
+ self.rebnconv6d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
122
+ self.rebnconv5d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
123
+ self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
124
+ self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
125
+ self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
126
+ self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
127
+
128
+ def forward(self, x):
129
+ b, c, h, w = x.shape
130
+
131
+ hx = x
132
+ hxin = self.rebnconvin(hx)
133
+
134
+ hx1 = self.rebnconv1(hxin)
135
+ hx = self.pool1(hx1)
136
+
137
+ hx2 = self.rebnconv2(hx)
138
+ hx = self.pool2(hx2)
139
+
140
+ hx3 = self.rebnconv3(hx)
141
+ hx = self.pool3(hx3)
142
+
143
+ hx4 = self.rebnconv4(hx)
144
+ hx = self.pool4(hx4)
145
+
146
+ hx5 = self.rebnconv5(hx)
147
+ hx = self.pool5(hx5)
148
+
149
+ hx6 = self.rebnconv6(hx)
150
+
151
+ hx7 = self.rebnconv7(hx6)
152
+
153
+ hx6d = self.rebnconv6d(torch.cat((hx7, hx6), 1))
154
+ hx6dup = _upsample_like(hx6d, hx5)
155
+
156
+ hx5d = self.rebnconv5d(torch.cat((hx6dup, hx5), 1))
157
+ hx5dup = _upsample_like(hx5d, hx4)
158
+
159
+ hx4d = self.rebnconv4d(torch.cat((hx5dup, hx4), 1))
160
+ hx4dup = _upsample_like(hx4d, hx3)
161
+
162
+ hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
163
+ hx3dup = _upsample_like(hx3d, hx2)
164
+
165
+ hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
166
+ hx2dup = _upsample_like(hx2d, hx1)
167
+
168
+ hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
169
+
170
+ return hx1d + hxin
171
+
172
+
173
+ ### RSU-6 ###
174
+ class RSU6(nn.Module):
175
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
176
+ super(RSU6, self).__init__()
177
+
178
+ self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
179
+
180
+ self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
181
+ self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
182
+
183
+ self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
184
+ self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
185
+
186
+ self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
187
+ self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
188
+
189
+ self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
190
+ self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
191
+
192
+ self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=1)
193
+
194
+ self.rebnconv6 = REBNCONV(mid_ch, mid_ch, dirate=2)
195
+
196
+ self.rebnconv5d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
197
+ self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
198
+ self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
199
+ self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
200
+ self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
201
+
202
+ def forward(self, x):
203
+ hx = x
204
+
205
+ hxin = self.rebnconvin(hx)
206
+
207
+ hx1 = self.rebnconv1(hxin)
208
+ hx = self.pool1(hx1)
209
+
210
+ hx2 = self.rebnconv2(hx)
211
+ hx = self.pool2(hx2)
212
+
213
+ hx3 = self.rebnconv3(hx)
214
+ hx = self.pool3(hx3)
215
+
216
+ hx4 = self.rebnconv4(hx)
217
+ hx = self.pool4(hx4)
218
+
219
+ hx5 = self.rebnconv5(hx)
220
+
221
+ hx6 = self.rebnconv6(hx5)
222
+
223
+ hx5d = self.rebnconv5d(torch.cat((hx6, hx5), 1))
224
+ hx5dup = _upsample_like(hx5d, hx4)
225
+
226
+ hx4d = self.rebnconv4d(torch.cat((hx5dup, hx4), 1))
227
+ hx4dup = _upsample_like(hx4d, hx3)
228
+
229
+ hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
230
+ hx3dup = _upsample_like(hx3d, hx2)
231
+
232
+ hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
233
+ hx2dup = _upsample_like(hx2d, hx1)
234
+
235
+ hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
236
+
237
+ return hx1d + hxin
238
+
239
+
240
+ ### RSU-5 ###
241
+ class RSU5(nn.Module):
242
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
243
+ super(RSU5, self).__init__()
244
+
245
+ self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
246
+
247
+ self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
248
+ self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
249
+
250
+ self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
251
+ self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
252
+
253
+ self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
254
+ self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
255
+
256
+ self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
257
+
258
+ self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=2)
259
+
260
+ self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
261
+ self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
262
+ self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
263
+ self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
264
+
265
+ def forward(self, x):
266
+ hx = x
267
+
268
+ hxin = self.rebnconvin(hx)
269
+
270
+ hx1 = self.rebnconv1(hxin)
271
+ hx = self.pool1(hx1)
272
+
273
+ hx2 = self.rebnconv2(hx)
274
+ hx = self.pool2(hx2)
275
+
276
+ hx3 = self.rebnconv3(hx)
277
+ hx = self.pool3(hx3)
278
+
279
+ hx4 = self.rebnconv4(hx)
280
+
281
+ hx5 = self.rebnconv5(hx4)
282
+
283
+ hx4d = self.rebnconv4d(torch.cat((hx5, hx4), 1))
284
+ hx4dup = _upsample_like(hx4d, hx3)
285
+
286
+ hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
287
+ hx3dup = _upsample_like(hx3d, hx2)
288
+
289
+ hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
290
+ hx2dup = _upsample_like(hx2d, hx1)
291
+
292
+ hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
293
+
294
+ return hx1d + hxin
295
+
296
+
297
+ ### RSU-4 ###
298
+ class RSU4(nn.Module):
299
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
300
+ super(RSU4, self).__init__()
301
+
302
+ self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
303
+
304
+ self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
305
+ self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
306
+
307
+ self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
308
+ self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
309
+
310
+ self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
311
+
312
+ self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=2)
313
+
314
+ self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
315
+ self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
316
+ self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
317
+
318
+ def forward(self, x):
319
+ hx = x
320
+
321
+ hxin = self.rebnconvin(hx)
322
+
323
+ hx1 = self.rebnconv1(hxin)
324
+ hx = self.pool1(hx1)
325
+
326
+ hx2 = self.rebnconv2(hx)
327
+ hx = self.pool2(hx2)
328
+
329
+ hx3 = self.rebnconv3(hx)
330
+
331
+ hx4 = self.rebnconv4(hx3)
332
+
333
+ hx3d = self.rebnconv3d(torch.cat((hx4, hx3), 1))
334
+ hx3dup = _upsample_like(hx3d, hx2)
335
+
336
+ hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
337
+ hx2dup = _upsample_like(hx2d, hx1)
338
+
339
+ hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
340
+
341
+ return hx1d + hxin
342
+
343
+
344
+ ### RSU-4F ###
345
+ class RSU4F(nn.Module):
346
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
347
+ super(RSU4F, self).__init__()
348
+
349
+ self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
350
+
351
+ self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
352
+ self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=2)
353
+ self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=4)
354
+
355
+ self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=8)
356
+
357
+ self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=4)
358
+ self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=2)
359
+ self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
360
+
361
+ def forward(self, x):
362
+ hx = x
363
+
364
+ hxin = self.rebnconvin(hx)
365
+
366
+ hx1 = self.rebnconv1(hxin)
367
+ hx2 = self.rebnconv2(hx1)
368
+ hx3 = self.rebnconv3(hx2)
369
+
370
+ hx4 = self.rebnconv4(hx3)
371
+
372
+ hx3d = self.rebnconv3d(torch.cat((hx4, hx3), 1))
373
+ hx2d = self.rebnconv2d(torch.cat((hx3d, hx2), 1))
374
+ hx1d = self.rebnconv1d(torch.cat((hx2d, hx1), 1))
375
+
376
+ return hx1d + hxin
377
+
378
+
379
+ class myrebnconv(nn.Module):
380
+ def __init__(
381
+ self,
382
+ in_ch=3,
383
+ out_ch=1,
384
+ kernel_size=3,
385
+ stride=1,
386
+ padding=1,
387
+ dilation=1,
388
+ groups=1,
389
+ ):
390
+ super(myrebnconv, self).__init__()
391
+
392
+ self.conv = nn.Conv2d(
393
+ in_ch,
394
+ out_ch,
395
+ kernel_size=kernel_size,
396
+ stride=stride,
397
+ padding=padding,
398
+ dilation=dilation,
399
+ groups=groups,
400
+ )
401
+ self.bn = nn.BatchNorm2d(out_ch)
402
+ self.rl = nn.ReLU(inplace=True)
403
+
404
+ def forward(self, x):
405
+ return self.rl(self.bn(self.conv(x)))
406
+
407
+
408
+ class ISNetGTEncoder(nn.Module):
409
+ def __init__(self, in_ch=1, out_ch=1):
410
+ super(ISNetGTEncoder, self).__init__()
411
+
412
+ self.conv_in = myrebnconv(
413
+ in_ch, 16, 3, stride=2, padding=1
414
+ ) # nn.Conv2d(in_ch,64,3,stride=2,padding=1)
415
+
416
+ self.stage1 = RSU7(16, 16, 64)
417
+ self.pool12 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
418
+
419
+ self.stage2 = RSU6(64, 16, 64)
420
+ self.pool23 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
421
+
422
+ self.stage3 = RSU5(64, 32, 128)
423
+ self.pool34 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
424
+
425
+ self.stage4 = RSU4(128, 32, 256)
426
+ self.pool45 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
427
+
428
+ self.stage5 = RSU4F(256, 64, 512)
429
+ self.pool56 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
430
+
431
+ self.stage6 = RSU4F(512, 64, 512)
432
+
433
+ self.side1 = nn.Conv2d(64, out_ch, 3, padding=1)
434
+ self.side2 = nn.Conv2d(64, out_ch, 3, padding=1)
435
+ self.side3 = nn.Conv2d(128, out_ch, 3, padding=1)
436
+ self.side4 = nn.Conv2d(256, out_ch, 3, padding=1)
437
+ self.side5 = nn.Conv2d(512, out_ch, 3, padding=1)
438
+ self.side6 = nn.Conv2d(512, out_ch, 3, padding=1)
439
+
440
+ @staticmethod
441
+ def compute_loss(args):
442
+ preds, targets = args
443
+ return muti_loss_fusion(preds, targets)
444
+
445
+ def forward(self, x):
446
+ hx = x
447
+
448
+ hxin = self.conv_in(hx)
449
+ # hx = self.pool_in(hxin)
450
+
451
+ # stage 1
452
+ hx1 = self.stage1(hxin)
453
+ hx = self.pool12(hx1)
454
+
455
+ # stage 2
456
+ hx2 = self.stage2(hx)
457
+ hx = self.pool23(hx2)
458
+
459
+ # stage 3
460
+ hx3 = self.stage3(hx)
461
+ hx = self.pool34(hx3)
462
+
463
+ # stage 4
464
+ hx4 = self.stage4(hx)
465
+ hx = self.pool45(hx4)
466
+
467
+ # stage 5
468
+ hx5 = self.stage5(hx)
469
+ hx = self.pool56(hx5)
470
+
471
+ # stage 6
472
+ hx6 = self.stage6(hx)
473
+
474
+ # side output
475
+ d1 = self.side1(hx1)
476
+ d1 = _upsample_like(d1, x)
477
+
478
+ d2 = self.side2(hx2)
479
+ d2 = _upsample_like(d2, x)
480
+
481
+ d3 = self.side3(hx3)
482
+ d3 = _upsample_like(d3, x)
483
+
484
+ d4 = self.side4(hx4)
485
+ d4 = _upsample_like(d4, x)
486
+
487
+ d5 = self.side5(hx5)
488
+ d5 = _upsample_like(d5, x)
489
+
490
+ d6 = self.side6(hx6)
491
+ d6 = _upsample_like(d6, x)
492
+
493
+ # d0 = self.outconv(torch.cat((d1,d2,d3,d4,d5,d6),1))
494
+
495
+ # return [torch.sigmoid(d1), torch.sigmoid(d2), torch.sigmoid(d3), torch.sigmoid(d4), torch.sigmoid(d5), torch.sigmoid(d6)], [hx1, hx2, hx3, hx4, hx5, hx6]
496
+ return [d1, d2, d3, d4, d5, d6], [hx1, hx2, hx3, hx4, hx5, hx6]
497
+
498
+
499
+ class ISNetDIS(nn.Module):
500
+ def __init__(self, in_ch=3, out_ch=1):
501
+ super(ISNetDIS, self).__init__()
502
+
503
+ self.conv_in = nn.Conv2d(in_ch, 64, 3, stride=2, padding=1)
504
+ self.pool_in = nn.MaxPool2d(2, stride=2, ceil_mode=True)
505
+
506
+ self.stage1 = RSU7(64, 32, 64)
507
+ self.pool12 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
508
+
509
+ self.stage2 = RSU6(64, 32, 128)
510
+ self.pool23 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
511
+
512
+ self.stage3 = RSU5(128, 64, 256)
513
+ self.pool34 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
514
+
515
+ self.stage4 = RSU4(256, 128, 512)
516
+ self.pool45 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
517
+
518
+ self.stage5 = RSU4F(512, 256, 512)
519
+ self.pool56 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
520
+
521
+ self.stage6 = RSU4F(512, 256, 512)
522
+
523
+ # decoder
524
+ self.stage5d = RSU4F(1024, 256, 512)
525
+ self.stage4d = RSU4(1024, 128, 256)
526
+ self.stage3d = RSU5(512, 64, 128)
527
+ self.stage2d = RSU6(256, 32, 64)
528
+ self.stage1d = RSU7(128, 16, 64)
529
+
530
+ self.side1 = nn.Conv2d(64, out_ch, 3, padding=1)
531
+ self.side2 = nn.Conv2d(64, out_ch, 3, padding=1)
532
+ self.side3 = nn.Conv2d(128, out_ch, 3, padding=1)
533
+ self.side4 = nn.Conv2d(256, out_ch, 3, padding=1)
534
+ self.side5 = nn.Conv2d(512, out_ch, 3, padding=1)
535
+ self.side6 = nn.Conv2d(512, out_ch, 3, padding=1)
536
+
537
+ # self.outconv = nn.Conv2d(6*out_ch,out_ch,1)
538
+
539
+ @staticmethod
540
+ def compute_loss_kl(preds, targets, dfs, fs, mode="MSE"):
541
+ return muti_loss_fusion_kl(preds, targets, dfs, fs, mode=mode)
542
+
543
+ @staticmethod
544
+ def compute_loss(args):
545
+ if len(args) == 3:
546
+ ds, dfs, labels = args
547
+ return muti_loss_fusion(ds, labels)
548
+ else:
549
+ ds, dfs, labels, fs = args
550
+ return muti_loss_fusion_kl(ds, labels, dfs, fs, mode="MSE")
551
+
552
+ def forward(self, x):
553
+ hx = x
554
+
555
+ hxin = self.conv_in(hx)
556
+ hx = self.pool_in(hxin)
557
+
558
+ # stage 1
559
+ hx1 = self.stage1(hxin)
560
+ hx = self.pool12(hx1)
561
+
562
+ # stage 2
563
+ hx2 = self.stage2(hx)
564
+ hx = self.pool23(hx2)
565
+
566
+ # stage 3
567
+ hx3 = self.stage3(hx)
568
+ hx = self.pool34(hx3)
569
+
570
+ # stage 4
571
+ hx4 = self.stage4(hx)
572
+ hx = self.pool45(hx4)
573
+
574
+ # stage 5
575
+ hx5 = self.stage5(hx)
576
+ hx = self.pool56(hx5)
577
+
578
+ # stage 6
579
+ hx6 = self.stage6(hx)
580
+ hx6up = _upsample_like(hx6, hx5)
581
+
582
+ # -------------------- decoder --------------------
583
+ hx5d = self.stage5d(torch.cat((hx6up, hx5), 1))
584
+ hx5dup = _upsample_like(hx5d, hx4)
585
+
586
+ hx4d = self.stage4d(torch.cat((hx5dup, hx4), 1))
587
+ hx4dup = _upsample_like(hx4d, hx3)
588
+
589
+ hx3d = self.stage3d(torch.cat((hx4dup, hx3), 1))
590
+ hx3dup = _upsample_like(hx3d, hx2)
591
+
592
+ hx2d = self.stage2d(torch.cat((hx3dup, hx2), 1))
593
+ hx2dup = _upsample_like(hx2d, hx1)
594
+
595
+ hx1d = self.stage1d(torch.cat((hx2dup, hx1), 1))
596
+
597
+ # side output
598
+ d1 = self.side1(hx1d)
599
+ d1 = _upsample_like(d1, x)
600
+
601
+ d2 = self.side2(hx2d)
602
+ d2 = _upsample_like(d2, x)
603
+
604
+ d3 = self.side3(hx3d)
605
+ d3 = _upsample_like(d3, x)
606
+
607
+ d4 = self.side4(hx4d)
608
+ d4 = _upsample_like(d4, x)
609
+
610
+ d5 = self.side5(hx5d)
611
+ d5 = _upsample_like(d5, x)
612
+
613
+ d6 = self.side6(hx6)
614
+ d6 = _upsample_like(d6, x)
615
+
616
+ # d0 = self.outconv(torch.cat((d1,d2,d3,d4,d5,d6),1))
617
+
618
+ # return [torch.sigmoid(d1), torch.sigmoid(d2), torch.sigmoid(d3), torch.sigmoid(d4), torch.sigmoid(d5), torch.sigmoid(d6)], [hx1d, hx2d, hx3d, hx4d, hx5d, hx6]
619
+ return [d1, d2, d3, d4, d5, d6], [hx1d, hx2d, hx3d, hx4d, hx5d, hx6]