frodos commited on
Commit
e9311f3
1 Parent(s): 9e86865

Add handler

Browse files
Files changed (3) hide show
  1. handler.py +64 -0
  2. isnet.pth +3 -0
  3. model.py +609 -0
handler.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, Any
2
+ from io import BytesIO
3
+ import base64
4
+ from model import ISNetDIS
5
+ import torch
6
+ import os
7
+ from PIL import Image
8
+ from torchvision.transforms import Compose, Normalize, functional
9
+
10
+
11
+ def process_image(image: torch.Tensor):
12
+ pipe = Compose([Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
13
+ img = pipe(image)
14
+ return torch.unsqueeze(img, 0)
15
+
16
+
17
+ def get_model(device="cpu"):
18
+ model = ISNetDIS()
19
+ weight_pth = os.path.join(os.path.dirname(__file__), "isnet.pth")
20
+ weights = torch.load(weight_pth, map_location=device)
21
+ model.load_state_dict(weights)
22
+ model.to(device)
23
+ model.eval()
24
+ return model
25
+
26
+
27
+ class EndpointHandler():
28
+
29
+ def __init__(self):
30
+ self._model = get_model()
31
+
32
+ def __call__(self, data: Dict[str, Any]) -> list[Dict[str, Any]]:
33
+ inputs = data.pop("inputs", data)
34
+
35
+ image = Image.open(BytesIO(base64.b64decode(inputs['image'])))
36
+ t = functional.pil_to_tensor(image).float().divide(255.0)
37
+ arr = process_image(t)
38
+ model = get_model()
39
+ v = model(arr)[0]
40
+
41
+ pred_val = v[0][0, :, :, :]
42
+ ma = torch.max(pred_val)
43
+ mi = torch.min(pred_val)
44
+ pred_val = (pred_val - mi) / (ma - mi)
45
+
46
+ msk = torch.gt(pred_val, 0.1)
47
+ w = torch.where(msk, t, 1)
48
+ w = torch.cat([w, msk], dim=0)
49
+
50
+ img2 = functional.to_pil_image(torch.squeeze(w))
51
+
52
+
53
+ stream = BytesIO()
54
+ img2.save(stream, format="png")
55
+ res = {"status": 200,
56
+ "image": base64.b64encode(stream.getvalue()).decode("utf8")
57
+ }
58
+ return res
59
+
60
+
61
+ if __name__ == "__main__":
62
+ h = EndpointHandler()
63
+ v = h({})
64
+ print(v)
isnet.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ea0889743a78391b48d6b7c40b4def963ee329cb10934c75aa32481dc5af9c61
3
+ size 176597693
model.py ADDED
@@ -0,0 +1,609 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torchvision import models
4
+ import torch.nn.functional as F
5
+
6
+ bce_loss = nn.BCELoss(size_average=True)
7
+
8
+
9
+ def muti_loss_fusion(preds, target):
10
+ loss0 = 0.0
11
+ loss = 0.0
12
+
13
+ for i in range(0, len(preds)):
14
+ # print("i: ", i, preds[i].shape)
15
+ if (preds[i].shape[2] != target.shape[2] or preds[i].shape[3] != target.shape[3]):
16
+ # tmp_target = _upsample_like(target,preds[i])
17
+ tmp_target = F.interpolate(target, size=preds[i].size()[2:], mode='bilinear', align_corners=True)
18
+ loss = loss + bce_loss(preds[i], tmp_target)
19
+ else:
20
+ loss = loss + bce_loss(preds[i], target)
21
+ if (i == 0):
22
+ loss0 = loss
23
+ return loss0, loss
24
+
25
+
26
+ fea_loss = nn.MSELoss(size_average=True)
27
+ kl_loss = nn.KLDivLoss(size_average=True)
28
+ l1_loss = nn.L1Loss(size_average=True)
29
+ smooth_l1_loss = nn.SmoothL1Loss(size_average=True)
30
+
31
+
32
+ def muti_loss_fusion_kl(preds, target, dfs, fs, mode='MSE'):
33
+ loss0 = 0.0
34
+ loss = 0.0
35
+
36
+ for i in range(0, len(preds)):
37
+ # print("i: ", i, preds[i].shape)
38
+ if (preds[i].shape[2] != target.shape[2] or preds[i].shape[3] != target.shape[3]):
39
+ # tmp_target = _upsample_like(target,preds[i])
40
+ tmp_target = F.interpolate(target, size=preds[i].size()[2:], mode='bilinear', align_corners=True)
41
+ loss = loss + bce_loss(preds[i], tmp_target)
42
+ else:
43
+ loss = loss + bce_loss(preds[i], target)
44
+ if (i == 0):
45
+ loss0 = loss
46
+
47
+ for i in range(0, len(dfs)):
48
+ if (mode == 'MSE'):
49
+ loss = loss + fea_loss(dfs[i], fs[i]) ### add the mse loss of features as additional constraints
50
+ # print("fea_loss: ", fea_loss(dfs[i],fs[i]).item())
51
+ elif (mode == 'KL'):
52
+ loss = loss + kl_loss(F.log_softmax(dfs[i], dim=1), F.softmax(fs[i], dim=1))
53
+ # print("kl_loss: ", kl_loss(F.log_softmax(dfs[i],dim=1),F.softmax(fs[i],dim=1)).item())
54
+ elif (mode == 'MAE'):
55
+ loss = loss + l1_loss(dfs[i], fs[i])
56
+ # print("ls_loss: ", l1_loss(dfs[i],fs[i]))
57
+ elif (mode == 'SmoothL1'):
58
+ loss = loss + smooth_l1_loss(dfs[i], fs[i])
59
+ # print("SmoothL1: ", smooth_l1_loss(dfs[i],fs[i]).item())
60
+
61
+ return loss0, loss
62
+
63
+
64
+ class REBNCONV(nn.Module):
65
+ def __init__(self, in_ch=3, out_ch=3, dirate=1, stride=1):
66
+ super(REBNCONV, self).__init__()
67
+
68
+ self.conv_s1 = nn.Conv2d(in_ch, out_ch, 3, padding=1 * dirate, dilation=1 * dirate, stride=stride)
69
+ self.bn_s1 = nn.BatchNorm2d(out_ch)
70
+ self.relu_s1 = nn.ReLU(inplace=True)
71
+
72
+ def forward(self, x):
73
+ hx = x
74
+ xout = self.relu_s1(self.bn_s1(self.conv_s1(hx)))
75
+
76
+ return xout
77
+
78
+
79
+ ## upsample tensor 'src' to have the same spatial size with tensor 'tar'
80
+ def _upsample_like(src, tar):
81
+ src = F.upsample(src, size=tar.shape[2:], mode='bilinear')
82
+
83
+ return src
84
+
85
+
86
+ ### RSU-7 ###
87
+ class RSU7(nn.Module):
88
+
89
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3, img_size=512):
90
+ super(RSU7, self).__init__()
91
+
92
+ self.in_ch = in_ch
93
+ self.mid_ch = mid_ch
94
+ self.out_ch = out_ch
95
+
96
+ self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1) ## 1 -> 1/2
97
+
98
+ self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
99
+ self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
100
+
101
+ self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
102
+ self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
103
+
104
+ self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
105
+ self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
106
+
107
+ self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
108
+ self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
109
+
110
+ self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=1)
111
+ self.pool5 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
112
+
113
+ self.rebnconv6 = REBNCONV(mid_ch, mid_ch, dirate=1)
114
+
115
+ self.rebnconv7 = REBNCONV(mid_ch, mid_ch, dirate=2)
116
+
117
+ self.rebnconv6d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
118
+ self.rebnconv5d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
119
+ self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
120
+ self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
121
+ self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
122
+ self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
123
+
124
+ def forward(self, x):
125
+ b, c, h, w = x.shape
126
+
127
+ hx = x
128
+ hxin = self.rebnconvin(hx)
129
+
130
+ hx1 = self.rebnconv1(hxin)
131
+ hx = self.pool1(hx1)
132
+
133
+ hx2 = self.rebnconv2(hx)
134
+ hx = self.pool2(hx2)
135
+
136
+ hx3 = self.rebnconv3(hx)
137
+ hx = self.pool3(hx3)
138
+
139
+ hx4 = self.rebnconv4(hx)
140
+ hx = self.pool4(hx4)
141
+
142
+ hx5 = self.rebnconv5(hx)
143
+ hx = self.pool5(hx5)
144
+
145
+ hx6 = self.rebnconv6(hx)
146
+
147
+ hx7 = self.rebnconv7(hx6)
148
+
149
+ hx6d = self.rebnconv6d(torch.cat((hx7, hx6), 1))
150
+ hx6dup = _upsample_like(hx6d, hx5)
151
+
152
+ hx5d = self.rebnconv5d(torch.cat((hx6dup, hx5), 1))
153
+ hx5dup = _upsample_like(hx5d, hx4)
154
+
155
+ hx4d = self.rebnconv4d(torch.cat((hx5dup, hx4), 1))
156
+ hx4dup = _upsample_like(hx4d, hx3)
157
+
158
+ hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
159
+ hx3dup = _upsample_like(hx3d, hx2)
160
+
161
+ hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
162
+ hx2dup = _upsample_like(hx2d, hx1)
163
+
164
+ hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
165
+
166
+ return hx1d + hxin
167
+
168
+
169
+ ### RSU-6 ###
170
+ class RSU6(nn.Module):
171
+
172
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
173
+ super(RSU6, self).__init__()
174
+
175
+ self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
176
+
177
+ self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
178
+ self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
179
+
180
+ self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
181
+ self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
182
+
183
+ self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
184
+ self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
185
+
186
+ self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
187
+ self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
188
+
189
+ self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=1)
190
+
191
+ self.rebnconv6 = REBNCONV(mid_ch, mid_ch, dirate=2)
192
+
193
+ self.rebnconv5d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
194
+ self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
195
+ self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
196
+ self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
197
+ self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
198
+
199
+ def forward(self, x):
200
+ hx = x
201
+
202
+ hxin = self.rebnconvin(hx)
203
+
204
+ hx1 = self.rebnconv1(hxin)
205
+ hx = self.pool1(hx1)
206
+
207
+ hx2 = self.rebnconv2(hx)
208
+ hx = self.pool2(hx2)
209
+
210
+ hx3 = self.rebnconv3(hx)
211
+ hx = self.pool3(hx3)
212
+
213
+ hx4 = self.rebnconv4(hx)
214
+ hx = self.pool4(hx4)
215
+
216
+ hx5 = self.rebnconv5(hx)
217
+
218
+ hx6 = self.rebnconv6(hx5)
219
+
220
+ hx5d = self.rebnconv5d(torch.cat((hx6, hx5), 1))
221
+ hx5dup = _upsample_like(hx5d, hx4)
222
+
223
+ hx4d = self.rebnconv4d(torch.cat((hx5dup, hx4), 1))
224
+ hx4dup = _upsample_like(hx4d, hx3)
225
+
226
+ hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
227
+ hx3dup = _upsample_like(hx3d, hx2)
228
+
229
+ hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
230
+ hx2dup = _upsample_like(hx2d, hx1)
231
+
232
+ hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
233
+
234
+ return hx1d + hxin
235
+
236
+
237
+ ### RSU-5 ###
238
+ class RSU5(nn.Module):
239
+
240
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
241
+ super(RSU5, self).__init__()
242
+
243
+ self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
244
+
245
+ self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
246
+ self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
247
+
248
+ self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
249
+ self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
250
+
251
+ self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
252
+ self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
253
+
254
+ self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
255
+
256
+ self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=2)
257
+
258
+ self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
259
+ self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
260
+ self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
261
+ self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
262
+
263
+ def forward(self, x):
264
+ hx = x
265
+
266
+ hxin = self.rebnconvin(hx)
267
+
268
+ hx1 = self.rebnconv1(hxin)
269
+ hx = self.pool1(hx1)
270
+
271
+ hx2 = self.rebnconv2(hx)
272
+ hx = self.pool2(hx2)
273
+
274
+ hx3 = self.rebnconv3(hx)
275
+ hx = self.pool3(hx3)
276
+
277
+ hx4 = self.rebnconv4(hx)
278
+
279
+ hx5 = self.rebnconv5(hx4)
280
+
281
+ hx4d = self.rebnconv4d(torch.cat((hx5, hx4), 1))
282
+ hx4dup = _upsample_like(hx4d, hx3)
283
+
284
+ hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
285
+ hx3dup = _upsample_like(hx3d, hx2)
286
+
287
+ hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
288
+ hx2dup = _upsample_like(hx2d, hx1)
289
+
290
+ hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
291
+
292
+ return hx1d + hxin
293
+
294
+
295
+ ### RSU-4 ###
296
+ class RSU4(nn.Module):
297
+
298
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
299
+ super(RSU4, self).__init__()
300
+
301
+ self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
302
+
303
+ self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
304
+ self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
305
+
306
+ self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
307
+ self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
308
+
309
+ self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
310
+
311
+ self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=2)
312
+
313
+ self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
314
+ self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
315
+ self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
316
+
317
+ def forward(self, x):
318
+ hx = x
319
+
320
+ hxin = self.rebnconvin(hx)
321
+
322
+ hx1 = self.rebnconv1(hxin)
323
+ hx = self.pool1(hx1)
324
+
325
+ hx2 = self.rebnconv2(hx)
326
+ hx = self.pool2(hx2)
327
+
328
+ hx3 = self.rebnconv3(hx)
329
+
330
+ hx4 = self.rebnconv4(hx3)
331
+
332
+ hx3d = self.rebnconv3d(torch.cat((hx4, hx3), 1))
333
+ hx3dup = _upsample_like(hx3d, hx2)
334
+
335
+ hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
336
+ hx2dup = _upsample_like(hx2d, hx1)
337
+
338
+ hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
339
+
340
+ return hx1d + hxin
341
+
342
+
343
+ ### RSU-4F ###
344
+ class RSU4F(nn.Module):
345
+
346
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
347
+ super(RSU4F, self).__init__()
348
+
349
+ self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
350
+
351
+ self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
352
+ self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=2)
353
+ self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=4)
354
+
355
+ self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=8)
356
+
357
+ self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=4)
358
+ self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=2)
359
+ self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
360
+
361
+ def forward(self, x):
362
+ hx = x
363
+
364
+ hxin = self.rebnconvin(hx)
365
+
366
+ hx1 = self.rebnconv1(hxin)
367
+ hx2 = self.rebnconv2(hx1)
368
+ hx3 = self.rebnconv3(hx2)
369
+
370
+ hx4 = self.rebnconv4(hx3)
371
+
372
+ hx3d = self.rebnconv3d(torch.cat((hx4, hx3), 1))
373
+ hx2d = self.rebnconv2d(torch.cat((hx3d, hx2), 1))
374
+ hx1d = self.rebnconv1d(torch.cat((hx2d, hx1), 1))
375
+
376
+ return hx1d + hxin
377
+
378
+
379
+ class myrebnconv(nn.Module):
380
+ def __init__(self, in_ch=3,
381
+ out_ch=1,
382
+ kernel_size=3,
383
+ stride=1,
384
+ padding=1,
385
+ dilation=1,
386
+ groups=1):
387
+ super(myrebnconv, self).__init__()
388
+
389
+ self.conv = nn.Conv2d(in_ch,
390
+ out_ch,
391
+ kernel_size=kernel_size,
392
+ stride=stride,
393
+ padding=padding,
394
+ dilation=dilation,
395
+ groups=groups)
396
+ self.bn = nn.BatchNorm2d(out_ch)
397
+ self.rl = nn.ReLU(inplace=True)
398
+
399
+ def forward(self, x):
400
+ return self.rl(self.bn(self.conv(x)))
401
+
402
+
403
+ class ISNetGTEncoder(nn.Module):
404
+
405
+ def __init__(self, in_ch=1, out_ch=1):
406
+ super(ISNetGTEncoder, self).__init__()
407
+
408
+ self.conv_in = myrebnconv(in_ch, 16, 3, stride=2, padding=1) # nn.Conv2d(in_ch,64,3,stride=2,padding=1)
409
+
410
+ self.stage1 = RSU7(16, 16, 64)
411
+ self.pool12 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
412
+
413
+ self.stage2 = RSU6(64, 16, 64)
414
+ self.pool23 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
415
+
416
+ self.stage3 = RSU5(64, 32, 128)
417
+ self.pool34 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
418
+
419
+ self.stage4 = RSU4(128, 32, 256)
420
+ self.pool45 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
421
+
422
+ self.stage5 = RSU4F(256, 64, 512)
423
+ self.pool56 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
424
+
425
+ self.stage6 = RSU4F(512, 64, 512)
426
+
427
+ self.side1 = nn.Conv2d(64, out_ch, 3, padding=1)
428
+ self.side2 = nn.Conv2d(64, out_ch, 3, padding=1)
429
+ self.side3 = nn.Conv2d(128, out_ch, 3, padding=1)
430
+ self.side4 = nn.Conv2d(256, out_ch, 3, padding=1)
431
+ self.side5 = nn.Conv2d(512, out_ch, 3, padding=1)
432
+ self.side6 = nn.Conv2d(512, out_ch, 3, padding=1)
433
+
434
+ def compute_loss(self, preds, targets):
435
+ return muti_loss_fusion(preds, targets)
436
+
437
+ def forward(self, x):
438
+ hx = x
439
+
440
+ hxin = self.conv_in(hx)
441
+ # hx = self.pool_in(hxin)
442
+
443
+ # stage 1
444
+ hx1 = self.stage1(hxin)
445
+ hx = self.pool12(hx1)
446
+
447
+ # stage 2
448
+ hx2 = self.stage2(hx)
449
+ hx = self.pool23(hx2)
450
+
451
+ # stage 3
452
+ hx3 = self.stage3(hx)
453
+ hx = self.pool34(hx3)
454
+
455
+ # stage 4
456
+ hx4 = self.stage4(hx)
457
+ hx = self.pool45(hx4)
458
+
459
+ # stage 5
460
+ hx5 = self.stage5(hx)
461
+ hx = self.pool56(hx5)
462
+
463
+ # stage 6
464
+ hx6 = self.stage6(hx)
465
+
466
+ # side output
467
+ d1 = self.side1(hx1)
468
+ d1 = _upsample_like(d1, x)
469
+
470
+ d2 = self.side2(hx2)
471
+ d2 = _upsample_like(d2, x)
472
+
473
+ d3 = self.side3(hx3)
474
+ d3 = _upsample_like(d3, x)
475
+
476
+ d4 = self.side4(hx4)
477
+ d4 = _upsample_like(d4, x)
478
+
479
+ d5 = self.side5(hx5)
480
+ d5 = _upsample_like(d5, x)
481
+
482
+ d6 = self.side6(hx6)
483
+ d6 = _upsample_like(d6, x)
484
+
485
+ # d0 = self.outconv(torch.cat((d1,d2,d3,d4,d5,d6),1))
486
+
487
+ return [F.sigmoid(d1), F.sigmoid(d2), F.sigmoid(d3), F.sigmoid(d4), F.sigmoid(d5), F.sigmoid(d6)], [hx1, hx2,
488
+ hx3, hx4,
489
+ hx5, hx6]
490
+
491
+
492
+ class ISNetDIS(nn.Module):
493
+
494
+ def __init__(self, in_ch=3, out_ch=1):
495
+ super(ISNetDIS, self).__init__()
496
+
497
+ self.conv_in = nn.Conv2d(in_ch, 64, 3, stride=2, padding=1)
498
+ self.pool_in = nn.MaxPool2d(2, stride=2, ceil_mode=True)
499
+
500
+ self.stage1 = RSU7(64, 32, 64)
501
+ self.pool12 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
502
+
503
+ self.stage2 = RSU6(64, 32, 128)
504
+ self.pool23 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
505
+
506
+ self.stage3 = RSU5(128, 64, 256)
507
+ self.pool34 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
508
+
509
+ self.stage4 = RSU4(256, 128, 512)
510
+ self.pool45 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
511
+
512
+ self.stage5 = RSU4F(512, 256, 512)
513
+ self.pool56 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
514
+
515
+ self.stage6 = RSU4F(512, 256, 512)
516
+
517
+ # decoder
518
+ self.stage5d = RSU4F(1024, 256, 512)
519
+ self.stage4d = RSU4(1024, 128, 256)
520
+ self.stage3d = RSU5(512, 64, 128)
521
+ self.stage2d = RSU6(256, 32, 64)
522
+ self.stage1d = RSU7(128, 16, 64)
523
+
524
+ self.side1 = nn.Conv2d(64, out_ch, 3, padding=1)
525
+ self.side2 = nn.Conv2d(64, out_ch, 3, padding=1)
526
+ self.side3 = nn.Conv2d(128, out_ch, 3, padding=1)
527
+ self.side4 = nn.Conv2d(256, out_ch, 3, padding=1)
528
+ self.side5 = nn.Conv2d(512, out_ch, 3, padding=1)
529
+ self.side6 = nn.Conv2d(512, out_ch, 3, padding=1)
530
+
531
+ # self.outconv = nn.Conv2d(6*out_ch,out_ch,1)
532
+
533
+ def compute_loss_kl(self, preds, targets, dfs, fs, mode='MSE'):
534
+ # return muti_loss_fusion(preds,targets)
535
+ return muti_loss_fusion_kl(preds, targets, dfs, fs, mode=mode)
536
+
537
+ def compute_loss(self, preds, targets):
538
+ # return muti_loss_fusion(preds,targets)
539
+ return muti_loss_fusion(preds, targets)
540
+
541
+ def forward(self, x):
542
+ hx = x
543
+
544
+ hxin = self.conv_in(hx)
545
+ # hx = self.pool_in(hxin)
546
+
547
+ # stage 1
548
+ hx1 = self.stage1(hxin)
549
+ hx = self.pool12(hx1)
550
+
551
+ # stage 2
552
+ hx2 = self.stage2(hx)
553
+ hx = self.pool23(hx2)
554
+
555
+ # stage 3
556
+ hx3 = self.stage3(hx)
557
+ hx = self.pool34(hx3)
558
+
559
+ # stage 4
560
+ hx4 = self.stage4(hx)
561
+ hx = self.pool45(hx4)
562
+
563
+ # stage 5
564
+ hx5 = self.stage5(hx)
565
+ hx = self.pool56(hx5)
566
+
567
+ # stage 6
568
+ hx6 = self.stage6(hx)
569
+ hx6up = _upsample_like(hx6, hx5)
570
+
571
+ # -------------------- decoder --------------------
572
+ hx5d = self.stage5d(torch.cat((hx6up, hx5), 1))
573
+ hx5dup = _upsample_like(hx5d, hx4)
574
+
575
+ hx4d = self.stage4d(torch.cat((hx5dup, hx4), 1))
576
+ hx4dup = _upsample_like(hx4d, hx3)
577
+
578
+ hx3d = self.stage3d(torch.cat((hx4dup, hx3), 1))
579
+ hx3dup = _upsample_like(hx3d, hx2)
580
+
581
+ hx2d = self.stage2d(torch.cat((hx3dup, hx2), 1))
582
+ hx2dup = _upsample_like(hx2d, hx1)
583
+
584
+ hx1d = self.stage1d(torch.cat((hx2dup, hx1), 1))
585
+
586
+ # side output
587
+ d1 = self.side1(hx1d)
588
+ d1 = _upsample_like(d1, x)
589
+
590
+ d2 = self.side2(hx2d)
591
+ d2 = _upsample_like(d2, x)
592
+
593
+ d3 = self.side3(hx3d)
594
+ d3 = _upsample_like(d3, x)
595
+
596
+ d4 = self.side4(hx4d)
597
+ d4 = _upsample_like(d4, x)
598
+
599
+ d5 = self.side5(hx5d)
600
+ d5 = _upsample_like(d5, x)
601
+
602
+ d6 = self.side6(hx6)
603
+ d6 = _upsample_like(d6, x)
604
+
605
+ # d0 = self.outconv(torch.cat((d1,d2,d3,d4,d5,d6),1))
606
+
607
+ return [F.sigmoid(d1), F.sigmoid(d2), F.sigmoid(d3), F.sigmoid(d4), F.sigmoid(d5), F.sigmoid(d6)], [hx1d, hx2d,
608
+ hx3d, hx4d,
609
+ hx5d, hx6]