AI-Cyber commited on
Commit
0223881
·
1 Parent(s): 8d7921b

Upload 2 files

Browse files
Files changed (2) hide show
  1. utils.py +387 -0
  2. visual_utils.py +435 -0
utils.py ADDED
@@ -0,0 +1,387 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ import shutil
4
+
5
+ import torch
6
+ import numpy as np
7
+ from torch.optim import SGD, Adam, AdamW
8
+ from tensorboardX import SummaryWriter
9
+
10
+ import sod_metric
11
+ class Averager():
12
+
13
+ def __init__(self):
14
+ self.n = 0.0
15
+ self.v = 0.0
16
+
17
+ def add(self, v, n=1.0):
18
+ self.v = (self.v * self.n + v * n) / (self.n + n)
19
+ self.n += n
20
+
21
+ def item(self):
22
+ return self.v
23
+
24
+
25
+ class Timer():
26
+
27
+ def __init__(self):
28
+ self.v = time.time()
29
+
30
+ def s(self):
31
+ self.v = time.time()
32
+
33
+ def t(self):
34
+ return time.time() - self.v
35
+
36
+
37
+ def time_text(t):
38
+ if t >= 3600:
39
+ return '{:.1f}h'.format(t / 3600)
40
+ elif t >= 60:
41
+ return '{:.1f}m'.format(t / 60)
42
+ else:
43
+ return '{:.1f}s'.format(t)
44
+
45
+
46
+ _log_path = None
47
+
48
+
49
+ def set_log_path(path):
50
+ global _log_path
51
+ _log_path = path
52
+
53
+
54
+ def log(obj, filename='log.txt'):
55
+ print(obj)
56
+ if _log_path is not None:
57
+ with open(os.path.join(_log_path, filename), 'a') as f:
58
+ print(obj, file=f)
59
+
60
+
61
+ def ensure_path(path, remove=True):
62
+ basename = os.path.basename(path.rstrip('/'))
63
+ if os.path.exists(path):
64
+ if remove and (basename.startswith('_')
65
+ or input('{} exists, remove? (y/[n]): '.format(path)) == 'y'):
66
+ shutil.rmtree(path)
67
+ os.makedirs(path, exist_ok=True)
68
+ else:
69
+ os.makedirs(path, exist_ok=True)
70
+
71
+
72
+ def set_save_path(save_path, remove=True):
73
+ ensure_path(save_path, remove=remove)
74
+ set_log_path(save_path)
75
+ writer = SummaryWriter(os.path.join(save_path, 'tensorboard'))
76
+ return log, writer
77
+
78
+
79
+ def compute_num_params(model, text=False):
80
+ tot = int(sum([np.prod(p.shape) for p in model.parameters()]))
81
+ if text:
82
+ if tot >= 1e6:
83
+ return '{:.1f}M'.format(tot / 1e6)
84
+ else:
85
+ return '{:.1f}K'.format(tot / 1e3)
86
+ else:
87
+ return tot
88
+
89
+
90
+ def make_optimizer(param_list, optimizer_spec, load_sd=False):
91
+ Optimizer = {
92
+ 'sgd': SGD,
93
+ 'adam': Adam,
94
+ 'adamw': AdamW
95
+ }[optimizer_spec['name']]
96
+ optimizer = Optimizer(param_list, **optimizer_spec['args'])
97
+ if load_sd:
98
+ optimizer.load_state_dict(optimizer_spec['sd'])
99
+ return optimizer
100
+
101
+
102
+ def make_coord(shape, ranges=None, flatten=True):
103
+ """ Make coordinates at grid centers.
104
+ """
105
+ coord_seqs = []
106
+ for i, n in enumerate(shape):
107
+ if ranges is None:
108
+ v0, v1 = -1, 1
109
+ else:
110
+ v0, v1 = ranges[i]
111
+ r = (v1 - v0) / (2 * n)
112
+ seq = v0 + r + (2 * r) * torch.arange(n).float()
113
+ coord_seqs.append(seq)
114
+ ret = torch.stack(torch.meshgrid(*coord_seqs), dim=-1)
115
+ # if flatten:
116
+ # ret = ret.view(-1, ret.shape[-1])
117
+
118
+ return ret
119
+
120
+
121
+
122
+ def calc_cod(y_pred, y_true):
123
+ batchsize = y_true.shape[0]
124
+
125
+ metric_FM = sod_metric.Fmeasure()
126
+ metric_WFM = sod_metric.WeightedFmeasure()
127
+ metric_SM = sod_metric.Smeasure()
128
+ metric_EM = sod_metric.Emeasure()
129
+ metric_MAE = sod_metric.MAE()
130
+ with torch.no_grad():
131
+ assert y_pred.shape == y_true.shape
132
+
133
+ for i in range(batchsize):
134
+ true, pred = \
135
+ y_true[i, 0].cpu().data.numpy() * 255, y_pred[i, 0].cpu().data.numpy() * 255
136
+
137
+ metric_FM.step(pred=pred, gt=true)
138
+ metric_WFM.step(pred=pred, gt=true)
139
+ metric_SM.step(pred=pred, gt=true)
140
+ metric_EM.step(pred=pred, gt=true)
141
+ metric_MAE.step(pred=pred, gt=true)
142
+
143
+ fm = metric_FM.get_results()["fm"]
144
+ wfm = metric_WFM.get_results()["wfm"]
145
+ sm = metric_SM.get_results()["sm"]
146
+ em = metric_EM.get_results()["em"]["curve"].mean()
147
+ mae = metric_MAE.get_results()["mae"]
148
+
149
+ return sm, em, wfm, mae
150
+
151
+
152
+ from sklearn.metrics import precision_recall_curve
153
+
154
+
155
+ def calc_f1(y_pred,y_true):
156
+ batchsize = y_true.shape[0]
157
+ with torch.no_grad():
158
+ print(y_pred.shape)
159
+ print(y_true.shape)
160
+ assert y_pred.shape == y_true.shape
161
+ f1, auc = 0, 0
162
+ y_true = y_true.cpu().numpy()
163
+ y_pred = y_pred.cpu().numpy()
164
+ for i in range(batchsize):
165
+ true = y_true[i].flatten()
166
+ true = true.astype(np.int)
167
+ pred = y_pred[i].flatten()
168
+
169
+ precision, recall, thresholds = precision_recall_curve(true, pred)
170
+
171
+ # auc
172
+ auc += roc_auc_score(true, pred)
173
+ # auc += roc_auc_score(np.array(true>0).astype(np.int), pred)
174
+ f1 += max([(2 * p * r) / (p + r+1e-10) for p, r in zip(precision, recall)])
175
+
176
+ return f1/batchsize, auc/batchsize, np.array(0), np.array(0)
177
+
178
+ def calc_fmeasure(y_pred,y_true):
179
+ batchsize = y_true.shape[0]
180
+
181
+ mae, preds, gts = [], [], []
182
+ with torch.no_grad():
183
+ for i in range(batchsize):
184
+ gt_float, pred_float = \
185
+ y_true[i, 0].cpu().data.numpy(), y_pred[i, 0].cpu().data.numpy()
186
+
187
+ # # MAE
188
+ mae.append(np.sum(cv2.absdiff(gt_float.astype(float), pred_float.astype(float))) / (
189
+ pred_float.shape[1] * pred_float.shape[0]))
190
+ # mae.append(np.mean(np.abs(pred_float - gt_float)))
191
+ #
192
+ pred = np.uint8(pred_float * 255)
193
+ gt = np.uint8(gt_float * 255)
194
+
195
+ pred_float_ = np.where(pred > min(1.5 * np.mean(pred), 255), np.ones_like(pred_float),
196
+ np.zeros_like(pred_float))
197
+ gt_float_ = np.where(gt > min(1.5 * np.mean(gt), 255), np.ones_like(pred_float),
198
+ np.zeros_like(pred_float))
199
+
200
+ preds.extend(pred_float_.ravel())
201
+ gts.extend(gt_float_.ravel())
202
+
203
+ RECALL = recall_score(gts, preds)
204
+ PERC = precision_score(gts, preds)
205
+
206
+ fmeasure = (1 + 0.3) * PERC * RECALL / (0.3 * PERC + RECALL)
207
+ MAE = np.mean(mae)
208
+
209
+ return fmeasure, MAE, np.array(0), np.array(0)
210
+
211
+ from sklearn.metrics import roc_auc_score,recall_score,precision_score
212
+ import cv2
213
+ def calc_ber(y_pred, y_true):
214
+ batchsize = y_true.shape[0]
215
+ y_pred, y_true = y_pred.permute(0, 2, 3, 1).squeeze(-1), y_true.permute(0, 2, 3, 1).squeeze(-1)
216
+ with torch.no_grad():
217
+ assert y_pred.shape == y_true.shape
218
+ pos_err, neg_err, ber = 0, 0, 0
219
+ y_true = y_true.cpu().numpy()
220
+ y_pred = y_pred.cpu().numpy()
221
+ for i in range(batchsize):
222
+ true = y_true[i].flatten()
223
+ pred = y_pred[i].flatten()
224
+
225
+ TP, TN, FP, FN, BER, ACC = get_binary_classification_metrics(pred * 255,
226
+ true * 255, 125)
227
+ pos_err += (1 - TP / (TP + FN)) * 100
228
+ neg_err += (1 - TN / (TN + FP)) * 100
229
+
230
+ return pos_err / batchsize, neg_err / batchsize, (pos_err + neg_err) / 2 / batchsize, np.array(0)
231
+
232
+ def get_binary_classification_metrics(pred, gt, threshold=None):
233
+ if threshold is not None:
234
+ gt = (gt > threshold)
235
+ pred = (pred > threshold)
236
+ TP = np.logical_and(gt, pred).sum()
237
+ TN = np.logical_and(np.logical_not(gt), np.logical_not(pred)).sum()
238
+ FN = np.logical_and(gt, np.logical_not(pred)).sum()
239
+ FP = np.logical_and(np.logical_not(gt), pred).sum()
240
+ BER = cal_ber(TN, TP, FN, FP)
241
+ ACC = cal_acc(TN, TP, FN, FP)
242
+ return TP, TN, FP, FN, BER, ACC
243
+
244
+ def cal_ber(tn, tp, fn, fp):
245
+ return 0.5*(fp/(tn+fp) + fn/(fn+tp))
246
+
247
+ def cal_acc(tn, tp, fn, fp):
248
+ return (tp + tn) / (tp + tn + fp + fn)
249
+
250
+ def _sigmoid(x):
251
+ return 1 / (1 + np.exp(-x))
252
+
253
+
254
+ def _eval_pr(y_pred, y, num):
255
+ prec, recall = torch.zeros(num), torch.zeros(num)
256
+ thlist = torch.linspace(0, 1 - 1e-10, num)
257
+ for i in range(num):
258
+ y_temp = (y_pred >= thlist[i]).float()
259
+ tp = (y_temp * y).sum()
260
+ prec[i], recall[i] = tp / (y_temp.sum() + 1e-20), tp / (y.sum() +
261
+ 1e-20)
262
+ return prec, recall
263
+
264
+ def _S_object(pred, gt):
265
+ fg = torch.where(gt == 0, torch.zeros_like(pred), pred)
266
+ bg = torch.where(gt == 1, torch.zeros_like(pred), 1 - pred)
267
+ o_fg = _object(fg, gt)
268
+ o_bg = _object(bg, 1 - gt)
269
+ u = gt.mean()
270
+ Q = u * o_fg + (1 - u) * o_bg
271
+ return Q
272
+
273
+ def _object(pred, gt):
274
+ temp = pred[gt == 1]
275
+ x = temp.mean()
276
+ sigma_x = temp.std()
277
+ score = 2.0 * x / (x * x + 1.0 + sigma_x + 1e-20)
278
+
279
+ return score
280
+
281
+ def _S_region(pred, gt):
282
+ X, Y = _centroid(gt)
283
+ gt1, gt2, gt3, gt4, w1, w2, w3, w4 = _divideGT(gt, X, Y)
284
+ p1, p2, p3, p4 = _dividePrediction(pred, X, Y)
285
+ Q1 = _ssim(p1, gt1)
286
+ Q2 = _ssim(p2, gt2)
287
+ Q3 = _ssim(p3, gt3)
288
+ Q4 = _ssim(p4, gt4)
289
+ Q = w1 * Q1 + w2 * Q2 + w3 * Q3 + w4 * Q4
290
+ return Q
291
+
292
+ def _centroid(gt):
293
+ rows, cols = gt.size()[-2:]
294
+ gt = gt.view(rows, cols)
295
+ if gt.sum() == 0:
296
+ X = torch.eye(1) * round(cols / 2)
297
+ Y = torch.eye(1) * round(rows / 2)
298
+ else:
299
+ total = gt.sum()
300
+ i = torch.from_numpy(np.arange(0, cols)).float().cuda()
301
+ j = torch.from_numpy(np.arange(0, rows)).float().cuda()
302
+ X = torch.round((gt.sum(dim=0) * i).sum() / total + 1e-20)
303
+ Y = torch.round((gt.sum(dim=1) * j).sum() / total + 1e-20)
304
+ return X.long(), Y.long()
305
+
306
+
307
+ def _divideGT(gt, X, Y):
308
+ h, w = gt.size()[-2:]
309
+ area = h * w
310
+ gt = gt.view(h, w)
311
+ LT = gt[:Y, :X]
312
+ RT = gt[:Y, X:w]
313
+ LB = gt[Y:h, :X]
314
+ RB = gt[Y:h, X:w]
315
+ X = X.float()
316
+ Y = Y.float()
317
+ w1 = X * Y / area
318
+ w2 = (w - X) * Y / area
319
+ w3 = X * (h - Y) / area
320
+ w4 = 1 - w1 - w2 - w3
321
+ return LT, RT, LB, RB, w1, w2, w3, w4
322
+
323
+
324
+ def _dividePrediction(pred, X, Y):
325
+ h, w = pred.size()[-2:]
326
+ pred = pred.view(h, w)
327
+ LT = pred[:Y, :X]
328
+ RT = pred[:Y, X:w]
329
+ LB = pred[Y:h, :X]
330
+ RB = pred[Y:h, X:w]
331
+ return LT, RT, LB, RB
332
+
333
+
334
+ def _ssim(pred, gt):
335
+ gt = gt.float()
336
+ h, w = pred.size()[-2:]
337
+ N = h * w
338
+ x = pred.mean()
339
+ y = gt.mean()
340
+ sigma_x2 = ((pred - x) * (pred - x)).sum() / (N - 1 + 1e-20)
341
+ sigma_y2 = ((gt - y) * (gt - y)).sum() / (N - 1 + 1e-20)
342
+ sigma_xy = ((pred - x) * (gt - y)).sum() / (N - 1 + 1e-20)
343
+
344
+ aplha = 4 * x * y * sigma_xy
345
+ beta = (x * x + y * y) * (sigma_x2 + sigma_y2)
346
+
347
+ if aplha != 0:
348
+ Q = aplha / (beta + 1e-20)
349
+ elif aplha == 0 and beta == 0:
350
+ Q = 1.0
351
+ else:
352
+ Q = 0
353
+ return Q
354
+
355
+ def _eval_e(y_pred, y, num):
356
+ score = torch.zeros(num)
357
+ thlist = torch.linspace(0, 1 - 1e-10, num)
358
+ for i in range(num):
359
+ y_pred_th = (y_pred >= thlist[i]).float()
360
+ fm = y_pred_th - y_pred_th.mean()
361
+ gt = y - y.mean()
362
+ align_matrix = 2 * gt * fm / (gt * gt + fm * fm + 1e-20)
363
+ enhanced = ((align_matrix + 1) * (align_matrix + 1)) / 4
364
+ score[i] = torch.sum(enhanced) / (y.numel() - 1 + 1e-20)
365
+ return score
366
+
367
+
368
+ def calc_Semantic_Segmentation(y_pred,y_true):
369
+ batchsize = y_true.shape[0]
370
+ with torch.no_grad():
371
+ assert y_pred.shape == y_true.shape
372
+ f1, auc = 0, 0
373
+ y_true = y_true.cpu().numpy()
374
+ y_pred = y_pred.cpu().numpy()
375
+ for i in range(batchsize):
376
+ true = y_true[i].flatten()
377
+ true = true.astype(np.int)
378
+ pred = y_pred[i].flatten()
379
+
380
+ precision, recall, thresholds = precision_recall_curve(true, pred)
381
+
382
+ # auc
383
+ auc += roc_auc_score(true, pred)
384
+ # auc += roc_auc_score(np.array(true>0).astype(np.int), pred)
385
+ f1 += max([(2 * p * r) / (p + r+1e-10) for p, r in zip(precision, recall)])
386
+
387
+ return f1/batchsize, auc/batchsize, np.array(0), np.array(0)
visual_utils.py ADDED
@@ -0,0 +1,435 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torchvision.transforms.functional import normalize
2
+ import torch.nn as nn
3
+ import numpy as np
4
+
5
+
6
+ def denormalize(tensor, mean, std):
7
+ mean = np.array(mean)
8
+ std = np.array(std)
9
+
10
+ _mean = -mean/std
11
+ _std = 1/std
12
+ return normalize(tensor, _mean, _std)
13
+
14
+
15
+ class Denormalize(object):
16
+ def __init__(self, mean, std):
17
+ mean = np.array(mean)
18
+ std = np.array(std)
19
+ self._mean = -mean/std
20
+ self._std = 1/std
21
+
22
+ def __call__(self, tensor):
23
+ if isinstance(tensor, np.ndarray):
24
+ return (tensor - self._mean.reshape(-1,1,1)) / self._std.reshape(-1,1,1)
25
+ return normalize(tensor, self._mean, self._std)
26
+
27
+
28
+ def fix_bn(model):
29
+ for m in model.modules():
30
+ if isinstance(m, nn.BatchNorm2d):
31
+ m.eval()
32
+ m.weight.requires_grad = False
33
+ m.bias.requires_grad = False
34
+
35
+
36
+ def color_map(dataset):
37
+ if dataset=='voc':
38
+ return voc_cmap()
39
+ elif dataset=='cityscapes':
40
+ return cityscapes_cmap()
41
+ elif dataset=='ade':
42
+ return ade_cmap()
43
+ elif dataset =='isaid':
44
+ return isaid_cmap()
45
+ elif dataset =='SAR2020':
46
+ return SAR2020_cmap()
47
+ elif dataset =='Unify_single':
48
+ return unify_single_cmap()
49
+ elif dataset == 'Unify_double':
50
+ return unify_cmap()
51
+ elif dataset == 'Unify_YIJISAR':
52
+ return unify_YIJISAR_cmap()
53
+ elif dataset == 'Unify_Vai':
54
+ return unify_Vai_cmap()
55
+
56
+ def unify_Vai_cmap():
57
+ cmap = np.zeros((255, 3), dtype=np.uint8)
58
+ colors = [
59
+ [0, 0, 0],#0
60
+ [0, 255, 0],
61
+ [0, 0, 255],
62
+ [0, 0, 255],
63
+ [0, 0, 0],
64
+ [159,129,183],
65
+ [0, 255, 255], #6
66
+ [255,195,128],
67
+ [0, 0, 0],
68
+ [255,255,0],
69
+ [0, 0, 0], #10
70
+ [0, 0, 0],
71
+ [0,0,0],#12
72
+ [0, 0, 0],
73
+ [255,0,0], #14
74
+
75
+ ]
76
+ for i in range(len(colors)):
77
+ cmap[i] = colors[i]
78
+
79
+ return cmap.astype(np.uint8)
80
+
81
+ def cityscapes_cmap():
82
+ return np.array([(128, 64,128), (244, 35,232), ( 70, 70, 70), (102,102,156), (190,153,153), (153,153,153), (250,170, 30),
83
+ (220,220, 0), (107,142, 35), (152,251,152), ( 70,130,180), (220, 20, 60), (255, 0, 0), ( 0, 0,142),
84
+ ( 0, 0, 70), ( 0, 60,100), ( 0, 80,100), ( 0, 0,230), (119, 11, 32), ( 0, 0, 0)],
85
+ dtype=np.uint8)
86
+
87
+ def SAR2020_cmap():
88
+ cmap = np.zeros((256, 3), dtype=np.uint8)
89
+ colors = [[255,255,255], [255,255,0], [0,0,255], [0, 255,0], [255,0,0], [0,255,255]]
90
+ for i in range(len(colors)):
91
+ cmap[i] = colors[i]
92
+
93
+ return cmap.astype(np.uint8)
94
+
95
+
96
+ def isaid_cmap():
97
+ cmap = np.zeros((255, 3), dtype=np.uint8)
98
+ colors = [[0, 0, 0],
99
+ [0,63,0],
100
+ [0,63,191],
101
+ [0, 127, 63],
102
+ [0, 63, 255],
103
+ [0,100,155],
104
+ [0, 0, 191], [0, 127, 127],
105
+ [0, 127, 255], [0, 191, 127], [0, 0, 63], [0, 191, 127], [0, 127, 191],[0, 63, 63], [0 ,0 ,255], [0,63,127]]
106
+ for i in range(len(colors)):
107
+ cmap[i] = colors[i]
108
+
109
+ return cmap.astype(np.uint8)
110
+
111
+
112
+
113
+ def unify_cmap():
114
+ cmap = np.zeros((255, 3), dtype=np.uint8)
115
+ colors = [
116
+ [0, 0, 0],
117
+ [0, 127, 255],
118
+ # [0, 0, 191],
119
+ [0, 63, 0],
120
+ [0, 127, 63],
121
+ [0, 63, 255],
122
+ [0, 127, 127],
123
+ # [0, 0, 127],
124
+ [0, 0, 63],
125
+ [0, 63, 127],
126
+ [0, 63, 191],
127
+ [0, 63, 63],
128
+ [0, 127, 191],
129
+ [0, 191, 127],
130
+ [0, 100, 155],
131
+ [0, 0, 255],
132
+ # [255, 255, 255], #0
133
+ # [0,63,0], #1
134
+ # [0,127,63],#2
135
+ # [0, 63, 255],#3
136
+ # [0, 127, 127],#4
137
+ # [0,128,0], #5,
138
+ # [0, 0, 63],#6
139
+ # [0,63,127],#7
140
+ # [0, 63, 191],#8
141
+ # [0, 63, 63],#9
142
+ # [0, 127, 191], #10
143
+ # [0, 191, 127],#11
144
+ # [0,100,155], #12
145
+ # [0, 0, 255], #13
146
+ [0,255,0], #14
147
+ [0,153,204], #15
148
+ [204,204,68], #16
149
+ [255, 204, 51], #17
150
+ [255, 255, 204], #18
151
+ [0, 255, 255],
152
+ [255,102,102], #20
153
+ [0,255,0] ,#21
154
+ [255,255,0],#22
155
+ [255,0,0],#23
156
+ [255,195,128],#24
157
+ [153,102,153]
158
+ ]
159
+ for i in range(len(colors)):
160
+ cmap[i] = colors[i]
161
+
162
+ return cmap.astype(np.uint8)
163
+
164
+ def unify_YIJISAR_cmap():
165
+ cmap = np.zeros((255, 3), dtype=np.uint8)
166
+ colors = [
167
+ [255, 255, 0],#0
168
+ [0, 255, 0],
169
+ [0, 63, 0],
170
+ [0, 0, 255],
171
+ [0, 63, 255],
172
+ [255, 0, 0],
173
+ [0, 0, 63],
174
+ [0, 63, 127],
175
+ [0, 63, 191],
176
+ [0,255, 255],
177
+ [0, 127, 191],
178
+ [0, 191, 127],
179
+ [0,0,0],#12
180
+ [0, 0, 255],
181
+ [0,255,0], #14
182
+ [0,153,204], #15
183
+ [204,204,68], #16
184
+ [255, 204, 51], #17
185
+ [255, 255, 204], #18
186
+ [0, 255, 255],
187
+ [255,102,102], #20
188
+ [0,255,0] ,#21
189
+ [255,255,0],#22
190
+ [255,0,0],#23
191
+ [255,195,128],#24
192
+ [153,102,153]
193
+ ]
194
+ for i in range(len(colors)):
195
+ cmap[i] = colors[i]
196
+
197
+ return cmap.astype(np.uint8)
198
+
199
+
200
+ def unify_single_cmap():
201
+ cmap = np.zeros((255, 3), dtype=np.uint8)
202
+ colors = [
203
+ [0, 127, 255],#0
204
+ [0, 0, 0],
205
+ [0, 0, 0],
206
+ [0, 0, 0], #3
207
+ [0, 0, 0], #4
208
+ [0,255,0], #5
209
+ [0, 0, 0],
210
+ [0, 0, 0], # 7
211
+ [0, 0, 0], # 8
212
+ [0, 0, 0],
213
+ [0, 0, 0], # 10
214
+ [0, 0, 0], # 11
215
+ [0, 0, 0],
216
+ [0, 0, 0], # 13
217
+ [0, 0, 0], # 14
218
+ [0, 0, 0], #15
219
+ [159,129,183], #16
220
+ [0, 0, 0], #17
221
+ [255, 195, 128], #18
222
+ [0, 0, 0],
223
+ [255, 0, 0],#20
224
+ [255,255,0],
225
+ [0,0,255], #22
226
+ [0, 0, 0],
227
+ [0, 0, 0],
228
+ [0,0,0]
229
+ ]
230
+ for i in range(len(colors)):
231
+ cmap[i] = colors[i]
232
+
233
+ return cmap.astype(np.uint8)
234
+
235
+ def ade_cmap():
236
+ cmap = np.zeros((256, 3), dtype=np.uint8)
237
+ colors = [
238
+ [0, 0, 0],
239
+ [120, 120, 120],
240
+ [180, 120, 120],
241
+ [6, 230, 230],
242
+ [80, 50, 50],
243
+ [4, 200, 3],
244
+ [120, 120, 80],
245
+ [140, 140, 140],
246
+ [204, 5, 255],
247
+ [230, 230, 230],
248
+ [4, 250, 7],
249
+ [224, 5, 255],
250
+ [235, 255, 7],
251
+ [150, 5, 61],
252
+ [120, 120, 70],
253
+ [8, 255, 51],
254
+ [255, 6, 82],
255
+ [143, 255, 140],
256
+ [204, 255, 4],
257
+ [255, 51, 7],
258
+ [204, 70, 3],
259
+ [0, 102, 200],
260
+ [61, 230, 250],
261
+ [255, 6, 51],
262
+ [11, 102, 255],
263
+ [255, 7, 71],
264
+ [255, 9, 224],
265
+ [9, 7, 230],
266
+ [220, 220, 220],
267
+ [255, 9, 92],
268
+ [112, 9, 255],
269
+ [8, 255, 214],
270
+ [7, 255, 224],
271
+ [255, 184, 6],
272
+ [10, 255, 71],
273
+ [255, 41, 10],
274
+ [7, 255, 255],
275
+ [224, 255, 8],
276
+ [102, 8, 255],
277
+ [255, 61, 6],
278
+ [255, 194, 7],
279
+ [255, 122, 8],
280
+ [0, 255, 20],
281
+ [255, 8, 41],
282
+ [255, 5, 153],
283
+ [6, 51, 255],
284
+ [235, 12, 255],
285
+ [160, 150, 20],
286
+ [0, 163, 255],
287
+ [140, 140, 140],
288
+ [250, 10, 15],
289
+ [20, 255, 0],
290
+ [31, 255, 0],
291
+ [255, 31, 0],
292
+ [255, 224, 0],
293
+ [153, 255, 0],
294
+ [0, 0, 255],
295
+ [255, 71, 0],
296
+ [0, 235, 255],
297
+ [0, 173, 255],
298
+ [31, 0, 255],
299
+ [11, 200, 200],
300
+ [255, 82, 0],
301
+ [0, 255, 245],
302
+ [0, 61, 255],
303
+ [0, 255, 112],
304
+ [0, 255, 133],
305
+ [255, 0, 0],
306
+ [255, 163, 0],
307
+ [255, 102, 0],
308
+ [194, 255, 0],
309
+ [0, 143, 255],
310
+ [51, 255, 0],
311
+ [0, 82, 255],
312
+ [0, 255, 41],
313
+ [0, 255, 173],
314
+ [10, 0, 255],
315
+ [173, 255, 0],
316
+ [0, 255, 153],
317
+ [255, 92, 0],
318
+ [255, 0, 255],
319
+ [255, 0, 245],
320
+ [255, 0, 102],
321
+ [255, 173, 0],
322
+ [255, 0, 20],
323
+ [255, 184, 184],
324
+ [0, 31, 255],
325
+ [0, 255, 61],
326
+ [0, 71, 255],
327
+ [255, 0, 204],
328
+ [0, 255, 194],
329
+ [0, 255, 82],
330
+ [0, 10, 255],
331
+ [0, 112, 255],
332
+ [51, 0, 255],
333
+ [0, 194, 255],
334
+ [0, 122, 255],
335
+ [0, 255, 163],
336
+ [255, 153, 0],
337
+ [0, 255, 10],
338
+ [255, 112, 0],
339
+ [143, 255, 0],
340
+ [82, 0, 255],
341
+ [163, 255, 0],
342
+ [255, 235, 0],
343
+ [8, 184, 170],
344
+ [133, 0, 255],
345
+ [0, 255, 92],
346
+ [184, 0, 255],
347
+ [255, 0, 31],
348
+ [0, 184, 255],
349
+ [0, 214, 255],
350
+ [255, 0, 112],
351
+ [92, 255, 0],
352
+ [0, 224, 255],
353
+ [112, 224, 255],
354
+ [70, 184, 160],
355
+ [163, 0, 255],
356
+ [153, 0, 255],
357
+ [71, 255, 0],
358
+ [255, 0, 163],
359
+ [255, 204, 0],
360
+ [255, 0, 143],
361
+ [0, 255, 235],
362
+ [133, 255, 0],
363
+ [255, 0, 235],
364
+ [245, 0, 255],
365
+ [255, 0, 122],
366
+ [255, 245, 0],
367
+ [10, 190, 212],
368
+ [214, 255, 0],
369
+ [0, 204, 255],
370
+ [20, 0, 255],
371
+ [255, 255, 0],
372
+ [0, 153, 255],
373
+ [0, 41, 255],
374
+ [0, 255, 204],
375
+ [41, 0, 255],
376
+ [41, 255, 0],
377
+ [173, 0, 255],
378
+ [0, 245, 255],
379
+ [71, 0, 255],
380
+ [122, 0, 255],
381
+ [0, 255, 184],
382
+ [0, 92, 255],
383
+ [184, 255, 0],
384
+ [0, 133, 255],
385
+ [255, 214, 0],
386
+ [25, 194, 194],
387
+ [102, 255, 0],
388
+ [92, 0, 255]
389
+ ]
390
+
391
+ for i in range(len(colors)):
392
+ cmap[i] = colors[i]
393
+
394
+ return cmap.astype(np.uint8)
395
+
396
+
397
+ def voc_cmap(N=256, normalized=False):
398
+ def bitget(byteval, idx):
399
+ return ((byteval & (1 << idx)) != 0)
400
+
401
+ dtype = 'float32' if normalized else 'uint8'
402
+ cmap = np.zeros((N, 3), dtype=dtype)
403
+ for i in range(N):
404
+ r = g = b = 0
405
+ c = i
406
+ for j in range(8):
407
+ r = r | (bitget(c, 0) << 7-j)
408
+ g = g | (bitget(c, 1) << 7-j)
409
+ b = b | (bitget(c, 2) << 7-j)
410
+ c = c >> 3
411
+
412
+ cmap[i] = np.array([r, g, b])
413
+
414
+ cmap = cmap/255 if normalized else cmap
415
+ return cmap
416
+
417
+
418
+ class Label2Color(object):
419
+ def __init__(self, cmap):
420
+ self.cmap = cmap
421
+
422
+ def __call__(self, lbls):
423
+ return self.cmap[lbls]
424
+
425
+
426
+ def convert_bn2gn(module):
427
+ mod = module
428
+ if isinstance(module, nn.modules.batchnorm._BatchNorm):
429
+ num_features = module.num_features
430
+ num_groups = num_features//16
431
+ mod = nn.GroupNorm(num_groups=num_groups, num_channels=num_features)
432
+ for name, child in module.named_children():
433
+ mod.add_module(name, convert_bn2gn(child))
434
+ del module
435
+ return mod