sammyview80 commited on
Commit
6f70afa
1 Parent(s): 7866247

Upload 8 files

Browse files
Dockerfile ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # read the doc: https://huggingface.co/docs/hub/spaces-sdks-docker
2
+ # you will also find guides on how best to write your Dockerfile
3
+
4
+ FROM python:3.9
5
+
6
+ WORKDIR /code
7
+
8
+ COPY ./requirements.txt /code/requirements.txt
9
+
10
+ RUN pip install --no-cache-dir --upgrade -r /code/requirements.txt
11
+
12
+ COPY . .
13
+
14
+ CMD ["gunicorn", "-b", "0.0.0.0:7860", "flask_app:app"]
briarmbg.py ADDED
@@ -0,0 +1,457 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from huggingface_hub import PyTorchModelHubMixin
5
+
6
+ class REBNCONV(nn.Module):
7
+ def __init__(self,in_ch=3,out_ch=3,dirate=1,stride=1):
8
+ super(REBNCONV,self).__init__()
9
+
10
+ self.conv_s1 = nn.Conv2d(in_ch,out_ch,3,padding=1*dirate,dilation=1*dirate,stride=stride)
11
+ self.bn_s1 = nn.BatchNorm2d(out_ch)
12
+ self.relu_s1 = nn.ReLU(inplace=True)
13
+
14
+ def forward(self,x):
15
+
16
+ hx = x
17
+ xout = self.relu_s1(self.bn_s1(self.conv_s1(hx)))
18
+
19
+ return xout
20
+
21
+ ## upsample tensor 'src' to have the same spatial size with tensor 'tar'
22
+ def _upsample_like(src,tar):
23
+
24
+ src = F.interpolate(src,size=tar.shape[2:],mode='bilinear')
25
+
26
+ return src
27
+
28
+
29
+ ### RSU-7 ###
30
+ class RSU7(nn.Module):
31
+
32
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3, img_size=512):
33
+ super(RSU7,self).__init__()
34
+
35
+ self.in_ch = in_ch
36
+ self.mid_ch = mid_ch
37
+ self.out_ch = out_ch
38
+
39
+ self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1) ## 1 -> 1/2
40
+
41
+ self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
42
+ self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
43
+
44
+ self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1)
45
+ self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
46
+
47
+ self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1)
48
+ self.pool3 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
49
+
50
+ self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=1)
51
+ self.pool4 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
52
+
53
+ self.rebnconv5 = REBNCONV(mid_ch,mid_ch,dirate=1)
54
+ self.pool5 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
55
+
56
+ self.rebnconv6 = REBNCONV(mid_ch,mid_ch,dirate=1)
57
+
58
+ self.rebnconv7 = REBNCONV(mid_ch,mid_ch,dirate=2)
59
+
60
+ self.rebnconv6d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
61
+ self.rebnconv5d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
62
+ self.rebnconv4d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
63
+ self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
64
+ self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
65
+ self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)
66
+
67
+ def forward(self,x):
68
+ b, c, h, w = x.shape
69
+
70
+ hx = x
71
+ hxin = self.rebnconvin(hx)
72
+
73
+ hx1 = self.rebnconv1(hxin)
74
+ hx = self.pool1(hx1)
75
+
76
+ hx2 = self.rebnconv2(hx)
77
+ hx = self.pool2(hx2)
78
+
79
+ hx3 = self.rebnconv3(hx)
80
+ hx = self.pool3(hx3)
81
+
82
+ hx4 = self.rebnconv4(hx)
83
+ hx = self.pool4(hx4)
84
+
85
+ hx5 = self.rebnconv5(hx)
86
+ hx = self.pool5(hx5)
87
+
88
+ hx6 = self.rebnconv6(hx)
89
+
90
+ hx7 = self.rebnconv7(hx6)
91
+
92
+ hx6d = self.rebnconv6d(torch.cat((hx7,hx6),1))
93
+ hx6dup = _upsample_like(hx6d,hx5)
94
+
95
+ hx5d = self.rebnconv5d(torch.cat((hx6dup,hx5),1))
96
+ hx5dup = _upsample_like(hx5d,hx4)
97
+
98
+ hx4d = self.rebnconv4d(torch.cat((hx5dup,hx4),1))
99
+ hx4dup = _upsample_like(hx4d,hx3)
100
+
101
+ hx3d = self.rebnconv3d(torch.cat((hx4dup,hx3),1))
102
+ hx3dup = _upsample_like(hx3d,hx2)
103
+
104
+ hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1))
105
+ hx2dup = _upsample_like(hx2d,hx1)
106
+
107
+ hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1))
108
+
109
+ return hx1d + hxin
110
+
111
+
112
+ ### RSU-6 ###
113
+ class RSU6(nn.Module):
114
+
115
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
116
+ super(RSU6,self).__init__()
117
+
118
+ self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)
119
+
120
+ self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
121
+ self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
122
+
123
+ self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1)
124
+ self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
125
+
126
+ self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1)
127
+ self.pool3 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
128
+
129
+ self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=1)
130
+ self.pool4 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
131
+
132
+ self.rebnconv5 = REBNCONV(mid_ch,mid_ch,dirate=1)
133
+
134
+ self.rebnconv6 = REBNCONV(mid_ch,mid_ch,dirate=2)
135
+
136
+ self.rebnconv5d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
137
+ self.rebnconv4d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
138
+ self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
139
+ self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
140
+ self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)
141
+
142
+ def forward(self,x):
143
+
144
+ hx = x
145
+
146
+ hxin = self.rebnconvin(hx)
147
+
148
+ hx1 = self.rebnconv1(hxin)
149
+ hx = self.pool1(hx1)
150
+
151
+ hx2 = self.rebnconv2(hx)
152
+ hx = self.pool2(hx2)
153
+
154
+ hx3 = self.rebnconv3(hx)
155
+ hx = self.pool3(hx3)
156
+
157
+ hx4 = self.rebnconv4(hx)
158
+ hx = self.pool4(hx4)
159
+
160
+ hx5 = self.rebnconv5(hx)
161
+
162
+ hx6 = self.rebnconv6(hx5)
163
+
164
+
165
+ hx5d = self.rebnconv5d(torch.cat((hx6,hx5),1))
166
+ hx5dup = _upsample_like(hx5d,hx4)
167
+
168
+ hx4d = self.rebnconv4d(torch.cat((hx5dup,hx4),1))
169
+ hx4dup = _upsample_like(hx4d,hx3)
170
+
171
+ hx3d = self.rebnconv3d(torch.cat((hx4dup,hx3),1))
172
+ hx3dup = _upsample_like(hx3d,hx2)
173
+
174
+ hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1))
175
+ hx2dup = _upsample_like(hx2d,hx1)
176
+
177
+ hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1))
178
+
179
+ return hx1d + hxin
180
+
181
+ ### RSU-5 ###
182
+ class RSU5(nn.Module):
183
+
184
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
185
+ super(RSU5,self).__init__()
186
+
187
+ self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)
188
+
189
+ self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
190
+ self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
191
+
192
+ self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1)
193
+ self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
194
+
195
+ self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1)
196
+ self.pool3 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
197
+
198
+ self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=1)
199
+
200
+ self.rebnconv5 = REBNCONV(mid_ch,mid_ch,dirate=2)
201
+
202
+ self.rebnconv4d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
203
+ self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
204
+ self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
205
+ self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)
206
+
207
+ def forward(self,x):
208
+
209
+ hx = x
210
+
211
+ hxin = self.rebnconvin(hx)
212
+
213
+ hx1 = self.rebnconv1(hxin)
214
+ hx = self.pool1(hx1)
215
+
216
+ hx2 = self.rebnconv2(hx)
217
+ hx = self.pool2(hx2)
218
+
219
+ hx3 = self.rebnconv3(hx)
220
+ hx = self.pool3(hx3)
221
+
222
+ hx4 = self.rebnconv4(hx)
223
+
224
+ hx5 = self.rebnconv5(hx4)
225
+
226
+ hx4d = self.rebnconv4d(torch.cat((hx5,hx4),1))
227
+ hx4dup = _upsample_like(hx4d,hx3)
228
+
229
+ hx3d = self.rebnconv3d(torch.cat((hx4dup,hx3),1))
230
+ hx3dup = _upsample_like(hx3d,hx2)
231
+
232
+ hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1))
233
+ hx2dup = _upsample_like(hx2d,hx1)
234
+
235
+ hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1))
236
+
237
+ return hx1d + hxin
238
+
239
+ ### RSU-4 ###
240
+ class RSU4(nn.Module):
241
+
242
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
243
+ super(RSU4,self).__init__()
244
+
245
+ self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)
246
+
247
+ self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
248
+ self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
249
+
250
+ self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1)
251
+ self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
252
+
253
+ self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1)
254
+
255
+ self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=2)
256
+
257
+ self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
258
+ self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
259
+ self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)
260
+
261
+ def forward(self,x):
262
+
263
+ hx = x
264
+
265
+ hxin = self.rebnconvin(hx)
266
+
267
+ hx1 = self.rebnconv1(hxin)
268
+ hx = self.pool1(hx1)
269
+
270
+ hx2 = self.rebnconv2(hx)
271
+ hx = self.pool2(hx2)
272
+
273
+ hx3 = self.rebnconv3(hx)
274
+
275
+ hx4 = self.rebnconv4(hx3)
276
+
277
+ hx3d = self.rebnconv3d(torch.cat((hx4,hx3),1))
278
+ hx3dup = _upsample_like(hx3d,hx2)
279
+
280
+ hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1))
281
+ hx2dup = _upsample_like(hx2d,hx1)
282
+
283
+ hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1))
284
+
285
+ return hx1d + hxin
286
+
287
+ ### RSU-4F ###
288
+ class RSU4F(nn.Module):
289
+
290
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
291
+ super(RSU4F,self).__init__()
292
+
293
+ self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)
294
+
295
+ self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
296
+ self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=2)
297
+ self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=4)
298
+
299
+ self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=8)
300
+
301
+ self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=4)
302
+ self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=2)
303
+ self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)
304
+
305
+ def forward(self,x):
306
+
307
+ hx = x
308
+
309
+ hxin = self.rebnconvin(hx)
310
+
311
+ hx1 = self.rebnconv1(hxin)
312
+ hx2 = self.rebnconv2(hx1)
313
+ hx3 = self.rebnconv3(hx2)
314
+
315
+ hx4 = self.rebnconv4(hx3)
316
+
317
+ hx3d = self.rebnconv3d(torch.cat((hx4,hx3),1))
318
+ hx2d = self.rebnconv2d(torch.cat((hx3d,hx2),1))
319
+ hx1d = self.rebnconv1d(torch.cat((hx2d,hx1),1))
320
+
321
+ return hx1d + hxin
322
+
323
+
324
+ class myrebnconv(nn.Module):
325
+ def __init__(self, in_ch=3,
326
+ out_ch=1,
327
+ kernel_size=3,
328
+ stride=1,
329
+ padding=1,
330
+ dilation=1,
331
+ groups=1):
332
+ super(myrebnconv,self).__init__()
333
+
334
+ self.conv = nn.Conv2d(in_ch,
335
+ out_ch,
336
+ kernel_size=kernel_size,
337
+ stride=stride,
338
+ padding=padding,
339
+ dilation=dilation,
340
+ groups=groups)
341
+ self.bn = nn.BatchNorm2d(out_ch)
342
+ self.rl = nn.ReLU(inplace=True)
343
+
344
+ def forward(self,x):
345
+ return self.rl(self.bn(self.conv(x)))
346
+
347
+
348
+ class BriaRMBG(nn.Module, PyTorchModelHubMixin):
349
+
350
+ def __init__(self,config:dict={"in_ch":3,"out_ch":1}):
351
+ super(BriaRMBG,self).__init__()
352
+ in_ch=config["in_ch"]
353
+ out_ch=config["out_ch"]
354
+ self.conv_in = nn.Conv2d(in_ch,64,3,stride=2,padding=1)
355
+ self.pool_in = nn.MaxPool2d(2,stride=2,ceil_mode=True)
356
+
357
+ self.stage1 = RSU7(64,32,64)
358
+ self.pool12 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
359
+
360
+ self.stage2 = RSU6(64,32,128)
361
+ self.pool23 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
362
+
363
+ self.stage3 = RSU5(128,64,256)
364
+ self.pool34 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
365
+
366
+ self.stage4 = RSU4(256,128,512)
367
+ self.pool45 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
368
+
369
+ self.stage5 = RSU4F(512,256,512)
370
+ self.pool56 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
371
+
372
+ self.stage6 = RSU4F(512,256,512)
373
+
374
+ # decoder
375
+ self.stage5d = RSU4F(1024,256,512)
376
+ self.stage4d = RSU4(1024,128,256)
377
+ self.stage3d = RSU5(512,64,128)
378
+ self.stage2d = RSU6(256,32,64)
379
+ self.stage1d = RSU7(128,16,64)
380
+
381
+ self.side1 = nn.Conv2d(64,out_ch,3,padding=1)
382
+ self.side2 = nn.Conv2d(64,out_ch,3,padding=1)
383
+ self.side3 = nn.Conv2d(128,out_ch,3,padding=1)
384
+ self.side4 = nn.Conv2d(256,out_ch,3,padding=1)
385
+ self.side5 = nn.Conv2d(512,out_ch,3,padding=1)
386
+ self.side6 = nn.Conv2d(512,out_ch,3,padding=1)
387
+
388
+ # self.outconv = nn.Conv2d(6*out_ch,out_ch,1)
389
+
390
+ def forward(self,x):
391
+
392
+ hx = x
393
+
394
+ hxin = self.conv_in(hx)
395
+ #hx = self.pool_in(hxin)
396
+
397
+ #stage 1
398
+ hx1 = self.stage1(hxin)
399
+ hx = self.pool12(hx1)
400
+
401
+ #stage 2
402
+ hx2 = self.stage2(hx)
403
+ hx = self.pool23(hx2)
404
+
405
+ #stage 3
406
+ hx3 = self.stage3(hx)
407
+ hx = self.pool34(hx3)
408
+
409
+ #stage 4
410
+ hx4 = self.stage4(hx)
411
+ hx = self.pool45(hx4)
412
+
413
+ #stage 5
414
+ hx5 = self.stage5(hx)
415
+ hx = self.pool56(hx5)
416
+
417
+ #stage 6
418
+ hx6 = self.stage6(hx)
419
+ hx6up = _upsample_like(hx6,hx5)
420
+
421
+ #-------------------- decoder --------------------
422
+ hx5d = self.stage5d(torch.cat((hx6up,hx5),1))
423
+ hx5dup = _upsample_like(hx5d,hx4)
424
+
425
+ hx4d = self.stage4d(torch.cat((hx5dup,hx4),1))
426
+ hx4dup = _upsample_like(hx4d,hx3)
427
+
428
+ hx3d = self.stage3d(torch.cat((hx4dup,hx3),1))
429
+ hx3dup = _upsample_like(hx3d,hx2)
430
+
431
+ hx2d = self.stage2d(torch.cat((hx3dup,hx2),1))
432
+ hx2dup = _upsample_like(hx2d,hx1)
433
+
434
+ hx1d = self.stage1d(torch.cat((hx2dup,hx1),1))
435
+
436
+
437
+ #side output
438
+ d1 = self.side1(hx1d)
439
+ d1 = _upsample_like(d1,x)
440
+
441
+ d2 = self.side2(hx2d)
442
+ d2 = _upsample_like(d2,x)
443
+
444
+ d3 = self.side3(hx3d)
445
+ d3 = _upsample_like(d3,x)
446
+
447
+ d4 = self.side4(hx4d)
448
+ d4 = _upsample_like(d4,x)
449
+
450
+ d5 = self.side5(hx5d)
451
+ d5 = _upsample_like(d5,x)
452
+
453
+ d6 = self.side6(hx6)
454
+ d6 = _upsample_like(d6,x)
455
+
456
+ return [F.sigmoid(d1), F.sigmoid(d2), F.sigmoid(d3), F.sigmoid(d4), F.sigmoid(d5), F.sigmoid(d6)],[hx1d,hx2d,hx3d,hx4d,hx5d,hx6]
457
+
example_inference.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from skimage import io
2
+ import torch, os
3
+ from PIL import Image
4
+ from briarmbg import BriaRMBG
5
+ from utilities import preprocess_image, postprocess_image
6
+ from huggingface_hub import hf_hub_download
7
+
8
+ def example_inference(im_path, transprent_bg=False):
9
+
10
+ net = BriaRMBG()
11
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
12
+ net = BriaRMBG.from_pretrained("briaai/RMBG-1.4")
13
+ net.to(device)
14
+ net.eval()
15
+
16
+ # prepare input
17
+ model_input_size = [1024,1024]
18
+ orig_im = io.imread(im_path)
19
+ orig_im_size = orig_im.shape[0:2]
20
+ image = preprocess_image(orig_im, model_input_size).to(device)
21
+
22
+ # inference
23
+ result=net(image)
24
+
25
+ # post process
26
+ result_image = postprocess_image(result[0][0], orig_im_size)
27
+ bgColor = (0,0,0, 0) if transprent_bg else (255,255,255, 255)
28
+ # save result
29
+ pil_im = Image.fromarray(result_image)
30
+ no_bg_image = Image.new("RGBA", pil_im.size, bgColor)
31
+ orig_image = Image.open(im_path)
32
+ no_bg_image.paste(orig_image, mask=pil_im)
33
+ no_bg_image.save("images/rm_image.png")
34
+ return 'rm_image.png'
flask_app.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import os
3
+ from flask import Flask, flash, request, redirect, url_for, session
4
+ from flask_session import Session
5
+ from werkzeug.utils import secure_filename
6
+ from gevent.pywsgi import WSGIServer
7
+ from example_inference import example_inference
8
+ from flask import send_from_directory
9
+
10
+
11
+ UPLOAD_FOLDER = 'images'
12
+ ALLOWED_EXTENSIONS = {'png', 'jpg', 'jpeg'}
13
+
14
+ app = Flask(__name__)
15
+ app.config["SESSION_PERMANENT"] = False
16
+ app.config["SESSION_TYPE"] = "filesystem"
17
+ app.config['UPLOAD_FOLDER'] = UPLOAD_FOLDER
18
+
19
+ Session(app)
20
+
21
+ def allowed_file(filename):
22
+ return '.' in filename and \
23
+ filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS
24
+
25
+ @app.route('/rm-bg/<transparent>', methods=['GET', 'POST'])
26
+ def upload_file(transparent):
27
+ print(transparent)
28
+ transparent = True if transparent == "true" else False
29
+ print(transparent)
30
+ if request.method == 'POST':
31
+ if 'file' not in request.files:
32
+ flash('No file part')
33
+ return {"status": "Failed", "message": "Please Provide file name(file)."}
34
+ file = request.files['file']
35
+ if file.filename == '':
36
+ flash('No selected file')
37
+ return {"status": "Failed", "message": "Filename Not Found."}
38
+ if file and allowed_file(file.filename):
39
+ filename = secure_filename('normal_image.png')
40
+ file.save(os.path.join(app.config['UPLOAD_FOLDER'], filename))
41
+ im_path = f"{os.path.dirname(os.path.abspath(__file__))}/images/{filename}"
42
+ rm_image_path = example_inference(im_path, transparent)
43
+ return send_from_directory(app.config["UPLOAD_FOLDER"], rm_image_path)
44
+ return {
45
+ "message": "Get Request not allowed"
46
+ }
47
+
48
+ if __name__ == '__main__':
49
+ # http_server = WSGIServer(('', 8000), app)
50
+ # http_server.serve_forever()
51
+ app.debug = True
52
+ app.run(port=8000)
images/normal_image.png ADDED
images/rm_image.png ADDED
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ torchvision
3
+ pillow
4
+ numpy
5
+ typing
6
+ scikit-image
7
+ huggingface_hub
8
+ flask
9
+ Flask-Session
10
+ gunicorn
utilities.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from torchvision.transforms.functional import normalize
4
+ import numpy as np
5
+
6
+ def preprocess_image(im: np.ndarray, model_input_size: list) -> torch.Tensor:
7
+ if len(im.shape) < 3:
8
+ im = im[:, :, np.newaxis]
9
+ # orig_im_size=im.shape[0:2]
10
+ im_tensor = torch.tensor(im, dtype=torch.float32).permute(2,0,1)
11
+ im_tensor = F.interpolate(torch.unsqueeze(im_tensor,0), size=model_input_size, mode='bilinear').type(torch.uint8)
12
+ image = torch.divide(im_tensor,255.0)
13
+ image = normalize(image,[0.5,0.5,0.5],[1.0,1.0,1.0])
14
+ return image
15
+
16
+
17
+ def postprocess_image(result: torch.Tensor, im_size: list)-> np.ndarray:
18
+ result = torch.squeeze(F.interpolate(result, size=im_size, mode='bilinear') ,0)
19
+ ma = torch.max(result)
20
+ mi = torch.min(result)
21
+ result = (result-mi)/(ma-mi)
22
+ im_array = (result*255).permute(1,2,0).cpu().data.numpy().astype(np.uint8)
23
+ im_array = np.squeeze(im_array)
24
+ return im_array
25
+