Fabrice-TIERCELIN commited on
Commit
feba46d
·
verified ·
1 Parent(s): 3d1f37f

Delete clipseg/evaluation_utils.py

Browse files
Files changed (1) hide show
  1. clipseg/evaluation_utils.py +0 -292
clipseg/evaluation_utils.py DELETED
@@ -1,292 +0,0 @@
1
- from torch.functional import Tensor
2
- from general_utils import load_model
3
- from torch.utils.data import DataLoader
4
- import torch
5
- import numpy as np
6
-
7
- def denorm(img):
8
-
9
- np_input = False
10
- if isinstance(img, np.ndarray):
11
- img = torch.from_numpy(img)
12
- np_input = True
13
-
14
- mean = torch.Tensor([0.485, 0.456, 0.406])
15
- std = torch.Tensor([0.229, 0.224, 0.225])
16
-
17
- img_denorm = (img*std[:,None,None]) + mean[:,None,None]
18
-
19
- if np_input:
20
- img_denorm = np.clip(img_denorm.numpy(), 0, 1)
21
- else:
22
- img_denorm = torch.clamp(img_denorm, 0, 1)
23
-
24
- return img_denorm
25
-
26
-
27
- def norm(img):
28
- mean = torch.Tensor([0.485, 0.456, 0.406])
29
- std = torch.Tensor([0.229, 0.224, 0.225])
30
- return (img - mean[:,None,None]) / std[:,None,None]
31
-
32
-
33
- def fast_iou_curve(p, g):
34
-
35
- g = g[p.sort().indices]
36
- p = torch.sigmoid(p.sort().values)
37
-
38
- scores = []
39
- vals = np.linspace(0, 1, 50)
40
-
41
- for q in vals:
42
-
43
- n = int(len(g) * q)
44
-
45
- valid = torch.where(p > q)[0]
46
- if len(valid) > 0:
47
- n = int(valid[0])
48
- else:
49
- n = len(g)
50
-
51
- fn = g[:n].sum()
52
- tn = n - fn
53
- tp = g[n:].sum()
54
- fp = len(g) - n - tp
55
-
56
- iou = tp / (tp + fn + fp)
57
-
58
- precision = tp / (tp + fp)
59
- recall = tp / (tp + fn)
60
-
61
- scores += [iou]
62
-
63
- return vals, scores
64
-
65
-
66
- def fast_rp_curve(p, g):
67
-
68
- g = g[p.sort().indices]
69
- p = torch.sigmoid(p.sort().values)
70
-
71
- precisions, recalls = [], []
72
- vals = np.linspace(p.min(), p.max(), 250)
73
-
74
- for q in p[::100000]:
75
-
76
- n = int(len(g) * q)
77
-
78
- valid = torch.where(p > q)[0]
79
- if len(valid) > 0:
80
- n = int(valid[0])
81
- else:
82
- n = len(g)
83
-
84
- fn = g[:n].sum()
85
- tn = n - fn
86
- tp = g[n:].sum()
87
- fp = len(g) - n - tp
88
-
89
- iou = tp / (tp + fn + fp)
90
-
91
- precision = tp / (tp + fp)
92
- recall = tp / (tp + fn)
93
-
94
- precisions += [precision]
95
- recalls += [recall]
96
-
97
- return recalls, precisions
98
-
99
-
100
- # Image processing
101
-
102
- def img_preprocess(batch, blur=0, grayscale=False, center_context=None, rect=False, rect_color=(255,0,0), rect_width=2,
103
- brightness=1.0, bg_fac=1, colorize=False, outline=False, image_size=224):
104
- import cv2
105
-
106
- rw = rect_width
107
-
108
- out = []
109
- for img, mask in zip(batch[1], batch[2]):
110
-
111
- img = img.cpu() if isinstance(img, torch.Tensor) else torch.from_numpy(img)
112
- mask = mask.cpu() if isinstance(mask, torch.Tensor) else torch.from_numpy(mask)
113
-
114
- img *= brightness
115
- img_bl = img
116
- if blur > 0: # best 5
117
- img_bl = torch.from_numpy(cv2.GaussianBlur(img.permute(1,2,0).numpy(), (15, 15), blur)).permute(2,0,1)
118
-
119
- if grayscale:
120
- img_bl = img_bl[1][None]
121
-
122
- #img_inp = img_ratio*img*mask + (1-img_ratio)*img_bl
123
- # img_inp = img_ratio*img*mask + (1-img_ratio)*img_bl * (1-mask)
124
- img_inp = img*mask + (bg_fac) * img_bl * (1-mask)
125
-
126
- if rect:
127
- _, bbox = crop_mask(img, mask, context=0.1)
128
- img_inp[:, bbox[2]: bbox[3], max(0, bbox[0]-rw):bbox[0]+rw] = torch.tensor(rect_color)[:,None,None]
129
- img_inp[:, bbox[2]: bbox[3], max(0, bbox[1]-rw):bbox[1]+rw] = torch.tensor(rect_color)[:,None,None]
130
- img_inp[:, max(0, bbox[2]-1): bbox[2]+rw, bbox[0]:bbox[1]] = torch.tensor(rect_color)[:,None,None]
131
- img_inp[:, max(0, bbox[3]-1): bbox[3]+rw, bbox[0]:bbox[1]] = torch.tensor(rect_color)[:,None,None]
132
-
133
-
134
- if center_context is not None:
135
- img_inp = object_crop(img_inp, mask, context=center_context, image_size=image_size)
136
-
137
- if colorize:
138
- img_gray = denorm(img)
139
- img_gray = cv2.cvtColor(img_gray.permute(1,2,0).numpy(), cv2.COLOR_RGB2GRAY)
140
- img_gray = torch.stack([torch.from_numpy(img_gray)]*3)
141
- img_inp = torch.tensor([1,0.2,0.2])[:,None,None] * img_gray * mask + bg_fac * img_gray * (1-mask)
142
- img_inp = norm(img_inp)
143
-
144
- if outline:
145
- cont = cv2.findContours(mask.byte().numpy(), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
146
- outline_img = np.zeros(mask.shape, dtype=np.uint8)
147
- cv2.drawContours(outline_img, cont[0], -1, thickness=5, color=(255, 255, 255))
148
- outline_img = torch.stack([torch.from_numpy(outline_img)]*3).float() / 255.
149
- img_inp = torch.tensor([1,0,0])[:,None,None] * outline_img + denorm(img_inp) * (1- outline_img)
150
- img_inp = norm(img_inp)
151
-
152
- out += [img_inp]
153
-
154
- return torch.stack(out)
155
-
156
-
157
- def object_crop(img, mask, context=0.0, square=False, image_size=224):
158
- img_crop, bbox = crop_mask(img, mask, context=context, square=square)
159
- img_crop = pad_to_square(img_crop, channel_dim=0)
160
- img_crop = torch.nn.functional.interpolate(img_crop.unsqueeze(0), (image_size, image_size)).squeeze(0)
161
- return img_crop
162
-
163
-
164
- def crop_mask(img, mask, context=0.0, square=False):
165
-
166
- assert img.shape[1:] == mask.shape
167
-
168
- bbox = [mask.max(0).values.argmax(), mask.size(0) - mask.max(0).values.flip(0).argmax()]
169
- bbox += [mask.max(1).values.argmax(), mask.size(1) - mask.max(1).values.flip(0).argmax()]
170
- bbox = [int(x) for x in bbox]
171
-
172
- width, height = (bbox[3] - bbox[2]), (bbox[1] - bbox[0])
173
-
174
- # square mask
175
- if square:
176
- bbox[0] = int(max(0, bbox[0] - context * height))
177
- bbox[1] = int(min(mask.size(0), bbox[1] + context * height))
178
- bbox[2] = int(max(0, bbox[2] - context * width))
179
- bbox[3] = int(min(mask.size(1), bbox[3] + context * width))
180
-
181
- width, height = (bbox[3] - bbox[2]), (bbox[1] - bbox[0])
182
- if height > width:
183
- bbox[2] = int(max(0, (bbox[2] - 0.5*height)))
184
- bbox[3] = bbox[2] + height
185
- else:
186
- bbox[0] = int(max(0, (bbox[0] - 0.5*width)))
187
- bbox[1] = bbox[0] + width
188
- else:
189
- bbox[0] = int(max(0, bbox[0] - context * height))
190
- bbox[1] = int(min(mask.size(0), bbox[1] + context * height))
191
- bbox[2] = int(max(0, bbox[2] - context * width))
192
- bbox[3] = int(min(mask.size(1), bbox[3] + context * width))
193
-
194
- width, height = (bbox[3] - bbox[2]), (bbox[1] - bbox[0])
195
- img_crop = img[:, bbox[2]: bbox[3], bbox[0]: bbox[1]]
196
- return img_crop, bbox
197
-
198
-
199
- def pad_to_square(img, channel_dim=2, fill=0):
200
- """
201
-
202
-
203
- add padding such that a squared image is returned """
204
-
205
- from torchvision.transforms.functional import pad
206
-
207
- if channel_dim == 2:
208
- img = img.permute(2, 0, 1)
209
- elif channel_dim == 0:
210
- pass
211
- else:
212
- raise ValueError('invalid channel_dim')
213
-
214
- h, w = img.shape[1:]
215
- pady1 = pady2 = padx1 = padx2 = 0
216
-
217
- if h > w:
218
- padx1 = (h - w) // 2
219
- padx2 = h - w - padx1
220
- elif w > h:
221
- pady1 = (w - h) // 2
222
- pady2 = w - h - pady1
223
-
224
- img_padded = pad(img, padding=(padx1, pady1, padx2, pady2), padding_mode='constant')
225
-
226
- if channel_dim == 2:
227
- img_padded = img_padded.permute(1, 2, 0)
228
-
229
- return img_padded
230
-
231
-
232
- # qualitative
233
-
234
- def split_sentence(inp, limit=9):
235
- t_new, current_len = [], 0
236
- for k, t in enumerate(inp.split(' ')):
237
- current_len += len(t) + 1
238
- t_new += [t+' ']
239
- # not last
240
- if current_len > limit and k != len(inp.split(' ')) - 1:
241
- current_len = 0
242
- t_new += ['\n']
243
-
244
- t_new = ''.join(t_new)
245
- return t_new
246
-
247
-
248
- from matplotlib import pyplot as plt
249
-
250
-
251
- def plot(imgs, *preds, labels=None, scale=1, cmap=plt.cm.magma, aps=None, gt_labels=None, vmax=None):
252
-
253
- row_off = 0 if labels is None else 1
254
- _, ax = plt.subplots(len(imgs) + row_off, 1 + len(preds), figsize=(scale * float(1 + 2*len(preds)), scale * float(len(imgs)*2)))
255
- [a.axis('off') for a in ax.flatten()]
256
-
257
- if labels is not None:
258
- for j in range(len(labels)):
259
- t_new = split_sentence(labels[j], limit=6)
260
- ax[0, 1+ j].text(0.5, 0.1, t_new, ha='center', fontsize=3+ 10*scale)
261
-
262
-
263
- for i in range(len(imgs)):
264
- ax[i + row_off,0].imshow(imgs[i])
265
- for j in range(len(preds)):
266
- img = preds[j][i][0].detach().cpu().numpy()
267
-
268
- if gt_labels is not None and labels[j] == gt_labels[i]:
269
- print(j, labels[j], gt_labels[i])
270
- edgecolor = 'red'
271
- if aps is not None:
272
- ax[i + row_off, 1 + j].text(30, 70, f'AP: {aps[i]:.3f}', color='red', fontsize=8)
273
- else:
274
- edgecolor = 'k'
275
-
276
- rect = plt.Rectangle([0,0], img.shape[0], img.shape[1], facecolor="none",
277
- edgecolor=edgecolor, linewidth=3)
278
- ax[i + row_off,1 + j].add_patch(rect)
279
-
280
- if vmax is None:
281
- this_vmax = 1
282
- elif vmax == 'per_prompt':
283
- this_vmax = max([preds[j][_i][0].max() for _i in range(len(imgs))])
284
- elif vmax == 'per_image':
285
- this_vmax = max([preds[_j][i][0].max() for _j in range(len(preds))])
286
-
287
- ax[i + row_off,1 + j].imshow(img, vmin=0, vmax=this_vmax, cmap=cmap)
288
-
289
-
290
- # ax[i,1 + j].imshow(preds[j][i][0].detach().cpu().numpy(), vmin=preds[j].min(), vmax=preds[j].max())
291
- plt.tight_layout()
292
- plt.subplots_adjust(wspace=0.05, hspace=0.05)