Spaces:
Running
Running
Upload 2 files
Browse files- utils.py +387 -0
- visual_utils.py +435 -0
utils.py
ADDED
@@ -0,0 +1,387 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import time
|
3 |
+
import shutil
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import numpy as np
|
7 |
+
from torch.optim import SGD, Adam, AdamW
|
8 |
+
from tensorboardX import SummaryWriter
|
9 |
+
|
10 |
+
import sod_metric
|
11 |
+
class Averager():
|
12 |
+
|
13 |
+
def __init__(self):
|
14 |
+
self.n = 0.0
|
15 |
+
self.v = 0.0
|
16 |
+
|
17 |
+
def add(self, v, n=1.0):
|
18 |
+
self.v = (self.v * self.n + v * n) / (self.n + n)
|
19 |
+
self.n += n
|
20 |
+
|
21 |
+
def item(self):
|
22 |
+
return self.v
|
23 |
+
|
24 |
+
|
25 |
+
class Timer():
|
26 |
+
|
27 |
+
def __init__(self):
|
28 |
+
self.v = time.time()
|
29 |
+
|
30 |
+
def s(self):
|
31 |
+
self.v = time.time()
|
32 |
+
|
33 |
+
def t(self):
|
34 |
+
return time.time() - self.v
|
35 |
+
|
36 |
+
|
37 |
+
def time_text(t):
|
38 |
+
if t >= 3600:
|
39 |
+
return '{:.1f}h'.format(t / 3600)
|
40 |
+
elif t >= 60:
|
41 |
+
return '{:.1f}m'.format(t / 60)
|
42 |
+
else:
|
43 |
+
return '{:.1f}s'.format(t)
|
44 |
+
|
45 |
+
|
46 |
+
_log_path = None
|
47 |
+
|
48 |
+
|
49 |
+
def set_log_path(path):
|
50 |
+
global _log_path
|
51 |
+
_log_path = path
|
52 |
+
|
53 |
+
|
54 |
+
def log(obj, filename='log.txt'):
|
55 |
+
print(obj)
|
56 |
+
if _log_path is not None:
|
57 |
+
with open(os.path.join(_log_path, filename), 'a') as f:
|
58 |
+
print(obj, file=f)
|
59 |
+
|
60 |
+
|
61 |
+
def ensure_path(path, remove=True):
|
62 |
+
basename = os.path.basename(path.rstrip('/'))
|
63 |
+
if os.path.exists(path):
|
64 |
+
if remove and (basename.startswith('_')
|
65 |
+
or input('{} exists, remove? (y/[n]): '.format(path)) == 'y'):
|
66 |
+
shutil.rmtree(path)
|
67 |
+
os.makedirs(path, exist_ok=True)
|
68 |
+
else:
|
69 |
+
os.makedirs(path, exist_ok=True)
|
70 |
+
|
71 |
+
|
72 |
+
def set_save_path(save_path, remove=True):
|
73 |
+
ensure_path(save_path, remove=remove)
|
74 |
+
set_log_path(save_path)
|
75 |
+
writer = SummaryWriter(os.path.join(save_path, 'tensorboard'))
|
76 |
+
return log, writer
|
77 |
+
|
78 |
+
|
79 |
+
def compute_num_params(model, text=False):
|
80 |
+
tot = int(sum([np.prod(p.shape) for p in model.parameters()]))
|
81 |
+
if text:
|
82 |
+
if tot >= 1e6:
|
83 |
+
return '{:.1f}M'.format(tot / 1e6)
|
84 |
+
else:
|
85 |
+
return '{:.1f}K'.format(tot / 1e3)
|
86 |
+
else:
|
87 |
+
return tot
|
88 |
+
|
89 |
+
|
90 |
+
def make_optimizer(param_list, optimizer_spec, load_sd=False):
|
91 |
+
Optimizer = {
|
92 |
+
'sgd': SGD,
|
93 |
+
'adam': Adam,
|
94 |
+
'adamw': AdamW
|
95 |
+
}[optimizer_spec['name']]
|
96 |
+
optimizer = Optimizer(param_list, **optimizer_spec['args'])
|
97 |
+
if load_sd:
|
98 |
+
optimizer.load_state_dict(optimizer_spec['sd'])
|
99 |
+
return optimizer
|
100 |
+
|
101 |
+
|
102 |
+
def make_coord(shape, ranges=None, flatten=True):
|
103 |
+
""" Make coordinates at grid centers.
|
104 |
+
"""
|
105 |
+
coord_seqs = []
|
106 |
+
for i, n in enumerate(shape):
|
107 |
+
if ranges is None:
|
108 |
+
v0, v1 = -1, 1
|
109 |
+
else:
|
110 |
+
v0, v1 = ranges[i]
|
111 |
+
r = (v1 - v0) / (2 * n)
|
112 |
+
seq = v0 + r + (2 * r) * torch.arange(n).float()
|
113 |
+
coord_seqs.append(seq)
|
114 |
+
ret = torch.stack(torch.meshgrid(*coord_seqs), dim=-1)
|
115 |
+
# if flatten:
|
116 |
+
# ret = ret.view(-1, ret.shape[-1])
|
117 |
+
|
118 |
+
return ret
|
119 |
+
|
120 |
+
|
121 |
+
|
122 |
+
def calc_cod(y_pred, y_true):
|
123 |
+
batchsize = y_true.shape[0]
|
124 |
+
|
125 |
+
metric_FM = sod_metric.Fmeasure()
|
126 |
+
metric_WFM = sod_metric.WeightedFmeasure()
|
127 |
+
metric_SM = sod_metric.Smeasure()
|
128 |
+
metric_EM = sod_metric.Emeasure()
|
129 |
+
metric_MAE = sod_metric.MAE()
|
130 |
+
with torch.no_grad():
|
131 |
+
assert y_pred.shape == y_true.shape
|
132 |
+
|
133 |
+
for i in range(batchsize):
|
134 |
+
true, pred = \
|
135 |
+
y_true[i, 0].cpu().data.numpy() * 255, y_pred[i, 0].cpu().data.numpy() * 255
|
136 |
+
|
137 |
+
metric_FM.step(pred=pred, gt=true)
|
138 |
+
metric_WFM.step(pred=pred, gt=true)
|
139 |
+
metric_SM.step(pred=pred, gt=true)
|
140 |
+
metric_EM.step(pred=pred, gt=true)
|
141 |
+
metric_MAE.step(pred=pred, gt=true)
|
142 |
+
|
143 |
+
fm = metric_FM.get_results()["fm"]
|
144 |
+
wfm = metric_WFM.get_results()["wfm"]
|
145 |
+
sm = metric_SM.get_results()["sm"]
|
146 |
+
em = metric_EM.get_results()["em"]["curve"].mean()
|
147 |
+
mae = metric_MAE.get_results()["mae"]
|
148 |
+
|
149 |
+
return sm, em, wfm, mae
|
150 |
+
|
151 |
+
|
152 |
+
from sklearn.metrics import precision_recall_curve
|
153 |
+
|
154 |
+
|
155 |
+
def calc_f1(y_pred,y_true):
|
156 |
+
batchsize = y_true.shape[0]
|
157 |
+
with torch.no_grad():
|
158 |
+
print(y_pred.shape)
|
159 |
+
print(y_true.shape)
|
160 |
+
assert y_pred.shape == y_true.shape
|
161 |
+
f1, auc = 0, 0
|
162 |
+
y_true = y_true.cpu().numpy()
|
163 |
+
y_pred = y_pred.cpu().numpy()
|
164 |
+
for i in range(batchsize):
|
165 |
+
true = y_true[i].flatten()
|
166 |
+
true = true.astype(np.int)
|
167 |
+
pred = y_pred[i].flatten()
|
168 |
+
|
169 |
+
precision, recall, thresholds = precision_recall_curve(true, pred)
|
170 |
+
|
171 |
+
# auc
|
172 |
+
auc += roc_auc_score(true, pred)
|
173 |
+
# auc += roc_auc_score(np.array(true>0).astype(np.int), pred)
|
174 |
+
f1 += max([(2 * p * r) / (p + r+1e-10) for p, r in zip(precision, recall)])
|
175 |
+
|
176 |
+
return f1/batchsize, auc/batchsize, np.array(0), np.array(0)
|
177 |
+
|
178 |
+
def calc_fmeasure(y_pred,y_true):
|
179 |
+
batchsize = y_true.shape[0]
|
180 |
+
|
181 |
+
mae, preds, gts = [], [], []
|
182 |
+
with torch.no_grad():
|
183 |
+
for i in range(batchsize):
|
184 |
+
gt_float, pred_float = \
|
185 |
+
y_true[i, 0].cpu().data.numpy(), y_pred[i, 0].cpu().data.numpy()
|
186 |
+
|
187 |
+
# # MAE
|
188 |
+
mae.append(np.sum(cv2.absdiff(gt_float.astype(float), pred_float.astype(float))) / (
|
189 |
+
pred_float.shape[1] * pred_float.shape[0]))
|
190 |
+
# mae.append(np.mean(np.abs(pred_float - gt_float)))
|
191 |
+
#
|
192 |
+
pred = np.uint8(pred_float * 255)
|
193 |
+
gt = np.uint8(gt_float * 255)
|
194 |
+
|
195 |
+
pred_float_ = np.where(pred > min(1.5 * np.mean(pred), 255), np.ones_like(pred_float),
|
196 |
+
np.zeros_like(pred_float))
|
197 |
+
gt_float_ = np.where(gt > min(1.5 * np.mean(gt), 255), np.ones_like(pred_float),
|
198 |
+
np.zeros_like(pred_float))
|
199 |
+
|
200 |
+
preds.extend(pred_float_.ravel())
|
201 |
+
gts.extend(gt_float_.ravel())
|
202 |
+
|
203 |
+
RECALL = recall_score(gts, preds)
|
204 |
+
PERC = precision_score(gts, preds)
|
205 |
+
|
206 |
+
fmeasure = (1 + 0.3) * PERC * RECALL / (0.3 * PERC + RECALL)
|
207 |
+
MAE = np.mean(mae)
|
208 |
+
|
209 |
+
return fmeasure, MAE, np.array(0), np.array(0)
|
210 |
+
|
211 |
+
from sklearn.metrics import roc_auc_score,recall_score,precision_score
|
212 |
+
import cv2
|
213 |
+
def calc_ber(y_pred, y_true):
|
214 |
+
batchsize = y_true.shape[0]
|
215 |
+
y_pred, y_true = y_pred.permute(0, 2, 3, 1).squeeze(-1), y_true.permute(0, 2, 3, 1).squeeze(-1)
|
216 |
+
with torch.no_grad():
|
217 |
+
assert y_pred.shape == y_true.shape
|
218 |
+
pos_err, neg_err, ber = 0, 0, 0
|
219 |
+
y_true = y_true.cpu().numpy()
|
220 |
+
y_pred = y_pred.cpu().numpy()
|
221 |
+
for i in range(batchsize):
|
222 |
+
true = y_true[i].flatten()
|
223 |
+
pred = y_pred[i].flatten()
|
224 |
+
|
225 |
+
TP, TN, FP, FN, BER, ACC = get_binary_classification_metrics(pred * 255,
|
226 |
+
true * 255, 125)
|
227 |
+
pos_err += (1 - TP / (TP + FN)) * 100
|
228 |
+
neg_err += (1 - TN / (TN + FP)) * 100
|
229 |
+
|
230 |
+
return pos_err / batchsize, neg_err / batchsize, (pos_err + neg_err) / 2 / batchsize, np.array(0)
|
231 |
+
|
232 |
+
def get_binary_classification_metrics(pred, gt, threshold=None):
|
233 |
+
if threshold is not None:
|
234 |
+
gt = (gt > threshold)
|
235 |
+
pred = (pred > threshold)
|
236 |
+
TP = np.logical_and(gt, pred).sum()
|
237 |
+
TN = np.logical_and(np.logical_not(gt), np.logical_not(pred)).sum()
|
238 |
+
FN = np.logical_and(gt, np.logical_not(pred)).sum()
|
239 |
+
FP = np.logical_and(np.logical_not(gt), pred).sum()
|
240 |
+
BER = cal_ber(TN, TP, FN, FP)
|
241 |
+
ACC = cal_acc(TN, TP, FN, FP)
|
242 |
+
return TP, TN, FP, FN, BER, ACC
|
243 |
+
|
244 |
+
def cal_ber(tn, tp, fn, fp):
|
245 |
+
return 0.5*(fp/(tn+fp) + fn/(fn+tp))
|
246 |
+
|
247 |
+
def cal_acc(tn, tp, fn, fp):
|
248 |
+
return (tp + tn) / (tp + tn + fp + fn)
|
249 |
+
|
250 |
+
def _sigmoid(x):
|
251 |
+
return 1 / (1 + np.exp(-x))
|
252 |
+
|
253 |
+
|
254 |
+
def _eval_pr(y_pred, y, num):
|
255 |
+
prec, recall = torch.zeros(num), torch.zeros(num)
|
256 |
+
thlist = torch.linspace(0, 1 - 1e-10, num)
|
257 |
+
for i in range(num):
|
258 |
+
y_temp = (y_pred >= thlist[i]).float()
|
259 |
+
tp = (y_temp * y).sum()
|
260 |
+
prec[i], recall[i] = tp / (y_temp.sum() + 1e-20), tp / (y.sum() +
|
261 |
+
1e-20)
|
262 |
+
return prec, recall
|
263 |
+
|
264 |
+
def _S_object(pred, gt):
|
265 |
+
fg = torch.where(gt == 0, torch.zeros_like(pred), pred)
|
266 |
+
bg = torch.where(gt == 1, torch.zeros_like(pred), 1 - pred)
|
267 |
+
o_fg = _object(fg, gt)
|
268 |
+
o_bg = _object(bg, 1 - gt)
|
269 |
+
u = gt.mean()
|
270 |
+
Q = u * o_fg + (1 - u) * o_bg
|
271 |
+
return Q
|
272 |
+
|
273 |
+
def _object(pred, gt):
|
274 |
+
temp = pred[gt == 1]
|
275 |
+
x = temp.mean()
|
276 |
+
sigma_x = temp.std()
|
277 |
+
score = 2.0 * x / (x * x + 1.0 + sigma_x + 1e-20)
|
278 |
+
|
279 |
+
return score
|
280 |
+
|
281 |
+
def _S_region(pred, gt):
|
282 |
+
X, Y = _centroid(gt)
|
283 |
+
gt1, gt2, gt3, gt4, w1, w2, w3, w4 = _divideGT(gt, X, Y)
|
284 |
+
p1, p2, p3, p4 = _dividePrediction(pred, X, Y)
|
285 |
+
Q1 = _ssim(p1, gt1)
|
286 |
+
Q2 = _ssim(p2, gt2)
|
287 |
+
Q3 = _ssim(p3, gt3)
|
288 |
+
Q4 = _ssim(p4, gt4)
|
289 |
+
Q = w1 * Q1 + w2 * Q2 + w3 * Q3 + w4 * Q4
|
290 |
+
return Q
|
291 |
+
|
292 |
+
def _centroid(gt):
|
293 |
+
rows, cols = gt.size()[-2:]
|
294 |
+
gt = gt.view(rows, cols)
|
295 |
+
if gt.sum() == 0:
|
296 |
+
X = torch.eye(1) * round(cols / 2)
|
297 |
+
Y = torch.eye(1) * round(rows / 2)
|
298 |
+
else:
|
299 |
+
total = gt.sum()
|
300 |
+
i = torch.from_numpy(np.arange(0, cols)).float().cuda()
|
301 |
+
j = torch.from_numpy(np.arange(0, rows)).float().cuda()
|
302 |
+
X = torch.round((gt.sum(dim=0) * i).sum() / total + 1e-20)
|
303 |
+
Y = torch.round((gt.sum(dim=1) * j).sum() / total + 1e-20)
|
304 |
+
return X.long(), Y.long()
|
305 |
+
|
306 |
+
|
307 |
+
def _divideGT(gt, X, Y):
|
308 |
+
h, w = gt.size()[-2:]
|
309 |
+
area = h * w
|
310 |
+
gt = gt.view(h, w)
|
311 |
+
LT = gt[:Y, :X]
|
312 |
+
RT = gt[:Y, X:w]
|
313 |
+
LB = gt[Y:h, :X]
|
314 |
+
RB = gt[Y:h, X:w]
|
315 |
+
X = X.float()
|
316 |
+
Y = Y.float()
|
317 |
+
w1 = X * Y / area
|
318 |
+
w2 = (w - X) * Y / area
|
319 |
+
w3 = X * (h - Y) / area
|
320 |
+
w4 = 1 - w1 - w2 - w3
|
321 |
+
return LT, RT, LB, RB, w1, w2, w3, w4
|
322 |
+
|
323 |
+
|
324 |
+
def _dividePrediction(pred, X, Y):
|
325 |
+
h, w = pred.size()[-2:]
|
326 |
+
pred = pred.view(h, w)
|
327 |
+
LT = pred[:Y, :X]
|
328 |
+
RT = pred[:Y, X:w]
|
329 |
+
LB = pred[Y:h, :X]
|
330 |
+
RB = pred[Y:h, X:w]
|
331 |
+
return LT, RT, LB, RB
|
332 |
+
|
333 |
+
|
334 |
+
def _ssim(pred, gt):
|
335 |
+
gt = gt.float()
|
336 |
+
h, w = pred.size()[-2:]
|
337 |
+
N = h * w
|
338 |
+
x = pred.mean()
|
339 |
+
y = gt.mean()
|
340 |
+
sigma_x2 = ((pred - x) * (pred - x)).sum() / (N - 1 + 1e-20)
|
341 |
+
sigma_y2 = ((gt - y) * (gt - y)).sum() / (N - 1 + 1e-20)
|
342 |
+
sigma_xy = ((pred - x) * (gt - y)).sum() / (N - 1 + 1e-20)
|
343 |
+
|
344 |
+
aplha = 4 * x * y * sigma_xy
|
345 |
+
beta = (x * x + y * y) * (sigma_x2 + sigma_y2)
|
346 |
+
|
347 |
+
if aplha != 0:
|
348 |
+
Q = aplha / (beta + 1e-20)
|
349 |
+
elif aplha == 0 and beta == 0:
|
350 |
+
Q = 1.0
|
351 |
+
else:
|
352 |
+
Q = 0
|
353 |
+
return Q
|
354 |
+
|
355 |
+
def _eval_e(y_pred, y, num):
|
356 |
+
score = torch.zeros(num)
|
357 |
+
thlist = torch.linspace(0, 1 - 1e-10, num)
|
358 |
+
for i in range(num):
|
359 |
+
y_pred_th = (y_pred >= thlist[i]).float()
|
360 |
+
fm = y_pred_th - y_pred_th.mean()
|
361 |
+
gt = y - y.mean()
|
362 |
+
align_matrix = 2 * gt * fm / (gt * gt + fm * fm + 1e-20)
|
363 |
+
enhanced = ((align_matrix + 1) * (align_matrix + 1)) / 4
|
364 |
+
score[i] = torch.sum(enhanced) / (y.numel() - 1 + 1e-20)
|
365 |
+
return score
|
366 |
+
|
367 |
+
|
368 |
+
def calc_Semantic_Segmentation(y_pred,y_true):
|
369 |
+
batchsize = y_true.shape[0]
|
370 |
+
with torch.no_grad():
|
371 |
+
assert y_pred.shape == y_true.shape
|
372 |
+
f1, auc = 0, 0
|
373 |
+
y_true = y_true.cpu().numpy()
|
374 |
+
y_pred = y_pred.cpu().numpy()
|
375 |
+
for i in range(batchsize):
|
376 |
+
true = y_true[i].flatten()
|
377 |
+
true = true.astype(np.int)
|
378 |
+
pred = y_pred[i].flatten()
|
379 |
+
|
380 |
+
precision, recall, thresholds = precision_recall_curve(true, pred)
|
381 |
+
|
382 |
+
# auc
|
383 |
+
auc += roc_auc_score(true, pred)
|
384 |
+
# auc += roc_auc_score(np.array(true>0).astype(np.int), pred)
|
385 |
+
f1 += max([(2 * p * r) / (p + r+1e-10) for p, r in zip(precision, recall)])
|
386 |
+
|
387 |
+
return f1/batchsize, auc/batchsize, np.array(0), np.array(0)
|
visual_utils.py
ADDED
@@ -0,0 +1,435 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torchvision.transforms.functional import normalize
|
2 |
+
import torch.nn as nn
|
3 |
+
import numpy as np
|
4 |
+
|
5 |
+
|
6 |
+
def denormalize(tensor, mean, std):
|
7 |
+
mean = np.array(mean)
|
8 |
+
std = np.array(std)
|
9 |
+
|
10 |
+
_mean = -mean/std
|
11 |
+
_std = 1/std
|
12 |
+
return normalize(tensor, _mean, _std)
|
13 |
+
|
14 |
+
|
15 |
+
class Denormalize(object):
|
16 |
+
def __init__(self, mean, std):
|
17 |
+
mean = np.array(mean)
|
18 |
+
std = np.array(std)
|
19 |
+
self._mean = -mean/std
|
20 |
+
self._std = 1/std
|
21 |
+
|
22 |
+
def __call__(self, tensor):
|
23 |
+
if isinstance(tensor, np.ndarray):
|
24 |
+
return (tensor - self._mean.reshape(-1,1,1)) / self._std.reshape(-1,1,1)
|
25 |
+
return normalize(tensor, self._mean, self._std)
|
26 |
+
|
27 |
+
|
28 |
+
def fix_bn(model):
|
29 |
+
for m in model.modules():
|
30 |
+
if isinstance(m, nn.BatchNorm2d):
|
31 |
+
m.eval()
|
32 |
+
m.weight.requires_grad = False
|
33 |
+
m.bias.requires_grad = False
|
34 |
+
|
35 |
+
|
36 |
+
def color_map(dataset):
|
37 |
+
if dataset=='voc':
|
38 |
+
return voc_cmap()
|
39 |
+
elif dataset=='cityscapes':
|
40 |
+
return cityscapes_cmap()
|
41 |
+
elif dataset=='ade':
|
42 |
+
return ade_cmap()
|
43 |
+
elif dataset =='isaid':
|
44 |
+
return isaid_cmap()
|
45 |
+
elif dataset =='SAR2020':
|
46 |
+
return SAR2020_cmap()
|
47 |
+
elif dataset =='Unify_single':
|
48 |
+
return unify_single_cmap()
|
49 |
+
elif dataset == 'Unify_double':
|
50 |
+
return unify_cmap()
|
51 |
+
elif dataset == 'Unify_YIJISAR':
|
52 |
+
return unify_YIJISAR_cmap()
|
53 |
+
elif dataset == 'Unify_Vai':
|
54 |
+
return unify_Vai_cmap()
|
55 |
+
|
56 |
+
def unify_Vai_cmap():
|
57 |
+
cmap = np.zeros((255, 3), dtype=np.uint8)
|
58 |
+
colors = [
|
59 |
+
[0, 0, 0],#0
|
60 |
+
[0, 255, 0],
|
61 |
+
[0, 0, 255],
|
62 |
+
[0, 0, 255],
|
63 |
+
[0, 0, 0],
|
64 |
+
[159,129,183],
|
65 |
+
[0, 255, 255], #6
|
66 |
+
[255,195,128],
|
67 |
+
[0, 0, 0],
|
68 |
+
[255,255,0],
|
69 |
+
[0, 0, 0], #10
|
70 |
+
[0, 0, 0],
|
71 |
+
[0,0,0],#12
|
72 |
+
[0, 0, 0],
|
73 |
+
[255,0,0], #14
|
74 |
+
|
75 |
+
]
|
76 |
+
for i in range(len(colors)):
|
77 |
+
cmap[i] = colors[i]
|
78 |
+
|
79 |
+
return cmap.astype(np.uint8)
|
80 |
+
|
81 |
+
def cityscapes_cmap():
|
82 |
+
return np.array([(128, 64,128), (244, 35,232), ( 70, 70, 70), (102,102,156), (190,153,153), (153,153,153), (250,170, 30),
|
83 |
+
(220,220, 0), (107,142, 35), (152,251,152), ( 70,130,180), (220, 20, 60), (255, 0, 0), ( 0, 0,142),
|
84 |
+
( 0, 0, 70), ( 0, 60,100), ( 0, 80,100), ( 0, 0,230), (119, 11, 32), ( 0, 0, 0)],
|
85 |
+
dtype=np.uint8)
|
86 |
+
|
87 |
+
def SAR2020_cmap():
|
88 |
+
cmap = np.zeros((256, 3), dtype=np.uint8)
|
89 |
+
colors = [[255,255,255], [255,255,0], [0,0,255], [0, 255,0], [255,0,0], [0,255,255]]
|
90 |
+
for i in range(len(colors)):
|
91 |
+
cmap[i] = colors[i]
|
92 |
+
|
93 |
+
return cmap.astype(np.uint8)
|
94 |
+
|
95 |
+
|
96 |
+
def isaid_cmap():
|
97 |
+
cmap = np.zeros((255, 3), dtype=np.uint8)
|
98 |
+
colors = [[0, 0, 0],
|
99 |
+
[0,63,0],
|
100 |
+
[0,63,191],
|
101 |
+
[0, 127, 63],
|
102 |
+
[0, 63, 255],
|
103 |
+
[0,100,155],
|
104 |
+
[0, 0, 191], [0, 127, 127],
|
105 |
+
[0, 127, 255], [0, 191, 127], [0, 0, 63], [0, 191, 127], [0, 127, 191],[0, 63, 63], [0 ,0 ,255], [0,63,127]]
|
106 |
+
for i in range(len(colors)):
|
107 |
+
cmap[i] = colors[i]
|
108 |
+
|
109 |
+
return cmap.astype(np.uint8)
|
110 |
+
|
111 |
+
|
112 |
+
|
113 |
+
def unify_cmap():
|
114 |
+
cmap = np.zeros((255, 3), dtype=np.uint8)
|
115 |
+
colors = [
|
116 |
+
[0, 0, 0],
|
117 |
+
[0, 127, 255],
|
118 |
+
# [0, 0, 191],
|
119 |
+
[0, 63, 0],
|
120 |
+
[0, 127, 63],
|
121 |
+
[0, 63, 255],
|
122 |
+
[0, 127, 127],
|
123 |
+
# [0, 0, 127],
|
124 |
+
[0, 0, 63],
|
125 |
+
[0, 63, 127],
|
126 |
+
[0, 63, 191],
|
127 |
+
[0, 63, 63],
|
128 |
+
[0, 127, 191],
|
129 |
+
[0, 191, 127],
|
130 |
+
[0, 100, 155],
|
131 |
+
[0, 0, 255],
|
132 |
+
# [255, 255, 255], #0
|
133 |
+
# [0,63,0], #1
|
134 |
+
# [0,127,63],#2
|
135 |
+
# [0, 63, 255],#3
|
136 |
+
# [0, 127, 127],#4
|
137 |
+
# [0,128,0], #5,
|
138 |
+
# [0, 0, 63],#6
|
139 |
+
# [0,63,127],#7
|
140 |
+
# [0, 63, 191],#8
|
141 |
+
# [0, 63, 63],#9
|
142 |
+
# [0, 127, 191], #10
|
143 |
+
# [0, 191, 127],#11
|
144 |
+
# [0,100,155], #12
|
145 |
+
# [0, 0, 255], #13
|
146 |
+
[0,255,0], #14
|
147 |
+
[0,153,204], #15
|
148 |
+
[204,204,68], #16
|
149 |
+
[255, 204, 51], #17
|
150 |
+
[255, 255, 204], #18
|
151 |
+
[0, 255, 255],
|
152 |
+
[255,102,102], #20
|
153 |
+
[0,255,0] ,#21
|
154 |
+
[255,255,0],#22
|
155 |
+
[255,0,0],#23
|
156 |
+
[255,195,128],#24
|
157 |
+
[153,102,153]
|
158 |
+
]
|
159 |
+
for i in range(len(colors)):
|
160 |
+
cmap[i] = colors[i]
|
161 |
+
|
162 |
+
return cmap.astype(np.uint8)
|
163 |
+
|
164 |
+
def unify_YIJISAR_cmap():
|
165 |
+
cmap = np.zeros((255, 3), dtype=np.uint8)
|
166 |
+
colors = [
|
167 |
+
[255, 255, 0],#0
|
168 |
+
[0, 255, 0],
|
169 |
+
[0, 63, 0],
|
170 |
+
[0, 0, 255],
|
171 |
+
[0, 63, 255],
|
172 |
+
[255, 0, 0],
|
173 |
+
[0, 0, 63],
|
174 |
+
[0, 63, 127],
|
175 |
+
[0, 63, 191],
|
176 |
+
[0,255, 255],
|
177 |
+
[0, 127, 191],
|
178 |
+
[0, 191, 127],
|
179 |
+
[0,0,0],#12
|
180 |
+
[0, 0, 255],
|
181 |
+
[0,255,0], #14
|
182 |
+
[0,153,204], #15
|
183 |
+
[204,204,68], #16
|
184 |
+
[255, 204, 51], #17
|
185 |
+
[255, 255, 204], #18
|
186 |
+
[0, 255, 255],
|
187 |
+
[255,102,102], #20
|
188 |
+
[0,255,0] ,#21
|
189 |
+
[255,255,0],#22
|
190 |
+
[255,0,0],#23
|
191 |
+
[255,195,128],#24
|
192 |
+
[153,102,153]
|
193 |
+
]
|
194 |
+
for i in range(len(colors)):
|
195 |
+
cmap[i] = colors[i]
|
196 |
+
|
197 |
+
return cmap.astype(np.uint8)
|
198 |
+
|
199 |
+
|
200 |
+
def unify_single_cmap():
|
201 |
+
cmap = np.zeros((255, 3), dtype=np.uint8)
|
202 |
+
colors = [
|
203 |
+
[0, 127, 255],#0
|
204 |
+
[0, 0, 0],
|
205 |
+
[0, 0, 0],
|
206 |
+
[0, 0, 0], #3
|
207 |
+
[0, 0, 0], #4
|
208 |
+
[0,255,0], #5
|
209 |
+
[0, 0, 0],
|
210 |
+
[0, 0, 0], # 7
|
211 |
+
[0, 0, 0], # 8
|
212 |
+
[0, 0, 0],
|
213 |
+
[0, 0, 0], # 10
|
214 |
+
[0, 0, 0], # 11
|
215 |
+
[0, 0, 0],
|
216 |
+
[0, 0, 0], # 13
|
217 |
+
[0, 0, 0], # 14
|
218 |
+
[0, 0, 0], #15
|
219 |
+
[159,129,183], #16
|
220 |
+
[0, 0, 0], #17
|
221 |
+
[255, 195, 128], #18
|
222 |
+
[0, 0, 0],
|
223 |
+
[255, 0, 0],#20
|
224 |
+
[255,255,0],
|
225 |
+
[0,0,255], #22
|
226 |
+
[0, 0, 0],
|
227 |
+
[0, 0, 0],
|
228 |
+
[0,0,0]
|
229 |
+
]
|
230 |
+
for i in range(len(colors)):
|
231 |
+
cmap[i] = colors[i]
|
232 |
+
|
233 |
+
return cmap.astype(np.uint8)
|
234 |
+
|
235 |
+
def ade_cmap():
|
236 |
+
cmap = np.zeros((256, 3), dtype=np.uint8)
|
237 |
+
colors = [
|
238 |
+
[0, 0, 0],
|
239 |
+
[120, 120, 120],
|
240 |
+
[180, 120, 120],
|
241 |
+
[6, 230, 230],
|
242 |
+
[80, 50, 50],
|
243 |
+
[4, 200, 3],
|
244 |
+
[120, 120, 80],
|
245 |
+
[140, 140, 140],
|
246 |
+
[204, 5, 255],
|
247 |
+
[230, 230, 230],
|
248 |
+
[4, 250, 7],
|
249 |
+
[224, 5, 255],
|
250 |
+
[235, 255, 7],
|
251 |
+
[150, 5, 61],
|
252 |
+
[120, 120, 70],
|
253 |
+
[8, 255, 51],
|
254 |
+
[255, 6, 82],
|
255 |
+
[143, 255, 140],
|
256 |
+
[204, 255, 4],
|
257 |
+
[255, 51, 7],
|
258 |
+
[204, 70, 3],
|
259 |
+
[0, 102, 200],
|
260 |
+
[61, 230, 250],
|
261 |
+
[255, 6, 51],
|
262 |
+
[11, 102, 255],
|
263 |
+
[255, 7, 71],
|
264 |
+
[255, 9, 224],
|
265 |
+
[9, 7, 230],
|
266 |
+
[220, 220, 220],
|
267 |
+
[255, 9, 92],
|
268 |
+
[112, 9, 255],
|
269 |
+
[8, 255, 214],
|
270 |
+
[7, 255, 224],
|
271 |
+
[255, 184, 6],
|
272 |
+
[10, 255, 71],
|
273 |
+
[255, 41, 10],
|
274 |
+
[7, 255, 255],
|
275 |
+
[224, 255, 8],
|
276 |
+
[102, 8, 255],
|
277 |
+
[255, 61, 6],
|
278 |
+
[255, 194, 7],
|
279 |
+
[255, 122, 8],
|
280 |
+
[0, 255, 20],
|
281 |
+
[255, 8, 41],
|
282 |
+
[255, 5, 153],
|
283 |
+
[6, 51, 255],
|
284 |
+
[235, 12, 255],
|
285 |
+
[160, 150, 20],
|
286 |
+
[0, 163, 255],
|
287 |
+
[140, 140, 140],
|
288 |
+
[250, 10, 15],
|
289 |
+
[20, 255, 0],
|
290 |
+
[31, 255, 0],
|
291 |
+
[255, 31, 0],
|
292 |
+
[255, 224, 0],
|
293 |
+
[153, 255, 0],
|
294 |
+
[0, 0, 255],
|
295 |
+
[255, 71, 0],
|
296 |
+
[0, 235, 255],
|
297 |
+
[0, 173, 255],
|
298 |
+
[31, 0, 255],
|
299 |
+
[11, 200, 200],
|
300 |
+
[255, 82, 0],
|
301 |
+
[0, 255, 245],
|
302 |
+
[0, 61, 255],
|
303 |
+
[0, 255, 112],
|
304 |
+
[0, 255, 133],
|
305 |
+
[255, 0, 0],
|
306 |
+
[255, 163, 0],
|
307 |
+
[255, 102, 0],
|
308 |
+
[194, 255, 0],
|
309 |
+
[0, 143, 255],
|
310 |
+
[51, 255, 0],
|
311 |
+
[0, 82, 255],
|
312 |
+
[0, 255, 41],
|
313 |
+
[0, 255, 173],
|
314 |
+
[10, 0, 255],
|
315 |
+
[173, 255, 0],
|
316 |
+
[0, 255, 153],
|
317 |
+
[255, 92, 0],
|
318 |
+
[255, 0, 255],
|
319 |
+
[255, 0, 245],
|
320 |
+
[255, 0, 102],
|
321 |
+
[255, 173, 0],
|
322 |
+
[255, 0, 20],
|
323 |
+
[255, 184, 184],
|
324 |
+
[0, 31, 255],
|
325 |
+
[0, 255, 61],
|
326 |
+
[0, 71, 255],
|
327 |
+
[255, 0, 204],
|
328 |
+
[0, 255, 194],
|
329 |
+
[0, 255, 82],
|
330 |
+
[0, 10, 255],
|
331 |
+
[0, 112, 255],
|
332 |
+
[51, 0, 255],
|
333 |
+
[0, 194, 255],
|
334 |
+
[0, 122, 255],
|
335 |
+
[0, 255, 163],
|
336 |
+
[255, 153, 0],
|
337 |
+
[0, 255, 10],
|
338 |
+
[255, 112, 0],
|
339 |
+
[143, 255, 0],
|
340 |
+
[82, 0, 255],
|
341 |
+
[163, 255, 0],
|
342 |
+
[255, 235, 0],
|
343 |
+
[8, 184, 170],
|
344 |
+
[133, 0, 255],
|
345 |
+
[0, 255, 92],
|
346 |
+
[184, 0, 255],
|
347 |
+
[255, 0, 31],
|
348 |
+
[0, 184, 255],
|
349 |
+
[0, 214, 255],
|
350 |
+
[255, 0, 112],
|
351 |
+
[92, 255, 0],
|
352 |
+
[0, 224, 255],
|
353 |
+
[112, 224, 255],
|
354 |
+
[70, 184, 160],
|
355 |
+
[163, 0, 255],
|
356 |
+
[153, 0, 255],
|
357 |
+
[71, 255, 0],
|
358 |
+
[255, 0, 163],
|
359 |
+
[255, 204, 0],
|
360 |
+
[255, 0, 143],
|
361 |
+
[0, 255, 235],
|
362 |
+
[133, 255, 0],
|
363 |
+
[255, 0, 235],
|
364 |
+
[245, 0, 255],
|
365 |
+
[255, 0, 122],
|
366 |
+
[255, 245, 0],
|
367 |
+
[10, 190, 212],
|
368 |
+
[214, 255, 0],
|
369 |
+
[0, 204, 255],
|
370 |
+
[20, 0, 255],
|
371 |
+
[255, 255, 0],
|
372 |
+
[0, 153, 255],
|
373 |
+
[0, 41, 255],
|
374 |
+
[0, 255, 204],
|
375 |
+
[41, 0, 255],
|
376 |
+
[41, 255, 0],
|
377 |
+
[173, 0, 255],
|
378 |
+
[0, 245, 255],
|
379 |
+
[71, 0, 255],
|
380 |
+
[122, 0, 255],
|
381 |
+
[0, 255, 184],
|
382 |
+
[0, 92, 255],
|
383 |
+
[184, 255, 0],
|
384 |
+
[0, 133, 255],
|
385 |
+
[255, 214, 0],
|
386 |
+
[25, 194, 194],
|
387 |
+
[102, 255, 0],
|
388 |
+
[92, 0, 255]
|
389 |
+
]
|
390 |
+
|
391 |
+
for i in range(len(colors)):
|
392 |
+
cmap[i] = colors[i]
|
393 |
+
|
394 |
+
return cmap.astype(np.uint8)
|
395 |
+
|
396 |
+
|
397 |
+
def voc_cmap(N=256, normalized=False):
|
398 |
+
def bitget(byteval, idx):
|
399 |
+
return ((byteval & (1 << idx)) != 0)
|
400 |
+
|
401 |
+
dtype = 'float32' if normalized else 'uint8'
|
402 |
+
cmap = np.zeros((N, 3), dtype=dtype)
|
403 |
+
for i in range(N):
|
404 |
+
r = g = b = 0
|
405 |
+
c = i
|
406 |
+
for j in range(8):
|
407 |
+
r = r | (bitget(c, 0) << 7-j)
|
408 |
+
g = g | (bitget(c, 1) << 7-j)
|
409 |
+
b = b | (bitget(c, 2) << 7-j)
|
410 |
+
c = c >> 3
|
411 |
+
|
412 |
+
cmap[i] = np.array([r, g, b])
|
413 |
+
|
414 |
+
cmap = cmap/255 if normalized else cmap
|
415 |
+
return cmap
|
416 |
+
|
417 |
+
|
418 |
+
class Label2Color(object):
|
419 |
+
def __init__(self, cmap):
|
420 |
+
self.cmap = cmap
|
421 |
+
|
422 |
+
def __call__(self, lbls):
|
423 |
+
return self.cmap[lbls]
|
424 |
+
|
425 |
+
|
426 |
+
def convert_bn2gn(module):
|
427 |
+
mod = module
|
428 |
+
if isinstance(module, nn.modules.batchnorm._BatchNorm):
|
429 |
+
num_features = module.num_features
|
430 |
+
num_groups = num_features//16
|
431 |
+
mod = nn.GroupNorm(num_groups=num_groups, num_channels=num_features)
|
432 |
+
for name, child in module.named_children():
|
433 |
+
mod.add_module(name, convert_bn2gn(child))
|
434 |
+
del module
|
435 |
+
return mod
|