abhicake commited on
Commit
a9d802a
1 Parent(s): 126ca0c

Upload 10 files

Browse files
Files changed (10) hide show
  1. MyConfig.py +13 -0
  2. MyPipe.py +76 -0
  3. briarmbg.py +458 -0
  4. config.json +25 -0
  5. example_inference.py +39 -0
  6. model.pth +3 -0
  7. preprocessor_config.json +23 -0
  8. pytorch_model.bin +3 -0
  9. requirements.txt +8 -0
  10. utilities.py +25 -0
MyConfig.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+ from typing import List
3
+
4
+ class RMBGConfig(PretrainedConfig):
5
+ model_type = "SegformerForSemanticSegmentation"
6
+ def __init__(
7
+ self,
8
+ in_ch=3,
9
+ out_ch=1,
10
+ **kwargs):
11
+ self.in_ch = in_ch
12
+ self.out_ch = out_ch
13
+ super().__init__(**kwargs)
MyPipe.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch, os
2
+ import torch.nn.functional as F
3
+ from torchvision.transforms.functional import normalize
4
+ import numpy as np
5
+ from transformers import Pipeline
6
+ from transformers.image_utils import load_image
7
+ from skimage import io
8
+ from PIL import Image
9
+
10
+ class RMBGPipe(Pipeline):
11
+ def __init__(self,**kwargs):
12
+ Pipeline.__init__(self,**kwargs)
13
+ self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
14
+ self.model.to(self.device)
15
+ self.model.eval()
16
+
17
+ def _sanitize_parameters(self, **kwargs):
18
+ # parse parameters
19
+ preprocess_kwargs = {}
20
+ postprocess_kwargs = {}
21
+ if "model_input_size" in kwargs :
22
+ preprocess_kwargs["model_input_size"] = kwargs["model_input_size"]
23
+ if "return_mask" in kwargs:
24
+ postprocess_kwargs["return_mask"] = kwargs["return_mask"]
25
+ return preprocess_kwargs, {}, postprocess_kwargs
26
+
27
+ def preprocess(self,input_image,model_input_size: list=[1024,1024]):
28
+ # preprocess the input
29
+ orig_im = load_image(input_image)
30
+ orig_im = np.array(orig_im)
31
+ orig_im_size = orig_im.shape[0:2]
32
+ preprocessed_image = self.preprocess_image(orig_im, model_input_size).to(self.device)
33
+ inputs = {
34
+ "preprocessed_image":preprocessed_image,
35
+ "orig_im_size":orig_im_size,
36
+ "input_image" : input_image
37
+ }
38
+ return inputs
39
+
40
+ def _forward(self,inputs):
41
+ result = self.model(inputs.pop("preprocessed_image"))
42
+ inputs["result"] = result
43
+ return inputs
44
+
45
+ def postprocess(self,inputs,return_mask:bool=False ):
46
+ result = inputs.pop("result")
47
+ orig_im_size = inputs.pop("orig_im_size")
48
+ input_image = inputs.pop("input_image")
49
+ result_image = self.postprocess_image(result[0][0], orig_im_size)
50
+ pil_im = Image.fromarray(result_image)
51
+ if return_mask ==True :
52
+ return pil_im
53
+ no_bg_image = Image.new("RGBA", pil_im.size, (0,0,0,0))
54
+ input_image = load_image(input_image)
55
+ no_bg_image.paste(input_image, mask=pil_im)
56
+ return no_bg_image
57
+
58
+ # utilities functions
59
+ def preprocess_image(self,im: np.ndarray, model_input_size: list=[1024,1024]) -> torch.Tensor:
60
+ # same as utilities.py with minor modification
61
+ if len(im.shape) < 3:
62
+ im = im[:, :, np.newaxis]
63
+ im_tensor = torch.tensor(im, dtype=torch.float32).permute(2,0,1)
64
+ im_tensor = F.interpolate(torch.unsqueeze(im_tensor,0), size=model_input_size, mode='bilinear')
65
+ image = torch.divide(im_tensor,255.0)
66
+ image = normalize(image,[0.5,0.5,0.5],[1.0,1.0,1.0])
67
+ return image
68
+
69
+ def postprocess_image(self,result: torch.Tensor, im_size: list)-> np.ndarray:
70
+ result = torch.squeeze(F.interpolate(result, size=im_size, mode='bilinear') ,0)
71
+ ma = torch.max(result)
72
+ mi = torch.min(result)
73
+ result = (result-mi)/(ma-mi)
74
+ im_array = (result*255).permute(1,2,0).cpu().data.numpy().astype(np.uint8)
75
+ im_array = np.squeeze(im_array)
76
+ return im_array
briarmbg.py ADDED
@@ -0,0 +1,458 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from transformers import PreTrainedModel
5
+ from .MyConfig import RMBGConfig
6
+
7
+ class REBNCONV(nn.Module):
8
+ def __init__(self,in_ch=3,out_ch=3,dirate=1,stride=1):
9
+ super(REBNCONV,self).__init__()
10
+
11
+ self.conv_s1 = nn.Conv2d(in_ch,out_ch,3,padding=1*dirate,dilation=1*dirate,stride=stride)
12
+ self.bn_s1 = nn.BatchNorm2d(out_ch)
13
+ self.relu_s1 = nn.ReLU(inplace=True)
14
+
15
+ def forward(self,x):
16
+
17
+ hx = x
18
+ xout = self.relu_s1(self.bn_s1(self.conv_s1(hx)))
19
+
20
+ return xout
21
+
22
+ ## upsample tensor 'src' to have the same spatial size with tensor 'tar'
23
+ def _upsample_like(src,tar):
24
+
25
+ src = F.interpolate(src,size=tar.shape[2:],mode='bilinear')
26
+
27
+ return src
28
+
29
+
30
+ ### RSU-7 ###
31
+ class RSU7(nn.Module):
32
+
33
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3, img_size=512):
34
+ super(RSU7,self).__init__()
35
+
36
+ self.in_ch = in_ch
37
+ self.mid_ch = mid_ch
38
+ self.out_ch = out_ch
39
+
40
+ self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1) ## 1 -> 1/2
41
+
42
+ self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
43
+ self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
44
+
45
+ self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1)
46
+ self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
47
+
48
+ self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1)
49
+ self.pool3 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
50
+
51
+ self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=1)
52
+ self.pool4 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
53
+
54
+ self.rebnconv5 = REBNCONV(mid_ch,mid_ch,dirate=1)
55
+ self.pool5 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
56
+
57
+ self.rebnconv6 = REBNCONV(mid_ch,mid_ch,dirate=1)
58
+
59
+ self.rebnconv7 = REBNCONV(mid_ch,mid_ch,dirate=2)
60
+
61
+ self.rebnconv6d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
62
+ self.rebnconv5d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
63
+ self.rebnconv4d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
64
+ self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
65
+ self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
66
+ self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)
67
+
68
+ def forward(self,x):
69
+ b, c, h, w = x.shape
70
+
71
+ hx = x
72
+ hxin = self.rebnconvin(hx)
73
+
74
+ hx1 = self.rebnconv1(hxin)
75
+ hx = self.pool1(hx1)
76
+
77
+ hx2 = self.rebnconv2(hx)
78
+ hx = self.pool2(hx2)
79
+
80
+ hx3 = self.rebnconv3(hx)
81
+ hx = self.pool3(hx3)
82
+
83
+ hx4 = self.rebnconv4(hx)
84
+ hx = self.pool4(hx4)
85
+
86
+ hx5 = self.rebnconv5(hx)
87
+ hx = self.pool5(hx5)
88
+
89
+ hx6 = self.rebnconv6(hx)
90
+
91
+ hx7 = self.rebnconv7(hx6)
92
+
93
+ hx6d = self.rebnconv6d(torch.cat((hx7,hx6),1))
94
+ hx6dup = _upsample_like(hx6d,hx5)
95
+
96
+ hx5d = self.rebnconv5d(torch.cat((hx6dup,hx5),1))
97
+ hx5dup = _upsample_like(hx5d,hx4)
98
+
99
+ hx4d = self.rebnconv4d(torch.cat((hx5dup,hx4),1))
100
+ hx4dup = _upsample_like(hx4d,hx3)
101
+
102
+ hx3d = self.rebnconv3d(torch.cat((hx4dup,hx3),1))
103
+ hx3dup = _upsample_like(hx3d,hx2)
104
+
105
+ hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1))
106
+ hx2dup = _upsample_like(hx2d,hx1)
107
+
108
+ hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1))
109
+
110
+ return hx1d + hxin
111
+
112
+
113
+ ### RSU-6 ###
114
+ class RSU6(nn.Module):
115
+
116
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
117
+ super(RSU6,self).__init__()
118
+
119
+ self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)
120
+
121
+ self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
122
+ self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
123
+
124
+ self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1)
125
+ self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
126
+
127
+ self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1)
128
+ self.pool3 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
129
+
130
+ self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=1)
131
+ self.pool4 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
132
+
133
+ self.rebnconv5 = REBNCONV(mid_ch,mid_ch,dirate=1)
134
+
135
+ self.rebnconv6 = REBNCONV(mid_ch,mid_ch,dirate=2)
136
+
137
+ self.rebnconv5d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
138
+ self.rebnconv4d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
139
+ self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
140
+ self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
141
+ self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)
142
+
143
+ def forward(self,x):
144
+
145
+ hx = x
146
+
147
+ hxin = self.rebnconvin(hx)
148
+
149
+ hx1 = self.rebnconv1(hxin)
150
+ hx = self.pool1(hx1)
151
+
152
+ hx2 = self.rebnconv2(hx)
153
+ hx = self.pool2(hx2)
154
+
155
+ hx3 = self.rebnconv3(hx)
156
+ hx = self.pool3(hx3)
157
+
158
+ hx4 = self.rebnconv4(hx)
159
+ hx = self.pool4(hx4)
160
+
161
+ hx5 = self.rebnconv5(hx)
162
+
163
+ hx6 = self.rebnconv6(hx5)
164
+
165
+
166
+ hx5d = self.rebnconv5d(torch.cat((hx6,hx5),1))
167
+ hx5dup = _upsample_like(hx5d,hx4)
168
+
169
+ hx4d = self.rebnconv4d(torch.cat((hx5dup,hx4),1))
170
+ hx4dup = _upsample_like(hx4d,hx3)
171
+
172
+ hx3d = self.rebnconv3d(torch.cat((hx4dup,hx3),1))
173
+ hx3dup = _upsample_like(hx3d,hx2)
174
+
175
+ hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1))
176
+ hx2dup = _upsample_like(hx2d,hx1)
177
+
178
+ hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1))
179
+
180
+ return hx1d + hxin
181
+
182
+ ### RSU-5 ###
183
+ class RSU5(nn.Module):
184
+
185
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
186
+ super(RSU5,self).__init__()
187
+
188
+ self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)
189
+
190
+ self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
191
+ self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
192
+
193
+ self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1)
194
+ self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
195
+
196
+ self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1)
197
+ self.pool3 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
198
+
199
+ self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=1)
200
+
201
+ self.rebnconv5 = REBNCONV(mid_ch,mid_ch,dirate=2)
202
+
203
+ self.rebnconv4d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
204
+ self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
205
+ self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
206
+ self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)
207
+
208
+ def forward(self,x):
209
+
210
+ hx = x
211
+
212
+ hxin = self.rebnconvin(hx)
213
+
214
+ hx1 = self.rebnconv1(hxin)
215
+ hx = self.pool1(hx1)
216
+
217
+ hx2 = self.rebnconv2(hx)
218
+ hx = self.pool2(hx2)
219
+
220
+ hx3 = self.rebnconv3(hx)
221
+ hx = self.pool3(hx3)
222
+
223
+ hx4 = self.rebnconv4(hx)
224
+
225
+ hx5 = self.rebnconv5(hx4)
226
+
227
+ hx4d = self.rebnconv4d(torch.cat((hx5,hx4),1))
228
+ hx4dup = _upsample_like(hx4d,hx3)
229
+
230
+ hx3d = self.rebnconv3d(torch.cat((hx4dup,hx3),1))
231
+ hx3dup = _upsample_like(hx3d,hx2)
232
+
233
+ hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1))
234
+ hx2dup = _upsample_like(hx2d,hx1)
235
+
236
+ hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1))
237
+
238
+ return hx1d + hxin
239
+
240
+ ### RSU-4 ###
241
+ class RSU4(nn.Module):
242
+
243
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
244
+ super(RSU4,self).__init__()
245
+
246
+ self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)
247
+
248
+ self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
249
+ self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
250
+
251
+ self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1)
252
+ self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
253
+
254
+ self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1)
255
+
256
+ self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=2)
257
+
258
+ self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
259
+ self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
260
+ self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)
261
+
262
+ def forward(self,x):
263
+
264
+ hx = x
265
+
266
+ hxin = self.rebnconvin(hx)
267
+
268
+ hx1 = self.rebnconv1(hxin)
269
+ hx = self.pool1(hx1)
270
+
271
+ hx2 = self.rebnconv2(hx)
272
+ hx = self.pool2(hx2)
273
+
274
+ hx3 = self.rebnconv3(hx)
275
+
276
+ hx4 = self.rebnconv4(hx3)
277
+
278
+ hx3d = self.rebnconv3d(torch.cat((hx4,hx3),1))
279
+ hx3dup = _upsample_like(hx3d,hx2)
280
+
281
+ hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1))
282
+ hx2dup = _upsample_like(hx2d,hx1)
283
+
284
+ hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1))
285
+
286
+ return hx1d + hxin
287
+
288
+ ### RSU-4F ###
289
+ class RSU4F(nn.Module):
290
+
291
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
292
+ super(RSU4F,self).__init__()
293
+
294
+ self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)
295
+
296
+ self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
297
+ self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=2)
298
+ self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=4)
299
+
300
+ self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=8)
301
+
302
+ self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=4)
303
+ self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=2)
304
+ self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)
305
+
306
+ def forward(self,x):
307
+
308
+ hx = x
309
+
310
+ hxin = self.rebnconvin(hx)
311
+
312
+ hx1 = self.rebnconv1(hxin)
313
+ hx2 = self.rebnconv2(hx1)
314
+ hx3 = self.rebnconv3(hx2)
315
+
316
+ hx4 = self.rebnconv4(hx3)
317
+
318
+ hx3d = self.rebnconv3d(torch.cat((hx4,hx3),1))
319
+ hx2d = self.rebnconv2d(torch.cat((hx3d,hx2),1))
320
+ hx1d = self.rebnconv1d(torch.cat((hx2d,hx1),1))
321
+
322
+ return hx1d + hxin
323
+
324
+
325
+ class myrebnconv(nn.Module):
326
+ def __init__(self, in_ch=3,
327
+ out_ch=1,
328
+ kernel_size=3,
329
+ stride=1,
330
+ padding=1,
331
+ dilation=1,
332
+ groups=1):
333
+ super(myrebnconv,self).__init__()
334
+
335
+ self.conv = nn.Conv2d(in_ch,
336
+ out_ch,
337
+ kernel_size=kernel_size,
338
+ stride=stride,
339
+ padding=padding,
340
+ dilation=dilation,
341
+ groups=groups)
342
+ self.bn = nn.BatchNorm2d(out_ch)
343
+ self.rl = nn.ReLU(inplace=True)
344
+
345
+ def forward(self,x):
346
+ return self.rl(self.bn(self.conv(x)))
347
+
348
+
349
+ class BriaRMBG(PreTrainedModel):
350
+ config_class = RMBGConfig
351
+ def __init__(self,config:RMBGConfig = RMBGConfig()):
352
+ super().__init__(config)
353
+ in_ch = config.in_ch # 3
354
+ out_ch = config.out_ch # 1
355
+ self.conv_in = nn.Conv2d(in_ch,64,3,stride=2,padding=1)
356
+ self.pool_in = nn.MaxPool2d(2,stride=2,ceil_mode=True)
357
+
358
+ self.stage1 = RSU7(64,32,64)
359
+ self.pool12 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
360
+
361
+ self.stage2 = RSU6(64,32,128)
362
+ self.pool23 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
363
+
364
+ self.stage3 = RSU5(128,64,256)
365
+ self.pool34 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
366
+
367
+ self.stage4 = RSU4(256,128,512)
368
+ self.pool45 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
369
+
370
+ self.stage5 = RSU4F(512,256,512)
371
+ self.pool56 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
372
+
373
+ self.stage6 = RSU4F(512,256,512)
374
+
375
+ # decoder
376
+ self.stage5d = RSU4F(1024,256,512)
377
+ self.stage4d = RSU4(1024,128,256)
378
+ self.stage3d = RSU5(512,64,128)
379
+ self.stage2d = RSU6(256,32,64)
380
+ self.stage1d = RSU7(128,16,64)
381
+
382
+ self.side1 = nn.Conv2d(64,out_ch,3,padding=1)
383
+ self.side2 = nn.Conv2d(64,out_ch,3,padding=1)
384
+ self.side3 = nn.Conv2d(128,out_ch,3,padding=1)
385
+ self.side4 = nn.Conv2d(256,out_ch,3,padding=1)
386
+ self.side5 = nn.Conv2d(512,out_ch,3,padding=1)
387
+ self.side6 = nn.Conv2d(512,out_ch,3,padding=1)
388
+
389
+ # self.outconv = nn.Conv2d(6*out_ch,out_ch,1)
390
+
391
+ def forward(self,x):
392
+
393
+ hx = x
394
+
395
+ hxin = self.conv_in(hx)
396
+ #hx = self.pool_in(hxin)
397
+
398
+ #stage 1
399
+ hx1 = self.stage1(hxin)
400
+ hx = self.pool12(hx1)
401
+
402
+ #stage 2
403
+ hx2 = self.stage2(hx)
404
+ hx = self.pool23(hx2)
405
+
406
+ #stage 3
407
+ hx3 = self.stage3(hx)
408
+ hx = self.pool34(hx3)
409
+
410
+ #stage 4
411
+ hx4 = self.stage4(hx)
412
+ hx = self.pool45(hx4)
413
+
414
+ #stage 5
415
+ hx5 = self.stage5(hx)
416
+ hx = self.pool56(hx5)
417
+
418
+ #stage 6
419
+ hx6 = self.stage6(hx)
420
+ hx6up = _upsample_like(hx6,hx5)
421
+
422
+ #-------------------- decoder --------------------
423
+ hx5d = self.stage5d(torch.cat((hx6up,hx5),1))
424
+ hx5dup = _upsample_like(hx5d,hx4)
425
+
426
+ hx4d = self.stage4d(torch.cat((hx5dup,hx4),1))
427
+ hx4dup = _upsample_like(hx4d,hx3)
428
+
429
+ hx3d = self.stage3d(torch.cat((hx4dup,hx3),1))
430
+ hx3dup = _upsample_like(hx3d,hx2)
431
+
432
+ hx2d = self.stage2d(torch.cat((hx3dup,hx2),1))
433
+ hx2dup = _upsample_like(hx2d,hx1)
434
+
435
+ hx1d = self.stage1d(torch.cat((hx2dup,hx1),1))
436
+
437
+
438
+ #side output
439
+ d1 = self.side1(hx1d)
440
+ d1 = _upsample_like(d1,x)
441
+
442
+ d2 = self.side2(hx2d)
443
+ d2 = _upsample_like(d2,x)
444
+
445
+ d3 = self.side3(hx3d)
446
+ d3 = _upsample_like(d3,x)
447
+
448
+ d4 = self.side4(hx4d)
449
+ d4 = _upsample_like(d4,x)
450
+
451
+ d5 = self.side5(hx5d)
452
+ d5 = _upsample_like(d5,x)
453
+
454
+ d6 = self.side6(hx6)
455
+ d6 = _upsample_like(d6,x)
456
+
457
+ return [F.sigmoid(d1), F.sigmoid(d2), F.sigmoid(d3), F.sigmoid(d4), F.sigmoid(d5), F.sigmoid(d6)],[hx1d,hx2d,hx3d,hx4d,hx5d,hx6]
458
+
config.json ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "briaai/RMBG-1.4",
3
+ "architectures": [
4
+ "BriaRMBG"
5
+ ],
6
+ "auto_map": {
7
+ "AutoConfig": "MyConfig.RMBGConfig",
8
+ "AutoModelForImageSegmentation": "briarmbg.BriaRMBG"
9
+ },
10
+ "custom_pipelines": {
11
+ "image-segmentation": {
12
+ "impl": "MyPipe.RMBGPipe",
13
+ "pt": [
14
+ "AutoModelForImageSegmentation"
15
+ ],
16
+ "tf": [],
17
+ "type": "image"
18
+ }
19
+ },
20
+ "in_ch": 3,
21
+ "model_type": "SegformerForSemanticSegmentation",
22
+ "out_ch": 1,
23
+ "torch_dtype": "float32",
24
+ "transformers_version": "4.38.0.dev0"
25
+ }
example_inference.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from skimage import io
2
+ import torch, os
3
+ from PIL import Image
4
+ from briarmbg import BriaRMBG
5
+ from utilities import preprocess_image, postprocess_image
6
+ from huggingface_hub import hf_hub_download
7
+
8
+ def example_inference():
9
+
10
+ im_path = f"{os.path.dirname(os.path.abspath(__file__))}/example_input.jpg"
11
+
12
+ net = BriaRMBG()
13
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
14
+ net = BriaRMBG.from_pretrained("briaai/RMBG-1.4")
15
+ net.to(device)
16
+ net.eval()
17
+
18
+ # prepare input
19
+ model_input_size = [1024,1024]
20
+ orig_im = io.imread(im_path)
21
+ orig_im_size = orig_im.shape[0:2]
22
+ image = preprocess_image(orig_im, model_input_size).to(device)
23
+
24
+ # inference
25
+ result=net(image)
26
+
27
+ # post process
28
+ result_image = postprocess_image(result[0][0], orig_im_size)
29
+
30
+ # save result
31
+ pil_im = Image.fromarray(result_image)
32
+ no_bg_image = Image.new("RGBA", pil_im.size, (0,0,0,0))
33
+ orig_image = Image.open(im_path)
34
+ no_bg_image.paste(orig_image, mask=pil_im)
35
+ no_bg_image.save("example_image_no_bg.png")
36
+
37
+
38
+ if __name__ == "__main__":
39
+ example_inference()
model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e0869e907ec6909e71fed3d19847716bf8d4e9dec1e48f1b67b1cbdc3a1ac952
3
+ size 134
preprocessor_config.json ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "do_normalize": true,
3
+ "do_pad": false,
4
+ "do_rescale": true,
5
+ "do_resize": true,
6
+ "image_mean": [
7
+ 0.5,
8
+ 0.5,
9
+ 0.5
10
+ ],
11
+ "feature_extractor_type": "ImageFeatureExtractor",
12
+ "image_std": [
13
+ 1,
14
+ 1,
15
+ 1
16
+ ],
17
+ "resample": 2,
18
+ "rescale_factor": 0.00392156862745098,
19
+ "size": {
20
+ "width": 1024,
21
+ "height": 1024
22
+ }
23
+ }
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:aaa8141adbc209cb12aa69347179a72eef72736b6729ea9a726fd5a8577d53a7
3
+ size 134
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ torchvision
3
+ pillow
4
+ numpy
5
+ typing
6
+ scikit-image
7
+ huggingface_hub
8
+ transformers>=4.39.1
utilities.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from torchvision.transforms.functional import normalize
4
+ import numpy as np
5
+
6
+ def preprocess_image(im: np.ndarray, model_input_size: list) -> torch.Tensor:
7
+ if len(im.shape) < 3:
8
+ im = im[:, :, np.newaxis]
9
+ # orig_im_size=im.shape[0:2]
10
+ im_tensor = torch.tensor(im, dtype=torch.float32).permute(2,0,1)
11
+ im_tensor = F.interpolate(torch.unsqueeze(im_tensor,0), size=model_input_size, mode='bilinear').type(torch.uint8)
12
+ image = torch.divide(im_tensor,255.0)
13
+ image = normalize(image,[0.5,0.5,0.5],[1.0,1.0,1.0])
14
+ return image
15
+
16
+
17
+ def postprocess_image(result: torch.Tensor, im_size: list)-> np.ndarray:
18
+ result = torch.squeeze(F.interpolate(result, size=im_size, mode='bilinear') ,0)
19
+ ma = torch.max(result)
20
+ mi = torch.min(result)
21
+ result = (result-mi)/(ma-mi)
22
+ im_array = (result*255).permute(1,2,0).cpu().data.numpy().astype(np.uint8)
23
+ im_array = np.squeeze(im_array)
24
+ return im_array
25
+