RamAnanth1 commited on
Commit
2c19c0f
1 Parent(s): a31136c

Create utils.py

Browse files
Files changed (1) hide show
  1. utils.py +361 -0
utils.py ADDED
@@ -0,0 +1,361 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # MIT License
2
+
3
+ # Copyright (c) 2022 Intelligent Systems Lab Org
4
+
5
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ # of this software and associated documentation files (the "Software"), to deal
7
+ # in the Software without restriction, including without limitation the rights
8
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ # copies of the Software, and to permit persons to whom the Software is
10
+ # furnished to do so, subject to the following conditions:
11
+
12
+ # The above copyright notice and this permission notice shall be included in all
13
+ # copies or substantial portions of the Software.
14
+
15
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ # SOFTWARE.
22
+
23
+ # File author: Shariq Farooq Bhat
24
+
25
+ """Miscellaneous utility functions."""
26
+
27
+ from scipy import ndimage
28
+
29
+ import base64
30
+ import math
31
+ import re
32
+ from io import BytesIO
33
+
34
+ import matplotlib
35
+ import matplotlib.cm
36
+ import numpy as np
37
+ import requests
38
+ import torch
39
+ import torch.distributed as dist
40
+ import torch.nn
41
+ import torch.nn as nn
42
+ import torch.utils.data.distributed
43
+ from PIL import Image
44
+ from torchvision.transforms import ToTensor
45
+
46
+
47
+ class RunningAverage:
48
+ def __init__(self):
49
+ self.avg = 0
50
+ self.count = 0
51
+
52
+ def append(self, value):
53
+ self.avg = (value + self.count * self.avg) / (self.count + 1)
54
+ self.count += 1
55
+
56
+ def get_value(self):
57
+ return self.avg
58
+
59
+
60
+ def denormalize(x):
61
+ """Reverses the imagenet normalization applied to the input.
62
+ Args:
63
+ x (torch.Tensor - shape(N,3,H,W)): input tensor
64
+ Returns:
65
+ torch.Tensor - shape(N,3,H,W): Denormalized input
66
+ """
67
+ mean = torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(x.device)
68
+ std = torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(x.device)
69
+ return x * std + mean
70
+
71
+
72
+ class RunningAverageDict:
73
+ """A dictionary of running averages."""
74
+ def __init__(self):
75
+ self._dict = None
76
+
77
+ def update(self, new_dict):
78
+ if new_dict is None:
79
+ return
80
+
81
+ if self._dict is None:
82
+ self._dict = dict()
83
+ for key, value in new_dict.items():
84
+ self._dict[key] = RunningAverage()
85
+
86
+ for key, value in new_dict.items():
87
+ self._dict[key].append(value)
88
+
89
+ def get_value(self):
90
+ if self._dict is None:
91
+ return None
92
+ return {key: value.get_value() for key, value in self._dict.items()}
93
+
94
+
95
+ def colorize(value, vmin=None, vmax=None, cmap='gray_r', invalid_val=-99, invalid_mask=None, background_color=(128, 128, 128, 255), gamma_corrected=False, value_transform=None):
96
+ """Converts a depth map to a color image.
97
+ Args:
98
+ value (torch.Tensor, numpy.ndarry): Input depth map. Shape: (H, W) or (1, H, W) or (1, 1, H, W). All singular dimensions are squeezed
99
+ vmin (float, optional): vmin-valued entries are mapped to start color of cmap. If None, value.min() is used. Defaults to None.
100
+ vmax (float, optional): vmax-valued entries are mapped to end color of cmap. If None, value.max() is used. Defaults to None.
101
+ cmap (str, optional): matplotlib colormap to use. Defaults to 'magma_r'.
102
+ invalid_val (int, optional): Specifies value of invalid pixels that should be colored as 'background_color'. Defaults to -99.
103
+ invalid_mask (numpy.ndarray, optional): Boolean mask for invalid regions. Defaults to None.
104
+ background_color (tuple[int], optional): 4-tuple RGB color to give to invalid pixels. Defaults to (128, 128, 128, 255).
105
+ gamma_corrected (bool, optional): Apply gamma correction to colored image. Defaults to False.
106
+ value_transform (Callable, optional): Apply transform function to valid pixels before coloring. Defaults to None.
107
+ Returns:
108
+ numpy.ndarray, dtype - uint8: Colored depth map. Shape: (H, W, 4)
109
+ """
110
+ if isinstance(value, torch.Tensor):
111
+ value = value.detach().cpu().numpy()
112
+
113
+ value = value.squeeze()
114
+ if invalid_mask is None:
115
+ invalid_mask = value == invalid_val
116
+ mask = np.logical_not(invalid_mask)
117
+
118
+ # normalize
119
+ vmin = np.percentile(value[mask],2) if vmin is None else vmin
120
+ vmax = np.percentile(value[mask],85) if vmax is None else vmax
121
+ if vmin != vmax:
122
+ value = (value - vmin) / (vmax - vmin) # vmin..vmax
123
+ else:
124
+ # Avoid 0-division
125
+ value = value * 0.
126
+
127
+ # squeeze last dim if it exists
128
+ # grey out the invalid values
129
+
130
+ value[invalid_mask] = np.nan
131
+ cmapper = matplotlib.cm.get_cmap(cmap)
132
+ if value_transform:
133
+ value = value_transform(value)
134
+ # value = value / value.max()
135
+ value = cmapper(value, bytes=True) # (nxmx4)
136
+
137
+ # img = value[:, :, :]
138
+ img = value[...]
139
+ img[invalid_mask] = background_color
140
+
141
+ # return img.transpose((2, 0, 1))
142
+ if gamma_corrected:
143
+ # gamma correction
144
+ img = img / 255
145
+ img = np.power(img, 2.2)
146
+ img = img * 255
147
+ img = img.astype(np.uint8)
148
+ return img
149
+
150
+
151
+ def count_parameters(model, include_all=False):
152
+ return sum(p.numel() for p in model.parameters() if p.requires_grad or include_all)
153
+
154
+
155
+ def compute_errors(gt, pred):
156
+ """Compute metrics for 'pred' compared to 'gt'
157
+ Args:
158
+ gt (numpy.ndarray): Ground truth values
159
+ pred (numpy.ndarray): Predicted values
160
+ gt.shape should be equal to pred.shape
161
+ Returns:
162
+ dict: Dictionary containing the following metrics:
163
+ 'a1': Delta1 accuracy: Fraction of pixels that are within a scale factor of 1.25
164
+ 'a2': Delta2 accuracy: Fraction of pixels that are within a scale factor of 1.25^2
165
+ 'a3': Delta3 accuracy: Fraction of pixels that are within a scale factor of 1.25^3
166
+ 'abs_rel': Absolute relative error
167
+ 'rmse': Root mean squared error
168
+ 'log_10': Absolute log10 error
169
+ 'sq_rel': Squared relative error
170
+ 'rmse_log': Root mean squared error on the log scale
171
+ 'silog': Scale invariant log error
172
+ """
173
+ thresh = np.maximum((gt / pred), (pred / gt))
174
+ a1 = (thresh < 1.25).mean()
175
+ a2 = (thresh < 1.25 ** 2).mean()
176
+ a3 = (thresh < 1.25 ** 3).mean()
177
+
178
+ abs_rel = np.mean(np.abs(gt - pred) / gt)
179
+ sq_rel = np.mean(((gt - pred) ** 2) / gt)
180
+
181
+ rmse = (gt - pred) ** 2
182
+ rmse = np.sqrt(rmse.mean())
183
+
184
+ rmse_log = (np.log(gt) - np.log(pred)) ** 2
185
+ rmse_log = np.sqrt(rmse_log.mean())
186
+
187
+ err = np.log(pred) - np.log(gt)
188
+ silog = np.sqrt(np.mean(err ** 2) - np.mean(err) ** 2) * 100
189
+
190
+ log_10 = (np.abs(np.log10(gt) - np.log10(pred))).mean()
191
+ return dict(a1=a1, a2=a2, a3=a3, abs_rel=abs_rel, rmse=rmse, log_10=log_10, rmse_log=rmse_log,
192
+ silog=silog, sq_rel=sq_rel)
193
+
194
+
195
+ def compute_metrics(gt, pred, interpolate=True, garg_crop=False, eigen_crop=True, dataset='nyu', min_depth_eval=0.1, max_depth_eval=10, **kwargs):
196
+ """Compute metrics of predicted depth maps. Applies cropping and masking as necessary or specified via arguments. Refer to compute_errors for more details on metrics.
197
+ """
198
+ if 'config' in kwargs:
199
+ config = kwargs['config']
200
+ garg_crop = config.garg_crop
201
+ eigen_crop = config.eigen_crop
202
+ min_depth_eval = config.min_depth_eval
203
+ max_depth_eval = config.max_depth_eval
204
+
205
+ if gt.shape[-2:] != pred.shape[-2:] and interpolate:
206
+ pred = nn.functional.interpolate(
207
+ pred, gt.shape[-2:], mode='bilinear', align_corners=True)
208
+
209
+ pred = pred.squeeze().cpu().numpy()
210
+ pred[pred < min_depth_eval] = min_depth_eval
211
+ pred[pred > max_depth_eval] = max_depth_eval
212
+ pred[np.isinf(pred)] = max_depth_eval
213
+ pred[np.isnan(pred)] = min_depth_eval
214
+
215
+ gt_depth = gt.squeeze().cpu().numpy()
216
+ valid_mask = np.logical_and(
217
+ gt_depth > min_depth_eval, gt_depth < max_depth_eval)
218
+
219
+ if garg_crop or eigen_crop:
220
+ gt_height, gt_width = gt_depth.shape
221
+ eval_mask = np.zeros(valid_mask.shape)
222
+
223
+ if garg_crop:
224
+ eval_mask[int(0.40810811 * gt_height):int(0.99189189 * gt_height),
225
+ int(0.03594771 * gt_width):int(0.96405229 * gt_width)] = 1
226
+
227
+ elif eigen_crop:
228
+ # print("-"*10, " EIGEN CROP ", "-"*10)
229
+ if dataset == 'kitti':
230
+ eval_mask[int(0.3324324 * gt_height):int(0.91351351 * gt_height),
231
+ int(0.0359477 * gt_width):int(0.96405229 * gt_width)] = 1
232
+ else:
233
+ # assert gt_depth.shape == (480, 640), "Error: Eigen crop is currently only valid for (480, 640) images"
234
+ eval_mask[45:471, 41:601] = 1
235
+ else:
236
+ eval_mask = np.ones(valid_mask.shape)
237
+ valid_mask = np.logical_and(valid_mask, eval_mask)
238
+ return compute_errors(gt_depth[valid_mask], pred[valid_mask])
239
+
240
+
241
+ #################################### Model uilts ################################################
242
+
243
+
244
+ def parallelize(config, model, find_unused_parameters=True):
245
+
246
+ if config.gpu is not None:
247
+ torch.cuda.set_device(config.gpu)
248
+ model = model.cuda(config.gpu)
249
+
250
+ config.multigpu = False
251
+ if config.distributed:
252
+ # Use DDP
253
+ config.multigpu = True
254
+ config.rank = config.rank * config.ngpus_per_node + config.gpu
255
+ dist.init_process_group(backend=config.dist_backend, init_method=config.dist_url,
256
+ world_size=config.world_size, rank=config.rank)
257
+ config.batch_size = int(config.batch_size / config.ngpus_per_node)
258
+ # config.batch_size = 8
259
+ config.workers = int(
260
+ (config.num_workers + config.ngpus_per_node - 1) / config.ngpus_per_node)
261
+ print("Device", config.gpu, "Rank", config.rank, "batch size",
262
+ config.batch_size, "Workers", config.workers)
263
+ torch.cuda.set_device(config.gpu)
264
+ model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
265
+ model = model.cuda(config.gpu)
266
+ model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[config.gpu], output_device=config.gpu,
267
+ find_unused_parameters=find_unused_parameters)
268
+
269
+ elif config.gpu is None:
270
+ # Use DP
271
+ config.multigpu = True
272
+ model = model.cuda()
273
+ model = torch.nn.DataParallel(model)
274
+
275
+ return model
276
+
277
+
278
+ #################################################################################################
279
+
280
+
281
+ #####################################################################################################
282
+
283
+
284
+ class colors:
285
+ '''Colors class:
286
+ Reset all colors with colors.reset
287
+ Two subclasses fg for foreground and bg for background.
288
+ Use as colors.subclass.colorname.
289
+ i.e. colors.fg.red or colors.bg.green
290
+ Also, the generic bold, disable, underline, reverse, strikethrough,
291
+ and invisible work with the main class
292
+ i.e. colors.bold
293
+ '''
294
+ reset = '\033[0m'
295
+ bold = '\033[01m'
296
+ disable = '\033[02m'
297
+ underline = '\033[04m'
298
+ reverse = '\033[07m'
299
+ strikethrough = '\033[09m'
300
+ invisible = '\033[08m'
301
+
302
+ class fg:
303
+ black = '\033[30m'
304
+ red = '\033[31m'
305
+ green = '\033[32m'
306
+ orange = '\033[33m'
307
+ blue = '\033[34m'
308
+ purple = '\033[35m'
309
+ cyan = '\033[36m'
310
+ lightgrey = '\033[37m'
311
+ darkgrey = '\033[90m'
312
+ lightred = '\033[91m'
313
+ lightgreen = '\033[92m'
314
+ yellow = '\033[93m'
315
+ lightblue = '\033[94m'
316
+ pink = '\033[95m'
317
+ lightcyan = '\033[96m'
318
+
319
+ class bg:
320
+ black = '\033[40m'
321
+ red = '\033[41m'
322
+ green = '\033[42m'
323
+ orange = '\033[43m'
324
+ blue = '\033[44m'
325
+ purple = '\033[45m'
326
+ cyan = '\033[46m'
327
+ lightgrey = '\033[47m'
328
+
329
+
330
+ def printc(text, color):
331
+ print(f"{color}{text}{colors.reset}")
332
+
333
+ ############################################
334
+
335
+ def get_image_from_url(url):
336
+ response = requests.get(url)
337
+ img = Image.open(BytesIO(response.content)).convert("RGB")
338
+ return img
339
+
340
+ def url_to_torch(url, size=(384, 384)):
341
+ img = get_image_from_url(url)
342
+ img = img.resize(size, Image.ANTIALIAS)
343
+ img = torch.from_numpy(np.asarray(img)).float()
344
+ img = img.permute(2, 0, 1)
345
+ img.div_(255)
346
+ return img
347
+
348
+ def pil_to_batched_tensor(img):
349
+ return ToTensor()(img).unsqueeze(0)
350
+
351
+ def save_raw_16bit(depth, fpath="raw.png"):
352
+ if isinstance(depth, torch.Tensor):
353
+ depth = depth.squeeze().cpu().numpy()
354
+
355
+ assert isinstance(depth, np.ndarray), "Depth must be a torch tensor or numpy array"
356
+ assert depth.ndim == 2, "Depth must be 2D"
357
+ depth = depth * 256 # scale for 16-bit png
358
+ depth = depth.astype(np.uint16)
359
+ depth = Image.fromarray(depth)
360
+ depth.save(fpath)
361
+ print("Saved raw depth to", fpath)