Andres Felipe Ruiz-Hurtado commited on
Commit
9f3ae4a
·
1 Parent(s): 173edf9
.gitignore ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PyInstaller
30
+ # Usually these files are written by a python script from a template
31
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
+ *.manifest
33
+ *.spec
34
+
35
+ # Installer logs
36
+ pip-log.txt
37
+ pip-delete-this-directory.txt
38
+
39
+ # Unit test / coverage reports
40
+ htmlcov/
41
+ .tox/
42
+ .nox/
43
+ .coverage
44
+ .coverage.*
45
+ .cache
46
+ nosetests.xml
47
+ coverage.xml
48
+ *.cover
49
+ *.py,cover
50
+ .hypothesis/
51
+ .pytest_cache/
52
+ cover/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ .pybuilder/
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # IPython
82
+ profile_default/
83
+ ipython_config.py
84
+
85
+ # pyenv
86
+ # For a library or package, you might want to ignore these files since the code is
87
+ # intended to run in multiple environments; otherwise, check them in:
88
+ # .python-version
89
+
90
+ # pipenv
91
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
+ # install all needed dependencies.
95
+ #Pipfile.lock
96
+
97
+ # poetry
98
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
100
+ # commonly ignored for libraries.
101
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102
+ #poetry.lock
103
+
104
+ # pdm
105
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106
+ #pdm.lock
107
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108
+ # in version control.
109
+ # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
110
+ .pdm.toml
111
+ .pdm-python
112
+ .pdm-build/
113
+
114
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
115
+ __pypackages__/
116
+
117
+ # Celery stuff
118
+ celerybeat-schedule
119
+ celerybeat.pid
120
+
121
+ # SageMath parsed files
122
+ *.sage.py
123
+
124
+ # Environments
125
+ .env
126
+ .venv
127
+ env/
128
+ venv/
129
+ ENV/
130
+ env.bak/
131
+ venv.bak/
132
+
133
+ # Spyder project settings
134
+ .spyderproject
135
+ .spyproject
136
+
137
+ # Rope project settings
138
+ .ropeproject
139
+
140
+ # mkdocs documentation
141
+ /site
142
+
143
+ # mypy
144
+ .mypy_cache/
145
+ .dmypy.json
146
+ dmypy.json
147
+
148
+ # Pyre type checker
149
+ .pyre/
150
+
151
+ # pytype static type analyzer
152
+ .pytype/
153
+
154
+ # Cython debug symbols
155
+ cython_debug/
156
+
157
+ # PyCharm
158
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
159
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
160
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
161
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
162
+ #.idea/
bgremover.py ADDED
@@ -0,0 +1,744 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2 as cv
2
+ import numpy as np
3
+ from PIL import Image
4
+ import glob
5
+ import pathlib
6
+
7
+ import sys
8
+
9
+ import u2net_utils
10
+
11
+ import os
12
+ from skimage import io, transform
13
+ import torch
14
+ import torchvision
15
+ from torch.autograd import Variable
16
+ import torch.nn as nn
17
+ import torch.nn.functional as F
18
+ from torch.utils.data import Dataset, DataLoader
19
+ from torchvision import transforms#, utils
20
+ # import torch.optim as optim
21
+
22
+ from u2net_utils.data_loader import RescaleT
23
+ from u2net_utils.data_loader import ToTensor
24
+ from u2net_utils.data_loader import ToTensorLab
25
+ from u2net_utils.data_loader import SalObjDataset
26
+
27
+ from u2net_utils.model import U2NET # full size version 173.6 MB
28
+ from u2net_utils.model import U2NETP # small version u2net 4.7 MB
29
+
30
+ from torchvision import models
31
+
32
+
33
+ import onnxruntime as ort
34
+ import cv2 as cv
35
+ import numpy as np
36
+ from torchvision.transforms import v2 as transforms
37
+
38
+ # MODEL_PATH = r"\\CATALOGUE.CGIARAD.ORG\AcceleratedBreedingInitiative\4.Scripts\AndresRuiz\local_mydata_gpu\models\u2net.pth"
39
+ # MODEL_PATH = r"D:\CIAT\catalogue\AcceleratedBreedingInitiative\1.Data\16. Spidermites_AdrianK\best_models"
40
+ # MODEL_PATH = r"D:\local_mydata\models\spidermites\best_models"
41
+
42
+ MODEL_PATH = "./models"
43
+
44
+ #************************
45
+ # from loguru import logger
46
+ # from segment_anything import build_sam, SamPredictor, SamAutomaticMaskGenerator
47
+ # import subprocess
48
+
49
+ # # Grounding DINO
50
+ # import GroundingDINO.groundingdino.datasets.transforms as T
51
+ # from GroundingDINO.groundingdino.models import build_model
52
+ # from GroundingDINO.groundingdino.util import box_ops
53
+ # from GroundingDINO.groundingdino.util.slconfig import SLConfig
54
+ # from GroundingDINO.groundingdino.util.utils import clean_state_dict, get_phrases_from_posmap
55
+
56
+ # from huggingface_hub import hf_hub_download
57
+
58
+ import gc
59
+
60
+ def clear():
61
+ gc.collect()
62
+ torch.cuda.empty_cache()
63
+
64
+ # normalize the predicted SOD probability map
65
+ def normPRED(d):
66
+ ma = torch.max(d)
67
+ mi = torch.min(d)
68
+
69
+ dn = (d-mi)/(ma-mi)
70
+
71
+ return dn
72
+
73
+ class BackgroundRemover():
74
+
75
+ def __init__(self):
76
+
77
+
78
+ #Load model
79
+ #model_dir = "/workspace/u2net.pth"
80
+ #model_dir = "D:/local_mydata/models/u2net.pth"
81
+ model_dir = r"\\CATALOGUE.CGIARAD.ORG\AcceleratedBreedingInitiative\4.Scripts\AndresRuiz\local_mydata_gpu\models\u2net.pth"
82
+ model_dir = os.path.join(MODEL_PATH, "u2net.pth")
83
+
84
+ ## Load model
85
+ net = U2NET(3,1)
86
+
87
+ if torch.cuda.is_available():
88
+ net.load_state_dict(torch.load(model_dir))
89
+ net.cuda()
90
+ else:
91
+ net.load_state_dict(torch.load(model_dir, map_location='cpu'))
92
+ net.eval()
93
+
94
+ self.net = net
95
+
96
+ def remove_background(self, filepath_image):
97
+
98
+ img_name_list = [filepath_image]
99
+
100
+ test_salobj_dataset = SalObjDataset(img_name_list = img_name_list,
101
+ lbl_name_list = [],
102
+ transform=transforms.Compose([RescaleT(320),
103
+ ToTensorLab(flag=0)])
104
+ )
105
+ test_salobj_dataloader = DataLoader(test_salobj_dataset,
106
+ batch_size=1,
107
+ shuffle=False,
108
+ num_workers=1)
109
+
110
+ net = self.net
111
+
112
+ for i_test, data_test in enumerate(test_salobj_dataloader):
113
+
114
+ print("inferencing:",img_name_list[i_test].split(os.sep)[-1])
115
+
116
+ inputs_test = data_test['image']
117
+ inputs_test = inputs_test.type(torch.FloatTensor)
118
+
119
+ if torch.cuda.is_available():
120
+ inputs_test = Variable(inputs_test.cuda())
121
+ else:
122
+ inputs_test = Variable(inputs_test)
123
+
124
+ d1,d2,d3,d4,d5,d6,d7= net(inputs_test)
125
+
126
+ # normalization
127
+ pred = d1[:,0,:,:]
128
+ pred = normPRED(pred)
129
+
130
+ # save results to test_results folder
131
+ #if not os.path.exists(prediction_dir):
132
+ # os.makedirs(prediction_dir, exist_ok=True)
133
+ #save_output(img_name_list[i_test],pred,prediction_dir)
134
+
135
+ predict = pred
136
+ predict = predict.squeeze()
137
+ #mask_torch.permute(1, 2, 0).detach().cpu().numpy()
138
+ predict_np = predict.cpu().data.numpy()
139
+
140
+ img = cv.imread(filepath_image)
141
+ w = img.shape[1]
142
+ h = img.shape[0]
143
+
144
+ #im = Image.fromarray(predict_np*255).convert('RGB')
145
+ #image = io.imread(filepath_image)
146
+ #imo = im.resize((image.shape[1],image.shape[0]),resample=Image.BILINEAR)
147
+
148
+ imo = cv.resize(predict_np, (w,h), cv.INTER_LINEAR )
149
+
150
+ #del d1,d2,d3,d4,d5,d6,d7
151
+ return imo
152
+
153
+ def remove_background_save(self, path_in, path_out, path_out_mask = None):
154
+
155
+ print("remove_background_save")
156
+
157
+ mask_torch = self.remove_background(path_in)
158
+ mask = mask_torch*255
159
+ mask = mask.astype(np.uint8)
160
+
161
+ img = cv.imread(path_in)
162
+ mask0 = mask#cv.UMat(cv.imread(mask,0))
163
+ #127
164
+ #200
165
+ ret,binary_mask = cv.threshold(mask0,80,255,cv.THRESH_BINARY)
166
+ binary_mask = np.uint8(binary_mask)
167
+ res = cv.bitwise_and(img,img, mask = binary_mask)
168
+
169
+ cv.imwrite(path_out, res)
170
+
171
+ if not (path_out_mask == None):
172
+ cv.imwrite(path_out_mask, mask)
173
+
174
+ def remove_background_dir(self, path_in, path_out):
175
+
176
+ img_name_list = glob.glob(os.path.join(path_in, "*.jpg"))
177
+
178
+ for img_name in img_name_list:
179
+
180
+ img_name_output = img_name.replace(path_in, path_out)
181
+
182
+ if not os.path.exists(img_name_output):
183
+ self.remove_background_save(img_name, img_name_output)
184
+ print(img_name.replace(path_in, path_out))
185
+
186
+ def remove_background_gradio(self, np_image):
187
+
188
+ w = np_image.shape[1]
189
+ h = np_image.shape[0]
190
+
191
+ #image = torch.tensor(np_image)
192
+ #image = image.permute(2,0,1)
193
+
194
+ image = np_image#Image.fromarray(np_image)
195
+ imidx = np.array([0])
196
+ #label = "test"
197
+
198
+ #***
199
+ label_3 = np.zeros(image.shape)
200
+
201
+ label = np.zeros(label_3.shape[0:2])
202
+ if(3==len(label_3.shape)):
203
+ label = label_3[:,:,0]
204
+ elif(2==len(label_3.shape)):
205
+ label = label_3
206
+
207
+ if(3==len(image.shape) and 2==len(label.shape)):
208
+ label = label[:,:,np.newaxis]
209
+ elif(2==len(image.shape) and 2==len(label.shape)):
210
+ image = image[:,:,np.newaxis]
211
+ label = label[:,:,np.newaxis]
212
+ #***
213
+
214
+ sample = {'imidx':imidx, 'image':image, 'label':label}
215
+ print(image.shape)
216
+ print(label.shape)
217
+
218
+
219
+ eval_transform = transforms.Compose([RescaleT(320),ToTensorLab(flag=0)])
220
+ #eval_transform = transforms.Compose([RescaleT(320)])
221
+ #eval_transform = transforms.Compose([RescaleT(320)])
222
+ #eval_transform = transforms.Compose([ToTensorLab(flag=0)])
223
+ #eval_transform = transforms.Compose([transforms.Resize(320)
224
+ # , transforms.ToTensor()])
225
+ #eval_transform = transforms.Compose([transforms.Resize(320)])
226
+
227
+ test_salobj_dataloader = DataLoader(sample,
228
+ batch_size=1,
229
+ shuffle=False,
230
+ num_workers=1)
231
+
232
+ sample = eval_transform(sample)
233
+
234
+ net = self.net
235
+
236
+ #for i_test, data_test in enumerate(test_salobj_dataloader):
237
+
238
+ #device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
239
+
240
+ #x = eval_transform(sample)
241
+ #x = x[:3, ...].to(device)
242
+
243
+ inputs_test = sample['image']
244
+ inputs_test = inputs_test.type(torch.FloatTensor)
245
+ inputs_test = inputs_test.unsqueeze(0)
246
+
247
+ print(inputs_test.shape)
248
+
249
+ if torch.cuda.is_available():
250
+ inputs_test = Variable(inputs_test.cuda())
251
+ else:
252
+ inputs_test = Variable(inputs_test)
253
+
254
+
255
+ d1,d2,d3,d4,d5,d6,d7= net(inputs_test)
256
+
257
+ # normalization
258
+ pred = d1[:,0,:,:]
259
+ pred = normPRED(pred)
260
+
261
+ predict = pred
262
+ predict = predict.squeeze()
263
+ #mask_torch.permute(1, 2, 0).detach().cpu().numpy()
264
+ predict_np = predict.cpu().data.numpy()
265
+
266
+ imo = cv.resize(predict_np, (w,h), cv.INTER_LINEAR )
267
+
268
+ mask = imo*255
269
+ mask = mask.astype(np.uint8)
270
+ mask0 = mask#cv.UMat(cv.imread(mask,0))
271
+ #127
272
+ #200
273
+ ret,binary_mask = cv.threshold(mask0,80,255,cv.THRESH_BINARY)
274
+ #ret,binary_mask = cv.threshold(mask0,233,255,cv.THRESH_BINARY)
275
+ binary_mask = np.uint8(binary_mask)
276
+ res = cv.bitwise_and(np_image,np_image, mask = binary_mask)
277
+
278
+ return mask, res
279
+
280
+ def apply_mask(self, input, mask, threshold):
281
+
282
+ mask = cv.cvtColor(mask, cv.COLOR_BGR2GRAY)
283
+ ret,binary_mask = cv.threshold(mask,threshold,255,cv.THRESH_BINARY)
284
+ #binary_mask = np.uint8(binary_mask)
285
+ #binary_mask = mask
286
+ print("apply mask")
287
+ print(input.shape)
288
+ print(input.dtype)
289
+ print(binary_mask.shape)
290
+ print(binary_mask.dtype)
291
+ res = cv.bitwise_and(input,input, mask = binary_mask)
292
+
293
+ # foreground_alpha = mask.astype(np.float32) / 255.0
294
+ # # Create a new image to store the result with same size and type as foreground
295
+ # blended_image = np.zeros_like(input)
296
+
297
+ # # Loop through each pixel and apply alpha based on mask value
298
+ # for channel in range(3): # Loop through BGR channels
299
+ # blended_image[:, :, channel] = input[:, :, channel] * foreground_alpha
300
+
301
+
302
+ return res, binary_mask
303
+
304
+
305
+ def get_transform(train = True):
306
+ transforms_list = []
307
+ #if train:
308
+ # transforms.append(T.RandomHorizontalFlip(0.5))
309
+ transforms_list.append(transforms.Resize(256))
310
+ transforms_list.append(transforms.CenterCrop(256))
311
+ #transforms_list.append(transforms.ToDtype(torch.float, scale=True))
312
+ transforms_list.append(transforms.ToTensor())
313
+ #transforms_list.append(transforms.ToDtype(torch.float32, scale=True))
314
+ transforms_list.append(transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]))
315
+
316
+ return transforms.Compose(transforms_list)
317
+
318
+ class DamageClassifier():
319
+
320
+ def __init__(self):
321
+
322
+ self.device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
323
+ self.model_name =""
324
+
325
+
326
+ def initialize(self, model_name):
327
+
328
+ #Load model
329
+
330
+ if model_name == "Resnet18":
331
+
332
+ model_filepath = r"\\CATALOGUE.CGIARAD.ORG\AcceleratedBreedingInitiative\1.Data\16. Spidermites_AdrianK\best_models\resnet18_SpidermitesModel.pth"
333
+ model_filepath = os.path.join(MODEL_PATH, "resnet18_SpidermitesModel.pth")
334
+ model = models.resnet18(weights='IMAGENET1K_V1')
335
+
336
+ if model_name == "Resnet152":
337
+
338
+ model_filepath = r"\\CATALOGUE.CGIARAD.ORG\AcceleratedBreedingInitiative\1.Data\16. Spidermites_AdrianK\best_models\short_resnet152_SpidermitesModel_44_44.pth"
339
+ model_filepath = os.path.join(MODEL_PATH, "short_resnet152_SpidermitesModel_44_44.pth")
340
+ model = models.resnet152(weights='IMAGENET1K_V1')
341
+
342
+ if model_name == "Googlenet":
343
+
344
+ model_filepath = r"\\catalogue.cgiarad.org\AcceleratedBreedingInitiative\1.Data\16. Spidermites_AdrianK\best_models\regnet_x_32gf_SpidermitesModel.pth"
345
+ model_filepath = model_filepath = os.path.join(MODEL_PATH, "regnet_x_32gf_SpidermitesModel.pth")
346
+ model = models.regnet_x_32gf(weights='IMAGENET1K_V1')
347
+
348
+ if model_name == "Regnet32":
349
+
350
+ model_filepath = r"\\CATALOGUE.CGIARAD.ORG\AcceleratedBreedingInitiative\1.Data\16. Spidermites_AdrianK\best_models\short_resnet18_SpidermitesModel.pth"
351
+ model_filepath = model_filepath = os.path.join(MODEL_PATH, "short_resnet18_SpidermitesModel.pth")
352
+ model = models.resnet18(weights='IMAGENET1K_V1')
353
+
354
+ #Add fully connected layer at the end with num_classes as output
355
+ num_ftrs = model.fc.in_features
356
+ model.fc = nn.Linear(num_ftrs, 4)
357
+
358
+ if torch.cuda.is_available():
359
+ model.load_state_dict(torch.load(model_filepath))
360
+ model.cuda()
361
+ else:
362
+ model.load_state_dict(torch.load(model_filepath, map_location='cpu'))
363
+ model.eval()
364
+
365
+ self.model = model
366
+ self.model_name = model_name
367
+
368
+ return
369
+
370
+
371
+ def inference(self, np_image, model_name):
372
+
373
+ if model_name == "Regnet":
374
+
375
+ model_filepath = r"\\CATALOGUE.CGIARAD.ORG\AcceleratedBreedingInitiative\1.Data\16. Spidermites_AdrianK\best_models\regnet_x_32gf_SpidermitesModel.onnx"
376
+ model_filepath = model_filepath = os.path.join(MODEL_PATH, "regnet_x_32gf_SpidermitesModel.onnx")
377
+ ort_sess = ort.InferenceSession(model_filepath
378
+ ,providers=ort.get_available_providers()
379
+ )
380
+
381
+ transforms_list = []
382
+ transforms_list.append(transforms.ToTensor())
383
+ transforms_list.append(transforms.Resize(512))
384
+ transforms_list.append(transforms.CenterCrop(512))
385
+ transforms_list.append(transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]))
386
+
387
+ apply_t = transforms.Compose(transforms_list)
388
+
389
+ img = apply_t(np_image)
390
+
391
+ imgs = np.array([img.numpy()])
392
+
393
+ outputs = ort_sess.run(None, {'input': [img.numpy()]})
394
+
395
+ np_res = outputs[0][0]
396
+
397
+
398
+ final_res = {'0-(No damage)': np_res[0]
399
+ ,'1-3-(Moderately damaged)': np_res[1]
400
+ ,'4-7-(Damaged)': np_res[2]
401
+ ,'8-10-(Severely damaged)': np_res[3]}
402
+
403
+ return final_res
404
+
405
+ else:
406
+
407
+ if self.model_name != model_name:
408
+ self.initialize(model_name)
409
+
410
+ with torch.no_grad():
411
+
412
+ print("inference")
413
+ print(np_image.shape)
414
+
415
+ pil_image = Image.fromarray(np_image.astype('uint8'))
416
+ data_transforms = get_transform(train = False)
417
+
418
+ img = data_transforms(pil_image)
419
+
420
+ inputs = img.to(self.device)
421
+
422
+ outputs = self.model(inputs.unsqueeze(0))
423
+ #_, preds = torch.max(outputs, 1)
424
+
425
+ print(outputs)
426
+
427
+ _, preds = torch.max(outputs, 1)
428
+ print(preds)
429
+
430
+ m = nn.Softmax(dim=1)
431
+ res = m(outputs)
432
+ print(res)
433
+
434
+ np_res = res[0].cpu().numpy()
435
+ print(np_res)
436
+
437
+ final_res = {'0-(No damage)': np_res[0]
438
+ ,'1-3-(Moderately damaged)': np_res[1]
439
+ ,'4-7-(Damaged)': np_res[2]
440
+ ,'8-10-(Severely damaged)': np_res[3]}
441
+
442
+ return final_res
443
+
444
+ class ColorCheckerDetector():
445
+
446
+ def __init__(self):
447
+
448
+ return
449
+
450
+ def process(self, np_image_mask, np_image):
451
+
452
+ ret,binary_mask = cv.threshold(np_image_mask,80,255,cv.THRESH_BINARY)
453
+ binary_mask_C = cv.cvtColor(binary_mask, cv.COLOR_BGR2GRAY) #change to single channel
454
+ (contours, hierarchy) = cv.findContours(binary_mask_C, cv.RETR_TREE, cv.CHAIN_APPROX_SIMPLE)
455
+
456
+ main_contour = contours[0]
457
+
458
+ # compute the center of the contour
459
+ moments = cv.moments(main_contour)
460
+ cx = int(moments["m10"] / moments["m00"])
461
+ cy = int(moments["m01"] / moments["m00"])
462
+
463
+ # Bounding rect
464
+ bb_x,bb_y,bb_w,bb_h = cv.boundingRect(binary_mask_C)
465
+
466
+ # Min Bounding rect
467
+ rect = cv.minAreaRect(main_contour)
468
+ box = cv.boxPoints(rect)
469
+ box = np.int64(box)
470
+
471
+ # Fitting line
472
+ rows,cols = binary_mask_C.shape[:2]
473
+ #[vx,vy,x,y] = cv.fitLine(main_contour, cv.DIST_L2,0,0.01,0.01)
474
+ [vx,vy,x,y] = cv.fitLine(box, cv.DIST_L2,0,0.01,0.01)
475
+ lefty = int((-x*vy/vx) + y)
476
+ righty = int(((cols-x)*vy/vx)+y)
477
+ point1 = (cols-1,righty)
478
+ point2 = (0,lefty)
479
+ angle = np.arctan2(np.abs(righty-lefty),cols)
480
+
481
+ # rotation matrix
482
+ M_rot = cv.getRotationMatrix2D((cx, cy), -angle*180.0/np.pi, 1.0)
483
+ rotated = cv.warpAffine(np_image, M_rot, (binary_mask.shape[1], binary_mask.shape[0]))
484
+
485
+ #perspective transform
486
+ input_pts = box.astype(np.float32)
487
+ maxHeight = 200
488
+ maxWidth = 290
489
+ output_pts = np.float32([[0, 0],
490
+ [maxWidth - 1, 0],
491
+ [maxWidth - 1, maxHeight - 1] ,
492
+ [0, maxHeight - 1]]
493
+ )
494
+ M_per = cv.getPerspectiveTransform(input_pts,output_pts)
495
+ corrected = cv.warpPerspective(np_image,M_per,(maxWidth, maxHeight),flags=cv.INTER_LINEAR)
496
+
497
+ res = cv.drawContours(np_image, main_contour, -1, (255,255,0), 5)
498
+ res = cv.rectangle(res,(bb_x,bb_y),(bb_x+bb_w,bb_y+bb_h),(0,255,0),5)
499
+ res = cv.drawContours(res,[box],0,(0,0,255),5)
500
+ res = cv.line(res,(cols-1,righty),(0,lefty),(0,0,255),5)
501
+
502
+ return [res, rotated, corrected]
503
+
504
+
505
+
506
+
507
+ class BatchProcessor():
508
+
509
+ def __init__(self):
510
+ return
511
+
512
+ def batch_process(self, input_dir, output_dir, output_suffixes = ["output"], format="jpg", pattern='**/*.tiff', processing_fc=None, output_format = None):
513
+
514
+ if processing_fc == None:
515
+ print("Processing function is None")
516
+ return
517
+ else:
518
+
519
+ if output_format == None:
520
+ output_format = format
521
+
522
+ # Get list of files in folder and subfolders
523
+ pattern = '**/*.' + format
524
+ files = glob.glob(pattern, root_dir=input_dir, recursive=True)
525
+
526
+ for file in files:
527
+
528
+ filepath = os.path.join(input_dir, file)
529
+ basename = os.path.basename(filepath)
530
+ parent_dir = os.path.dirname(filepath)
531
+ extra_path = file.replace(basename,"")
532
+ output_dir = os.path.join(output_dir, extra_path)
533
+
534
+ # Create output filepath list
535
+ output_filepaths = []
536
+ for suffix in output_suffixes:
537
+ output_filepaths.append(os.path.join(output_dir, basename.replace("." + format, "_" + suffix + "." + output_format)))
538
+
539
+ if not os.path.exists(output_filepaths[0]):# Process only if first output file does not exist
540
+
541
+ if not os.path.exists(output_dir): # Create subfolders if necessary
542
+ pathlib.Path(output_dir).mkdir(parents=True, exist_ok=True)
543
+
544
+
545
+ processing_fc(filepath, output_filepaths) # Process and save file
546
+
547
+ print(file)
548
+ print(output_filepaths[0])
549
+ print("****")
550
+
551
+
552
+ class Segmentor():
553
+
554
+ def __init__(self):
555
+
556
+ self.sam_predictor = None
557
+ self.groundingdino_model = None
558
+ #self.sam_checkpoint = './sam_vit_h_4b8939.pth'
559
+ #self.sam_checkpoint = r"\\CATALOGUE.CGIARAD.ORG\AcceleratedBreedingInitiative\4.Scripts\AndresRuiz\local_mydata_backup\model\sam_vit_h_4b8939.pth"
560
+ self.sam_checkpoint = r"D:\local_mydev\Grounded-Segment-Anything\sam_vit_h_4b8939.pth"
561
+
562
+
563
+ # self.config_file = 'GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py'
564
+ # self.ckpt_repo_id = "ShilongLiu/GroundingDINO"
565
+ # self.ckpt_filename = "groundingdino_swint_ogc.pth"
566
+
567
+ self.config_file = r"D:\local_mydev\gsam\GroundingDINO\groundingdino\config\GroundingDINO_SwinT_OGC.py"
568
+ self.ckpt_repo_id = "ShilongLiu/GroundingDINO"
569
+ self.ckpt_filename = "groundingdino_swint_ogc.pth"
570
+
571
+ self.device ='cpu'
572
+
573
+ self.load_sam_model(self.device)
574
+ self.load_groundingdino_model(self.device)
575
+
576
+ return
577
+
578
+ def get_sam_vit_h_4b8939(self):
579
+ return
580
+ # if not os.path.exists('./sam_vit_h_4b8939.pth'):
581
+ # logger.info(f"get sam_vit_h_4b8939.pth...")
582
+ # result = subprocess.run(['wget', '-nv', 'https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth'], check=True)
583
+ # print(f'wget sam_vit_h_4b8939.pth result = {result}')
584
+
585
+ def load_sam_model(self, device):
586
+
587
+ sam_checkpoint = self.sam_checkpoint
588
+
589
+ # initialize SAM
590
+ self.get_sam_vit_h_4b8939()
591
+ logger.info(f"initialize SAM model...")
592
+ sam_device = device
593
+ sam_model = build_sam(checkpoint=sam_checkpoint).to(sam_device)
594
+ self.sam_predictor = SamPredictor(sam_model)
595
+ self.sam_mask_generator = SamAutomaticMaskGenerator(sam_model)
596
+
597
+ def get_grounding_output(self, model, image, caption, box_threshold, text_threshold, with_logits=True, device="cpu"):
598
+ caption = caption.lower()
599
+ caption = caption.strip()
600
+ if not caption.endswith("."):
601
+ caption = caption + "."
602
+ model = model.to(device)
603
+ image = image.to(device)
604
+ with torch.no_grad():
605
+ outputs = model(image[None], captions=[caption])
606
+ logits = outputs["pred_logits"].cpu().sigmoid()[0] # (nq, 256)
607
+ boxes = outputs["pred_boxes"].cpu()[0] # (nq, 4)
608
+ logits.shape[0]
609
+
610
+ # filter output
611
+ logits_filt = logits.clone()
612
+ boxes_filt = boxes.clone()
613
+ filt_mask = logits_filt.max(dim=1)[0] > box_threshold
614
+ logits_filt = logits_filt[filt_mask] # num_filt, 256
615
+ boxes_filt = boxes_filt[filt_mask] # num_filt, 4
616
+ logits_filt.shape[0]
617
+
618
+ # get phrase
619
+ tokenlizer = model.tokenizer
620
+ tokenized = tokenlizer(caption)
621
+ # build pred
622
+ pred_phrases = []
623
+ for logit, box in zip(logits_filt, boxes_filt):
624
+ pred_phrase = get_phrases_from_posmap(logit > text_threshold, tokenized, tokenlizer)
625
+ if with_logits:
626
+ pred_phrases.append(pred_phrase + f"({str(logit.max().item())[:4]})")
627
+ else:
628
+ pred_phrases.append(pred_phrase)
629
+
630
+ return boxes_filt, pred_phrases
631
+
632
+ def load_model_hf(self, model_config_path, repo_id, filename, device='cpu'):
633
+ args = SLConfig.fromfile(model_config_path)
634
+ model = build_model(args)
635
+ args.device = device
636
+
637
+ cache_file = hf_hub_download(repo_id=repo_id, filename=filename)
638
+ checkpoint = torch.load(cache_file, map_location=device)
639
+ print(checkpoint['model'])
640
+ log = model.load_state_dict(clean_state_dict(checkpoint['model']), strict=False)
641
+ print("Model loaded from {} \n => {}".format(cache_file, log))
642
+ _ = model.eval()
643
+ return model
644
+
645
+ def load_groundingdino_model(self, device):
646
+
647
+ config_file = self.config_file
648
+ ckpt_repo_id = self.ckpt_repo_id
649
+ ckpt_filename = self.ckpt_filename
650
+
651
+
652
+ # initialize groundingdino model
653
+ logger.info(f"initialize groundingdino model...")
654
+ self.groundingdino_model = self.load_model_hf(config_file, ckpt_repo_id, ckpt_filename, device=device) #'cpu')
655
+ logger.info(f"initialize groundingdino model...{type(self.groundingdino_model)}")
656
+
657
+ def show_mask(self, mask, random_color=False):
658
+ if random_color:
659
+ color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
660
+ else:
661
+ color = np.array([30/255, 144/255, 255/255, 0.6])
662
+ color = np.array([1.0, 0, 0, 1.0])
663
+ h, w = mask.shape[-2:]
664
+ mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
665
+
666
+ return mask_image
667
+
668
+
669
+ def process(self, np_image, text_prompt):
670
+
671
+ results = []
672
+ results.append(np_image)
673
+ #results.append(np_image)
674
+
675
+ sam_predictor = self.sam_predictor
676
+ groundingdino_model = self.groundingdino_model
677
+
678
+ image = np_image
679
+ #text_prompt = text_prompt.strip()
680
+
681
+ box_threshold = 0.3
682
+ text_threshold = 0.25
683
+ size = image.shape
684
+ H, W = size[1], size[0]
685
+
686
+ # RUN grounding dino model
687
+ groundingdino_device = 'cpu'
688
+
689
+ #image_dino = torch.from_numpy(image)
690
+ image_dino = Image.fromarray(image)
691
+ transform = T.Compose(
692
+ [
693
+ T.RandomResize([800], max_size=1333),
694
+ T.ToTensor(),
695
+ T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
696
+ ]
697
+ )
698
+ print(image.shape)
699
+ image_dino, _ = transform(image_dino, None) # 3, h, w
700
+
701
+ boxes_filt, pred_phrases =self.get_grounding_output(
702
+ groundingdino_model, image_dino, text_prompt, box_threshold, text_threshold, device=groundingdino_device
703
+ )
704
+
705
+ if sam_predictor:
706
+ sam_predictor.set_image(image)
707
+
708
+ if sam_predictor:
709
+
710
+
711
+ for i in range(boxes_filt.size(0)):
712
+ boxes_filt[i] = boxes_filt[i] * torch.Tensor([W, H, W, H])
713
+ boxes_filt[i][:2] -= boxes_filt[i][2:] / 2
714
+ boxes_filt[i][2:] += boxes_filt[i][:2]
715
+
716
+ transformed_boxes = sam_predictor.transform.apply_boxes_torch(boxes_filt, image.shape[:2])
717
+
718
+
719
+ masks, _, _, _ = sam_predictor.predict_torch(
720
+ point_coords = None,
721
+ point_labels = None,
722
+ boxes = transformed_boxes,
723
+ multimask_output = False,
724
+ )
725
+
726
+ print("RESULTS*************")
727
+ print(len(masks))
728
+
729
+ # results = []
730
+
731
+ for mask in masks:
732
+ print(type(mask))
733
+ print(mask.shape)
734
+ #mask_img = mask.cpu().data.numpy()
735
+ mask_img =self.show_mask(mask.cpu().numpy())
736
+ print(type(mask_img))
737
+ print(mask_img.shape)
738
+ results.append(mask_img)
739
+ # results.append(mask.cpu().numpy())
740
+
741
+ return results
742
+ #assert sam_checkpoint, 'sam_checkpoint is not found!'
743
+
744
+ return None
main.py ADDED
@@ -0,0 +1,204 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from bgremover import BackgroundRemover
3
+ from bgremover import DamageClassifier
4
+ from bgremover import clear
5
+ from bgremover import ColorCheckerDetector
6
+ from bgremover import Segmentor
7
+ import rasterio
8
+ import os
9
+ from PIL import Image
10
+ from gradio_client import Client
11
+
12
+ PRELOAD_MODELS = False
13
+
14
+ if PRELOAD_MODELS:
15
+ backgroundRemover = BackgroundRemover()
16
+ damage_classifier = DamageClassifier()
17
+ segmentor = Segmentor()
18
+
19
+ def process(input_img):
20
+
21
+ if PRELOAD_MODELS:
22
+ global backgroundRemover
23
+ else:
24
+ backgroundRemover = BackgroundRemover()
25
+
26
+ output_mask, output_img = backgroundRemover.remove_background_gradio(input_img)
27
+
28
+
29
+ return [output_img, output_mask]
30
+
31
+ def process_classification(input_img, model_name):
32
+
33
+ if PRELOAD_MODELS:
34
+ global damage_classifier
35
+ else:
36
+ damage_classifier = DamageClassifier()
37
+
38
+ res = damage_classifier.inference(input_img, model_name)
39
+
40
+ #return {'No damage': 0.1, 'Moderately damaged': 0.1,'Damaged': 0.7, 'Severy damaged': 0.1}
41
+ return res
42
+
43
+
44
+ def segment_plant(threshold, input_im, im_mask):
45
+
46
+ if PRELOAD_MODELS:
47
+ global backgroundRemover
48
+ else:
49
+ backgroundRemover = BackgroundRemover()
50
+
51
+ print("segment plant", threshold)
52
+
53
+ res, mask = backgroundRemover.apply_mask(input_im, im_mask, threshold)
54
+
55
+ return res, mask
56
+
57
+ def rectangle(im, im_mask):
58
+
59
+ colorCheckerDetector = ColorCheckerDetector()
60
+
61
+
62
+ return colorCheckerDetector.process(im_mask, im)
63
+
64
+ def get_file_content(file):
65
+ with rasterio.open(file) as src:
66
+ # Read the image data
67
+ image_data = src.read()
68
+ image = Image.fromarray((image_data[0] * 255).astype(np.uint8))
69
+ return (gr.Image(value=image, type="pil"))
70
+
71
+ def on_img_color_load(input):
72
+ print("on_img_color_load")
73
+ print(input)
74
+
75
+ def run_anything_task(input_image):
76
+
77
+ text_prompt = "color-checker"
78
+ task_type = "inpainting"
79
+
80
+ #text_prompt = "rocket"
81
+
82
+ if PRELOAD_MODELS:
83
+ global segmentor
84
+ else:
85
+ segmentor = Segmentor()
86
+
87
+ return segmentor.process(input_image, text_prompt)
88
+
89
+ with gr.Blocks(title="Phenotyping pipeline") as demo:
90
+
91
+ gr.Markdown(
92
+ """
93
+ # Phenotyping pipeline
94
+ Modular phenotyping pipeline.
95
+ """)
96
+
97
+ input_im = gr.Image(render=False)
98
+ im_result = gr.Image(render=False)
99
+ im_mask = gr.Image(render=False)
100
+ im_masked = gr.Image(render=False)
101
+
102
+ im_color = gr.Image(render=False)
103
+ im_color_orginal = gr.Image(render=False)
104
+ im_color.change(on_img_color_load, im_color)
105
+
106
+ im_color_checker_mask = gr.Image(render=False)
107
+
108
+
109
+
110
+ with gr.Tab("Damage Classification"):
111
+
112
+ model_option = gr.Dropdown(
113
+ ["Regnet", "Resnet18", "Resnet152", "Googlenet"]
114
+ , label="Classification model"
115
+ , info="The classification model to use for inference"
116
+ , value="Regnet"
117
+ )
118
+
119
+ gr.Interface(fn=process_classification
120
+ , inputs= [input_im, model_option]
121
+ , outputs="label"
122
+ , examples = [
123
+ ["183_Week_1_(28th_Aug_-_1st_Sept.)_2023_nd.jpg"]
124
+ ,["20_WEEK_5_(_FIELD_A)_md.jpg"]
125
+ ,["30_WEEK_5_(_FIELD_A)_damaged.jpg"]
126
+ ,["25_WEEK_4_(_Field_A)_sd.jpg"]
127
+ #,["30_WEEK_4_(_Field_A)_sd.jpg"]
128
+ ]
129
+ )
130
+ #gr.Button("Classify")
131
+
132
+ with gr.Tab("Color Checker detection"):
133
+
134
+ #gr.Interface(fn=process_classification, inputs= input_im, outputs="label" )
135
+ #gr.Button("Classify")
136
+ gr.Interface(fn=run_anything_task, inputs= input_im, outputs=gr.Gallery() )
137
+
138
+ with gr.Tab("Color Calibration"):
139
+
140
+ #gr.Interface(fn=process_classification, inputs= input_im, outputs="label" )
141
+ #gr.Button("Classify")
142
+ gr.Interface(fn=rectangle
143
+ , inputs= [input_im, im_color_checker_mask]
144
+ , outputs=gr.Gallery()
145
+ , examples = [["264_WEEK_5_(_FIELD_A).jpg","264_mask.jpg"]]
146
+ )
147
+ gr.Button("Calibrate")
148
+
149
+ with gr.Tab("Plant segmentation"):
150
+
151
+ with gr.Column(scale=1):
152
+ #gr.Interface(fn=process, inputs= gr.Image(), outputs=[im_result, "image"] )
153
+ gr.Interface(fn=process, inputs= input_im, outputs=[im_result, im_mask] )
154
+
155
+ slider_thresh = gr.Slider(minimum=0, maximum=255, value=100, step=1, label="Threshold"
156
+ , info="Segmentation threshold", interactive=True)
157
+ slider_thresh.release(fn=segment_plant, inputs = [slider_thresh, input_im, im_mask], outputs = [gr.Image(), gr.Image()])
158
+
159
+ #button = gr.Button("Clip")
160
+ #button.click()
161
+ #gr.Image(value=im_masked)
162
+
163
+ # with gr.Tab("Damage segmentation"):
164
+
165
+ # gr.Button("Damage")
166
+
167
+ # with gr.Tab("Batch processing"):
168
+
169
+ # gr.Button("Run")
170
+
171
+ # with gr.Tab("Batch processing"):
172
+
173
+ # gr.Interface(fn=run_anything_task, inputs= input_im, outputs= gr.Gallery())
174
+
175
+ #with gr.Tab("Tests"):
176
+
177
+ # gr.Markdown("# Preview Images:")
178
+ # with gr.Group(visible=True):
179
+ # with gr.Row(visible=True):
180
+ # preview = gr.FileExplorer( scale = 1,
181
+ # glob = "*.tif",
182
+ # value = ["./"],
183
+ # file_count = "single",
184
+ # root_dir = "./",
185
+ # elem_id = "file",
186
+ # every= 1,
187
+ # interactive=True
188
+ # )
189
+
190
+ # #image = gr.Image(type="pil")
191
+ # image = gr.Image()
192
+ # preview.change(get_file_content, preview, image)
193
+
194
+
195
+
196
+
197
+
198
+ if __name__ == "__main__":
199
+ #demo.launch(show_api=False)
200
+ #client = Client(demo)
201
+ #demo.launch(show_api=True, server_name="0.0.0.0", server_port=int(os.environ.get("GRADIO_SERVER_PORT", 7861)))
202
+ demo.launch(allowed_paths=["30_WEEK_5_(_FIELD_A)_damaged.jpg"],server_port=int(os.environ.get("GRADIO_SERVER_PORT", 7861)), share=True)
203
+
204
+
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ matplotlib
2
+ numpy
3
+ opencv-python
4
+ pillow
5
+ scikit-image
6
+ scikit-learn
7
+ torch
8
+ torchvision
9
+ gradio
u2net_utils/__init__.py ADDED
File without changes
u2net_utils/data_loader.py ADDED
@@ -0,0 +1,266 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # data loader
2
+ from __future__ import print_function, division
3
+ import glob
4
+ import torch
5
+ from skimage import io, transform, color
6
+ import numpy as np
7
+ import random
8
+ import math
9
+ import matplotlib.pyplot as plt
10
+ from torch.utils.data import Dataset, DataLoader
11
+ from torchvision import transforms, utils
12
+ from PIL import Image
13
+
14
+ #==========================dataset load==========================
15
+ class RescaleT(object):
16
+
17
+ def __init__(self,output_size):
18
+ assert isinstance(output_size,(int,tuple))
19
+ self.output_size = output_size
20
+
21
+ def __call__(self,sample):
22
+ imidx, image, label = sample['imidx'], sample['image'],sample['label']
23
+
24
+ h, w = image.shape[:2]
25
+
26
+ if isinstance(self.output_size,int):
27
+ if h > w:
28
+ new_h, new_w = self.output_size*h/w,self.output_size
29
+ else:
30
+ new_h, new_w = self.output_size,self.output_size*w/h
31
+ else:
32
+ new_h, new_w = self.output_size
33
+
34
+ new_h, new_w = int(new_h), int(new_w)
35
+
36
+ # #resize the image to new_h x new_w and convert image from range [0,255] to [0,1]
37
+ # img = transform.resize(image,(new_h,new_w),mode='constant')
38
+ # lbl = transform.resize(label,(new_h,new_w),mode='constant', order=0, preserve_range=True)
39
+
40
+ img = transform.resize(image,(self.output_size,self.output_size),mode='constant')
41
+ lbl = transform.resize(label,(self.output_size,self.output_size),mode='constant', order=0, preserve_range=True)
42
+
43
+ return {'imidx':imidx, 'image':img,'label':lbl}
44
+
45
+ class Rescale(object):
46
+
47
+ def __init__(self,output_size):
48
+ assert isinstance(output_size,(int,tuple))
49
+ self.output_size = output_size
50
+
51
+ def __call__(self,sample):
52
+ imidx, image, label = sample['imidx'], sample['image'],sample['label']
53
+
54
+ if random.random() >= 0.5:
55
+ image = image[::-1]
56
+ label = label[::-1]
57
+
58
+ h, w = image.shape[:2]
59
+
60
+ if isinstance(self.output_size,int):
61
+ if h > w:
62
+ new_h, new_w = self.output_size*h/w,self.output_size
63
+ else:
64
+ new_h, new_w = self.output_size,self.output_size*w/h
65
+ else:
66
+ new_h, new_w = self.output_size
67
+
68
+ new_h, new_w = int(new_h), int(new_w)
69
+
70
+ # #resize the image to new_h x new_w and convert image from range [0,255] to [0,1]
71
+ img = transform.resize(image,(new_h,new_w),mode='constant')
72
+ lbl = transform.resize(label,(new_h,new_w),mode='constant', order=0, preserve_range=True)
73
+
74
+ return {'imidx':imidx, 'image':img,'label':lbl}
75
+
76
+ class RandomCrop(object):
77
+
78
+ def __init__(self,output_size):
79
+ assert isinstance(output_size, (int, tuple))
80
+ if isinstance(output_size, int):
81
+ self.output_size = (output_size, output_size)
82
+ else:
83
+ assert len(output_size) == 2
84
+ self.output_size = output_size
85
+ def __call__(self,sample):
86
+ imidx, image, label = sample['imidx'], sample['image'], sample['label']
87
+
88
+ if random.random() >= 0.5:
89
+ image = image[::-1]
90
+ label = label[::-1]
91
+
92
+ h, w = image.shape[:2]
93
+ new_h, new_w = self.output_size
94
+
95
+ top = np.random.randint(0, h - new_h)
96
+ left = np.random.randint(0, w - new_w)
97
+
98
+ image = image[top: top + new_h, left: left + new_w]
99
+ label = label[top: top + new_h, left: left + new_w]
100
+
101
+ return {'imidx':imidx,'image':image, 'label':label}
102
+
103
+ class ToTensor(object):
104
+ """Convert ndarrays in sample to Tensors."""
105
+
106
+ def __call__(self, sample):
107
+
108
+ imidx, image, label = sample['imidx'], sample['image'], sample['label']
109
+
110
+ tmpImg = np.zeros((image.shape[0],image.shape[1],3))
111
+ tmpLbl = np.zeros(label.shape)
112
+
113
+ image = image/np.max(image)
114
+ if(np.max(label)<1e-6):
115
+ label = label
116
+ else:
117
+ label = label/np.max(label)
118
+
119
+ if image.shape[2]==1:
120
+ tmpImg[:,:,0] = (image[:,:,0]-0.485)/0.229
121
+ tmpImg[:,:,1] = (image[:,:,0]-0.485)/0.229
122
+ tmpImg[:,:,2] = (image[:,:,0]-0.485)/0.229
123
+ else:
124
+ tmpImg[:,:,0] = (image[:,:,0]-0.485)/0.229
125
+ tmpImg[:,:,1] = (image[:,:,1]-0.456)/0.224
126
+ tmpImg[:,:,2] = (image[:,:,2]-0.406)/0.225
127
+
128
+ tmpLbl[:,:,0] = label[:,:,0]
129
+
130
+
131
+ tmpImg = tmpImg.transpose((2, 0, 1))
132
+ tmpLbl = label.transpose((2, 0, 1))
133
+
134
+ return {'imidx':torch.from_numpy(imidx), 'image': torch.from_numpy(tmpImg), 'label': torch.from_numpy(tmpLbl)}
135
+
136
+ class ToTensorLab(object):
137
+ """Convert ndarrays in sample to Tensors."""
138
+ def __init__(self,flag=0):
139
+ self.flag = flag
140
+
141
+ def __call__(self, sample):
142
+
143
+ imidx, image, label =sample['imidx'], sample['image'], sample['label']
144
+
145
+ tmpLbl = np.zeros(label.shape)
146
+
147
+ if(np.max(label)<1e-6):
148
+ label = label
149
+ else:
150
+ label = label/np.max(label)
151
+
152
+ # change the color space
153
+ if self.flag == 2: # with rgb and Lab colors
154
+ tmpImg = np.zeros((image.shape[0],image.shape[1],6))
155
+ tmpImgt = np.zeros((image.shape[0],image.shape[1],3))
156
+ if image.shape[2]==1:
157
+ tmpImgt[:,:,0] = image[:,:,0]
158
+ tmpImgt[:,:,1] = image[:,:,0]
159
+ tmpImgt[:,:,2] = image[:,:,0]
160
+ else:
161
+ tmpImgt = image
162
+ tmpImgtl = color.rgb2lab(tmpImgt)
163
+
164
+ # nomalize image to range [0,1]
165
+ tmpImg[:,:,0] = (tmpImgt[:,:,0]-np.min(tmpImgt[:,:,0]))/(np.max(tmpImgt[:,:,0])-np.min(tmpImgt[:,:,0]))
166
+ tmpImg[:,:,1] = (tmpImgt[:,:,1]-np.min(tmpImgt[:,:,1]))/(np.max(tmpImgt[:,:,1])-np.min(tmpImgt[:,:,1]))
167
+ tmpImg[:,:,2] = (tmpImgt[:,:,2]-np.min(tmpImgt[:,:,2]))/(np.max(tmpImgt[:,:,2])-np.min(tmpImgt[:,:,2]))
168
+ tmpImg[:,:,3] = (tmpImgtl[:,:,0]-np.min(tmpImgtl[:,:,0]))/(np.max(tmpImgtl[:,:,0])-np.min(tmpImgtl[:,:,0]))
169
+ tmpImg[:,:,4] = (tmpImgtl[:,:,1]-np.min(tmpImgtl[:,:,1]))/(np.max(tmpImgtl[:,:,1])-np.min(tmpImgtl[:,:,1]))
170
+ tmpImg[:,:,5] = (tmpImgtl[:,:,2]-np.min(tmpImgtl[:,:,2]))/(np.max(tmpImgtl[:,:,2])-np.min(tmpImgtl[:,:,2]))
171
+
172
+ # tmpImg = tmpImg/(np.max(tmpImg)-np.min(tmpImg))
173
+
174
+ tmpImg[:,:,0] = (tmpImg[:,:,0]-np.mean(tmpImg[:,:,0]))/np.std(tmpImg[:,:,0])
175
+ tmpImg[:,:,1] = (tmpImg[:,:,1]-np.mean(tmpImg[:,:,1]))/np.std(tmpImg[:,:,1])
176
+ tmpImg[:,:,2] = (tmpImg[:,:,2]-np.mean(tmpImg[:,:,2]))/np.std(tmpImg[:,:,2])
177
+ tmpImg[:,:,3] = (tmpImg[:,:,3]-np.mean(tmpImg[:,:,3]))/np.std(tmpImg[:,:,3])
178
+ tmpImg[:,:,4] = (tmpImg[:,:,4]-np.mean(tmpImg[:,:,4]))/np.std(tmpImg[:,:,4])
179
+ tmpImg[:,:,5] = (tmpImg[:,:,5]-np.mean(tmpImg[:,:,5]))/np.std(tmpImg[:,:,5])
180
+
181
+ elif self.flag == 1: #with Lab color
182
+ tmpImg = np.zeros((image.shape[0],image.shape[1],3))
183
+
184
+ if image.shape[2]==1:
185
+ tmpImg[:,:,0] = image[:,:,0]
186
+ tmpImg[:,:,1] = image[:,:,0]
187
+ tmpImg[:,:,2] = image[:,:,0]
188
+ else:
189
+ tmpImg = image
190
+
191
+ tmpImg = color.rgb2lab(tmpImg)
192
+
193
+ # tmpImg = tmpImg/(np.max(tmpImg)-np.min(tmpImg))
194
+
195
+ tmpImg[:,:,0] = (tmpImg[:,:,0]-np.min(tmpImg[:,:,0]))/(np.max(tmpImg[:,:,0])-np.min(tmpImg[:,:,0]))
196
+ tmpImg[:,:,1] = (tmpImg[:,:,1]-np.min(tmpImg[:,:,1]))/(np.max(tmpImg[:,:,1])-np.min(tmpImg[:,:,1]))
197
+ tmpImg[:,:,2] = (tmpImg[:,:,2]-np.min(tmpImg[:,:,2]))/(np.max(tmpImg[:,:,2])-np.min(tmpImg[:,:,2]))
198
+
199
+ tmpImg[:,:,0] = (tmpImg[:,:,0]-np.mean(tmpImg[:,:,0]))/np.std(tmpImg[:,:,0])
200
+ tmpImg[:,:,1] = (tmpImg[:,:,1]-np.mean(tmpImg[:,:,1]))/np.std(tmpImg[:,:,1])
201
+ tmpImg[:,:,2] = (tmpImg[:,:,2]-np.mean(tmpImg[:,:,2]))/np.std(tmpImg[:,:,2])
202
+
203
+ else: # with rgb color
204
+ tmpImg = np.zeros((image.shape[0],image.shape[1],3))
205
+ image = image/np.max(image)
206
+ if image.shape[2]==1:
207
+ tmpImg[:,:,0] = (image[:,:,0]-0.485)/0.229
208
+ tmpImg[:,:,1] = (image[:,:,0]-0.485)/0.229
209
+ tmpImg[:,:,2] = (image[:,:,0]-0.485)/0.229
210
+ else:
211
+ tmpImg[:,:,0] = (image[:,:,0]-0.485)/0.229
212
+ tmpImg[:,:,1] = (image[:,:,1]-0.456)/0.224
213
+ tmpImg[:,:,2] = (image[:,:,2]-0.406)/0.225
214
+
215
+ tmpLbl[:,:,0] = label[:,:,0]
216
+
217
+
218
+ tmpImg = tmpImg.transpose((2, 0, 1))
219
+ tmpLbl = label.transpose((2, 0, 1))
220
+
221
+ return {'imidx':torch.from_numpy(imidx), 'image': torch.from_numpy(tmpImg), 'label': torch.from_numpy(tmpLbl)}
222
+
223
+ class SalObjDataset(Dataset):
224
+ def __init__(self,img_name_list,lbl_name_list,transform=None):
225
+ # self.root_dir = root_dir
226
+ # self.image_name_list = glob.glob(image_dir+'*.png')
227
+ # self.label_name_list = glob.glob(label_dir+'*.png')
228
+ self.image_name_list = img_name_list
229
+ self.label_name_list = lbl_name_list
230
+ self.transform = transform
231
+
232
+ def __len__(self):
233
+ return len(self.image_name_list)
234
+
235
+ def __getitem__(self,idx):
236
+
237
+ # image = Image.open(self.image_name_list[idx])#io.imread(self.image_name_list[idx])
238
+ # label = Image.open(self.label_name_list[idx])#io.imread(self.label_name_list[idx])
239
+
240
+ image = io.imread(self.image_name_list[idx])
241
+ imname = self.image_name_list[idx]
242
+ imidx = np.array([idx])
243
+
244
+ if(0==len(self.label_name_list)):
245
+ label_3 = np.zeros(image.shape)
246
+ else:
247
+ label_3 = io.imread(self.label_name_list[idx])
248
+
249
+ label = np.zeros(label_3.shape[0:2])
250
+ if(3==len(label_3.shape)):
251
+ label = label_3[:,:,0]
252
+ elif(2==len(label_3.shape)):
253
+ label = label_3
254
+
255
+ if(3==len(image.shape) and 2==len(label.shape)):
256
+ label = label[:,:,np.newaxis]
257
+ elif(2==len(image.shape) and 2==len(label.shape)):
258
+ image = image[:,:,np.newaxis]
259
+ label = label[:,:,np.newaxis]
260
+
261
+ sample = {'imidx':imidx, 'image':image, 'label':label}
262
+
263
+ if self.transform:
264
+ sample = self.transform(sample)
265
+
266
+ return sample
u2net_utils/model/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .u2net import U2NET
2
+ from .u2net import U2NETP
u2net_utils/model/u2net.py ADDED
@@ -0,0 +1,525 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ class REBNCONV(nn.Module):
6
+ def __init__(self,in_ch=3,out_ch=3,dirate=1):
7
+ super(REBNCONV,self).__init__()
8
+
9
+ self.conv_s1 = nn.Conv2d(in_ch,out_ch,3,padding=1*dirate,dilation=1*dirate)
10
+ self.bn_s1 = nn.BatchNorm2d(out_ch)
11
+ self.relu_s1 = nn.ReLU(inplace=True)
12
+
13
+ def forward(self,x):
14
+
15
+ hx = x
16
+ xout = self.relu_s1(self.bn_s1(self.conv_s1(hx)))
17
+
18
+ return xout
19
+
20
+ ## upsample tensor 'src' to have the same spatial size with tensor 'tar'
21
+ def _upsample_like(src,tar):
22
+
23
+ src = F.upsample(src,size=tar.shape[2:],mode='bilinear')
24
+
25
+ return src
26
+
27
+
28
+ ### RSU-7 ###
29
+ class RSU7(nn.Module):#UNet07DRES(nn.Module):
30
+
31
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
32
+ super(RSU7,self).__init__()
33
+
34
+ self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)
35
+
36
+ self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
37
+ self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
38
+
39
+ self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1)
40
+ self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
41
+
42
+ self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1)
43
+ self.pool3 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
44
+
45
+ self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=1)
46
+ self.pool4 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
47
+
48
+ self.rebnconv5 = REBNCONV(mid_ch,mid_ch,dirate=1)
49
+ self.pool5 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
50
+
51
+ self.rebnconv6 = REBNCONV(mid_ch,mid_ch,dirate=1)
52
+
53
+ self.rebnconv7 = REBNCONV(mid_ch,mid_ch,dirate=2)
54
+
55
+ self.rebnconv6d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
56
+ self.rebnconv5d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
57
+ self.rebnconv4d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
58
+ self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
59
+ self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
60
+ self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)
61
+
62
+ def forward(self,x):
63
+
64
+ hx = x
65
+ hxin = self.rebnconvin(hx)
66
+
67
+ hx1 = self.rebnconv1(hxin)
68
+ hx = self.pool1(hx1)
69
+
70
+ hx2 = self.rebnconv2(hx)
71
+ hx = self.pool2(hx2)
72
+
73
+ hx3 = self.rebnconv3(hx)
74
+ hx = self.pool3(hx3)
75
+
76
+ hx4 = self.rebnconv4(hx)
77
+ hx = self.pool4(hx4)
78
+
79
+ hx5 = self.rebnconv5(hx)
80
+ hx = self.pool5(hx5)
81
+
82
+ hx6 = self.rebnconv6(hx)
83
+
84
+ hx7 = self.rebnconv7(hx6)
85
+
86
+ hx6d = self.rebnconv6d(torch.cat((hx7,hx6),1))
87
+ hx6dup = _upsample_like(hx6d,hx5)
88
+
89
+ hx5d = self.rebnconv5d(torch.cat((hx6dup,hx5),1))
90
+ hx5dup = _upsample_like(hx5d,hx4)
91
+
92
+ hx4d = self.rebnconv4d(torch.cat((hx5dup,hx4),1))
93
+ hx4dup = _upsample_like(hx4d,hx3)
94
+
95
+ hx3d = self.rebnconv3d(torch.cat((hx4dup,hx3),1))
96
+ hx3dup = _upsample_like(hx3d,hx2)
97
+
98
+ hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1))
99
+ hx2dup = _upsample_like(hx2d,hx1)
100
+
101
+ hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1))
102
+
103
+ return hx1d + hxin
104
+
105
+ ### RSU-6 ###
106
+ class RSU6(nn.Module):#UNet06DRES(nn.Module):
107
+
108
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
109
+ super(RSU6,self).__init__()
110
+
111
+ self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)
112
+
113
+ self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
114
+ self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
115
+
116
+ self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1)
117
+ self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
118
+
119
+ self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1)
120
+ self.pool3 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
121
+
122
+ self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=1)
123
+ self.pool4 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
124
+
125
+ self.rebnconv5 = REBNCONV(mid_ch,mid_ch,dirate=1)
126
+
127
+ self.rebnconv6 = REBNCONV(mid_ch,mid_ch,dirate=2)
128
+
129
+ self.rebnconv5d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
130
+ self.rebnconv4d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
131
+ self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
132
+ self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
133
+ self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)
134
+
135
+ def forward(self,x):
136
+
137
+ hx = x
138
+
139
+ hxin = self.rebnconvin(hx)
140
+
141
+ hx1 = self.rebnconv1(hxin)
142
+ hx = self.pool1(hx1)
143
+
144
+ hx2 = self.rebnconv2(hx)
145
+ hx = self.pool2(hx2)
146
+
147
+ hx3 = self.rebnconv3(hx)
148
+ hx = self.pool3(hx3)
149
+
150
+ hx4 = self.rebnconv4(hx)
151
+ hx = self.pool4(hx4)
152
+
153
+ hx5 = self.rebnconv5(hx)
154
+
155
+ hx6 = self.rebnconv6(hx5)
156
+
157
+
158
+ hx5d = self.rebnconv5d(torch.cat((hx6,hx5),1))
159
+ hx5dup = _upsample_like(hx5d,hx4)
160
+
161
+ hx4d = self.rebnconv4d(torch.cat((hx5dup,hx4),1))
162
+ hx4dup = _upsample_like(hx4d,hx3)
163
+
164
+ hx3d = self.rebnconv3d(torch.cat((hx4dup,hx3),1))
165
+ hx3dup = _upsample_like(hx3d,hx2)
166
+
167
+ hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1))
168
+ hx2dup = _upsample_like(hx2d,hx1)
169
+
170
+ hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1))
171
+
172
+ return hx1d + hxin
173
+
174
+ ### RSU-5 ###
175
+ class RSU5(nn.Module):#UNet05DRES(nn.Module):
176
+
177
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
178
+ super(RSU5,self).__init__()
179
+
180
+ self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)
181
+
182
+ self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
183
+ self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
184
+
185
+ self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1)
186
+ self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
187
+
188
+ self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1)
189
+ self.pool3 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
190
+
191
+ self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=1)
192
+
193
+ self.rebnconv5 = REBNCONV(mid_ch,mid_ch,dirate=2)
194
+
195
+ self.rebnconv4d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
196
+ self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
197
+ self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
198
+ self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)
199
+
200
+ def forward(self,x):
201
+
202
+ hx = x
203
+
204
+ hxin = self.rebnconvin(hx)
205
+
206
+ hx1 = self.rebnconv1(hxin)
207
+ hx = self.pool1(hx1)
208
+
209
+ hx2 = self.rebnconv2(hx)
210
+ hx = self.pool2(hx2)
211
+
212
+ hx3 = self.rebnconv3(hx)
213
+ hx = self.pool3(hx3)
214
+
215
+ hx4 = self.rebnconv4(hx)
216
+
217
+ hx5 = self.rebnconv5(hx4)
218
+
219
+ hx4d = self.rebnconv4d(torch.cat((hx5,hx4),1))
220
+ hx4dup = _upsample_like(hx4d,hx3)
221
+
222
+ hx3d = self.rebnconv3d(torch.cat((hx4dup,hx3),1))
223
+ hx3dup = _upsample_like(hx3d,hx2)
224
+
225
+ hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1))
226
+ hx2dup = _upsample_like(hx2d,hx1)
227
+
228
+ hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1))
229
+
230
+ return hx1d + hxin
231
+
232
+ ### RSU-4 ###
233
+ class RSU4(nn.Module):#UNet04DRES(nn.Module):
234
+
235
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
236
+ super(RSU4,self).__init__()
237
+
238
+ self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)
239
+
240
+ self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
241
+ self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
242
+
243
+ self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1)
244
+ self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
245
+
246
+ self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1)
247
+
248
+ self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=2)
249
+
250
+ self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
251
+ self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
252
+ self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)
253
+
254
+ def forward(self,x):
255
+
256
+ hx = x
257
+
258
+ hxin = self.rebnconvin(hx)
259
+
260
+ hx1 = self.rebnconv1(hxin)
261
+ hx = self.pool1(hx1)
262
+
263
+ hx2 = self.rebnconv2(hx)
264
+ hx = self.pool2(hx2)
265
+
266
+ hx3 = self.rebnconv3(hx)
267
+
268
+ hx4 = self.rebnconv4(hx3)
269
+
270
+ hx3d = self.rebnconv3d(torch.cat((hx4,hx3),1))
271
+ hx3dup = _upsample_like(hx3d,hx2)
272
+
273
+ hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1))
274
+ hx2dup = _upsample_like(hx2d,hx1)
275
+
276
+ hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1))
277
+
278
+ return hx1d + hxin
279
+
280
+ ### RSU-4F ###
281
+ class RSU4F(nn.Module):#UNet04FRES(nn.Module):
282
+
283
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
284
+ super(RSU4F,self).__init__()
285
+
286
+ self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)
287
+
288
+ self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
289
+ self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=2)
290
+ self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=4)
291
+
292
+ self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=8)
293
+
294
+ self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=4)
295
+ self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=2)
296
+ self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)
297
+
298
+ def forward(self,x):
299
+
300
+ hx = x
301
+
302
+ hxin = self.rebnconvin(hx)
303
+
304
+ hx1 = self.rebnconv1(hxin)
305
+ hx2 = self.rebnconv2(hx1)
306
+ hx3 = self.rebnconv3(hx2)
307
+
308
+ hx4 = self.rebnconv4(hx3)
309
+
310
+ hx3d = self.rebnconv3d(torch.cat((hx4,hx3),1))
311
+ hx2d = self.rebnconv2d(torch.cat((hx3d,hx2),1))
312
+ hx1d = self.rebnconv1d(torch.cat((hx2d,hx1),1))
313
+
314
+ return hx1d + hxin
315
+
316
+
317
+ ##### U^2-Net ####
318
+ class U2NET(nn.Module):
319
+
320
+ def __init__(self,in_ch=3,out_ch=1):
321
+ super(U2NET,self).__init__()
322
+
323
+ self.stage1 = RSU7(in_ch,32,64)
324
+ self.pool12 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
325
+
326
+ self.stage2 = RSU6(64,32,128)
327
+ self.pool23 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
328
+
329
+ self.stage3 = RSU5(128,64,256)
330
+ self.pool34 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
331
+
332
+ self.stage4 = RSU4(256,128,512)
333
+ self.pool45 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
334
+
335
+ self.stage5 = RSU4F(512,256,512)
336
+ self.pool56 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
337
+
338
+ self.stage6 = RSU4F(512,256,512)
339
+
340
+ # decoder
341
+ self.stage5d = RSU4F(1024,256,512)
342
+ self.stage4d = RSU4(1024,128,256)
343
+ self.stage3d = RSU5(512,64,128)
344
+ self.stage2d = RSU6(256,32,64)
345
+ self.stage1d = RSU7(128,16,64)
346
+
347
+ self.side1 = nn.Conv2d(64,out_ch,3,padding=1)
348
+ self.side2 = nn.Conv2d(64,out_ch,3,padding=1)
349
+ self.side3 = nn.Conv2d(128,out_ch,3,padding=1)
350
+ self.side4 = nn.Conv2d(256,out_ch,3,padding=1)
351
+ self.side5 = nn.Conv2d(512,out_ch,3,padding=1)
352
+ self.side6 = nn.Conv2d(512,out_ch,3,padding=1)
353
+
354
+ self.outconv = nn.Conv2d(6*out_ch,out_ch,1)
355
+
356
+ def forward(self,x):
357
+
358
+ hx = x
359
+
360
+ #stage 1
361
+ hx1 = self.stage1(hx)
362
+ hx = self.pool12(hx1)
363
+
364
+ #stage 2
365
+ hx2 = self.stage2(hx)
366
+ hx = self.pool23(hx2)
367
+
368
+ #stage 3
369
+ hx3 = self.stage3(hx)
370
+ hx = self.pool34(hx3)
371
+
372
+ #stage 4
373
+ hx4 = self.stage4(hx)
374
+ hx = self.pool45(hx4)
375
+
376
+ #stage 5
377
+ hx5 = self.stage5(hx)
378
+ hx = self.pool56(hx5)
379
+
380
+ #stage 6
381
+ hx6 = self.stage6(hx)
382
+ hx6up = _upsample_like(hx6,hx5)
383
+
384
+ #-------------------- decoder --------------------
385
+ hx5d = self.stage5d(torch.cat((hx6up,hx5),1))
386
+ hx5dup = _upsample_like(hx5d,hx4)
387
+
388
+ hx4d = self.stage4d(torch.cat((hx5dup,hx4),1))
389
+ hx4dup = _upsample_like(hx4d,hx3)
390
+
391
+ hx3d = self.stage3d(torch.cat((hx4dup,hx3),1))
392
+ hx3dup = _upsample_like(hx3d,hx2)
393
+
394
+ hx2d = self.stage2d(torch.cat((hx3dup,hx2),1))
395
+ hx2dup = _upsample_like(hx2d,hx1)
396
+
397
+ hx1d = self.stage1d(torch.cat((hx2dup,hx1),1))
398
+
399
+
400
+ #side output
401
+ d1 = self.side1(hx1d)
402
+
403
+ d2 = self.side2(hx2d)
404
+ d2 = _upsample_like(d2,d1)
405
+
406
+ d3 = self.side3(hx3d)
407
+ d3 = _upsample_like(d3,d1)
408
+
409
+ d4 = self.side4(hx4d)
410
+ d4 = _upsample_like(d4,d1)
411
+
412
+ d5 = self.side5(hx5d)
413
+ d5 = _upsample_like(d5,d1)
414
+
415
+ d6 = self.side6(hx6)
416
+ d6 = _upsample_like(d6,d1)
417
+
418
+ d0 = self.outconv(torch.cat((d1,d2,d3,d4,d5,d6),1))
419
+
420
+ return F.sigmoid(d0), F.sigmoid(d1), F.sigmoid(d2), F.sigmoid(d3), F.sigmoid(d4), F.sigmoid(d5), F.sigmoid(d6)
421
+
422
+ ### U^2-Net small ###
423
+ class U2NETP(nn.Module):
424
+
425
+ def __init__(self,in_ch=3,out_ch=1):
426
+ super(U2NETP,self).__init__()
427
+
428
+ self.stage1 = RSU7(in_ch,16,64)
429
+ self.pool12 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
430
+
431
+ self.stage2 = RSU6(64,16,64)
432
+ self.pool23 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
433
+
434
+ self.stage3 = RSU5(64,16,64)
435
+ self.pool34 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
436
+
437
+ self.stage4 = RSU4(64,16,64)
438
+ self.pool45 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
439
+
440
+ self.stage5 = RSU4F(64,16,64)
441
+ self.pool56 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
442
+
443
+ self.stage6 = RSU4F(64,16,64)
444
+
445
+ # decoder
446
+ self.stage5d = RSU4F(128,16,64)
447
+ self.stage4d = RSU4(128,16,64)
448
+ self.stage3d = RSU5(128,16,64)
449
+ self.stage2d = RSU6(128,16,64)
450
+ self.stage1d = RSU7(128,16,64)
451
+
452
+ self.side1 = nn.Conv2d(64,out_ch,3,padding=1)
453
+ self.side2 = nn.Conv2d(64,out_ch,3,padding=1)
454
+ self.side3 = nn.Conv2d(64,out_ch,3,padding=1)
455
+ self.side4 = nn.Conv2d(64,out_ch,3,padding=1)
456
+ self.side5 = nn.Conv2d(64,out_ch,3,padding=1)
457
+ self.side6 = nn.Conv2d(64,out_ch,3,padding=1)
458
+
459
+ self.outconv = nn.Conv2d(6*out_ch,out_ch,1)
460
+
461
+ def forward(self,x):
462
+
463
+ hx = x
464
+
465
+ #stage 1
466
+ hx1 = self.stage1(hx)
467
+ hx = self.pool12(hx1)
468
+
469
+ #stage 2
470
+ hx2 = self.stage2(hx)
471
+ hx = self.pool23(hx2)
472
+
473
+ #stage 3
474
+ hx3 = self.stage3(hx)
475
+ hx = self.pool34(hx3)
476
+
477
+ #stage 4
478
+ hx4 = self.stage4(hx)
479
+ hx = self.pool45(hx4)
480
+
481
+ #stage 5
482
+ hx5 = self.stage5(hx)
483
+ hx = self.pool56(hx5)
484
+
485
+ #stage 6
486
+ hx6 = self.stage6(hx)
487
+ hx6up = _upsample_like(hx6,hx5)
488
+
489
+ #decoder
490
+ hx5d = self.stage5d(torch.cat((hx6up,hx5),1))
491
+ hx5dup = _upsample_like(hx5d,hx4)
492
+
493
+ hx4d = self.stage4d(torch.cat((hx5dup,hx4),1))
494
+ hx4dup = _upsample_like(hx4d,hx3)
495
+
496
+ hx3d = self.stage3d(torch.cat((hx4dup,hx3),1))
497
+ hx3dup = _upsample_like(hx3d,hx2)
498
+
499
+ hx2d = self.stage2d(torch.cat((hx3dup,hx2),1))
500
+ hx2dup = _upsample_like(hx2d,hx1)
501
+
502
+ hx1d = self.stage1d(torch.cat((hx2dup,hx1),1))
503
+
504
+
505
+ #side output
506
+ d1 = self.side1(hx1d)
507
+
508
+ d2 = self.side2(hx2d)
509
+ d2 = _upsample_like(d2,d1)
510
+
511
+ d3 = self.side3(hx3d)
512
+ d3 = _upsample_like(d3,d1)
513
+
514
+ d4 = self.side4(hx4d)
515
+ d4 = _upsample_like(d4,d1)
516
+
517
+ d5 = self.side5(hx5d)
518
+ d5 = _upsample_like(d5,d1)
519
+
520
+ d6 = self.side6(hx6)
521
+ d6 = _upsample_like(d6,d1)
522
+
523
+ d0 = self.outconv(torch.cat((d1,d2,d3,d4,d5,d6),1))
524
+
525
+ return F.sigmoid(d0), F.sigmoid(d1), F.sigmoid(d2), F.sigmoid(d3), F.sigmoid(d4), F.sigmoid(d5), F.sigmoid(d6)
u2net_utils/model/u2net_refactor.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ import math
5
+
6
+ __all__ = ['U2NET_full', 'U2NET_lite']
7
+
8
+
9
+ def _upsample_like(x, size):
10
+ return nn.Upsample(size=size, mode='bilinear', align_corners=False)(x)
11
+
12
+
13
+ def _size_map(x, height):
14
+ # {height: size} for Upsample
15
+ size = list(x.shape[-2:])
16
+ sizes = {}
17
+ for h in range(1, height):
18
+ sizes[h] = size
19
+ size = [math.ceil(w / 2) for w in size]
20
+ return sizes
21
+
22
+
23
+ class REBNCONV(nn.Module):
24
+ def __init__(self, in_ch=3, out_ch=3, dilate=1):
25
+ super(REBNCONV, self).__init__()
26
+
27
+ self.conv_s1 = nn.Conv2d(in_ch, out_ch, 3, padding=1 * dilate, dilation=1 * dilate)
28
+ self.bn_s1 = nn.BatchNorm2d(out_ch)
29
+ self.relu_s1 = nn.ReLU(inplace=True)
30
+
31
+ def forward(self, x):
32
+ return self.relu_s1(self.bn_s1(self.conv_s1(x)))
33
+
34
+
35
+ class RSU(nn.Module):
36
+ def __init__(self, name, height, in_ch, mid_ch, out_ch, dilated=False):
37
+ super(RSU, self).__init__()
38
+ self.name = name
39
+ self.height = height
40
+ self.dilated = dilated
41
+ self._make_layers(height, in_ch, mid_ch, out_ch, dilated)
42
+
43
+ def forward(self, x):
44
+ sizes = _size_map(x, self.height)
45
+ x = self.rebnconvin(x)
46
+
47
+ # U-Net like symmetric encoder-decoder structure
48
+ def unet(x, height=1):
49
+ if height < self.height:
50
+ x1 = getattr(self, f'rebnconv{height}')(x)
51
+ if not self.dilated and height < self.height - 1:
52
+ x2 = unet(getattr(self, 'downsample')(x1), height + 1)
53
+ else:
54
+ x2 = unet(x1, height + 1)
55
+
56
+ x = getattr(self, f'rebnconv{height}d')(torch.cat((x2, x1), 1))
57
+ return _upsample_like(x, sizes[height - 1]) if not self.dilated and height > 1 else x
58
+ else:
59
+ return getattr(self, f'rebnconv{height}')(x)
60
+
61
+ return x + unet(x)
62
+
63
+ def _make_layers(self, height, in_ch, mid_ch, out_ch, dilated=False):
64
+ self.add_module('rebnconvin', REBNCONV(in_ch, out_ch))
65
+ self.add_module('downsample', nn.MaxPool2d(2, stride=2, ceil_mode=True))
66
+
67
+ self.add_module(f'rebnconv1', REBNCONV(out_ch, mid_ch))
68
+ self.add_module(f'rebnconv1d', REBNCONV(mid_ch * 2, out_ch))
69
+
70
+ for i in range(2, height):
71
+ dilate = 1 if not dilated else 2 ** (i - 1)
72
+ self.add_module(f'rebnconv{i}', REBNCONV(mid_ch, mid_ch, dilate=dilate))
73
+ self.add_module(f'rebnconv{i}d', REBNCONV(mid_ch * 2, mid_ch, dilate=dilate))
74
+
75
+ dilate = 2 if not dilated else 2 ** (height - 1)
76
+ self.add_module(f'rebnconv{height}', REBNCONV(mid_ch, mid_ch, dilate=dilate))
77
+
78
+
79
+ class U2NET(nn.Module):
80
+ def __init__(self, cfgs, out_ch):
81
+ super(U2NET, self).__init__()
82
+ self.out_ch = out_ch
83
+ self._make_layers(cfgs)
84
+
85
+ def forward(self, x):
86
+ sizes = _size_map(x, self.height)
87
+ maps = [] # storage for maps
88
+
89
+ # side saliency map
90
+ def unet(x, height=1):
91
+ if height < 6:
92
+ x1 = getattr(self, f'stage{height}')(x)
93
+ x2 = unet(getattr(self, 'downsample')(x1), height + 1)
94
+ x = getattr(self, f'stage{height}d')(torch.cat((x2, x1), 1))
95
+ side(x, height)
96
+ return _upsample_like(x, sizes[height - 1]) if height > 1 else x
97
+ else:
98
+ x = getattr(self, f'stage{height}')(x)
99
+ side(x, height)
100
+ return _upsample_like(x, sizes[height - 1])
101
+
102
+ def side(x, h):
103
+ # side output saliency map (before sigmoid)
104
+ x = getattr(self, f'side{h}')(x)
105
+ x = _upsample_like(x, sizes[1])
106
+ maps.append(x)
107
+
108
+ def fuse():
109
+ # fuse saliency probability maps
110
+ maps.reverse()
111
+ x = torch.cat(maps, 1)
112
+ x = getattr(self, 'outconv')(x)
113
+ maps.insert(0, x)
114
+ return [torch.sigmoid(x) for x in maps]
115
+
116
+ unet(x)
117
+ maps = fuse()
118
+ return maps
119
+
120
+ def _make_layers(self, cfgs):
121
+ self.height = int((len(cfgs) + 1) / 2)
122
+ self.add_module('downsample', nn.MaxPool2d(2, stride=2, ceil_mode=True))
123
+ for k, v in cfgs.items():
124
+ # build rsu block
125
+ self.add_module(k, RSU(v[0], *v[1]))
126
+ if v[2] > 0:
127
+ # build side layer
128
+ self.add_module(f'side{v[0][-1]}', nn.Conv2d(v[2], self.out_ch, 3, padding=1))
129
+ # build fuse layer
130
+ self.add_module('outconv', nn.Conv2d(int(self.height * self.out_ch), self.out_ch, 1))
131
+
132
+
133
+ def U2NET_full():
134
+ full = {
135
+ # cfgs for building RSUs and sides
136
+ # {stage : [name, (height(L), in_ch, mid_ch, out_ch, dilated), side]}
137
+ 'stage1': ['En_1', (7, 3, 32, 64), -1],
138
+ 'stage2': ['En_2', (6, 64, 32, 128), -1],
139
+ 'stage3': ['En_3', (5, 128, 64, 256), -1],
140
+ 'stage4': ['En_4', (4, 256, 128, 512), -1],
141
+ 'stage5': ['En_5', (4, 512, 256, 512, True), -1],
142
+ 'stage6': ['En_6', (4, 512, 256, 512, True), 512],
143
+ 'stage5d': ['De_5', (4, 1024, 256, 512, True), 512],
144
+ 'stage4d': ['De_4', (4, 1024, 128, 256), 256],
145
+ 'stage3d': ['De_3', (5, 512, 64, 128), 128],
146
+ 'stage2d': ['De_2', (6, 256, 32, 64), 64],
147
+ 'stage1d': ['De_1', (7, 128, 16, 64), 64],
148
+ }
149
+ return U2NET(cfgs=full, out_ch=1)
150
+
151
+
152
+ def U2NET_lite():
153
+ lite = {
154
+ # cfgs for building RSUs and sides
155
+ # {stage : [name, (height(L), in_ch, mid_ch, out_ch, dilated), side]}
156
+ 'stage1': ['En_1', (7, 3, 16, 64), -1],
157
+ 'stage2': ['En_2', (6, 64, 16, 64), -1],
158
+ 'stage3': ['En_3', (5, 64, 16, 64), -1],
159
+ 'stage4': ['En_4', (4, 64, 16, 64), -1],
160
+ 'stage5': ['En_5', (4, 64, 16, 64, True), -1],
161
+ 'stage6': ['En_6', (4, 64, 16, 64, True), 64],
162
+ 'stage5d': ['De_5', (4, 128, 16, 64, True), 64],
163
+ 'stage4d': ['De_4', (4, 128, 16, 64), 64],
164
+ 'stage3d': ['De_3', (5, 128, 16, 64), 64],
165
+ 'stage2d': ['De_2', (6, 128, 16, 64), 64],
166
+ 'stage1d': ['De_1', (7, 128, 16, 64), 64],
167
+ }
168
+ return U2NET(cfgs=lite, out_ch=1)