Charlie Amalet commited on
Commit
0b808ab
·
1 Parent(s): ae4bd74

Upload Zoedepth folder

Browse files
Files changed (41) hide show
  1. zoedepth/data/__init__.py +24 -0
  2. zoedepth/data/data_mono.py +573 -0
  3. zoedepth/data/ddad.py +117 -0
  4. zoedepth/data/diml_indoor_test.py +125 -0
  5. zoedepth/data/diml_outdoor_test.py +114 -0
  6. zoedepth/data/diode.py +125 -0
  7. zoedepth/data/hypersim.py +138 -0
  8. zoedepth/data/ibims.py +81 -0
  9. zoedepth/data/preprocess.py +154 -0
  10. zoedepth/data/sun_rgbd_loader.py +106 -0
  11. zoedepth/data/transforms.py +481 -0
  12. zoedepth/data/vkitti.py +151 -0
  13. zoedepth/data/vkitti2.py +187 -0
  14. zoedepth/models/__init__.py +24 -0
  15. zoedepth/models/base_models/__init__.py +24 -0
  16. zoedepth/models/base_models/midas.py +377 -0
  17. zoedepth/models/builder.py +51 -0
  18. zoedepth/models/depth_model.py +152 -0
  19. zoedepth/models/layers/attractor.py +208 -0
  20. zoedepth/models/layers/dist_layers.py +121 -0
  21. zoedepth/models/layers/localbins_layers.py +169 -0
  22. zoedepth/models/layers/patch_transformer.py +91 -0
  23. zoedepth/models/model_io.py +92 -0
  24. zoedepth/models/zoedepth/__init__.py +31 -0
  25. zoedepth/models/zoedepth/config_zoedepth.json +58 -0
  26. zoedepth/models/zoedepth/config_zoedepth_kitti.json +22 -0
  27. zoedepth/models/zoedepth/zoedepth_v1.py +250 -0
  28. zoedepth/models/zoedepth_nk/__init__.py +31 -0
  29. zoedepth/models/zoedepth_nk/config_zoedepth_nk.json +67 -0
  30. zoedepth/models/zoedepth_nk/zoedepth_nk_v1.py +333 -0
  31. zoedepth/trainers/base_trainer.py +326 -0
  32. zoedepth/trainers/builder.py +48 -0
  33. zoedepth/trainers/loss.py +316 -0
  34. zoedepth/trainers/zoedepth_nk_trainer.py +143 -0
  35. zoedepth/trainers/zoedepth_trainer.py +177 -0
  36. zoedepth/utils/__init__.py +24 -0
  37. zoedepth/utils/arg_utils.py +33 -0
  38. zoedepth/utils/config.py +437 -0
  39. zoedepth/utils/easydict/__init__.py +158 -0
  40. zoedepth/utils/geometry.py +98 -0
  41. zoedepth/utils/misc.py +368 -0
zoedepth/data/__init__.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # MIT License
2
+
3
+ # Copyright (c) 2022 Intelligent Systems Lab Org
4
+
5
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ # of this software and associated documentation files (the "Software"), to deal
7
+ # in the Software without restriction, including without limitation the rights
8
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ # copies of the Software, and to permit persons to whom the Software is
10
+ # furnished to do so, subject to the following conditions:
11
+
12
+ # The above copyright notice and this permission notice shall be included in all
13
+ # copies or substantial portions of the Software.
14
+
15
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ # SOFTWARE.
22
+
23
+ # File author: Shariq Farooq Bhat
24
+
zoedepth/data/data_mono.py ADDED
@@ -0,0 +1,573 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # MIT License
2
+
3
+ # Copyright (c) 2022 Intelligent Systems Lab Org
4
+
5
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ # of this software and associated documentation files (the "Software"), to deal
7
+ # in the Software without restriction, including without limitation the rights
8
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ # copies of the Software, and to permit persons to whom the Software is
10
+ # furnished to do so, subject to the following conditions:
11
+
12
+ # The above copyright notice and this permission notice shall be included in all
13
+ # copies or substantial portions of the Software.
14
+
15
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ # SOFTWARE.
22
+
23
+ # File author: Shariq Farooq Bhat
24
+
25
+ # This file is partly inspired from BTS (https://github.com/cleinc/bts/blob/master/pytorch/bts_dataloader.py); author: Jin Han Lee
26
+
27
+ import itertools
28
+ import os
29
+ import random
30
+
31
+ import numpy as np
32
+ import cv2
33
+ import torch
34
+ import torch.nn as nn
35
+ import torch.utils.data.distributed
36
+ from zoedepth.utils.easydict import EasyDict as edict
37
+ from PIL import Image, ImageOps
38
+ from torch.utils.data import DataLoader, Dataset
39
+ from torchvision import transforms
40
+
41
+ from zoedepth.utils.config import change_dataset
42
+
43
+ from .ddad import get_ddad_loader
44
+ from .diml_indoor_test import get_diml_indoor_loader
45
+ from .diml_outdoor_test import get_diml_outdoor_loader
46
+ from .diode import get_diode_loader
47
+ from .hypersim import get_hypersim_loader
48
+ from .ibims import get_ibims_loader
49
+ from .sun_rgbd_loader import get_sunrgbd_loader
50
+ from .vkitti import get_vkitti_loader
51
+ from .vkitti2 import get_vkitti2_loader
52
+
53
+ from .preprocess import CropParams, get_white_border, get_black_border
54
+
55
+
56
+ def _is_pil_image(img):
57
+ return isinstance(img, Image.Image)
58
+
59
+
60
+ def _is_numpy_image(img):
61
+ return isinstance(img, np.ndarray) and (img.ndim in {2, 3})
62
+
63
+
64
+ def preprocessing_transforms(mode, **kwargs):
65
+ return transforms.Compose([
66
+ ToTensor(mode=mode, **kwargs)
67
+ ])
68
+
69
+
70
+ class DepthDataLoader(object):
71
+ def __init__(self, config, mode, device='cpu', transform=None, **kwargs):
72
+ """
73
+ Data loader for depth datasets
74
+
75
+ Args:
76
+ config (dict): Config dictionary. Refer to utils/config.py
77
+ mode (str): "train" or "online_eval"
78
+ device (str, optional): Device to load the data on. Defaults to 'cpu'.
79
+ transform (torchvision.transforms, optional): Transform to apply to the data. Defaults to None.
80
+ """
81
+
82
+ self.config = config
83
+
84
+ if config.dataset == 'ibims':
85
+ self.data = get_ibims_loader(config, batch_size=1, num_workers=1)
86
+ return
87
+
88
+ if config.dataset == 'sunrgbd':
89
+ self.data = get_sunrgbd_loader(
90
+ data_dir_root=config.sunrgbd_root, batch_size=1, num_workers=1)
91
+ return
92
+
93
+ if config.dataset == 'diml_indoor':
94
+ self.data = get_diml_indoor_loader(
95
+ data_dir_root=config.diml_indoor_root, batch_size=1, num_workers=1)
96
+ return
97
+
98
+ if config.dataset == 'diml_outdoor':
99
+ self.data = get_diml_outdoor_loader(
100
+ data_dir_root=config.diml_outdoor_root, batch_size=1, num_workers=1)
101
+ return
102
+
103
+ if "diode" in config.dataset:
104
+ self.data = get_diode_loader(
105
+ config[config.dataset+"_root"], batch_size=1, num_workers=1)
106
+ return
107
+
108
+ if config.dataset == 'hypersim_test':
109
+ self.data = get_hypersim_loader(
110
+ config.hypersim_test_root, batch_size=1, num_workers=1)
111
+ return
112
+
113
+ if config.dataset == 'vkitti':
114
+ self.data = get_vkitti_loader(
115
+ config.vkitti_root, batch_size=1, num_workers=1)
116
+ return
117
+
118
+ if config.dataset == 'vkitti2':
119
+ self.data = get_vkitti2_loader(
120
+ config.vkitti2_root, batch_size=1, num_workers=1)
121
+ return
122
+
123
+ if config.dataset == 'ddad':
124
+ self.data = get_ddad_loader(config.ddad_root, resize_shape=(
125
+ 352, 1216), batch_size=1, num_workers=1)
126
+ return
127
+
128
+ img_size = self.config.get("img_size", None)
129
+ img_size = img_size if self.config.get(
130
+ "do_input_resize", False) else None
131
+
132
+ if transform is None:
133
+ transform = preprocessing_transforms(mode, size=img_size)
134
+
135
+ if mode == 'train':
136
+
137
+ Dataset = DataLoadPreprocess
138
+ self.training_samples = Dataset(
139
+ config, mode, transform=transform, device=device)
140
+
141
+ if config.distributed:
142
+ self.train_sampler = torch.utils.data.distributed.DistributedSampler(
143
+ self.training_samples)
144
+ else:
145
+ self.train_sampler = None
146
+
147
+ self.data = DataLoader(self.training_samples,
148
+ batch_size=config.batch_size,
149
+ shuffle=(self.train_sampler is None),
150
+ num_workers=config.workers,
151
+ pin_memory=True,
152
+ persistent_workers=True,
153
+ # prefetch_factor=2,
154
+ sampler=self.train_sampler)
155
+
156
+ elif mode == 'online_eval':
157
+ self.testing_samples = DataLoadPreprocess(
158
+ config, mode, transform=transform)
159
+ if config.distributed: # redundant. here only for readability and to be more explicit
160
+ # Give whole test set to all processes (and report evaluation only on one) regardless
161
+ self.eval_sampler = None
162
+ else:
163
+ self.eval_sampler = None
164
+ self.data = DataLoader(self.testing_samples, 1,
165
+ shuffle=kwargs.get("shuffle_test", False),
166
+ num_workers=1,
167
+ pin_memory=False,
168
+ sampler=self.eval_sampler)
169
+
170
+ elif mode == 'test':
171
+ self.testing_samples = DataLoadPreprocess(
172
+ config, mode, transform=transform)
173
+ self.data = DataLoader(self.testing_samples,
174
+ 1, shuffle=False, num_workers=1)
175
+
176
+ else:
177
+ print(
178
+ 'mode should be one of \'train, test, online_eval\'. Got {}'.format(mode))
179
+
180
+
181
+ def repetitive_roundrobin(*iterables):
182
+ """
183
+ cycles through iterables but sample wise
184
+ first yield first sample from first iterable then first sample from second iterable and so on
185
+ then second sample from first iterable then second sample from second iterable and so on
186
+
187
+ If one iterable is shorter than the others, it is repeated until all iterables are exhausted
188
+ repetitive_roundrobin('ABC', 'D', 'EF') --> A D E B D F C D E
189
+ """
190
+ # Repetitive roundrobin
191
+ iterables_ = [iter(it) for it in iterables]
192
+ exhausted = [False] * len(iterables)
193
+ while not all(exhausted):
194
+ for i, it in enumerate(iterables_):
195
+ try:
196
+ yield next(it)
197
+ except StopIteration:
198
+ exhausted[i] = True
199
+ iterables_[i] = itertools.cycle(iterables[i])
200
+ # First elements may get repeated if one iterable is shorter than the others
201
+ yield next(iterables_[i])
202
+
203
+
204
+ class RepetitiveRoundRobinDataLoader(object):
205
+ def __init__(self, *dataloaders):
206
+ self.dataloaders = dataloaders
207
+
208
+ def __iter__(self):
209
+ return repetitive_roundrobin(*self.dataloaders)
210
+
211
+ def __len__(self):
212
+ # First samples get repeated, thats why the plus one
213
+ return len(self.dataloaders) * (max(len(dl) for dl in self.dataloaders) + 1)
214
+
215
+
216
+ class MixedNYUKITTI(object):
217
+ def __init__(self, config, mode, device='cpu', **kwargs):
218
+ config = edict(config)
219
+ config.workers = config.workers // 2
220
+ self.config = config
221
+ nyu_conf = change_dataset(edict(config), 'nyu')
222
+ kitti_conf = change_dataset(edict(config), 'kitti')
223
+
224
+ # make nyu default for testing
225
+ self.config = config = nyu_conf
226
+ img_size = self.config.get("img_size", None)
227
+ img_size = img_size if self.config.get(
228
+ "do_input_resize", False) else None
229
+ if mode == 'train':
230
+ nyu_loader = DepthDataLoader(
231
+ nyu_conf, mode, device=device, transform=preprocessing_transforms(mode, size=img_size)).data
232
+ kitti_loader = DepthDataLoader(
233
+ kitti_conf, mode, device=device, transform=preprocessing_transforms(mode, size=img_size)).data
234
+ # It has been changed to repetitive roundrobin
235
+ self.data = RepetitiveRoundRobinDataLoader(
236
+ nyu_loader, kitti_loader)
237
+ else:
238
+ self.data = DepthDataLoader(nyu_conf, mode, device=device).data
239
+
240
+
241
+ def remove_leading_slash(s):
242
+ if s[0] == '/' or s[0] == '\\':
243
+ return s[1:]
244
+ return s
245
+
246
+
247
+ class CachedReader:
248
+ def __init__(self, shared_dict=None):
249
+ if shared_dict:
250
+ self._cache = shared_dict
251
+ else:
252
+ self._cache = {}
253
+
254
+ def open(self, fpath):
255
+ im = self._cache.get(fpath, None)
256
+ if im is None:
257
+ im = self._cache[fpath] = Image.open(fpath)
258
+ return im
259
+
260
+
261
+ class ImReader:
262
+ def __init__(self):
263
+ pass
264
+
265
+ # @cache
266
+ def open(self, fpath):
267
+ return Image.open(fpath)
268
+
269
+
270
+ class DataLoadPreprocess(Dataset):
271
+ def __init__(self, config, mode, transform=None, is_for_online_eval=False, **kwargs):
272
+ self.config = config
273
+ if mode == 'online_eval':
274
+ with open(config.filenames_file_eval, 'r') as f:
275
+ self.filenames = f.readlines()
276
+ else:
277
+ with open(config.filenames_file, 'r') as f:
278
+ self.filenames = f.readlines()
279
+
280
+ self.mode = mode
281
+ self.transform = transform
282
+ self.to_tensor = ToTensor(mode)
283
+ self.is_for_online_eval = is_for_online_eval
284
+ if config.use_shared_dict:
285
+ self.reader = CachedReader(config.shared_dict)
286
+ else:
287
+ self.reader = ImReader()
288
+
289
+ def postprocess(self, sample):
290
+ return sample
291
+
292
+ def __getitem__(self, idx):
293
+ sample_path = self.filenames[idx]
294
+ focal = float(sample_path.split()[2])
295
+ sample = {}
296
+
297
+ if self.mode == 'train':
298
+ if self.config.dataset == 'kitti' and self.config.use_right and random.random() > 0.5:
299
+ image_path = os.path.join(
300
+ self.config.data_path, remove_leading_slash(sample_path.split()[3]))
301
+ depth_path = os.path.join(
302
+ self.config.gt_path, remove_leading_slash(sample_path.split()[4]))
303
+ else:
304
+ image_path = os.path.join(
305
+ self.config.data_path, remove_leading_slash(sample_path.split()[0]))
306
+ depth_path = os.path.join(
307
+ self.config.gt_path, remove_leading_slash(sample_path.split()[1]))
308
+
309
+ image = self.reader.open(image_path)
310
+ depth_gt = self.reader.open(depth_path)
311
+ w, h = image.size
312
+
313
+ if self.config.do_kb_crop:
314
+ height = image.height
315
+ width = image.width
316
+ top_margin = int(height - 352)
317
+ left_margin = int((width - 1216) / 2)
318
+ depth_gt = depth_gt.crop(
319
+ (left_margin, top_margin, left_margin + 1216, top_margin + 352))
320
+ image = image.crop(
321
+ (left_margin, top_margin, left_margin + 1216, top_margin + 352))
322
+
323
+ # Avoid blank boundaries due to pixel registration?
324
+ # Train images have white border. Test images have black border.
325
+ if self.config.dataset == 'nyu' and self.config.avoid_boundary:
326
+ # print("Avoiding Blank Boundaries!")
327
+ # We just crop and pad again with reflect padding to original size
328
+ # original_size = image.size
329
+ crop_params = get_white_border(np.array(image, dtype=np.uint8))
330
+ image = image.crop((crop_params.left, crop_params.top, crop_params.right, crop_params.bottom))
331
+ depth_gt = depth_gt.crop((crop_params.left, crop_params.top, crop_params.right, crop_params.bottom))
332
+
333
+ # Use reflect padding to fill the blank
334
+ image = np.array(image)
335
+ image = np.pad(image, ((crop_params.top, h - crop_params.bottom), (crop_params.left, w - crop_params.right), (0, 0)), mode='reflect')
336
+ image = Image.fromarray(image)
337
+
338
+ depth_gt = np.array(depth_gt)
339
+ depth_gt = np.pad(depth_gt, ((crop_params.top, h - crop_params.bottom), (crop_params.left, w - crop_params.right)), 'constant', constant_values=0)
340
+ depth_gt = Image.fromarray(depth_gt)
341
+
342
+
343
+ if self.config.do_random_rotate and (self.config.aug):
344
+ random_angle = (random.random() - 0.5) * 2 * self.config.degree
345
+ image = self.rotate_image(image, random_angle)
346
+ depth_gt = self.rotate_image(
347
+ depth_gt, random_angle, flag=Image.NEAREST)
348
+
349
+ image = np.asarray(image, dtype=np.float32) / 255.0
350
+ depth_gt = np.asarray(depth_gt, dtype=np.float32)
351
+ depth_gt = np.expand_dims(depth_gt, axis=2)
352
+
353
+ if self.config.dataset == 'nyu':
354
+ depth_gt = depth_gt / 1000.0
355
+ else:
356
+ depth_gt = depth_gt / 256.0
357
+
358
+ if self.config.aug and (self.config.random_crop):
359
+ image, depth_gt = self.random_crop(
360
+ image, depth_gt, self.config.input_height, self.config.input_width)
361
+
362
+ if self.config.aug and self.config.random_translate:
363
+ # print("Random Translation!")
364
+ image, depth_gt = self.random_translate(image, depth_gt, self.config.max_translation)
365
+
366
+ image, depth_gt = self.train_preprocess(image, depth_gt)
367
+ mask = np.logical_and(depth_gt > self.config.min_depth,
368
+ depth_gt < self.config.max_depth).squeeze()[None, ...]
369
+ sample = {'image': image, 'depth': depth_gt, 'focal': focal,
370
+ 'mask': mask, **sample}
371
+
372
+ else:
373
+ if self.mode == 'online_eval':
374
+ data_path = self.config.data_path_eval
375
+ else:
376
+ data_path = self.config.data_path
377
+
378
+ image_path = os.path.join(
379
+ data_path, remove_leading_slash(sample_path.split()[0]))
380
+ image = np.asarray(self.reader.open(image_path),
381
+ dtype=np.float32) / 255.0
382
+
383
+ if self.mode == 'online_eval':
384
+ gt_path = self.config.gt_path_eval
385
+ depth_path = os.path.join(
386
+ gt_path, remove_leading_slash(sample_path.split()[1]))
387
+ has_valid_depth = False
388
+ try:
389
+ depth_gt = self.reader.open(depth_path)
390
+ has_valid_depth = True
391
+ except IOError:
392
+ depth_gt = False
393
+ # print('Missing gt for {}'.format(image_path))
394
+
395
+ if has_valid_depth:
396
+ depth_gt = np.asarray(depth_gt, dtype=np.float32)
397
+ depth_gt = np.expand_dims(depth_gt, axis=2)
398
+ if self.config.dataset == 'nyu':
399
+ depth_gt = depth_gt / 1000.0
400
+ else:
401
+ depth_gt = depth_gt / 256.0
402
+
403
+ mask = np.logical_and(
404
+ depth_gt >= self.config.min_depth, depth_gt <= self.config.max_depth).squeeze()[None, ...]
405
+ else:
406
+ mask = False
407
+
408
+ if self.config.do_kb_crop:
409
+ height = image.shape[0]
410
+ width = image.shape[1]
411
+ top_margin = int(height - 352)
412
+ left_margin = int((width - 1216) / 2)
413
+ image = image[top_margin:top_margin + 352,
414
+ left_margin:left_margin + 1216, :]
415
+ if self.mode == 'online_eval' and has_valid_depth:
416
+ depth_gt = depth_gt[top_margin:top_margin +
417
+ 352, left_margin:left_margin + 1216, :]
418
+
419
+ if self.mode == 'online_eval':
420
+ sample = {'image': image, 'depth': depth_gt, 'focal': focal, 'has_valid_depth': has_valid_depth,
421
+ 'image_path': sample_path.split()[0], 'depth_path': sample_path.split()[1],
422
+ 'mask': mask}
423
+ else:
424
+ sample = {'image': image, 'focal': focal}
425
+
426
+ if (self.mode == 'train') or ('has_valid_depth' in sample and sample['has_valid_depth']):
427
+ mask = np.logical_and(depth_gt > self.config.min_depth,
428
+ depth_gt < self.config.max_depth).squeeze()[None, ...]
429
+ sample['mask'] = mask
430
+
431
+ if self.transform:
432
+ sample = self.transform(sample)
433
+
434
+ sample = self.postprocess(sample)
435
+ sample['dataset'] = self.config.dataset
436
+ sample = {**sample, 'image_path': sample_path.split()[0], 'depth_path': sample_path.split()[1]}
437
+
438
+ return sample
439
+
440
+ def rotate_image(self, image, angle, flag=Image.BILINEAR):
441
+ result = image.rotate(angle, resample=flag)
442
+ return result
443
+
444
+ def random_crop(self, img, depth, height, width):
445
+ assert img.shape[0] >= height
446
+ assert img.shape[1] >= width
447
+ assert img.shape[0] == depth.shape[0]
448
+ assert img.shape[1] == depth.shape[1]
449
+ x = random.randint(0, img.shape[1] - width)
450
+ y = random.randint(0, img.shape[0] - height)
451
+ img = img[y:y + height, x:x + width, :]
452
+ depth = depth[y:y + height, x:x + width, :]
453
+
454
+ return img, depth
455
+
456
+ def random_translate(self, img, depth, max_t=20):
457
+ assert img.shape[0] == depth.shape[0]
458
+ assert img.shape[1] == depth.shape[1]
459
+ p = self.config.translate_prob
460
+ do_translate = random.random()
461
+ if do_translate > p:
462
+ return img, depth
463
+ x = random.randint(-max_t, max_t)
464
+ y = random.randint(-max_t, max_t)
465
+ M = np.float32([[1, 0, x], [0, 1, y]])
466
+ # print(img.shape, depth.shape)
467
+ img = cv2.warpAffine(img, M, (img.shape[1], img.shape[0]))
468
+ depth = cv2.warpAffine(depth, M, (depth.shape[1], depth.shape[0]))
469
+ depth = depth.squeeze()[..., None] # add channel dim back. Affine warp removes it
470
+ # print("after", img.shape, depth.shape)
471
+ return img, depth
472
+
473
+ def train_preprocess(self, image, depth_gt):
474
+ if self.config.aug:
475
+ # Random flipping
476
+ do_flip = random.random()
477
+ if do_flip > 0.5:
478
+ image = (image[:, ::-1, :]).copy()
479
+ depth_gt = (depth_gt[:, ::-1, :]).copy()
480
+
481
+ # Random gamma, brightness, color augmentation
482
+ do_augment = random.random()
483
+ if do_augment > 0.5:
484
+ image = self.augment_image(image)
485
+
486
+ return image, depth_gt
487
+
488
+ def augment_image(self, image):
489
+ # gamma augmentation
490
+ gamma = random.uniform(0.9, 1.1)
491
+ image_aug = image ** gamma
492
+
493
+ # brightness augmentation
494
+ if self.config.dataset == 'nyu':
495
+ brightness = random.uniform(0.75, 1.25)
496
+ else:
497
+ brightness = random.uniform(0.9, 1.1)
498
+ image_aug = image_aug * brightness
499
+
500
+ # color augmentation
501
+ colors = np.random.uniform(0.9, 1.1, size=3)
502
+ white = np.ones((image.shape[0], image.shape[1]))
503
+ color_image = np.stack([white * colors[i] for i in range(3)], axis=2)
504
+ image_aug *= color_image
505
+ image_aug = np.clip(image_aug, 0, 1)
506
+
507
+ return image_aug
508
+
509
+ def __len__(self):
510
+ return len(self.filenames)
511
+
512
+
513
+ class ToTensor(object):
514
+ def __init__(self, mode, do_normalize=False, size=None):
515
+ self.mode = mode
516
+ self.normalize = transforms.Normalize(
517
+ mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) if do_normalize else nn.Identity()
518
+ self.size = size
519
+ if size is not None:
520
+ self.resize = transforms.Resize(size=size)
521
+ else:
522
+ self.resize = nn.Identity()
523
+
524
+ def __call__(self, sample):
525
+ image, focal = sample['image'], sample['focal']
526
+ image = self.to_tensor(image)
527
+ image = self.normalize(image)
528
+ image = self.resize(image)
529
+
530
+ if self.mode == 'test':
531
+ return {'image': image, 'focal': focal}
532
+
533
+ depth = sample['depth']
534
+ if self.mode == 'train':
535
+ depth = self.to_tensor(depth)
536
+ return {**sample, 'image': image, 'depth': depth, 'focal': focal}
537
+ else:
538
+ has_valid_depth = sample['has_valid_depth']
539
+ image = self.resize(image)
540
+ return {**sample, 'image': image, 'depth': depth, 'focal': focal, 'has_valid_depth': has_valid_depth,
541
+ 'image_path': sample['image_path'], 'depth_path': sample['depth_path']}
542
+
543
+ def to_tensor(self, pic):
544
+ if not (_is_pil_image(pic) or _is_numpy_image(pic)):
545
+ raise TypeError(
546
+ 'pic should be PIL Image or ndarray. Got {}'.format(type(pic)))
547
+
548
+ if isinstance(pic, np.ndarray):
549
+ img = torch.from_numpy(pic.transpose((2, 0, 1)))
550
+ return img
551
+
552
+ # handle PIL Image
553
+ if pic.mode == 'I':
554
+ img = torch.from_numpy(np.array(pic, np.int32, copy=False))
555
+ elif pic.mode == 'I;16':
556
+ img = torch.from_numpy(np.array(pic, np.int16, copy=False))
557
+ else:
558
+ img = torch.ByteTensor(
559
+ torch.ByteStorage.from_buffer(pic.tobytes()))
560
+ # PIL image mode: 1, L, P, I, F, RGB, YCbCr, RGBA, CMYK
561
+ if pic.mode == 'YCbCr':
562
+ nchannel = 3
563
+ elif pic.mode == 'I;16':
564
+ nchannel = 1
565
+ else:
566
+ nchannel = len(pic.mode)
567
+ img = img.view(pic.size[1], pic.size[0], nchannel)
568
+
569
+ img = img.transpose(0, 1).transpose(0, 2).contiguous()
570
+ if isinstance(img, torch.ByteTensor):
571
+ return img.float()
572
+ else:
573
+ return img
zoedepth/data/ddad.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # MIT License
2
+
3
+ # Copyright (c) 2022 Intelligent Systems Lab Org
4
+
5
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ # of this software and associated documentation files (the "Software"), to deal
7
+ # in the Software without restriction, including without limitation the rights
8
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ # copies of the Software, and to permit persons to whom the Software is
10
+ # furnished to do so, subject to the following conditions:
11
+
12
+ # The above copyright notice and this permission notice shall be included in all
13
+ # copies or substantial portions of the Software.
14
+
15
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ # SOFTWARE.
22
+
23
+ # File author: Shariq Farooq Bhat
24
+
25
+ import os
26
+
27
+ import numpy as np
28
+ import torch
29
+ from PIL import Image
30
+ from torch.utils.data import DataLoader, Dataset
31
+ from torchvision import transforms
32
+
33
+
34
+ class ToTensor(object):
35
+ def __init__(self, resize_shape):
36
+ # self.normalize = transforms.Normalize(
37
+ # mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
38
+ self.normalize = lambda x : x
39
+ self.resize = transforms.Resize(resize_shape)
40
+
41
+ def __call__(self, sample):
42
+ image, depth = sample['image'], sample['depth']
43
+ image = self.to_tensor(image)
44
+ image = self.normalize(image)
45
+ depth = self.to_tensor(depth)
46
+
47
+ image = self.resize(image)
48
+
49
+ return {'image': image, 'depth': depth, 'dataset': "ddad"}
50
+
51
+ def to_tensor(self, pic):
52
+
53
+ if isinstance(pic, np.ndarray):
54
+ img = torch.from_numpy(pic.transpose((2, 0, 1)))
55
+ return img
56
+
57
+ # # handle PIL Image
58
+ if pic.mode == 'I':
59
+ img = torch.from_numpy(np.array(pic, np.int32, copy=False))
60
+ elif pic.mode == 'I;16':
61
+ img = torch.from_numpy(np.array(pic, np.int16, copy=False))
62
+ else:
63
+ img = torch.ByteTensor(
64
+ torch.ByteStorage.from_buffer(pic.tobytes()))
65
+ # PIL image mode: 1, L, P, I, F, RGB, YCbCr, RGBA, CMYK
66
+ if pic.mode == 'YCbCr':
67
+ nchannel = 3
68
+ elif pic.mode == 'I;16':
69
+ nchannel = 1
70
+ else:
71
+ nchannel = len(pic.mode)
72
+ img = img.view(pic.size[1], pic.size[0], nchannel)
73
+
74
+ img = img.transpose(0, 1).transpose(0, 2).contiguous()
75
+
76
+ if isinstance(img, torch.ByteTensor):
77
+ return img.float()
78
+ else:
79
+ return img
80
+
81
+
82
+ class DDAD(Dataset):
83
+ def __init__(self, data_dir_root, resize_shape):
84
+ import glob
85
+
86
+ # image paths are of the form <data_dir_root>/{outleft, depthmap}/*.png
87
+ self.image_files = glob.glob(os.path.join(data_dir_root, '*.png'))
88
+ self.depth_files = [r.replace("_rgb.png", "_depth.npy")
89
+ for r in self.image_files]
90
+ self.transform = ToTensor(resize_shape)
91
+
92
+ def __getitem__(self, idx):
93
+
94
+ image_path = self.image_files[idx]
95
+ depth_path = self.depth_files[idx]
96
+
97
+ image = np.asarray(Image.open(image_path), dtype=np.float32) / 255.0
98
+ depth = np.load(depth_path) # meters
99
+
100
+ # depth[depth > 8] = -1
101
+ depth = depth[..., None]
102
+
103
+ sample = dict(image=image, depth=depth)
104
+ sample = self.transform(sample)
105
+
106
+ if idx == 0:
107
+ print(sample["image"].shape)
108
+
109
+ return sample
110
+
111
+ def __len__(self):
112
+ return len(self.image_files)
113
+
114
+
115
+ def get_ddad_loader(data_dir_root, resize_shape, batch_size=1, **kwargs):
116
+ dataset = DDAD(data_dir_root, resize_shape)
117
+ return DataLoader(dataset, batch_size, **kwargs)
zoedepth/data/diml_indoor_test.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # MIT License
2
+
3
+ # Copyright (c) 2022 Intelligent Systems Lab Org
4
+
5
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ # of this software and associated documentation files (the "Software"), to deal
7
+ # in the Software without restriction, including without limitation the rights
8
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ # copies of the Software, and to permit persons to whom the Software is
10
+ # furnished to do so, subject to the following conditions:
11
+
12
+ # The above copyright notice and this permission notice shall be included in all
13
+ # copies or substantial portions of the Software.
14
+
15
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ # SOFTWARE.
22
+
23
+ # File author: Shariq Farooq Bhat
24
+
25
+ import os
26
+
27
+ import numpy as np
28
+ import torch
29
+ from PIL import Image
30
+ from torch.utils.data import DataLoader, Dataset
31
+ from torchvision import transforms
32
+
33
+
34
+ class ToTensor(object):
35
+ def __init__(self):
36
+ # self.normalize = transforms.Normalize(
37
+ # mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
38
+ self.normalize = lambda x : x
39
+ self.resize = transforms.Resize((480, 640))
40
+
41
+ def __call__(self, sample):
42
+ image, depth = sample['image'], sample['depth']
43
+ image = self.to_tensor(image)
44
+ image = self.normalize(image)
45
+ depth = self.to_tensor(depth)
46
+
47
+ image = self.resize(image)
48
+
49
+ return {'image': image, 'depth': depth, 'dataset': "diml_indoor"}
50
+
51
+ def to_tensor(self, pic):
52
+
53
+ if isinstance(pic, np.ndarray):
54
+ img = torch.from_numpy(pic.transpose((2, 0, 1)))
55
+ return img
56
+
57
+ # # handle PIL Image
58
+ if pic.mode == 'I':
59
+ img = torch.from_numpy(np.array(pic, np.int32, copy=False))
60
+ elif pic.mode == 'I;16':
61
+ img = torch.from_numpy(np.array(pic, np.int16, copy=False))
62
+ else:
63
+ img = torch.ByteTensor(
64
+ torch.ByteStorage.from_buffer(pic.tobytes()))
65
+ # PIL image mode: 1, L, P, I, F, RGB, YCbCr, RGBA, CMYK
66
+ if pic.mode == 'YCbCr':
67
+ nchannel = 3
68
+ elif pic.mode == 'I;16':
69
+ nchannel = 1
70
+ else:
71
+ nchannel = len(pic.mode)
72
+ img = img.view(pic.size[1], pic.size[0], nchannel)
73
+
74
+ img = img.transpose(0, 1).transpose(0, 2).contiguous()
75
+ if isinstance(img, torch.ByteTensor):
76
+ return img.float()
77
+ else:
78
+ return img
79
+
80
+
81
+ class DIML_Indoor(Dataset):
82
+ def __init__(self, data_dir_root):
83
+ import glob
84
+
85
+ # image paths are of the form <data_dir_root>/{HR, LR}/<scene>/{color, depth_filled}/*.png
86
+ self.image_files = glob.glob(os.path.join(
87
+ data_dir_root, "LR", '*', 'color', '*.png'))
88
+ self.depth_files = [r.replace("color", "depth_filled").replace(
89
+ "_c.png", "_depth_filled.png") for r in self.image_files]
90
+ self.transform = ToTensor()
91
+
92
+ def __getitem__(self, idx):
93
+ image_path = self.image_files[idx]
94
+ depth_path = self.depth_files[idx]
95
+
96
+ image = np.asarray(Image.open(image_path), dtype=np.float32) / 255.0
97
+ depth = np.asarray(Image.open(depth_path),
98
+ dtype='uint16') / 1000.0 # mm to meters
99
+
100
+ # print(np.shape(image))
101
+ # print(np.shape(depth))
102
+
103
+ # depth[depth > 8] = -1
104
+ depth = depth[..., None]
105
+
106
+ sample = dict(image=image, depth=depth)
107
+
108
+ # return sample
109
+ sample = self.transform(sample)
110
+
111
+ if idx == 0:
112
+ print(sample["image"].shape)
113
+
114
+ return sample
115
+
116
+ def __len__(self):
117
+ return len(self.image_files)
118
+
119
+
120
+ def get_diml_indoor_loader(data_dir_root, batch_size=1, **kwargs):
121
+ dataset = DIML_Indoor(data_dir_root)
122
+ return DataLoader(dataset, batch_size, **kwargs)
123
+
124
+ # get_diml_indoor_loader(data_dir_root="datasets/diml/indoor/test/HR")
125
+ # get_diml_indoor_loader(data_dir_root="datasets/diml/indoor/test/LR")
zoedepth/data/diml_outdoor_test.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # MIT License
2
+
3
+ # Copyright (c) 2022 Intelligent Systems Lab Org
4
+
5
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ # of this software and associated documentation files (the "Software"), to deal
7
+ # in the Software without restriction, including without limitation the rights
8
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ # copies of the Software, and to permit persons to whom the Software is
10
+ # furnished to do so, subject to the following conditions:
11
+
12
+ # The above copyright notice and this permission notice shall be included in all
13
+ # copies or substantial portions of the Software.
14
+
15
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ # SOFTWARE.
22
+
23
+ # File author: Shariq Farooq Bhat
24
+
25
+ import os
26
+
27
+ import numpy as np
28
+ import torch
29
+ from PIL import Image
30
+ from torch.utils.data import DataLoader, Dataset
31
+ from torchvision import transforms
32
+
33
+
34
+ class ToTensor(object):
35
+ def __init__(self):
36
+ # self.normalize = transforms.Normalize(
37
+ # mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
38
+ self.normalize = lambda x : x
39
+
40
+ def __call__(self, sample):
41
+ image, depth = sample['image'], sample['depth']
42
+ image = self.to_tensor(image)
43
+ image = self.normalize(image)
44
+ depth = self.to_tensor(depth)
45
+
46
+ return {'image': image, 'depth': depth, 'dataset': "diml_outdoor"}
47
+
48
+ def to_tensor(self, pic):
49
+
50
+ if isinstance(pic, np.ndarray):
51
+ img = torch.from_numpy(pic.transpose((2, 0, 1)))
52
+ return img
53
+
54
+ # # handle PIL Image
55
+ if pic.mode == 'I':
56
+ img = torch.from_numpy(np.array(pic, np.int32, copy=False))
57
+ elif pic.mode == 'I;16':
58
+ img = torch.from_numpy(np.array(pic, np.int16, copy=False))
59
+ else:
60
+ img = torch.ByteTensor(
61
+ torch.ByteStorage.from_buffer(pic.tobytes()))
62
+ # PIL image mode: 1, L, P, I, F, RGB, YCbCr, RGBA, CMYK
63
+ if pic.mode == 'YCbCr':
64
+ nchannel = 3
65
+ elif pic.mode == 'I;16':
66
+ nchannel = 1
67
+ else:
68
+ nchannel = len(pic.mode)
69
+ img = img.view(pic.size[1], pic.size[0], nchannel)
70
+
71
+ img = img.transpose(0, 1).transpose(0, 2).contiguous()
72
+ if isinstance(img, torch.ByteTensor):
73
+ return img.float()
74
+ else:
75
+ return img
76
+
77
+
78
+ class DIML_Outdoor(Dataset):
79
+ def __init__(self, data_dir_root):
80
+ import glob
81
+
82
+ # image paths are of the form <data_dir_root>/{outleft, depthmap}/*.png
83
+ self.image_files = glob.glob(os.path.join(
84
+ data_dir_root, "*", 'outleft', '*.png'))
85
+ self.depth_files = [r.replace("outleft", "depthmap")
86
+ for r in self.image_files]
87
+ self.transform = ToTensor()
88
+
89
+ def __getitem__(self, idx):
90
+ image_path = self.image_files[idx]
91
+ depth_path = self.depth_files[idx]
92
+
93
+ image = np.asarray(Image.open(image_path), dtype=np.float32) / 255.0
94
+ depth = np.asarray(Image.open(depth_path),
95
+ dtype='uint16') / 1000.0 # mm to meters
96
+
97
+ # depth[depth > 8] = -1
98
+ depth = depth[..., None]
99
+
100
+ sample = dict(image=image, depth=depth, dataset="diml_outdoor")
101
+
102
+ # return sample
103
+ return self.transform(sample)
104
+
105
+ def __len__(self):
106
+ return len(self.image_files)
107
+
108
+
109
+ def get_diml_outdoor_loader(data_dir_root, batch_size=1, **kwargs):
110
+ dataset = DIML_Outdoor(data_dir_root)
111
+ return DataLoader(dataset, batch_size, **kwargs)
112
+
113
+ # get_diml_outdoor_loader(data_dir_root="datasets/diml/outdoor/test/HR")
114
+ # get_diml_outdoor_loader(data_dir_root="datasets/diml/outdoor/test/LR")
zoedepth/data/diode.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # MIT License
2
+
3
+ # Copyright (c) 2022 Intelligent Systems Lab Org
4
+
5
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ # of this software and associated documentation files (the "Software"), to deal
7
+ # in the Software without restriction, including without limitation the rights
8
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ # copies of the Software, and to permit persons to whom the Software is
10
+ # furnished to do so, subject to the following conditions:
11
+
12
+ # The above copyright notice and this permission notice shall be included in all
13
+ # copies or substantial portions of the Software.
14
+
15
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ # SOFTWARE.
22
+
23
+ # File author: Shariq Farooq Bhat
24
+
25
+ import os
26
+
27
+ import numpy as np
28
+ import torch
29
+ from PIL import Image
30
+ from torch.utils.data import DataLoader, Dataset
31
+ from torchvision import transforms
32
+
33
+
34
+ class ToTensor(object):
35
+ def __init__(self):
36
+ # self.normalize = transforms.Normalize(
37
+ # mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
38
+ self.normalize = lambda x : x
39
+ self.resize = transforms.Resize(480)
40
+
41
+ def __call__(self, sample):
42
+ image, depth = sample['image'], sample['depth']
43
+ image = self.to_tensor(image)
44
+ image = self.normalize(image)
45
+ depth = self.to_tensor(depth)
46
+
47
+ image = self.resize(image)
48
+
49
+ return {'image': image, 'depth': depth, 'dataset': "diode"}
50
+
51
+ def to_tensor(self, pic):
52
+
53
+ if isinstance(pic, np.ndarray):
54
+ img = torch.from_numpy(pic.transpose((2, 0, 1)))
55
+ return img
56
+
57
+ # # handle PIL Image
58
+ if pic.mode == 'I':
59
+ img = torch.from_numpy(np.array(pic, np.int32, copy=False))
60
+ elif pic.mode == 'I;16':
61
+ img = torch.from_numpy(np.array(pic, np.int16, copy=False))
62
+ else:
63
+ img = torch.ByteTensor(
64
+ torch.ByteStorage.from_buffer(pic.tobytes()))
65
+ # PIL image mode: 1, L, P, I, F, RGB, YCbCr, RGBA, CMYK
66
+ if pic.mode == 'YCbCr':
67
+ nchannel = 3
68
+ elif pic.mode == 'I;16':
69
+ nchannel = 1
70
+ else:
71
+ nchannel = len(pic.mode)
72
+ img = img.view(pic.size[1], pic.size[0], nchannel)
73
+
74
+ img = img.transpose(0, 1).transpose(0, 2).contiguous()
75
+
76
+ if isinstance(img, torch.ByteTensor):
77
+ return img.float()
78
+ else:
79
+ return img
80
+
81
+
82
+ class DIODE(Dataset):
83
+ def __init__(self, data_dir_root):
84
+ import glob
85
+
86
+ # image paths are of the form <data_dir_root>/scene_#/scan_#/*.png
87
+ self.image_files = glob.glob(
88
+ os.path.join(data_dir_root, '*', '*', '*.png'))
89
+ self.depth_files = [r.replace(".png", "_depth.npy")
90
+ for r in self.image_files]
91
+ self.depth_mask_files = [
92
+ r.replace(".png", "_depth_mask.npy") for r in self.image_files]
93
+ self.transform = ToTensor()
94
+
95
+ def __getitem__(self, idx):
96
+ image_path = self.image_files[idx]
97
+ depth_path = self.depth_files[idx]
98
+ depth_mask_path = self.depth_mask_files[idx]
99
+
100
+ image = np.asarray(Image.open(image_path), dtype=np.float32) / 255.0
101
+ depth = np.load(depth_path) # in meters
102
+ valid = np.load(depth_mask_path) # binary
103
+
104
+ # depth[depth > 8] = -1
105
+ # depth = depth[..., None]
106
+
107
+ sample = dict(image=image, depth=depth, valid=valid)
108
+
109
+ # return sample
110
+ sample = self.transform(sample)
111
+
112
+ if idx == 0:
113
+ print(sample["image"].shape)
114
+
115
+ return sample
116
+
117
+ def __len__(self):
118
+ return len(self.image_files)
119
+
120
+
121
+ def get_diode_loader(data_dir_root, batch_size=1, **kwargs):
122
+ dataset = DIODE(data_dir_root)
123
+ return DataLoader(dataset, batch_size, **kwargs)
124
+
125
+ # get_diode_loader(data_dir_root="datasets/diode/val/outdoor")
zoedepth/data/hypersim.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # MIT License
2
+
3
+ # Copyright (c) 2022 Intelligent Systems Lab Org
4
+
5
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ # of this software and associated documentation files (the "Software"), to deal
7
+ # in the Software without restriction, including without limitation the rights
8
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ # copies of the Software, and to permit persons to whom the Software is
10
+ # furnished to do so, subject to the following conditions:
11
+
12
+ # The above copyright notice and this permission notice shall be included in all
13
+ # copies or substantial portions of the Software.
14
+
15
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ # SOFTWARE.
22
+
23
+ # File author: Shariq Farooq Bhat
24
+
25
+ import glob
26
+ import os
27
+
28
+ import h5py
29
+ import numpy as np
30
+ import torch
31
+ from PIL import Image
32
+ from torch.utils.data import DataLoader, Dataset
33
+ from torchvision import transforms
34
+
35
+
36
+ def hypersim_distance_to_depth(npyDistance):
37
+ intWidth, intHeight, fltFocal = 1024, 768, 886.81
38
+
39
+ npyImageplaneX = np.linspace((-0.5 * intWidth) + 0.5, (0.5 * intWidth) - 0.5, intWidth).reshape(
40
+ 1, intWidth).repeat(intHeight, 0).astype(np.float32)[:, :, None]
41
+ npyImageplaneY = np.linspace((-0.5 * intHeight) + 0.5, (0.5 * intHeight) - 0.5,
42
+ intHeight).reshape(intHeight, 1).repeat(intWidth, 1).astype(np.float32)[:, :, None]
43
+ npyImageplaneZ = np.full([intHeight, intWidth, 1], fltFocal, np.float32)
44
+ npyImageplane = np.concatenate(
45
+ [npyImageplaneX, npyImageplaneY, npyImageplaneZ], 2)
46
+
47
+ npyDepth = npyDistance / np.linalg.norm(npyImageplane, 2, 2) * fltFocal
48
+ return npyDepth
49
+
50
+
51
+ class ToTensor(object):
52
+ def __init__(self):
53
+ # self.normalize = transforms.Normalize(
54
+ # mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
55
+ self.normalize = lambda x: x
56
+ self.resize = transforms.Resize((480, 640))
57
+
58
+ def __call__(self, sample):
59
+ image, depth = sample['image'], sample['depth']
60
+ image = self.to_tensor(image)
61
+ image = self.normalize(image)
62
+ depth = self.to_tensor(depth)
63
+
64
+ image = self.resize(image)
65
+
66
+ return {'image': image, 'depth': depth, 'dataset': "hypersim"}
67
+
68
+ def to_tensor(self, pic):
69
+
70
+ if isinstance(pic, np.ndarray):
71
+ img = torch.from_numpy(pic.transpose((2, 0, 1)))
72
+ return img
73
+
74
+ # # handle PIL Image
75
+ if pic.mode == 'I':
76
+ img = torch.from_numpy(np.array(pic, np.int32, copy=False))
77
+ elif pic.mode == 'I;16':
78
+ img = torch.from_numpy(np.array(pic, np.int16, copy=False))
79
+ else:
80
+ img = torch.ByteTensor(
81
+ torch.ByteStorage.from_buffer(pic.tobytes()))
82
+ # PIL image mode: 1, L, P, I, F, RGB, YCbCr, RGBA, CMYK
83
+ if pic.mode == 'YCbCr':
84
+ nchannel = 3
85
+ elif pic.mode == 'I;16':
86
+ nchannel = 1
87
+ else:
88
+ nchannel = len(pic.mode)
89
+ img = img.view(pic.size[1], pic.size[0], nchannel)
90
+
91
+ img = img.transpose(0, 1).transpose(0, 2).contiguous()
92
+ if isinstance(img, torch.ByteTensor):
93
+ return img.float()
94
+ else:
95
+ return img
96
+
97
+
98
+ class HyperSim(Dataset):
99
+ def __init__(self, data_dir_root):
100
+ # image paths are of the form <data_dir_root>/<scene>/images/scene_cam_#_final_preview/*.tonemap.jpg
101
+ # depth paths are of the form <data_dir_root>/<scene>/images/scene_cam_#_final_preview/*.depth_meters.hdf5
102
+ self.image_files = glob.glob(os.path.join(
103
+ data_dir_root, '*', 'images', 'scene_cam_*_final_preview', '*.tonemap.jpg'))
104
+ self.depth_files = [r.replace("_final_preview", "_geometry_hdf5").replace(
105
+ ".tonemap.jpg", ".depth_meters.hdf5") for r in self.image_files]
106
+ self.transform = ToTensor()
107
+
108
+ def __getitem__(self, idx):
109
+ image_path = self.image_files[idx]
110
+ depth_path = self.depth_files[idx]
111
+
112
+ image = np.asarray(Image.open(image_path), dtype=np.float32) / 255.0
113
+
114
+ # depth from hdf5
115
+ depth_fd = h5py.File(depth_path, "r")
116
+ # in meters (Euclidean distance)
117
+ distance_meters = np.array(depth_fd['dataset'])
118
+ depth = hypersim_distance_to_depth(
119
+ distance_meters) # in meters (planar depth)
120
+
121
+ # depth[depth > 8] = -1
122
+ depth = depth[..., None]
123
+
124
+ sample = dict(image=image, depth=depth)
125
+ sample = self.transform(sample)
126
+
127
+ if idx == 0:
128
+ print(sample["image"].shape)
129
+
130
+ return sample
131
+
132
+ def __len__(self):
133
+ return len(self.image_files)
134
+
135
+
136
+ def get_hypersim_loader(data_dir_root, batch_size=1, **kwargs):
137
+ dataset = HyperSim(data_dir_root)
138
+ return DataLoader(dataset, batch_size, **kwargs)
zoedepth/data/ibims.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # MIT License
2
+
3
+ # Copyright (c) 2022 Intelligent Systems Lab Org
4
+
5
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ # of this software and associated documentation files (the "Software"), to deal
7
+ # in the Software without restriction, including without limitation the rights
8
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ # copies of the Software, and to permit persons to whom the Software is
10
+ # furnished to do so, subject to the following conditions:
11
+
12
+ # The above copyright notice and this permission notice shall be included in all
13
+ # copies or substantial portions of the Software.
14
+
15
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ # SOFTWARE.
22
+
23
+ # File author: Shariq Farooq Bhat
24
+
25
+ import os
26
+
27
+ import numpy as np
28
+ import torch
29
+ from PIL import Image
30
+ from torch.utils.data import DataLoader, Dataset
31
+ from torchvision import transforms as T
32
+
33
+
34
+ class iBims(Dataset):
35
+ def __init__(self, config):
36
+ root_folder = config.ibims_root
37
+ with open(os.path.join(root_folder, "imagelist.txt"), 'r') as f:
38
+ imglist = f.read().split()
39
+
40
+ samples = []
41
+ for basename in imglist:
42
+ img_path = os.path.join(root_folder, 'rgb', basename + ".png")
43
+ depth_path = os.path.join(root_folder, 'depth', basename + ".png")
44
+ valid_mask_path = os.path.join(
45
+ root_folder, 'mask_invalid', basename+".png")
46
+ transp_mask_path = os.path.join(
47
+ root_folder, 'mask_transp', basename+".png")
48
+
49
+ samples.append(
50
+ (img_path, depth_path, valid_mask_path, transp_mask_path))
51
+
52
+ self.samples = samples
53
+ # self.normalize = T.Normalize(
54
+ # mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
55
+ self.normalize = lambda x : x
56
+
57
+ def __getitem__(self, idx):
58
+ img_path, depth_path, valid_mask_path, transp_mask_path = self.samples[idx]
59
+
60
+ img = np.asarray(Image.open(img_path), dtype=np.float32) / 255.0
61
+ depth = np.asarray(Image.open(depth_path),
62
+ dtype=np.uint16).astype('float')*50.0/65535
63
+
64
+ mask_valid = np.asarray(Image.open(valid_mask_path))
65
+ mask_transp = np.asarray(Image.open(transp_mask_path))
66
+
67
+ # depth = depth * mask_valid * mask_transp
68
+ depth = np.where(mask_valid * mask_transp, depth, -1)
69
+
70
+ img = torch.from_numpy(img).permute(2, 0, 1)
71
+ img = self.normalize(img)
72
+ depth = torch.from_numpy(depth).unsqueeze(0)
73
+ return dict(image=img, depth=depth, image_path=img_path, depth_path=depth_path, dataset='ibims')
74
+
75
+ def __len__(self):
76
+ return len(self.samples)
77
+
78
+
79
+ def get_ibims_loader(config, batch_size=1, **kwargs):
80
+ dataloader = DataLoader(iBims(config), batch_size=batch_size, **kwargs)
81
+ return dataloader
zoedepth/data/preprocess.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # MIT License
2
+
3
+ # Copyright (c) 2022 Intelligent Systems Lab Org
4
+
5
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ # of this software and associated documentation files (the "Software"), to deal
7
+ # in the Software without restriction, including without limitation the rights
8
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ # copies of the Software, and to permit persons to whom the Software is
10
+ # furnished to do so, subject to the following conditions:
11
+
12
+ # The above copyright notice and this permission notice shall be included in all
13
+ # copies or substantial portions of the Software.
14
+
15
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ # SOFTWARE.
22
+
23
+ # File author: Shariq Farooq Bhat
24
+
25
+ import numpy as np
26
+ from dataclasses import dataclass
27
+ from typing import Tuple, List
28
+
29
+ # dataclass to store the crop parameters
30
+ @dataclass
31
+ class CropParams:
32
+ top: int
33
+ bottom: int
34
+ left: int
35
+ right: int
36
+
37
+
38
+
39
+ def get_border_params(rgb_image, tolerance=0.1, cut_off=20, value=0, level_diff_threshold=5, channel_axis=-1, min_border=5) -> CropParams:
40
+ gray_image = np.mean(rgb_image, axis=channel_axis)
41
+ h, w = gray_image.shape
42
+
43
+
44
+ def num_value_pixels(arr):
45
+ return np.sum(np.abs(arr - value) < level_diff_threshold)
46
+
47
+ def is_above_tolerance(arr, total_pixels):
48
+ return (num_value_pixels(arr) / total_pixels) > tolerance
49
+
50
+ # Crop top border until number of value pixels become below tolerance
51
+ top = min_border
52
+ while is_above_tolerance(gray_image[top, :], w) and top < h-1:
53
+ top += 1
54
+ if top > cut_off:
55
+ break
56
+
57
+ # Crop bottom border until number of value pixels become below tolerance
58
+ bottom = h - min_border
59
+ while is_above_tolerance(gray_image[bottom, :], w) and bottom > 0:
60
+ bottom -= 1
61
+ if h - bottom > cut_off:
62
+ break
63
+
64
+ # Crop left border until number of value pixels become below tolerance
65
+ left = min_border
66
+ while is_above_tolerance(gray_image[:, left], h) and left < w-1:
67
+ left += 1
68
+ if left > cut_off:
69
+ break
70
+
71
+ # Crop right border until number of value pixels become below tolerance
72
+ right = w - min_border
73
+ while is_above_tolerance(gray_image[:, right], h) and right > 0:
74
+ right -= 1
75
+ if w - right > cut_off:
76
+ break
77
+
78
+
79
+ return CropParams(top, bottom, left, right)
80
+
81
+
82
+ def get_white_border(rgb_image, value=255, **kwargs) -> CropParams:
83
+ """Crops the white border of the RGB.
84
+
85
+ Args:
86
+ rgb: RGB image, shape (H, W, 3).
87
+ Returns:
88
+ Crop parameters.
89
+ """
90
+ if value == 255:
91
+ # assert range of values in rgb image is [0, 255]
92
+ assert np.max(rgb_image) <= 255 and np.min(rgb_image) >= 0, "RGB image values are not in range [0, 255]."
93
+ assert rgb_image.max() > 1, "RGB image values are not in range [0, 255]."
94
+ elif value == 1:
95
+ # assert range of values in rgb image is [0, 1]
96
+ assert np.max(rgb_image) <= 1 and np.min(rgb_image) >= 0, "RGB image values are not in range [0, 1]."
97
+
98
+ return get_border_params(rgb_image, value=value, **kwargs)
99
+
100
+ def get_black_border(rgb_image, **kwargs) -> CropParams:
101
+ """Crops the black border of the RGB.
102
+
103
+ Args:
104
+ rgb: RGB image, shape (H, W, 3).
105
+
106
+ Returns:
107
+ Crop parameters.
108
+ """
109
+
110
+ return get_border_params(rgb_image, value=0, **kwargs)
111
+
112
+ def crop_image(image: np.ndarray, crop_params: CropParams) -> np.ndarray:
113
+ """Crops the image according to the crop parameters.
114
+
115
+ Args:
116
+ image: RGB or depth image, shape (H, W, 3) or (H, W).
117
+ crop_params: Crop parameters.
118
+
119
+ Returns:
120
+ Cropped image.
121
+ """
122
+ return image[crop_params.top:crop_params.bottom, crop_params.left:crop_params.right]
123
+
124
+ def crop_images(*images: np.ndarray, crop_params: CropParams) -> Tuple[np.ndarray]:
125
+ """Crops the images according to the crop parameters.
126
+
127
+ Args:
128
+ images: RGB or depth images, shape (H, W, 3) or (H, W).
129
+ crop_params: Crop parameters.
130
+
131
+ Returns:
132
+ Cropped images.
133
+ """
134
+ return tuple(crop_image(image, crop_params) for image in images)
135
+
136
+ def crop_black_or_white_border(rgb_image, *other_images: np.ndarray, tolerance=0.1, cut_off=20, level_diff_threshold=5) -> Tuple[np.ndarray]:
137
+ """Crops the white and black border of the RGB and depth images.
138
+
139
+ Args:
140
+ rgb: RGB image, shape (H, W, 3). This image is used to determine the border.
141
+ other_images: The other images to crop according to the border of the RGB image.
142
+ Returns:
143
+ Cropped RGB and other images.
144
+ """
145
+ # crop black border
146
+ crop_params = get_black_border(rgb_image, tolerance=tolerance, cut_off=cut_off, level_diff_threshold=level_diff_threshold)
147
+ cropped_images = crop_images(rgb_image, *other_images, crop_params=crop_params)
148
+
149
+ # crop white border
150
+ crop_params = get_white_border(cropped_images[0], tolerance=tolerance, cut_off=cut_off, level_diff_threshold=level_diff_threshold)
151
+ cropped_images = crop_images(*cropped_images, crop_params=crop_params)
152
+
153
+ return cropped_images
154
+
zoedepth/data/sun_rgbd_loader.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # MIT License
2
+
3
+ # Copyright (c) 2022 Intelligent Systems Lab Org
4
+
5
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ # of this software and associated documentation files (the "Software"), to deal
7
+ # in the Software without restriction, including without limitation the rights
8
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ # copies of the Software, and to permit persons to whom the Software is
10
+ # furnished to do so, subject to the following conditions:
11
+
12
+ # The above copyright notice and this permission notice shall be included in all
13
+ # copies or substantial portions of the Software.
14
+
15
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ # SOFTWARE.
22
+
23
+ # File author: Shariq Farooq Bhat
24
+
25
+ import os
26
+
27
+ import numpy as np
28
+ import torch
29
+ from PIL import Image
30
+ from torch.utils.data import DataLoader, Dataset
31
+ from torchvision import transforms
32
+
33
+
34
+ class ToTensor(object):
35
+ def __init__(self):
36
+ # self.normalize = transforms.Normalize(
37
+ # mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
38
+ self.normalize = lambda x : x
39
+
40
+ def __call__(self, sample):
41
+ image, depth = sample['image'], sample['depth']
42
+ image = self.to_tensor(image)
43
+ image = self.normalize(image)
44
+ depth = self.to_tensor(depth)
45
+
46
+ return {'image': image, 'depth': depth, 'dataset': "sunrgbd"}
47
+
48
+ def to_tensor(self, pic):
49
+
50
+ if isinstance(pic, np.ndarray):
51
+ img = torch.from_numpy(pic.transpose((2, 0, 1)))
52
+ return img
53
+
54
+ # # handle PIL Image
55
+ if pic.mode == 'I':
56
+ img = torch.from_numpy(np.array(pic, np.int32, copy=False))
57
+ elif pic.mode == 'I;16':
58
+ img = torch.from_numpy(np.array(pic, np.int16, copy=False))
59
+ else:
60
+ img = torch.ByteTensor(
61
+ torch.ByteStorage.from_buffer(pic.tobytes()))
62
+ # PIL image mode: 1, L, P, I, F, RGB, YCbCr, RGBA, CMYK
63
+ if pic.mode == 'YCbCr':
64
+ nchannel = 3
65
+ elif pic.mode == 'I;16':
66
+ nchannel = 1
67
+ else:
68
+ nchannel = len(pic.mode)
69
+ img = img.view(pic.size[1], pic.size[0], nchannel)
70
+
71
+ img = img.transpose(0, 1).transpose(0, 2).contiguous()
72
+ if isinstance(img, torch.ByteTensor):
73
+ return img.float()
74
+ else:
75
+ return img
76
+
77
+
78
+ class SunRGBD(Dataset):
79
+ def __init__(self, data_dir_root):
80
+ # test_file_dirs = loadmat(train_test_file)['alltest'].squeeze()
81
+ # all_test = [t[0].replace("/n/fs/sun3d/data/", "") for t in test_file_dirs]
82
+ # self.all_test = [os.path.join(data_dir_root, t) for t in all_test]
83
+ import glob
84
+ self.image_files = glob.glob(
85
+ os.path.join(data_dir_root, 'rgb', 'rgb', '*'))
86
+ self.depth_files = [
87
+ r.replace("rgb/rgb", "gt/gt").replace("jpg", "png") for r in self.image_files]
88
+ self.transform = ToTensor()
89
+
90
+ def __getitem__(self, idx):
91
+ image_path = self.image_files[idx]
92
+ depth_path = self.depth_files[idx]
93
+
94
+ image = np.asarray(Image.open(image_path), dtype=np.float32) / 255.0
95
+ depth = np.asarray(Image.open(depth_path), dtype='uint16') / 1000.0
96
+ depth[depth > 8] = -1
97
+ depth = depth[..., None]
98
+ return self.transform(dict(image=image, depth=depth))
99
+
100
+ def __len__(self):
101
+ return len(self.image_files)
102
+
103
+
104
+ def get_sunrgbd_loader(data_dir_root, batch_size=1, **kwargs):
105
+ dataset = SunRGBD(data_dir_root)
106
+ return DataLoader(dataset, batch_size, **kwargs)
zoedepth/data/transforms.py ADDED
@@ -0,0 +1,481 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # MIT License
2
+
3
+ # Copyright (c) 2022 Intelligent Systems Lab Org
4
+
5
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ # of this software and associated documentation files (the "Software"), to deal
7
+ # in the Software without restriction, including without limitation the rights
8
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ # copies of the Software, and to permit persons to whom the Software is
10
+ # furnished to do so, subject to the following conditions:
11
+
12
+ # The above copyright notice and this permission notice shall be included in all
13
+ # copies or substantial portions of the Software.
14
+
15
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ # SOFTWARE.
22
+
23
+ # File author: Shariq Farooq Bhat
24
+
25
+ import math
26
+ import random
27
+
28
+ import cv2
29
+ import numpy as np
30
+
31
+
32
+ class RandomFliplr(object):
33
+ """Horizontal flip of the sample with given probability.
34
+ """
35
+
36
+ def __init__(self, probability=0.5):
37
+ """Init.
38
+
39
+ Args:
40
+ probability (float, optional): Flip probability. Defaults to 0.5.
41
+ """
42
+ self.__probability = probability
43
+
44
+ def __call__(self, sample):
45
+ prob = random.random()
46
+
47
+ if prob < self.__probability:
48
+ for k, v in sample.items():
49
+ if len(v.shape) >= 2:
50
+ sample[k] = np.fliplr(v).copy()
51
+
52
+ return sample
53
+
54
+
55
+ def apply_min_size(sample, size, image_interpolation_method=cv2.INTER_AREA):
56
+ """Rezise the sample to ensure the given size. Keeps aspect ratio.
57
+
58
+ Args:
59
+ sample (dict): sample
60
+ size (tuple): image size
61
+
62
+ Returns:
63
+ tuple: new size
64
+ """
65
+ shape = list(sample["disparity"].shape)
66
+
67
+ if shape[0] >= size[0] and shape[1] >= size[1]:
68
+ return sample
69
+
70
+ scale = [0, 0]
71
+ scale[0] = size[0] / shape[0]
72
+ scale[1] = size[1] / shape[1]
73
+
74
+ scale = max(scale)
75
+
76
+ shape[0] = math.ceil(scale * shape[0])
77
+ shape[1] = math.ceil(scale * shape[1])
78
+
79
+ # resize
80
+ sample["image"] = cv2.resize(
81
+ sample["image"], tuple(shape[::-1]), interpolation=image_interpolation_method
82
+ )
83
+
84
+ sample["disparity"] = cv2.resize(
85
+ sample["disparity"], tuple(shape[::-1]), interpolation=cv2.INTER_NEAREST
86
+ )
87
+ sample["mask"] = cv2.resize(
88
+ sample["mask"].astype(np.float32),
89
+ tuple(shape[::-1]),
90
+ interpolation=cv2.INTER_NEAREST,
91
+ )
92
+ sample["mask"] = sample["mask"].astype(bool)
93
+
94
+ return tuple(shape)
95
+
96
+
97
+ class RandomCrop(object):
98
+ """Get a random crop of the sample with the given size (width, height).
99
+ """
100
+
101
+ def __init__(
102
+ self,
103
+ width,
104
+ height,
105
+ resize_if_needed=False,
106
+ image_interpolation_method=cv2.INTER_AREA,
107
+ ):
108
+ """Init.
109
+
110
+ Args:
111
+ width (int): output width
112
+ height (int): output height
113
+ resize_if_needed (bool, optional): If True, sample might be upsampled to ensure
114
+ that a crop of size (width, height) is possbile. Defaults to False.
115
+ """
116
+ self.__size = (height, width)
117
+ self.__resize_if_needed = resize_if_needed
118
+ self.__image_interpolation_method = image_interpolation_method
119
+
120
+ def __call__(self, sample):
121
+
122
+ shape = sample["disparity"].shape
123
+
124
+ if self.__size[0] > shape[0] or self.__size[1] > shape[1]:
125
+ if self.__resize_if_needed:
126
+ shape = apply_min_size(
127
+ sample, self.__size, self.__image_interpolation_method
128
+ )
129
+ else:
130
+ raise Exception(
131
+ "Output size {} bigger than input size {}.".format(
132
+ self.__size, shape
133
+ )
134
+ )
135
+
136
+ offset = (
137
+ np.random.randint(shape[0] - self.__size[0] + 1),
138
+ np.random.randint(shape[1] - self.__size[1] + 1),
139
+ )
140
+
141
+ for k, v in sample.items():
142
+ if k == "code" or k == "basis":
143
+ continue
144
+
145
+ if len(sample[k].shape) >= 2:
146
+ sample[k] = v[
147
+ offset[0]: offset[0] + self.__size[0],
148
+ offset[1]: offset[1] + self.__size[1],
149
+ ]
150
+
151
+ return sample
152
+
153
+
154
+ class Resize(object):
155
+ """Resize sample to given size (width, height).
156
+ """
157
+
158
+ def __init__(
159
+ self,
160
+ width,
161
+ height,
162
+ resize_target=True,
163
+ keep_aspect_ratio=False,
164
+ ensure_multiple_of=1,
165
+ resize_method="lower_bound",
166
+ image_interpolation_method=cv2.INTER_AREA,
167
+ letter_box=False,
168
+ ):
169
+ """Init.
170
+
171
+ Args:
172
+ width (int): desired output width
173
+ height (int): desired output height
174
+ resize_target (bool, optional):
175
+ True: Resize the full sample (image, mask, target).
176
+ False: Resize image only.
177
+ Defaults to True.
178
+ keep_aspect_ratio (bool, optional):
179
+ True: Keep the aspect ratio of the input sample.
180
+ Output sample might not have the given width and height, and
181
+ resize behaviour depends on the parameter 'resize_method'.
182
+ Defaults to False.
183
+ ensure_multiple_of (int, optional):
184
+ Output width and height is constrained to be multiple of this parameter.
185
+ Defaults to 1.
186
+ resize_method (str, optional):
187
+ "lower_bound": Output will be at least as large as the given size.
188
+ "upper_bound": Output will be at max as large as the given size. (Output size might be smaller than given size.)
189
+ "minimal": Scale as least as possible. (Output size might be smaller than given size.)
190
+ Defaults to "lower_bound".
191
+ """
192
+ self.__width = width
193
+ self.__height = height
194
+
195
+ self.__resize_target = resize_target
196
+ self.__keep_aspect_ratio = keep_aspect_ratio
197
+ self.__multiple_of = ensure_multiple_of
198
+ self.__resize_method = resize_method
199
+ self.__image_interpolation_method = image_interpolation_method
200
+ self.__letter_box = letter_box
201
+
202
+ def constrain_to_multiple_of(self, x, min_val=0, max_val=None):
203
+ y = (np.round(x / self.__multiple_of) * self.__multiple_of).astype(int)
204
+
205
+ if max_val is not None and y > max_val:
206
+ y = (np.floor(x / self.__multiple_of)
207
+ * self.__multiple_of).astype(int)
208
+
209
+ if y < min_val:
210
+ y = (np.ceil(x / self.__multiple_of)
211
+ * self.__multiple_of).astype(int)
212
+
213
+ return y
214
+
215
+ def get_size(self, width, height):
216
+ # determine new height and width
217
+ scale_height = self.__height / height
218
+ scale_width = self.__width / width
219
+
220
+ if self.__keep_aspect_ratio:
221
+ if self.__resize_method == "lower_bound":
222
+ # scale such that output size is lower bound
223
+ if scale_width > scale_height:
224
+ # fit width
225
+ scale_height = scale_width
226
+ else:
227
+ # fit height
228
+ scale_width = scale_height
229
+ elif self.__resize_method == "upper_bound":
230
+ # scale such that output size is upper bound
231
+ if scale_width < scale_height:
232
+ # fit width
233
+ scale_height = scale_width
234
+ else:
235
+ # fit height
236
+ scale_width = scale_height
237
+ elif self.__resize_method == "minimal":
238
+ # scale as least as possbile
239
+ if abs(1 - scale_width) < abs(1 - scale_height):
240
+ # fit width
241
+ scale_height = scale_width
242
+ else:
243
+ # fit height
244
+ scale_width = scale_height
245
+ else:
246
+ raise ValueError(
247
+ f"resize_method {self.__resize_method} not implemented"
248
+ )
249
+
250
+ if self.__resize_method == "lower_bound":
251
+ new_height = self.constrain_to_multiple_of(
252
+ scale_height * height, min_val=self.__height
253
+ )
254
+ new_width = self.constrain_to_multiple_of(
255
+ scale_width * width, min_val=self.__width
256
+ )
257
+ elif self.__resize_method == "upper_bound":
258
+ new_height = self.constrain_to_multiple_of(
259
+ scale_height * height, max_val=self.__height
260
+ )
261
+ new_width = self.constrain_to_multiple_of(
262
+ scale_width * width, max_val=self.__width
263
+ )
264
+ elif self.__resize_method == "minimal":
265
+ new_height = self.constrain_to_multiple_of(scale_height * height)
266
+ new_width = self.constrain_to_multiple_of(scale_width * width)
267
+ else:
268
+ raise ValueError(
269
+ f"resize_method {self.__resize_method} not implemented")
270
+
271
+ return (new_width, new_height)
272
+
273
+ def make_letter_box(self, sample):
274
+ top = bottom = (self.__height - sample.shape[0]) // 2
275
+ left = right = (self.__width - sample.shape[1]) // 2
276
+ sample = cv2.copyMakeBorder(
277
+ sample, top, bottom, left, right, cv2.BORDER_CONSTANT, None, 0)
278
+ return sample
279
+
280
+ def __call__(self, sample):
281
+ width, height = self.get_size(
282
+ sample["image"].shape[1], sample["image"].shape[0]
283
+ )
284
+
285
+ # resize sample
286
+ sample["image"] = cv2.resize(
287
+ sample["image"],
288
+ (width, height),
289
+ interpolation=self.__image_interpolation_method,
290
+ )
291
+
292
+ if self.__letter_box:
293
+ sample["image"] = self.make_letter_box(sample["image"])
294
+
295
+ if self.__resize_target:
296
+ if "disparity" in sample:
297
+ sample["disparity"] = cv2.resize(
298
+ sample["disparity"],
299
+ (width, height),
300
+ interpolation=cv2.INTER_NEAREST,
301
+ )
302
+
303
+ if self.__letter_box:
304
+ sample["disparity"] = self.make_letter_box(
305
+ sample["disparity"])
306
+
307
+ if "depth" in sample:
308
+ sample["depth"] = cv2.resize(
309
+ sample["depth"], (width,
310
+ height), interpolation=cv2.INTER_NEAREST
311
+ )
312
+
313
+ if self.__letter_box:
314
+ sample["depth"] = self.make_letter_box(sample["depth"])
315
+
316
+ sample["mask"] = cv2.resize(
317
+ sample["mask"].astype(np.float32),
318
+ (width, height),
319
+ interpolation=cv2.INTER_NEAREST,
320
+ )
321
+
322
+ if self.__letter_box:
323
+ sample["mask"] = self.make_letter_box(sample["mask"])
324
+
325
+ sample["mask"] = sample["mask"].astype(bool)
326
+
327
+ return sample
328
+
329
+
330
+ class ResizeFixed(object):
331
+ def __init__(self, size):
332
+ self.__size = size
333
+
334
+ def __call__(self, sample):
335
+ sample["image"] = cv2.resize(
336
+ sample["image"], self.__size[::-1], interpolation=cv2.INTER_LINEAR
337
+ )
338
+
339
+ sample["disparity"] = cv2.resize(
340
+ sample["disparity"], self.__size[::-
341
+ 1], interpolation=cv2.INTER_NEAREST
342
+ )
343
+
344
+ sample["mask"] = cv2.resize(
345
+ sample["mask"].astype(np.float32),
346
+ self.__size[::-1],
347
+ interpolation=cv2.INTER_NEAREST,
348
+ )
349
+ sample["mask"] = sample["mask"].astype(bool)
350
+
351
+ return sample
352
+
353
+
354
+ class Rescale(object):
355
+ """Rescale target values to the interval [0, max_val].
356
+ If input is constant, values are set to max_val / 2.
357
+ """
358
+
359
+ def __init__(self, max_val=1.0, use_mask=True):
360
+ """Init.
361
+
362
+ Args:
363
+ max_val (float, optional): Max output value. Defaults to 1.0.
364
+ use_mask (bool, optional): Only operate on valid pixels (mask == True). Defaults to True.
365
+ """
366
+ self.__max_val = max_val
367
+ self.__use_mask = use_mask
368
+
369
+ def __call__(self, sample):
370
+ disp = sample["disparity"]
371
+
372
+ if self.__use_mask:
373
+ mask = sample["mask"]
374
+ else:
375
+ mask = np.ones_like(disp, dtype=np.bool)
376
+
377
+ if np.sum(mask) == 0:
378
+ return sample
379
+
380
+ min_val = np.min(disp[mask])
381
+ max_val = np.max(disp[mask])
382
+
383
+ if max_val > min_val:
384
+ sample["disparity"][mask] = (
385
+ (disp[mask] - min_val) / (max_val - min_val) * self.__max_val
386
+ )
387
+ else:
388
+ sample["disparity"][mask] = np.ones_like(
389
+ disp[mask]) * self.__max_val / 2.0
390
+
391
+ return sample
392
+
393
+
394
+ # mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
395
+ class NormalizeImage(object):
396
+ """Normlize image by given mean and std.
397
+ """
398
+
399
+ def __init__(self, mean, std):
400
+ self.__mean = mean
401
+ self.__std = std
402
+
403
+ def __call__(self, sample):
404
+ sample["image"] = (sample["image"] - self.__mean) / self.__std
405
+
406
+ return sample
407
+
408
+
409
+ class DepthToDisparity(object):
410
+ """Convert depth to disparity. Removes depth from sample.
411
+ """
412
+
413
+ def __init__(self, eps=1e-4):
414
+ self.__eps = eps
415
+
416
+ def __call__(self, sample):
417
+ assert "depth" in sample
418
+
419
+ sample["mask"][sample["depth"] < self.__eps] = False
420
+
421
+ sample["disparity"] = np.zeros_like(sample["depth"])
422
+ sample["disparity"][sample["depth"] >= self.__eps] = (
423
+ 1.0 / sample["depth"][sample["depth"] >= self.__eps]
424
+ )
425
+
426
+ del sample["depth"]
427
+
428
+ return sample
429
+
430
+
431
+ class DisparityToDepth(object):
432
+ """Convert disparity to depth. Removes disparity from sample.
433
+ """
434
+
435
+ def __init__(self, eps=1e-4):
436
+ self.__eps = eps
437
+
438
+ def __call__(self, sample):
439
+ assert "disparity" in sample
440
+
441
+ disp = np.abs(sample["disparity"])
442
+ sample["mask"][disp < self.__eps] = False
443
+
444
+ # print(sample["disparity"])
445
+ # print(sample["mask"].sum())
446
+ # exit()
447
+
448
+ sample["depth"] = np.zeros_like(disp)
449
+ sample["depth"][disp >= self.__eps] = (
450
+ 1.0 / disp[disp >= self.__eps]
451
+ )
452
+
453
+ del sample["disparity"]
454
+
455
+ return sample
456
+
457
+
458
+ class PrepareForNet(object):
459
+ """Prepare sample for usage as network input.
460
+ """
461
+
462
+ def __init__(self):
463
+ pass
464
+
465
+ def __call__(self, sample):
466
+ image = np.transpose(sample["image"], (2, 0, 1))
467
+ sample["image"] = np.ascontiguousarray(image).astype(np.float32)
468
+
469
+ if "mask" in sample:
470
+ sample["mask"] = sample["mask"].astype(np.float32)
471
+ sample["mask"] = np.ascontiguousarray(sample["mask"])
472
+
473
+ if "disparity" in sample:
474
+ disparity = sample["disparity"].astype(np.float32)
475
+ sample["disparity"] = np.ascontiguousarray(disparity)
476
+
477
+ if "depth" in sample:
478
+ depth = sample["depth"].astype(np.float32)
479
+ sample["depth"] = np.ascontiguousarray(depth)
480
+
481
+ return sample
zoedepth/data/vkitti.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # MIT License
2
+
3
+ # Copyright (c) 2022 Intelligent Systems Lab Org
4
+
5
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ # of this software and associated documentation files (the "Software"), to deal
7
+ # in the Software without restriction, including without limitation the rights
8
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ # copies of the Software, and to permit persons to whom the Software is
10
+ # furnished to do so, subject to the following conditions:
11
+
12
+ # The above copyright notice and this permission notice shall be included in all
13
+ # copies or substantial portions of the Software.
14
+
15
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ # SOFTWARE.
22
+
23
+ # File author: Shariq Farooq Bhat
24
+
25
+ import torch
26
+ from torch.utils.data import Dataset, DataLoader
27
+ from torchvision import transforms
28
+ import os
29
+
30
+ from PIL import Image
31
+ import numpy as np
32
+ import cv2
33
+
34
+
35
+ class ToTensor(object):
36
+ def __init__(self):
37
+ self.normalize = transforms.Normalize(
38
+ mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
39
+ # self.resize = transforms.Resize((375, 1242))
40
+
41
+ def __call__(self, sample):
42
+ image, depth = sample['image'], sample['depth']
43
+
44
+ image = self.to_tensor(image)
45
+ image = self.normalize(image)
46
+ depth = self.to_tensor(depth)
47
+
48
+ # image = self.resize(image)
49
+
50
+ return {'image': image, 'depth': depth, 'dataset': "vkitti"}
51
+
52
+ def to_tensor(self, pic):
53
+
54
+ if isinstance(pic, np.ndarray):
55
+ img = torch.from_numpy(pic.transpose((2, 0, 1)))
56
+ return img
57
+
58
+ # # handle PIL Image
59
+ if pic.mode == 'I':
60
+ img = torch.from_numpy(np.array(pic, np.int32, copy=False))
61
+ elif pic.mode == 'I;16':
62
+ img = torch.from_numpy(np.array(pic, np.int16, copy=False))
63
+ else:
64
+ img = torch.ByteTensor(
65
+ torch.ByteStorage.from_buffer(pic.tobytes()))
66
+ # PIL image mode: 1, L, P, I, F, RGB, YCbCr, RGBA, CMYK
67
+ if pic.mode == 'YCbCr':
68
+ nchannel = 3
69
+ elif pic.mode == 'I;16':
70
+ nchannel = 1
71
+ else:
72
+ nchannel = len(pic.mode)
73
+ img = img.view(pic.size[1], pic.size[0], nchannel)
74
+
75
+ img = img.transpose(0, 1).transpose(0, 2).contiguous()
76
+ if isinstance(img, torch.ByteTensor):
77
+ return img.float()
78
+ else:
79
+ return img
80
+
81
+
82
+ class VKITTI(Dataset):
83
+ def __init__(self, data_dir_root, do_kb_crop=True):
84
+ import glob
85
+ # image paths are of the form <data_dir_root>/{HR, LR}/<scene>/{color, depth_filled}/*.png
86
+ self.image_files = glob.glob(os.path.join(
87
+ data_dir_root, "test_color", '*.png'))
88
+ self.depth_files = [r.replace("test_color", "test_depth")
89
+ for r in self.image_files]
90
+ self.do_kb_crop = True
91
+ self.transform = ToTensor()
92
+
93
+ def __getitem__(self, idx):
94
+ image_path = self.image_files[idx]
95
+ depth_path = self.depth_files[idx]
96
+
97
+ image = Image.open(image_path)
98
+ depth = Image.open(depth_path)
99
+ depth = cv2.imread(depth_path, cv2.IMREAD_ANYCOLOR |
100
+ cv2.IMREAD_ANYDEPTH)
101
+ print("dpeth min max", depth.min(), depth.max())
102
+
103
+ # print(np.shape(image))
104
+ # print(np.shape(depth))
105
+
106
+ # depth[depth > 8] = -1
107
+
108
+ if self.do_kb_crop and False:
109
+ height = image.height
110
+ width = image.width
111
+ top_margin = int(height - 352)
112
+ left_margin = int((width - 1216) / 2)
113
+ depth = depth.crop(
114
+ (left_margin, top_margin, left_margin + 1216, top_margin + 352))
115
+ image = image.crop(
116
+ (left_margin, top_margin, left_margin + 1216, top_margin + 352))
117
+ # uv = uv[:, top_margin:top_margin + 352, left_margin:left_margin + 1216]
118
+
119
+ image = np.asarray(image, dtype=np.float32) / 255.0
120
+ # depth = np.asarray(depth, dtype=np.uint16) /1.
121
+ depth = depth[..., None]
122
+ sample = dict(image=image, depth=depth)
123
+
124
+ # return sample
125
+ sample = self.transform(sample)
126
+
127
+ if idx == 0:
128
+ print(sample["image"].shape)
129
+
130
+ return sample
131
+
132
+ def __len__(self):
133
+ return len(self.image_files)
134
+
135
+
136
+ def get_vkitti_loader(data_dir_root, batch_size=1, **kwargs):
137
+ dataset = VKITTI(data_dir_root)
138
+ return DataLoader(dataset, batch_size, **kwargs)
139
+
140
+
141
+ if __name__ == "__main__":
142
+ loader = get_vkitti_loader(
143
+ data_dir_root="/home/bhatsf/shortcuts/datasets/vkitti_test")
144
+ print("Total files", len(loader.dataset))
145
+ for i, sample in enumerate(loader):
146
+ print(sample["image"].shape)
147
+ print(sample["depth"].shape)
148
+ print(sample["dataset"])
149
+ print(sample['depth'].min(), sample['depth'].max())
150
+ if i > 5:
151
+ break
zoedepth/data/vkitti2.py ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # MIT License
2
+
3
+ # Copyright (c) 2022 Intelligent Systems Lab Org
4
+
5
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ # of this software and associated documentation files (the "Software"), to deal
7
+ # in the Software without restriction, including without limitation the rights
8
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ # copies of the Software, and to permit persons to whom the Software is
10
+ # furnished to do so, subject to the following conditions:
11
+
12
+ # The above copyright notice and this permission notice shall be included in all
13
+ # copies or substantial portions of the Software.
14
+
15
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ # SOFTWARE.
22
+
23
+ # File author: Shariq Farooq Bhat
24
+
25
+ import os
26
+
27
+ import cv2
28
+ import numpy as np
29
+ import torch
30
+ from PIL import Image
31
+ from torch.utils.data import DataLoader, Dataset
32
+ from torchvision import transforms
33
+
34
+
35
+ class ToTensor(object):
36
+ def __init__(self):
37
+ # self.normalize = transforms.Normalize(
38
+ # mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
39
+ self.normalize = lambda x: x
40
+ # self.resize = transforms.Resize((375, 1242))
41
+
42
+ def __call__(self, sample):
43
+ image, depth = sample['image'], sample['depth']
44
+
45
+ image = self.to_tensor(image)
46
+ image = self.normalize(image)
47
+ depth = self.to_tensor(depth)
48
+
49
+ # image = self.resize(image)
50
+
51
+ return {'image': image, 'depth': depth, 'dataset': "vkitti"}
52
+
53
+ def to_tensor(self, pic):
54
+
55
+ if isinstance(pic, np.ndarray):
56
+ img = torch.from_numpy(pic.transpose((2, 0, 1)))
57
+ return img
58
+
59
+ # # handle PIL Image
60
+ if pic.mode == 'I':
61
+ img = torch.from_numpy(np.array(pic, np.int32, copy=False))
62
+ elif pic.mode == 'I;16':
63
+ img = torch.from_numpy(np.array(pic, np.int16, copy=False))
64
+ else:
65
+ img = torch.ByteTensor(
66
+ torch.ByteStorage.from_buffer(pic.tobytes()))
67
+ # PIL image mode: 1, L, P, I, F, RGB, YCbCr, RGBA, CMYK
68
+ if pic.mode == 'YCbCr':
69
+ nchannel = 3
70
+ elif pic.mode == 'I;16':
71
+ nchannel = 1
72
+ else:
73
+ nchannel = len(pic.mode)
74
+ img = img.view(pic.size[1], pic.size[0], nchannel)
75
+
76
+ img = img.transpose(0, 1).transpose(0, 2).contiguous()
77
+ if isinstance(img, torch.ByteTensor):
78
+ return img.float()
79
+ else:
80
+ return img
81
+
82
+
83
+ class VKITTI2(Dataset):
84
+ def __init__(self, data_dir_root, do_kb_crop=True, split="test"):
85
+ import glob
86
+
87
+ # image paths are of the form <data_dir_root>/rgb/<scene>/<variant>/frames/<rgb,depth>/Camera<0,1>/rgb_{}.jpg
88
+ self.image_files = glob.glob(os.path.join(
89
+ data_dir_root, "rgb", "**", "frames", "rgb", "Camera_0", '*.jpg'), recursive=True)
90
+ self.depth_files = [r.replace("/rgb/", "/depth/").replace(
91
+ "rgb_", "depth_").replace(".jpg", ".png") for r in self.image_files]
92
+ self.do_kb_crop = True
93
+ self.transform = ToTensor()
94
+
95
+ # If train test split is not created, then create one.
96
+ # Split is such that 8% of the frames from each scene are used for testing.
97
+ if not os.path.exists(os.path.join(data_dir_root, "train.txt")):
98
+ import random
99
+ scenes = set([os.path.basename(os.path.dirname(
100
+ os.path.dirname(os.path.dirname(f)))) for f in self.image_files])
101
+ train_files = []
102
+ test_files = []
103
+ for scene in scenes:
104
+ scene_files = [f for f in self.image_files if os.path.basename(
105
+ os.path.dirname(os.path.dirname(os.path.dirname(f)))) == scene]
106
+ random.shuffle(scene_files)
107
+ train_files.extend(scene_files[:int(len(scene_files) * 0.92)])
108
+ test_files.extend(scene_files[int(len(scene_files) * 0.92):])
109
+ with open(os.path.join(data_dir_root, "train.txt"), "w") as f:
110
+ f.write("\n".join(train_files))
111
+ with open(os.path.join(data_dir_root, "test.txt"), "w") as f:
112
+ f.write("\n".join(test_files))
113
+
114
+ if split == "train":
115
+ with open(os.path.join(data_dir_root, "train.txt"), "r") as f:
116
+ self.image_files = f.read().splitlines()
117
+ self.depth_files = [r.replace("/rgb/", "/depth/").replace(
118
+ "rgb_", "depth_").replace(".jpg", ".png") for r in self.image_files]
119
+ elif split == "test":
120
+ with open(os.path.join(data_dir_root, "test.txt"), "r") as f:
121
+ self.image_files = f.read().splitlines()
122
+ self.depth_files = [r.replace("/rgb/", "/depth/").replace(
123
+ "rgb_", "depth_").replace(".jpg", ".png") for r in self.image_files]
124
+
125
+ def __getitem__(self, idx):
126
+ image_path = self.image_files[idx]
127
+ depth_path = self.depth_files[idx]
128
+
129
+ image = Image.open(image_path)
130
+ # depth = Image.open(depth_path)
131
+ depth = cv2.imread(depth_path, cv2.IMREAD_ANYCOLOR |
132
+ cv2.IMREAD_ANYDEPTH) / 100.0 # cm to m
133
+ depth = Image.fromarray(depth)
134
+ # print("dpeth min max", depth.min(), depth.max())
135
+
136
+ # print(np.shape(image))
137
+ # print(np.shape(depth))
138
+
139
+ if self.do_kb_crop:
140
+ if idx == 0:
141
+ print("Using KB input crop")
142
+ height = image.height
143
+ width = image.width
144
+ top_margin = int(height - 352)
145
+ left_margin = int((width - 1216) / 2)
146
+ depth = depth.crop(
147
+ (left_margin, top_margin, left_margin + 1216, top_margin + 352))
148
+ image = image.crop(
149
+ (left_margin, top_margin, left_margin + 1216, top_margin + 352))
150
+ # uv = uv[:, top_margin:top_margin + 352, left_margin:left_margin + 1216]
151
+
152
+ image = np.asarray(image, dtype=np.float32) / 255.0
153
+ # depth = np.asarray(depth, dtype=np.uint16) /1.
154
+ depth = np.asarray(depth, dtype=np.float32) / 1.
155
+ depth[depth > 80] = -1
156
+
157
+ depth = depth[..., None]
158
+ sample = dict(image=image, depth=depth)
159
+
160
+ # return sample
161
+ sample = self.transform(sample)
162
+
163
+ if idx == 0:
164
+ print(sample["image"].shape)
165
+
166
+ return sample
167
+
168
+ def __len__(self):
169
+ return len(self.image_files)
170
+
171
+
172
+ def get_vkitti2_loader(data_dir_root, batch_size=1, **kwargs):
173
+ dataset = VKITTI2(data_dir_root)
174
+ return DataLoader(dataset, batch_size, **kwargs)
175
+
176
+
177
+ if __name__ == "__main__":
178
+ loader = get_vkitti2_loader(
179
+ data_dir_root="/home/bhatsf/shortcuts/datasets/vkitti2")
180
+ print("Total files", len(loader.dataset))
181
+ for i, sample in enumerate(loader):
182
+ print(sample["image"].shape)
183
+ print(sample["depth"].shape)
184
+ print(sample["dataset"])
185
+ print(sample['depth'].min(), sample['depth'].max())
186
+ if i > 5:
187
+ break
zoedepth/models/__init__.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # MIT License
2
+
3
+ # Copyright (c) 2022 Intelligent Systems Lab Org
4
+
5
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ # of this software and associated documentation files (the "Software"), to deal
7
+ # in the Software without restriction, including without limitation the rights
8
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ # copies of the Software, and to permit persons to whom the Software is
10
+ # furnished to do so, subject to the following conditions:
11
+
12
+ # The above copyright notice and this permission notice shall be included in all
13
+ # copies or substantial portions of the Software.
14
+
15
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ # SOFTWARE.
22
+
23
+ # File author: Shariq Farooq Bhat
24
+
zoedepth/models/base_models/__init__.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # MIT License
2
+
3
+ # Copyright (c) 2022 Intelligent Systems Lab Org
4
+
5
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ # of this software and associated documentation files (the "Software"), to deal
7
+ # in the Software without restriction, including without limitation the rights
8
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ # copies of the Software, and to permit persons to whom the Software is
10
+ # furnished to do so, subject to the following conditions:
11
+
12
+ # The above copyright notice and this permission notice shall be included in all
13
+ # copies or substantial portions of the Software.
14
+
15
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ # SOFTWARE.
22
+
23
+ # File author: Shariq Farooq Bhat
24
+
zoedepth/models/base_models/midas.py ADDED
@@ -0,0 +1,377 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # MIT License
2
+
3
+ # Copyright (c) 2022 Intelligent Systems Lab Org
4
+
5
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ # of this software and associated documentation files (the "Software"), to deal
7
+ # in the Software without restriction, including without limitation the rights
8
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ # copies of the Software, and to permit persons to whom the Software is
10
+ # furnished to do so, subject to the following conditions:
11
+
12
+ # The above copyright notice and this permission notice shall be included in all
13
+ # copies or substantial portions of the Software.
14
+
15
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ # SOFTWARE.
22
+
23
+ # File author: Shariq Farooq Bhat
24
+
25
+ import torch
26
+ import torch.nn as nn
27
+ import numpy as np
28
+ from torchvision.transforms import Normalize
29
+
30
+
31
+ def denormalize(x):
32
+ """Reverses the imagenet normalization applied to the input.
33
+
34
+ Args:
35
+ x (torch.Tensor - shape(N,3,H,W)): input tensor
36
+
37
+ Returns:
38
+ torch.Tensor - shape(N,3,H,W): Denormalized input
39
+ """
40
+ mean = torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(x.device)
41
+ std = torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(x.device)
42
+ return x * std + mean
43
+
44
+ def get_activation(name, bank):
45
+ def hook(model, input, output):
46
+ bank[name] = output
47
+ return hook
48
+
49
+
50
+ class Resize(object):
51
+ """Resize sample to given size (width, height).
52
+ """
53
+
54
+ def __init__(
55
+ self,
56
+ width,
57
+ height,
58
+ resize_target=True,
59
+ keep_aspect_ratio=False,
60
+ ensure_multiple_of=1,
61
+ resize_method="lower_bound",
62
+ ):
63
+ """Init.
64
+ Args:
65
+ width (int): desired output width
66
+ height (int): desired output height
67
+ resize_target (bool, optional):
68
+ True: Resize the full sample (image, mask, target).
69
+ False: Resize image only.
70
+ Defaults to True.
71
+ keep_aspect_ratio (bool, optional):
72
+ True: Keep the aspect ratio of the input sample.
73
+ Output sample might not have the given width and height, and
74
+ resize behaviour depends on the parameter 'resize_method'.
75
+ Defaults to False.
76
+ ensure_multiple_of (int, optional):
77
+ Output width and height is constrained to be multiple of this parameter.
78
+ Defaults to 1.
79
+ resize_method (str, optional):
80
+ "lower_bound": Output will be at least as large as the given size.
81
+ "upper_bound": Output will be at max as large as the given size. (Output size might be smaller than given size.)
82
+ "minimal": Scale as least as possible. (Output size might be smaller than given size.)
83
+ Defaults to "lower_bound".
84
+ """
85
+ print("Params passed to Resize transform:")
86
+ print("\twidth: ", width)
87
+ print("\theight: ", height)
88
+ print("\tresize_target: ", resize_target)
89
+ print("\tkeep_aspect_ratio: ", keep_aspect_ratio)
90
+ print("\tensure_multiple_of: ", ensure_multiple_of)
91
+ print("\tresize_method: ", resize_method)
92
+
93
+ self.__width = width
94
+ self.__height = height
95
+
96
+ self.__keep_aspect_ratio = keep_aspect_ratio
97
+ self.__multiple_of = ensure_multiple_of
98
+ self.__resize_method = resize_method
99
+
100
+ def constrain_to_multiple_of(self, x, min_val=0, max_val=None):
101
+ y = (np.round(x / self.__multiple_of) * self.__multiple_of).astype(int)
102
+
103
+ if max_val is not None and y > max_val:
104
+ y = (np.floor(x / self.__multiple_of)
105
+ * self.__multiple_of).astype(int)
106
+
107
+ if y < min_val:
108
+ y = (np.ceil(x / self.__multiple_of)
109
+ * self.__multiple_of).astype(int)
110
+
111
+ return y
112
+
113
+ def get_size(self, width, height):
114
+ # determine new height and width
115
+ scale_height = self.__height / height
116
+ scale_width = self.__width / width
117
+
118
+ if self.__keep_aspect_ratio:
119
+ if self.__resize_method == "lower_bound":
120
+ # scale such that output size is lower bound
121
+ if scale_width > scale_height:
122
+ # fit width
123
+ scale_height = scale_width
124
+ else:
125
+ # fit height
126
+ scale_width = scale_height
127
+ elif self.__resize_method == "upper_bound":
128
+ # scale such that output size is upper bound
129
+ if scale_width < scale_height:
130
+ # fit width
131
+ scale_height = scale_width
132
+ else:
133
+ # fit height
134
+ scale_width = scale_height
135
+ elif self.__resize_method == "minimal":
136
+ # scale as least as possbile
137
+ if abs(1 - scale_width) < abs(1 - scale_height):
138
+ # fit width
139
+ scale_height = scale_width
140
+ else:
141
+ # fit height
142
+ scale_width = scale_height
143
+ else:
144
+ raise ValueError(
145
+ f"resize_method {self.__resize_method} not implemented"
146
+ )
147
+
148
+ if self.__resize_method == "lower_bound":
149
+ new_height = self.constrain_to_multiple_of(
150
+ scale_height * height, min_val=self.__height
151
+ )
152
+ new_width = self.constrain_to_multiple_of(
153
+ scale_width * width, min_val=self.__width
154
+ )
155
+ elif self.__resize_method == "upper_bound":
156
+ new_height = self.constrain_to_multiple_of(
157
+ scale_height * height, max_val=self.__height
158
+ )
159
+ new_width = self.constrain_to_multiple_of(
160
+ scale_width * width, max_val=self.__width
161
+ )
162
+ elif self.__resize_method == "minimal":
163
+ new_height = self.constrain_to_multiple_of(scale_height * height)
164
+ new_width = self.constrain_to_multiple_of(scale_width * width)
165
+ else:
166
+ raise ValueError(
167
+ f"resize_method {self.__resize_method} not implemented")
168
+
169
+ return (new_width, new_height)
170
+
171
+ def __call__(self, x):
172
+ width, height = self.get_size(*x.shape[-2:][::-1])
173
+ return nn.functional.interpolate(x, (height, width), mode='bilinear', align_corners=True)
174
+
175
+ class PrepForMidas(object):
176
+ def __init__(self, resize_mode="minimal", keep_aspect_ratio=True, img_size=384, do_resize=True):
177
+ if isinstance(img_size, int):
178
+ img_size = (img_size, img_size)
179
+ net_h, net_w = img_size
180
+ self.normalization = Normalize(
181
+ mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
182
+ self.resizer = Resize(net_w, net_h, keep_aspect_ratio=keep_aspect_ratio, ensure_multiple_of=32, resize_method=resize_mode) \
183
+ if do_resize else nn.Identity()
184
+
185
+ def __call__(self, x):
186
+ return self.normalization(self.resizer(x))
187
+
188
+
189
+ class MidasCore(nn.Module):
190
+ def __init__(self, midas, trainable=False, fetch_features=True, layer_names=('out_conv', 'l4_rn', 'r4', 'r3', 'r2', 'r1'), freeze_bn=False, keep_aspect_ratio=True,
191
+ img_size=384, **kwargs):
192
+ """Midas Base model used for multi-scale feature extraction.
193
+
194
+ Args:
195
+ midas (torch.nn.Module): Midas model.
196
+ trainable (bool, optional): Train midas model. Defaults to False.
197
+ fetch_features (bool, optional): Extract multi-scale features. Defaults to True.
198
+ layer_names (tuple, optional): Layers used for feature extraction. Order = (head output features, last layer features, ...decoder features). Defaults to ('out_conv', 'l4_rn', 'r4', 'r3', 'r2', 'r1').
199
+ freeze_bn (bool, optional): Freeze BatchNorm. Generally results in better finetuning performance. Defaults to False.
200
+ keep_aspect_ratio (bool, optional): Keep the aspect ratio of input images while resizing. Defaults to True.
201
+ img_size (int, tuple, optional): Input resolution. Defaults to 384.
202
+ """
203
+ super().__init__()
204
+ self.core = midas
205
+ self.output_channels = None
206
+ self.core_out = {}
207
+ self.trainable = trainable
208
+ self.fetch_features = fetch_features
209
+ # midas.scratch.output_conv = nn.Identity()
210
+ self.handles = []
211
+ # self.layer_names = ['out_conv','l4_rn', 'r4', 'r3', 'r2', 'r1']
212
+ self.layer_names = layer_names
213
+
214
+ self.set_trainable(trainable)
215
+ self.set_fetch_features(fetch_features)
216
+
217
+ self.prep = PrepForMidas(keep_aspect_ratio=keep_aspect_ratio,
218
+ img_size=img_size, do_resize=kwargs.get('do_resize', True))
219
+
220
+ if freeze_bn:
221
+ self.freeze_bn()
222
+
223
+ def set_trainable(self, trainable):
224
+ self.trainable = trainable
225
+ if trainable:
226
+ self.unfreeze()
227
+ else:
228
+ self.freeze()
229
+ return self
230
+
231
+ def set_fetch_features(self, fetch_features):
232
+ self.fetch_features = fetch_features
233
+ if fetch_features:
234
+ if len(self.handles) == 0:
235
+ self.attach_hooks(self.core)
236
+ else:
237
+ self.remove_hooks()
238
+ return self
239
+
240
+ def freeze(self):
241
+ for p in self.parameters():
242
+ p.requires_grad = False
243
+ self.trainable = False
244
+ return self
245
+
246
+ def unfreeze(self):
247
+ for p in self.parameters():
248
+ p.requires_grad = True
249
+ self.trainable = True
250
+ return self
251
+
252
+ def freeze_bn(self):
253
+ for m in self.modules():
254
+ if isinstance(m, nn.BatchNorm2d):
255
+ m.eval()
256
+ return self
257
+
258
+ def forward(self, x, denorm=False, return_rel_depth=False):
259
+ with torch.no_grad():
260
+ if denorm:
261
+ x = denormalize(x)
262
+ x = self.prep(x)
263
+ # print("Shape after prep: ", x.shape)
264
+
265
+ with torch.set_grad_enabled(self.trainable):
266
+
267
+ # print("Input size to Midascore", x.shape)
268
+ rel_depth = self.core(x)
269
+ # print("Output from midas shape", rel_depth.shape)
270
+ if not self.fetch_features:
271
+ return rel_depth
272
+ out = [self.core_out[k] for k in self.layer_names]
273
+
274
+ if return_rel_depth:
275
+ return rel_depth, out
276
+ return out
277
+
278
+ def get_rel_pos_params(self):
279
+ for name, p in self.core.pretrained.named_parameters():
280
+ if "relative_position" in name:
281
+ yield p
282
+
283
+ def get_enc_params_except_rel_pos(self):
284
+ for name, p in self.core.pretrained.named_parameters():
285
+ if "relative_position" not in name:
286
+ yield p
287
+
288
+ def freeze_encoder(self, freeze_rel_pos=False):
289
+ if freeze_rel_pos:
290
+ for p in self.core.pretrained.parameters():
291
+ p.requires_grad = False
292
+ else:
293
+ for p in self.get_enc_params_except_rel_pos():
294
+ p.requires_grad = False
295
+ return self
296
+
297
+ def attach_hooks(self, midas):
298
+ if len(self.handles) > 0:
299
+ self.remove_hooks()
300
+ if "out_conv" in self.layer_names:
301
+ self.handles.append(list(midas.scratch.output_conv.children())[
302
+ 3].register_forward_hook(get_activation("out_conv", self.core_out)))
303
+ if "r4" in self.layer_names:
304
+ self.handles.append(midas.scratch.refinenet4.register_forward_hook(
305
+ get_activation("r4", self.core_out)))
306
+ if "r3" in self.layer_names:
307
+ self.handles.append(midas.scratch.refinenet3.register_forward_hook(
308
+ get_activation("r3", self.core_out)))
309
+ if "r2" in self.layer_names:
310
+ self.handles.append(midas.scratch.refinenet2.register_forward_hook(
311
+ get_activation("r2", self.core_out)))
312
+ if "r1" in self.layer_names:
313
+ self.handles.append(midas.scratch.refinenet1.register_forward_hook(
314
+ get_activation("r1", self.core_out)))
315
+ if "l4_rn" in self.layer_names:
316
+ self.handles.append(midas.scratch.layer4_rn.register_forward_hook(
317
+ get_activation("l4_rn", self.core_out)))
318
+
319
+ return self
320
+
321
+ def remove_hooks(self):
322
+ for h in self.handles:
323
+ h.remove()
324
+ return self
325
+
326
+ def __del__(self):
327
+ self.remove_hooks()
328
+
329
+ def set_output_channels(self, model_type):
330
+ self.output_channels = MIDAS_SETTINGS[model_type]
331
+
332
+ @staticmethod
333
+ def build(midas_model_type="DPT_BEiT_L_384", train_midas=False, use_pretrained_midas=True, fetch_features=False, freeze_bn=True, force_keep_ar=False, force_reload=False, **kwargs):
334
+ if midas_model_type not in MIDAS_SETTINGS:
335
+ raise ValueError(
336
+ f"Invalid model type: {midas_model_type}. Must be one of {list(MIDAS_SETTINGS.keys())}")
337
+ if "img_size" in kwargs:
338
+ kwargs = MidasCore.parse_img_size(kwargs)
339
+ img_size = kwargs.pop("img_size", [384, 384])
340
+ print("img_size", img_size)
341
+ midas = torch.hub.load("intel-isl/MiDaS", midas_model_type,
342
+ pretrained=use_pretrained_midas, force_reload=force_reload)
343
+ kwargs.update({'keep_aspect_ratio': force_keep_ar})
344
+ midas_core = MidasCore(midas, trainable=train_midas, fetch_features=fetch_features,
345
+ freeze_bn=freeze_bn, img_size=img_size, **kwargs)
346
+ midas_core.set_output_channels(midas_model_type)
347
+ return midas_core
348
+
349
+ @staticmethod
350
+ def build_from_config(config):
351
+ return MidasCore.build(**config)
352
+
353
+ @staticmethod
354
+ def parse_img_size(config):
355
+ assert 'img_size' in config
356
+ if isinstance(config['img_size'], str):
357
+ assert "," in config['img_size'], "img_size should be a string with comma separated img_size=H,W"
358
+ config['img_size'] = list(map(int, config['img_size'].split(",")))
359
+ assert len(
360
+ config['img_size']) == 2, "img_size should be a string with comma separated img_size=H,W"
361
+ elif isinstance(config['img_size'], int):
362
+ config['img_size'] = [config['img_size'], config['img_size']]
363
+ else:
364
+ assert isinstance(config['img_size'], list) and len(
365
+ config['img_size']) == 2, "img_size should be a list of H,W"
366
+ return config
367
+
368
+
369
+ nchannels2models = {
370
+ tuple([256]*5): ["DPT_BEiT_L_384", "DPT_BEiT_L_512", "DPT_BEiT_B_384", "DPT_SwinV2_L_384", "DPT_SwinV2_B_384", "DPT_SwinV2_T_256", "DPT_Large", "DPT_Hybrid"],
371
+ (512, 256, 128, 64, 64): ["MiDaS_small"]
372
+ }
373
+
374
+ # Model name to number of output channels
375
+ MIDAS_SETTINGS = {m: k for k, v in nchannels2models.items()
376
+ for m in v
377
+ }
zoedepth/models/builder.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # MIT License
2
+
3
+ # Copyright (c) 2022 Intelligent Systems Lab Org
4
+
5
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ # of this software and associated documentation files (the "Software"), to deal
7
+ # in the Software without restriction, including without limitation the rights
8
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ # copies of the Software, and to permit persons to whom the Software is
10
+ # furnished to do so, subject to the following conditions:
11
+
12
+ # The above copyright notice and this permission notice shall be included in all
13
+ # copies or substantial portions of the Software.
14
+
15
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ # SOFTWARE.
22
+
23
+ # File author: Shariq Farooq Bhat
24
+
25
+ from importlib import import_module
26
+ from zoedepth.models.depth_model import DepthModel
27
+
28
+ def build_model(config) -> DepthModel:
29
+ """Builds a model from a config. The model is specified by the model name and version in the config. The model is then constructed using the build_from_config function of the model interface.
30
+ This function should be used to construct models for training and evaluation.
31
+
32
+ Args:
33
+ config (dict): Config dict. Config is constructed in utils/config.py. Each model has its own config file(s) saved in its root model folder.
34
+
35
+ Returns:
36
+ torch.nn.Module: Model corresponding to name and version as specified in config
37
+ """
38
+ module_name = f"zoedepth.models.{config.model}"
39
+ try:
40
+ module = import_module(module_name)
41
+ except ModuleNotFoundError as e:
42
+ # print the original error message
43
+ print(e)
44
+ raise ValueError(
45
+ f"Model {config.model} not found. Refer above error for details.") from e
46
+ try:
47
+ get_version = getattr(module, "get_version")
48
+ except AttributeError as e:
49
+ raise ValueError(
50
+ f"Model {config.model} has no get_version function.") from e
51
+ return get_version(config.version_name).build_from_config(config)
zoedepth/models/depth_model.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # MIT License
2
+
3
+ # Copyright (c) 2022 Intelligent Systems Lab Org
4
+
5
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ # of this software and associated documentation files (the "Software"), to deal
7
+ # in the Software without restriction, including without limitation the rights
8
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ # copies of the Software, and to permit persons to whom the Software is
10
+ # furnished to do so, subject to the following conditions:
11
+
12
+ # The above copyright notice and this permission notice shall be included in all
13
+ # copies or substantial portions of the Software.
14
+
15
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ # SOFTWARE.
22
+
23
+ # File author: Shariq Farooq Bhat
24
+
25
+ import numpy as np
26
+ import torch
27
+ import torch.nn as nn
28
+ import torch.nn.functional as F
29
+ from torchvision import transforms
30
+ import PIL.Image
31
+ from PIL import Image
32
+ from typing import Union
33
+
34
+
35
+ class DepthModel(nn.Module):
36
+ def __init__(self):
37
+ super().__init__()
38
+ self.device = 'cpu'
39
+
40
+ def to(self, device) -> nn.Module:
41
+ self.device = device
42
+ return super().to(device)
43
+
44
+ def forward(self, x, *args, **kwargs):
45
+ raise NotImplementedError
46
+
47
+ def _infer(self, x: torch.Tensor):
48
+ """
49
+ Inference interface for the model
50
+ Args:
51
+ x (torch.Tensor): input tensor of shape (b, c, h, w)
52
+ Returns:
53
+ torch.Tensor: output tensor of shape (b, 1, h, w)
54
+ """
55
+ return self(x)['metric_depth']
56
+
57
+ def _infer_with_pad_aug(self, x: torch.Tensor, pad_input: bool=True, fh: float=3, fw: float=3, upsampling_mode: str='bicubic', padding_mode="reflect", **kwargs) -> torch.Tensor:
58
+ """
59
+ Inference interface for the model with padding augmentation
60
+ Padding augmentation fixes the boundary artifacts in the output depth map.
61
+ Boundary artifacts are sometimes caused by the fact that the model is trained on NYU raw dataset which has a black or white border around the image.
62
+ This augmentation pads the input image and crops the prediction back to the original size / view.
63
+
64
+ Note: This augmentation is not required for the models trained with 'avoid_boundary'=True.
65
+ Args:
66
+ x (torch.Tensor): input tensor of shape (b, c, h, w)
67
+ pad_input (bool, optional): whether to pad the input or not. Defaults to True.
68
+ fh (float, optional): height padding factor. The padding is calculated as sqrt(h/2) * fh. Defaults to 3.
69
+ fw (float, optional): width padding factor. The padding is calculated as sqrt(w/2) * fw. Defaults to 3.
70
+ upsampling_mode (str, optional): upsampling mode. Defaults to 'bicubic'.
71
+ padding_mode (str, optional): padding mode. Defaults to "reflect".
72
+ Returns:
73
+ torch.Tensor: output tensor of shape (b, 1, h, w)
74
+ """
75
+ # assert x is nchw and c = 3
76
+ assert x.dim() == 4, "x must be 4 dimensional, got {}".format(x.dim())
77
+ assert x.shape[1] == 3, "x must have 3 channels, got {}".format(x.shape[1])
78
+
79
+ if pad_input:
80
+ assert fh > 0 or fw > 0, "atlease one of fh and fw must be greater than 0"
81
+ pad_h = int(np.sqrt(x.shape[2]/2) * fh)
82
+ pad_w = int(np.sqrt(x.shape[3]/2) * fw)
83
+ padding = [pad_w, pad_w]
84
+ if pad_h > 0:
85
+ padding += [pad_h, pad_h]
86
+
87
+ x = F.pad(x, padding, mode=padding_mode, **kwargs)
88
+ out = self._infer(x)
89
+ if out.shape[-2:] != x.shape[-2:]:
90
+ out = F.interpolate(out, size=(x.shape[2], x.shape[3]), mode=upsampling_mode, align_corners=False)
91
+ if pad_input:
92
+ # crop to the original size, handling the case where pad_h and pad_w is 0
93
+ if pad_h > 0:
94
+ out = out[:, :, pad_h:-pad_h,:]
95
+ if pad_w > 0:
96
+ out = out[:, :, :, pad_w:-pad_w]
97
+ return out
98
+
99
+ def infer_with_flip_aug(self, x, pad_input: bool=True, **kwargs) -> torch.Tensor:
100
+ """
101
+ Inference interface for the model with horizontal flip augmentation
102
+ Horizontal flip augmentation improves the accuracy of the model by averaging the output of the model with and without horizontal flip.
103
+ Args:
104
+ x (torch.Tensor): input tensor of shape (b, c, h, w)
105
+ pad_input (bool, optional): whether to use padding augmentation. Defaults to True.
106
+ Returns:
107
+ torch.Tensor: output tensor of shape (b, 1, h, w)
108
+ """
109
+ # infer with horizontal flip and average
110
+ out = self._infer_with_pad_aug(x, pad_input=pad_input, **kwargs)
111
+ out_flip = self._infer_with_pad_aug(torch.flip(x, dims=[3]), pad_input=pad_input, **kwargs)
112
+ out = (out + torch.flip(out_flip, dims=[3])) / 2
113
+ return out
114
+
115
+ def infer(self, x, pad_input: bool=True, with_flip_aug: bool=True, **kwargs) -> torch.Tensor:
116
+ """
117
+ Inference interface for the model
118
+ Args:
119
+ x (torch.Tensor): input tensor of shape (b, c, h, w)
120
+ pad_input (bool, optional): whether to use padding augmentation. Defaults to True.
121
+ with_flip_aug (bool, optional): whether to use horizontal flip augmentation. Defaults to True.
122
+ Returns:
123
+ torch.Tensor: output tensor of shape (b, 1, h, w)
124
+ """
125
+ if with_flip_aug:
126
+ return self.infer_with_flip_aug(x, pad_input=pad_input, **kwargs)
127
+ else:
128
+ return self._infer_with_pad_aug(x, pad_input=pad_input, **kwargs)
129
+
130
+ @torch.no_grad()
131
+ def infer_pil(self, pil_img, pad_input: bool=True, with_flip_aug: bool=True, output_type: str="numpy", **kwargs) -> Union[np.ndarray, PIL.Image.Image, torch.Tensor]:
132
+ """
133
+ Inference interface for the model for PIL image
134
+ Args:
135
+ pil_img (PIL.Image.Image): input PIL image
136
+ pad_input (bool, optional): whether to use padding augmentation. Defaults to True.
137
+ with_flip_aug (bool, optional): whether to use horizontal flip augmentation. Defaults to True.
138
+ output_type (str, optional): output type. Supported values are 'numpy', 'pil' and 'tensor'. Defaults to "numpy".
139
+ """
140
+ x = transforms.ToTensor()(pil_img).unsqueeze(0).to(self.device)
141
+ out_tensor = self.infer(x, pad_input=pad_input, with_flip_aug=with_flip_aug, **kwargs)
142
+ if output_type == "numpy":
143
+ return out_tensor.squeeze().cpu().numpy()
144
+ elif output_type == "pil":
145
+ # uint16 is required for depth pil image
146
+ out_16bit_numpy = (out_tensor.squeeze().cpu().numpy()*256).astype(np.uint16)
147
+ return Image.fromarray(out_16bit_numpy)
148
+ elif output_type == "tensor":
149
+ return out_tensor.squeeze().cpu()
150
+ else:
151
+ raise ValueError(f"output_type {output_type} not supported. Supported values are 'numpy', 'pil' and 'tensor'")
152
+
zoedepth/models/layers/attractor.py ADDED
@@ -0,0 +1,208 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # MIT License
2
+
3
+ # Copyright (c) 2022 Intelligent Systems Lab Org
4
+
5
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ # of this software and associated documentation files (the "Software"), to deal
7
+ # in the Software without restriction, including without limitation the rights
8
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ # copies of the Software, and to permit persons to whom the Software is
10
+ # furnished to do so, subject to the following conditions:
11
+
12
+ # The above copyright notice and this permission notice shall be included in all
13
+ # copies or substantial portions of the Software.
14
+
15
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ # SOFTWARE.
22
+
23
+ # File author: Shariq Farooq Bhat
24
+
25
+ import torch
26
+ import torch.nn as nn
27
+
28
+
29
+ @torch.jit.script
30
+ def exp_attractor(dx, alpha: float = 300, gamma: int = 2):
31
+ """Exponential attractor: dc = exp(-alpha*|dx|^gamma) * dx , where dx = a - c, a = attractor point, c = bin center, dc = shift in bin centermmary for exp_attractor
32
+
33
+ Args:
34
+ dx (torch.Tensor): The difference tensor dx = Ai - Cj, where Ai is the attractor point and Cj is the bin center.
35
+ alpha (float, optional): Proportional Attractor strength. Determines the absolute strength. Lower alpha = greater attraction. Defaults to 300.
36
+ gamma (int, optional): Exponential Attractor strength. Determines the "region of influence" and indirectly number of bin centers affected. Lower gamma = farther reach. Defaults to 2.
37
+
38
+ Returns:
39
+ torch.Tensor : Delta shifts - dc; New bin centers = Old bin centers + dc
40
+ """
41
+ return torch.exp(-alpha*(torch.abs(dx)**gamma)) * (dx)
42
+
43
+
44
+ @torch.jit.script
45
+ def inv_attractor(dx, alpha: float = 300, gamma: int = 2):
46
+ """Inverse attractor: dc = dx / (1 + alpha*dx^gamma), where dx = a - c, a = attractor point, c = bin center, dc = shift in bin center
47
+ This is the default one according to the accompanying paper.
48
+
49
+ Args:
50
+ dx (torch.Tensor): The difference tensor dx = Ai - Cj, where Ai is the attractor point and Cj is the bin center.
51
+ alpha (float, optional): Proportional Attractor strength. Determines the absolute strength. Lower alpha = greater attraction. Defaults to 300.
52
+ gamma (int, optional): Exponential Attractor strength. Determines the "region of influence" and indirectly number of bin centers affected. Lower gamma = farther reach. Defaults to 2.
53
+
54
+ Returns:
55
+ torch.Tensor: Delta shifts - dc; New bin centers = Old bin centers + dc
56
+ """
57
+ return dx.div(1+alpha*dx.pow(gamma))
58
+
59
+
60
+ class AttractorLayer(nn.Module):
61
+ def __init__(self, in_features, n_bins, n_attractors=16, mlp_dim=128, min_depth=1e-3, max_depth=10,
62
+ alpha=300, gamma=2, kind='sum', attractor_type='exp', memory_efficient=False):
63
+ """
64
+ Attractor layer for bin centers. Bin centers are bounded on the interval (min_depth, max_depth)
65
+ """
66
+ super().__init__()
67
+
68
+ self.n_attractors = n_attractors
69
+ self.n_bins = n_bins
70
+ self.min_depth = min_depth
71
+ self.max_depth = max_depth
72
+ self.alpha = alpha
73
+ self.gamma = gamma
74
+ self.kind = kind
75
+ self.attractor_type = attractor_type
76
+ self.memory_efficient = memory_efficient
77
+
78
+ self._net = nn.Sequential(
79
+ nn.Conv2d(in_features, mlp_dim, 1, 1, 0),
80
+ nn.ReLU(inplace=True),
81
+ nn.Conv2d(mlp_dim, n_attractors*2, 1, 1, 0), # x2 for linear norm
82
+ nn.ReLU(inplace=True)
83
+ )
84
+
85
+ def forward(self, x, b_prev, prev_b_embedding=None, interpolate=True, is_for_query=False):
86
+ """
87
+ Args:
88
+ x (torch.Tensor) : feature block; shape - n, c, h, w
89
+ b_prev (torch.Tensor) : previous bin centers normed; shape - n, prev_nbins, h, w
90
+
91
+ Returns:
92
+ tuple(torch.Tensor,torch.Tensor) : new bin centers normed and scaled; shape - n, nbins, h, w
93
+ """
94
+ if prev_b_embedding is not None:
95
+ if interpolate:
96
+ prev_b_embedding = nn.functional.interpolate(
97
+ prev_b_embedding, x.shape[-2:], mode='bilinear', align_corners=True)
98
+ x = x + prev_b_embedding
99
+
100
+ A = self._net(x)
101
+ eps = 1e-3
102
+ A = A + eps
103
+ n, c, h, w = A.shape
104
+ A = A.view(n, self.n_attractors, 2, h, w)
105
+ A_normed = A / A.sum(dim=2, keepdim=True) # n, a, 2, h, w
106
+ A_normed = A[:, :, 0, ...] # n, na, h, w
107
+
108
+ b_prev = nn.functional.interpolate(
109
+ b_prev, (h, w), mode='bilinear', align_corners=True)
110
+ b_centers = b_prev
111
+
112
+ if self.attractor_type == 'exp':
113
+ dist = exp_attractor
114
+ else:
115
+ dist = inv_attractor
116
+
117
+ if not self.memory_efficient:
118
+ func = {'mean': torch.mean, 'sum': torch.sum}[self.kind]
119
+ # .shape N, nbins, h, w
120
+ delta_c = func(dist(A_normed.unsqueeze(
121
+ 2) - b_centers.unsqueeze(1)), dim=1)
122
+ else:
123
+ delta_c = torch.zeros_like(b_centers, device=b_centers.device)
124
+ for i in range(self.n_attractors):
125
+ # .shape N, nbins, h, w
126
+ delta_c += dist(A_normed[:, i, ...].unsqueeze(1) - b_centers)
127
+
128
+ if self.kind == 'mean':
129
+ delta_c = delta_c / self.n_attractors
130
+
131
+ b_new_centers = b_centers + delta_c
132
+ B_centers = (self.max_depth - self.min_depth) * \
133
+ b_new_centers + self.min_depth
134
+ B_centers, _ = torch.sort(B_centers, dim=1)
135
+ B_centers = torch.clip(B_centers, self.min_depth, self.max_depth)
136
+ return b_new_centers, B_centers
137
+
138
+
139
+ class AttractorLayerUnnormed(nn.Module):
140
+ def __init__(self, in_features, n_bins, n_attractors=16, mlp_dim=128, min_depth=1e-3, max_depth=10,
141
+ alpha=300, gamma=2, kind='sum', attractor_type='exp', memory_efficient=False):
142
+ """
143
+ Attractor layer for bin centers. Bin centers are unbounded
144
+ """
145
+ super().__init__()
146
+
147
+ self.n_attractors = n_attractors
148
+ self.n_bins = n_bins
149
+ self.min_depth = min_depth
150
+ self.max_depth = max_depth
151
+ self.alpha = alpha
152
+ self.gamma = gamma
153
+ self.kind = kind
154
+ self.attractor_type = attractor_type
155
+ self.memory_efficient = memory_efficient
156
+
157
+ self._net = nn.Sequential(
158
+ nn.Conv2d(in_features, mlp_dim, 1, 1, 0),
159
+ nn.ReLU(inplace=True),
160
+ nn.Conv2d(mlp_dim, n_attractors, 1, 1, 0),
161
+ nn.Softplus()
162
+ )
163
+
164
+ def forward(self, x, b_prev, prev_b_embedding=None, interpolate=True, is_for_query=False):
165
+ """
166
+ Args:
167
+ x (torch.Tensor) : feature block; shape - n, c, h, w
168
+ b_prev (torch.Tensor) : previous bin centers normed; shape - n, prev_nbins, h, w
169
+
170
+ Returns:
171
+ tuple(torch.Tensor,torch.Tensor) : new bin centers unbounded; shape - n, nbins, h, w. Two outputs just to keep the API consistent with the normed version
172
+ """
173
+ if prev_b_embedding is not None:
174
+ if interpolate:
175
+ prev_b_embedding = nn.functional.interpolate(
176
+ prev_b_embedding, x.shape[-2:], mode='bilinear', align_corners=True)
177
+ x = x + prev_b_embedding
178
+
179
+ A = self._net(x)
180
+ n, c, h, w = A.shape
181
+
182
+ b_prev = nn.functional.interpolate(
183
+ b_prev, (h, w), mode='bilinear', align_corners=True)
184
+ b_centers = b_prev
185
+
186
+ if self.attractor_type == 'exp':
187
+ dist = exp_attractor
188
+ else:
189
+ dist = inv_attractor
190
+
191
+ if not self.memory_efficient:
192
+ func = {'mean': torch.mean, 'sum': torch.sum}[self.kind]
193
+ # .shape N, nbins, h, w
194
+ delta_c = func(
195
+ dist(A.unsqueeze(2) - b_centers.unsqueeze(1)), dim=1)
196
+ else:
197
+ delta_c = torch.zeros_like(b_centers, device=b_centers.device)
198
+ for i in range(self.n_attractors):
199
+ delta_c += dist(A[:, i, ...].unsqueeze(1) -
200
+ b_centers) # .shape N, nbins, h, w
201
+
202
+ if self.kind == 'mean':
203
+ delta_c = delta_c / self.n_attractors
204
+
205
+ b_new_centers = b_centers + delta_c
206
+ B_centers = b_new_centers
207
+
208
+ return b_new_centers, B_centers
zoedepth/models/layers/dist_layers.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # MIT License
2
+
3
+ # Copyright (c) 2022 Intelligent Systems Lab Org
4
+
5
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ # of this software and associated documentation files (the "Software"), to deal
7
+ # in the Software without restriction, including without limitation the rights
8
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ # copies of the Software, and to permit persons to whom the Software is
10
+ # furnished to do so, subject to the following conditions:
11
+
12
+ # The above copyright notice and this permission notice shall be included in all
13
+ # copies or substantial portions of the Software.
14
+
15
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ # SOFTWARE.
22
+
23
+ # File author: Shariq Farooq Bhat
24
+
25
+ import torch
26
+ import torch.nn as nn
27
+
28
+
29
+ def log_binom(n, k, eps=1e-7):
30
+ """ log(nCk) using stirling approximation """
31
+ n = n + eps
32
+ k = k + eps
33
+ return n * torch.log(n) - k * torch.log(k) - (n-k) * torch.log(n-k+eps)
34
+
35
+
36
+ class LogBinomial(nn.Module):
37
+ def __init__(self, n_classes=256, act=torch.softmax):
38
+ """Compute log binomial distribution for n_classes
39
+
40
+ Args:
41
+ n_classes (int, optional): number of output classes. Defaults to 256.
42
+ """
43
+ super().__init__()
44
+ self.K = n_classes
45
+ self.act = act
46
+ self.register_buffer('k_idx', torch.arange(
47
+ 0, n_classes).view(1, -1, 1, 1))
48
+ self.register_buffer('K_minus_1', torch.Tensor(
49
+ [self.K-1]).view(1, -1, 1, 1))
50
+
51
+ def forward(self, x, t=1., eps=1e-4):
52
+ """Compute log binomial distribution for x
53
+
54
+ Args:
55
+ x (torch.Tensor - NCHW): probabilities
56
+ t (float, torch.Tensor - NCHW, optional): Temperature of distribution. Defaults to 1..
57
+ eps (float, optional): Small number for numerical stability. Defaults to 1e-4.
58
+
59
+ Returns:
60
+ torch.Tensor -NCHW: log binomial distribution logbinomial(p;t)
61
+ """
62
+ if x.ndim == 3:
63
+ x = x.unsqueeze(1) # make it nchw
64
+
65
+ one_minus_x = torch.clamp(1 - x, eps, 1)
66
+ x = torch.clamp(x, eps, 1)
67
+ y = log_binom(self.K_minus_1, self.k_idx) + self.k_idx * \
68
+ torch.log(x) + (self.K - 1 - self.k_idx) * torch.log(one_minus_x)
69
+ return self.act(y/t, dim=1)
70
+
71
+
72
+ class ConditionalLogBinomial(nn.Module):
73
+ def __init__(self, in_features, condition_dim, n_classes=256, bottleneck_factor=2, p_eps=1e-4, max_temp=50, min_temp=1e-7, act=torch.softmax):
74
+ """Conditional Log Binomial distribution
75
+
76
+ Args:
77
+ in_features (int): number of input channels in main feature
78
+ condition_dim (int): number of input channels in condition feature
79
+ n_classes (int, optional): Number of classes. Defaults to 256.
80
+ bottleneck_factor (int, optional): Hidden dim factor. Defaults to 2.
81
+ p_eps (float, optional): small eps value. Defaults to 1e-4.
82
+ max_temp (float, optional): Maximum temperature of output distribution. Defaults to 50.
83
+ min_temp (float, optional): Minimum temperature of output distribution. Defaults to 1e-7.
84
+ """
85
+ super().__init__()
86
+ self.p_eps = p_eps
87
+ self.max_temp = max_temp
88
+ self.min_temp = min_temp
89
+ self.log_binomial_transform = LogBinomial(n_classes, act=act)
90
+ bottleneck = (in_features + condition_dim) // bottleneck_factor
91
+ self.mlp = nn.Sequential(
92
+ nn.Conv2d(in_features + condition_dim, bottleneck,
93
+ kernel_size=1, stride=1, padding=0),
94
+ nn.GELU(),
95
+ # 2 for p linear norm, 2 for t linear norm
96
+ nn.Conv2d(bottleneck, 2+2, kernel_size=1, stride=1, padding=0),
97
+ nn.Softplus()
98
+ )
99
+
100
+ def forward(self, x, cond):
101
+ """Forward pass
102
+
103
+ Args:
104
+ x (torch.Tensor - NCHW): Main feature
105
+ cond (torch.Tensor - NCHW): condition feature
106
+
107
+ Returns:
108
+ torch.Tensor: Output log binomial distribution
109
+ """
110
+ pt = self.mlp(torch.concat((x, cond), dim=1))
111
+ p, t = pt[:, :2, ...], pt[:, 2:, ...]
112
+
113
+ p = p + self.p_eps
114
+ p = p[:, 0, ...] / (p[:, 0, ...] + p[:, 1, ...])
115
+
116
+ t = t + self.p_eps
117
+ t = t[:, 0, ...] / (t[:, 0, ...] + t[:, 1, ...])
118
+ t = t.unsqueeze(1)
119
+ t = (self.max_temp - self.min_temp) * t + self.min_temp
120
+
121
+ return self.log_binomial_transform(p, t)
zoedepth/models/layers/localbins_layers.py ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # MIT License
2
+
3
+ # Copyright (c) 2022 Intelligent Systems Lab Org
4
+
5
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ # of this software and associated documentation files (the "Software"), to deal
7
+ # in the Software without restriction, including without limitation the rights
8
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ # copies of the Software, and to permit persons to whom the Software is
10
+ # furnished to do so, subject to the following conditions:
11
+
12
+ # The above copyright notice and this permission notice shall be included in all
13
+ # copies or substantial portions of the Software.
14
+
15
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ # SOFTWARE.
22
+
23
+ # File author: Shariq Farooq Bhat
24
+
25
+ import torch
26
+ import torch.nn as nn
27
+
28
+
29
+ class SeedBinRegressor(nn.Module):
30
+ def __init__(self, in_features, n_bins=16, mlp_dim=256, min_depth=1e-3, max_depth=10):
31
+ """Bin center regressor network. Bin centers are bounded on (min_depth, max_depth) interval.
32
+
33
+ Args:
34
+ in_features (int): input channels
35
+ n_bins (int, optional): Number of bin centers. Defaults to 16.
36
+ mlp_dim (int, optional): Hidden dimension. Defaults to 256.
37
+ min_depth (float, optional): Min depth value. Defaults to 1e-3.
38
+ max_depth (float, optional): Max depth value. Defaults to 10.
39
+ """
40
+ super().__init__()
41
+ self.version = "1_1"
42
+ self.min_depth = min_depth
43
+ self.max_depth = max_depth
44
+
45
+ self._net = nn.Sequential(
46
+ nn.Conv2d(in_features, mlp_dim, 1, 1, 0),
47
+ nn.ReLU(inplace=True),
48
+ nn.Conv2d(mlp_dim, n_bins, 1, 1, 0),
49
+ nn.ReLU(inplace=True)
50
+ )
51
+
52
+ def forward(self, x):
53
+ """
54
+ Returns tensor of bin_width vectors (centers). One vector b for every pixel
55
+ """
56
+ B = self._net(x)
57
+ eps = 1e-3
58
+ B = B + eps
59
+ B_widths_normed = B / B.sum(dim=1, keepdim=True)
60
+ B_widths = (self.max_depth - self.min_depth) * \
61
+ B_widths_normed # .shape NCHW
62
+ # pad has the form (left, right, top, bottom, front, back)
63
+ B_widths = nn.functional.pad(
64
+ B_widths, (0, 0, 0, 0, 1, 0), mode='constant', value=self.min_depth)
65
+ B_edges = torch.cumsum(B_widths, dim=1) # .shape NCHW
66
+
67
+ B_centers = 0.5 * (B_edges[:, :-1, ...] + B_edges[:, 1:, ...])
68
+ return B_widths_normed, B_centers
69
+
70
+
71
+ class SeedBinRegressorUnnormed(nn.Module):
72
+ def __init__(self, in_features, n_bins=16, mlp_dim=256, min_depth=1e-3, max_depth=10):
73
+ """Bin center regressor network. Bin centers are unbounded
74
+
75
+ Args:
76
+ in_features (int): input channels
77
+ n_bins (int, optional): Number of bin centers. Defaults to 16.
78
+ mlp_dim (int, optional): Hidden dimension. Defaults to 256.
79
+ min_depth (float, optional): Not used. (for compatibility with SeedBinRegressor)
80
+ max_depth (float, optional): Not used. (for compatibility with SeedBinRegressor)
81
+ """
82
+ super().__init__()
83
+ self.version = "1_1"
84
+ self._net = nn.Sequential(
85
+ nn.Conv2d(in_features, mlp_dim, 1, 1, 0),
86
+ nn.ReLU(inplace=True),
87
+ nn.Conv2d(mlp_dim, n_bins, 1, 1, 0),
88
+ nn.Softplus()
89
+ )
90
+
91
+ def forward(self, x):
92
+ """
93
+ Returns tensor of bin_width vectors (centers). One vector b for every pixel
94
+ """
95
+ B_centers = self._net(x)
96
+ return B_centers, B_centers
97
+
98
+
99
+ class Projector(nn.Module):
100
+ def __init__(self, in_features, out_features, mlp_dim=128):
101
+ """Projector MLP
102
+
103
+ Args:
104
+ in_features (int): input channels
105
+ out_features (int): output channels
106
+ mlp_dim (int, optional): hidden dimension. Defaults to 128.
107
+ """
108
+ super().__init__()
109
+
110
+ self._net = nn.Sequential(
111
+ nn.Conv2d(in_features, mlp_dim, 1, 1, 0),
112
+ nn.ReLU(inplace=True),
113
+ nn.Conv2d(mlp_dim, out_features, 1, 1, 0),
114
+ )
115
+
116
+ def forward(self, x):
117
+ return self._net(x)
118
+
119
+
120
+
121
+ class LinearSplitter(nn.Module):
122
+ def __init__(self, in_features, prev_nbins, split_factor=2, mlp_dim=128, min_depth=1e-3, max_depth=10):
123
+ super().__init__()
124
+
125
+ self.prev_nbins = prev_nbins
126
+ self.split_factor = split_factor
127
+ self.min_depth = min_depth
128
+ self.max_depth = max_depth
129
+
130
+ self._net = nn.Sequential(
131
+ nn.Conv2d(in_features, mlp_dim, 1, 1, 0),
132
+ nn.GELU(),
133
+ nn.Conv2d(mlp_dim, prev_nbins * split_factor, 1, 1, 0),
134
+ nn.ReLU()
135
+ )
136
+
137
+ def forward(self, x, b_prev, prev_b_embedding=None, interpolate=True, is_for_query=False):
138
+ """
139
+ x : feature block; shape - n, c, h, w
140
+ b_prev : previous bin widths normed; shape - n, prev_nbins, h, w
141
+ """
142
+ if prev_b_embedding is not None:
143
+ if interpolate:
144
+ prev_b_embedding = nn.functional.interpolate(prev_b_embedding, x.shape[-2:], mode='bilinear', align_corners=True)
145
+ x = x + prev_b_embedding
146
+ S = self._net(x)
147
+ eps = 1e-3
148
+ S = S + eps
149
+ n, c, h, w = S.shape
150
+ S = S.view(n, self.prev_nbins, self.split_factor, h, w)
151
+ S_normed = S / S.sum(dim=2, keepdim=True) # fractional splits
152
+
153
+ b_prev = nn.functional.interpolate(b_prev, (h,w), mode='bilinear', align_corners=True)
154
+
155
+
156
+ b_prev = b_prev / b_prev.sum(dim=1, keepdim=True) # renormalize for gurantees
157
+ # print(b_prev.shape, S_normed.shape)
158
+ # if is_for_query:(1).expand(-1, b_prev.size(0)//n, -1, -1, -1, -1).flatten(0,1) # TODO ? can replace all this with a single torch.repeat?
159
+ b = b_prev.unsqueeze(2) * S_normed
160
+ b = b.flatten(1,2) # .shape n, prev_nbins * split_factor, h, w
161
+
162
+ # calculate bin centers for loss calculation
163
+ B_widths = (self.max_depth - self.min_depth) * b # .shape N, nprev * splitfactor, H, W
164
+ # pad has the form (left, right, top, bottom, front, back)
165
+ B_widths = nn.functional.pad(B_widths, (0,0,0,0,1,0), mode='constant', value=self.min_depth)
166
+ B_edges = torch.cumsum(B_widths, dim=1) # .shape NCHW
167
+
168
+ B_centers = 0.5 * (B_edges[:, :-1, ...] + B_edges[:,1:,...])
169
+ return b, B_centers
zoedepth/models/layers/patch_transformer.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # MIT License
2
+
3
+ # Copyright (c) 2022 Intelligent Systems Lab Org
4
+
5
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ # of this software and associated documentation files (the "Software"), to deal
7
+ # in the Software without restriction, including without limitation the rights
8
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ # copies of the Software, and to permit persons to whom the Software is
10
+ # furnished to do so, subject to the following conditions:
11
+
12
+ # The above copyright notice and this permission notice shall be included in all
13
+ # copies or substantial portions of the Software.
14
+
15
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ # SOFTWARE.
22
+
23
+ # File author: Shariq Farooq Bhat
24
+
25
+ import torch
26
+ import torch.nn as nn
27
+
28
+
29
+ class PatchTransformerEncoder(nn.Module):
30
+ def __init__(self, in_channels, patch_size=10, embedding_dim=128, num_heads=4, use_class_token=False):
31
+ """ViT-like transformer block
32
+
33
+ Args:
34
+ in_channels (int): Input channels
35
+ patch_size (int, optional): patch size. Defaults to 10.
36
+ embedding_dim (int, optional): Embedding dimension in transformer model. Defaults to 128.
37
+ num_heads (int, optional): number of attention heads. Defaults to 4.
38
+ use_class_token (bool, optional): Whether to use extra token at the start for global accumulation (called as "class token"). Defaults to False.
39
+ """
40
+ super(PatchTransformerEncoder, self).__init__()
41
+ self.use_class_token = use_class_token
42
+ encoder_layers = nn.TransformerEncoderLayer(
43
+ embedding_dim, num_heads, dim_feedforward=1024)
44
+ self.transformer_encoder = nn.TransformerEncoder(
45
+ encoder_layers, num_layers=4) # takes shape S,N,E
46
+
47
+ self.embedding_convPxP = nn.Conv2d(in_channels, embedding_dim,
48
+ kernel_size=patch_size, stride=patch_size, padding=0)
49
+
50
+ def positional_encoding_1d(self, sequence_length, batch_size, embedding_dim, device='cpu'):
51
+ """Generate positional encodings
52
+
53
+ Args:
54
+ sequence_length (int): Sequence length
55
+ embedding_dim (int): Embedding dimension
56
+
57
+ Returns:
58
+ torch.Tensor SBE: Positional encodings
59
+ """
60
+ position = torch.arange(
61
+ 0, sequence_length, dtype=torch.float32, device=device).unsqueeze(1)
62
+ index = torch.arange(
63
+ 0, embedding_dim, 2, dtype=torch.float32, device=device).unsqueeze(0)
64
+ div_term = torch.exp(index * (-torch.log(torch.tensor(10000.0, device=device)) / embedding_dim))
65
+ pos_encoding = position * div_term
66
+ pos_encoding = torch.cat([torch.sin(pos_encoding), torch.cos(pos_encoding)], dim=1)
67
+ pos_encoding = pos_encoding.unsqueeze(1).repeat(1, batch_size, 1)
68
+ return pos_encoding
69
+
70
+
71
+ def forward(self, x):
72
+ """Forward pass
73
+
74
+ Args:
75
+ x (torch.Tensor - NCHW): Input feature tensor
76
+
77
+ Returns:
78
+ torch.Tensor - SNE: Transformer output embeddings. S - sequence length (=HW/patch_size^2), N - batch size, E - embedding dim
79
+ """
80
+ embeddings = self.embedding_convPxP(x).flatten(
81
+ 2) # .shape = n,c,s = n, embedding_dim, s
82
+ if self.use_class_token:
83
+ # extra special token at start ?
84
+ embeddings = nn.functional.pad(embeddings, (1, 0))
85
+
86
+ # change to S,N,E format required by transformer
87
+ embeddings = embeddings.permute(2, 0, 1)
88
+ S, N, E = embeddings.shape
89
+ embeddings = embeddings + self.positional_encoding_1d(S, N, E, device=embeddings.device)
90
+ x = self.transformer_encoder(embeddings) # .shape = S, N, E
91
+ return x
zoedepth/models/model_io.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # MIT License
2
+
3
+ # Copyright (c) 2022 Intelligent Systems Lab Org
4
+
5
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ # of this software and associated documentation files (the "Software"), to deal
7
+ # in the Software without restriction, including without limitation the rights
8
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ # copies of the Software, and to permit persons to whom the Software is
10
+ # furnished to do so, subject to the following conditions:
11
+
12
+ # The above copyright notice and this permission notice shall be included in all
13
+ # copies or substantial portions of the Software.
14
+
15
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ # SOFTWARE.
22
+
23
+ # File author: Shariq Farooq Bhat
24
+
25
+ import torch
26
+
27
+ def load_state_dict(model, state_dict):
28
+ """Load state_dict into model, handling DataParallel and DistributedDataParallel. Also checks for "model" key in state_dict.
29
+
30
+ DataParallel prefixes state_dict keys with 'module.' when saving.
31
+ If the model is not a DataParallel model but the state_dict is, then prefixes are removed.
32
+ If the model is a DataParallel model but the state_dict is not, then prefixes are added.
33
+ """
34
+ state_dict = state_dict.get('model', state_dict)
35
+ # if model is a DataParallel model, then state_dict keys are prefixed with 'module.'
36
+
37
+ do_prefix = isinstance(
38
+ model, (torch.nn.DataParallel, torch.nn.parallel.DistributedDataParallel))
39
+ state = {}
40
+ for k, v in state_dict.items():
41
+ if k.startswith('module.') and not do_prefix:
42
+ k = k[7:]
43
+
44
+ if not k.startswith('module.') and do_prefix:
45
+ k = 'module.' + k
46
+
47
+ state[k] = v
48
+
49
+ model.load_state_dict(state)
50
+ print("Loaded successfully")
51
+ return model
52
+
53
+
54
+ def load_wts(model, checkpoint_path):
55
+ ckpt = torch.load(checkpoint_path, map_location='cpu')
56
+ return load_state_dict(model, ckpt)
57
+
58
+
59
+ def load_state_dict_from_url(model, url, **kwargs):
60
+ state_dict = torch.hub.load_state_dict_from_url(url, map_location='cpu', **kwargs)
61
+ return load_state_dict(model, state_dict)
62
+
63
+
64
+ def load_state_from_resource(model, resource: str):
65
+ """Loads weights to the model from a given resource. A resource can be of following types:
66
+ 1. URL. Prefixed with "url::"
67
+ e.g. url::http(s)://url.resource.com/ckpt.pt
68
+
69
+ 2. Local path. Prefixed with "local::"
70
+ e.g. local::/path/to/ckpt.pt
71
+
72
+
73
+ Args:
74
+ model (torch.nn.Module): Model
75
+ resource (str): resource string
76
+
77
+ Returns:
78
+ torch.nn.Module: Model with loaded weights
79
+ """
80
+ print(f"Using pretrained resource {resource}")
81
+
82
+ if resource.startswith('url::'):
83
+ url = resource.split('url::')[1]
84
+ return load_state_dict_from_url(model, url, progress=True)
85
+
86
+ elif resource.startswith('local::'):
87
+ path = resource.split('local::')[1]
88
+ return load_wts(model, path)
89
+
90
+ else:
91
+ raise ValueError("Invalid resource type, only url:: and local:: are supported")
92
+
zoedepth/models/zoedepth/__init__.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # MIT License
2
+
3
+ # Copyright (c) 2022 Intelligent Systems Lab Org
4
+
5
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ # of this software and associated documentation files (the "Software"), to deal
7
+ # in the Software without restriction, including without limitation the rights
8
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ # copies of the Software, and to permit persons to whom the Software is
10
+ # furnished to do so, subject to the following conditions:
11
+
12
+ # The above copyright notice and this permission notice shall be included in all
13
+ # copies or substantial portions of the Software.
14
+
15
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ # SOFTWARE.
22
+
23
+ # File author: Shariq Farooq Bhat
24
+
25
+ from .zoedepth_v1 import ZoeDepth
26
+
27
+ all_versions = {
28
+ "v1": ZoeDepth,
29
+ }
30
+
31
+ get_version = lambda v : all_versions[v]
zoedepth/models/zoedepth/config_zoedepth.json ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model": {
3
+ "name": "ZoeDepth",
4
+ "version_name": "v1",
5
+ "n_bins": 64,
6
+ "bin_embedding_dim": 128,
7
+ "bin_centers_type": "softplus",
8
+ "n_attractors":[16, 8, 4, 1],
9
+ "attractor_alpha": 1000,
10
+ "attractor_gamma": 2,
11
+ "attractor_kind" : "mean",
12
+ "attractor_type" : "inv",
13
+ "midas_model_type" : "DPT_BEiT_L_384",
14
+ "min_temp": 0.0212,
15
+ "max_temp": 50.0,
16
+ "output_distribution": "logbinomial",
17
+ "memory_efficient": true,
18
+ "inverse_midas": false,
19
+ "img_size": [384, 512]
20
+ },
21
+
22
+ "train": {
23
+ "train_midas": true,
24
+ "use_pretrained_midas": true,
25
+ "trainer": "zoedepth",
26
+ "epochs": 5,
27
+ "bs": 16,
28
+ "optim_kwargs": {"lr": 0.000161, "wd": 0.01},
29
+ "sched_kwargs": {"div_factor": 1, "final_div_factor": 10000, "pct_start": 0.7, "three_phase":false, "cycle_momentum": true},
30
+ "same_lr": false,
31
+ "w_si": 1,
32
+ "w_domain": 0.2,
33
+ "w_reg": 0,
34
+ "w_grad": 0,
35
+ "avoid_boundary": false,
36
+ "random_crop": false,
37
+ "input_width": 640,
38
+ "input_height": 480,
39
+ "midas_lr_factor": 1,
40
+ "encoder_lr_factor":10,
41
+ "pos_enc_lr_factor":10,
42
+ "freeze_midas_bn": true
43
+
44
+ },
45
+
46
+ "infer":{
47
+ "train_midas": false,
48
+ "use_pretrained_midas": false,
49
+ "pretrained_resource" : "url::https://github.com/isl-org/ZoeDepth/releases/download/v1.0/ZoeD_M12_N.pt",
50
+ "force_keep_ar": true
51
+ },
52
+
53
+ "eval":{
54
+ "train_midas": false,
55
+ "use_pretrained_midas": false,
56
+ "pretrained_resource" : "url::https://github.com/isl-org/ZoeDepth/releases/download/v1.0/ZoeD_M12_N.pt"
57
+ }
58
+ }
zoedepth/models/zoedepth/config_zoedepth_kitti.json ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model": {
3
+ "bin_centers_type": "normed",
4
+ "img_size": [384, 768]
5
+ },
6
+
7
+ "train": {
8
+ },
9
+
10
+ "infer":{
11
+ "train_midas": false,
12
+ "use_pretrained_midas": false,
13
+ "pretrained_resource" : "url::https://github.com/isl-org/ZoeDepth/releases/download/v1.0/ZoeD_M12_K.pt",
14
+ "force_keep_ar": true
15
+ },
16
+
17
+ "eval":{
18
+ "train_midas": false,
19
+ "use_pretrained_midas": false,
20
+ "pretrained_resource" : "url::https://github.com/isl-org/ZoeDepth/releases/download/v1.0/ZoeD_M12_K.pt"
21
+ }
22
+ }
zoedepth/models/zoedepth/zoedepth_v1.py ADDED
@@ -0,0 +1,250 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # MIT License
2
+
3
+ # Copyright (c) 2022 Intelligent Systems Lab Org
4
+
5
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ # of this software and associated documentation files (the "Software"), to deal
7
+ # in the Software without restriction, including without limitation the rights
8
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ # copies of the Software, and to permit persons to whom the Software is
10
+ # furnished to do so, subject to the following conditions:
11
+
12
+ # The above copyright notice and this permission notice shall be included in all
13
+ # copies or substantial portions of the Software.
14
+
15
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ # SOFTWARE.
22
+
23
+ # File author: Shariq Farooq Bhat
24
+
25
+ import itertools
26
+
27
+ import torch
28
+ import torch.nn as nn
29
+ from zoedepth.models.depth_model import DepthModel
30
+ from zoedepth.models.base_models.midas import MidasCore
31
+ from zoedepth.models.layers.attractor import AttractorLayer, AttractorLayerUnnormed
32
+ from zoedepth.models.layers.dist_layers import ConditionalLogBinomial
33
+ from zoedepth.models.layers.localbins_layers import (Projector, SeedBinRegressor,
34
+ SeedBinRegressorUnnormed)
35
+ from zoedepth.models.model_io import load_state_from_resource
36
+
37
+
38
+ class ZoeDepth(DepthModel):
39
+ def __init__(self, core, n_bins=64, bin_centers_type="softplus", bin_embedding_dim=128, min_depth=1e-3, max_depth=10,
40
+ n_attractors=[16, 8, 4, 1], attractor_alpha=300, attractor_gamma=2, attractor_kind='sum', attractor_type='exp', min_temp=5, max_temp=50, train_midas=True,
41
+ midas_lr_factor=10, encoder_lr_factor=10, pos_enc_lr_factor=10, inverse_midas=False, **kwargs):
42
+ """ZoeDepth model. This is the version of ZoeDepth that has a single metric head
43
+
44
+ Args:
45
+ core (models.base_models.midas.MidasCore): The base midas model that is used for extraction of "relative" features
46
+ n_bins (int, optional): Number of bin centers. Defaults to 64.
47
+ bin_centers_type (str, optional): "normed" or "softplus". Activation type used for bin centers. For "normed" bin centers, linear normalization trick is applied. This results in bounded bin centers.
48
+ For "softplus", softplus activation is used and thus are unbounded. Defaults to "softplus".
49
+ bin_embedding_dim (int, optional): bin embedding dimension. Defaults to 128.
50
+ min_depth (float, optional): Lower bound for normed bin centers. Defaults to 1e-3.
51
+ max_depth (float, optional): Upper bound for normed bin centers. Defaults to 10.
52
+ n_attractors (List[int], optional): Number of bin attractors at decoder layers. Defaults to [16, 8, 4, 1].
53
+ attractor_alpha (int, optional): Proportional attractor strength. Refer to models.layers.attractor for more details. Defaults to 300.
54
+ attractor_gamma (int, optional): Exponential attractor strength. Refer to models.layers.attractor for more details. Defaults to 2.
55
+ attractor_kind (str, optional): Attraction aggregation "sum" or "mean". Defaults to 'sum'.
56
+ attractor_type (str, optional): Type of attractor to use; "inv" (Inverse attractor) or "exp" (Exponential attractor). Defaults to 'exp'.
57
+ min_temp (int, optional): Lower bound for temperature of output probability distribution. Defaults to 5.
58
+ max_temp (int, optional): Upper bound for temperature of output probability distribution. Defaults to 50.
59
+ train_midas (bool, optional): Whether to train "core", the base midas model. Defaults to True.
60
+ midas_lr_factor (int, optional): Learning rate reduction factor for base midas model except its encoder and positional encodings. Defaults to 10.
61
+ encoder_lr_factor (int, optional): Learning rate reduction factor for the encoder in midas model. Defaults to 10.
62
+ pos_enc_lr_factor (int, optional): Learning rate reduction factor for positional encodings in the base midas model. Defaults to 10.
63
+ """
64
+ super().__init__()
65
+
66
+ self.core = core
67
+ self.max_depth = max_depth
68
+ self.min_depth = min_depth
69
+ self.min_temp = min_temp
70
+ self.bin_centers_type = bin_centers_type
71
+
72
+ self.midas_lr_factor = midas_lr_factor
73
+ self.encoder_lr_factor = encoder_lr_factor
74
+ self.pos_enc_lr_factor = pos_enc_lr_factor
75
+ self.train_midas = train_midas
76
+ self.inverse_midas = inverse_midas
77
+
78
+ if self.encoder_lr_factor <= 0:
79
+ self.core.freeze_encoder(
80
+ freeze_rel_pos=self.pos_enc_lr_factor <= 0)
81
+
82
+ N_MIDAS_OUT = 32
83
+ btlnck_features = self.core.output_channels[0]
84
+ num_out_features = self.core.output_channels[1:]
85
+
86
+ self.conv2 = nn.Conv2d(btlnck_features, btlnck_features,
87
+ kernel_size=1, stride=1, padding=0) # btlnck conv
88
+
89
+ if bin_centers_type == "normed":
90
+ SeedBinRegressorLayer = SeedBinRegressor
91
+ Attractor = AttractorLayer
92
+ elif bin_centers_type == "softplus":
93
+ SeedBinRegressorLayer = SeedBinRegressorUnnormed
94
+ Attractor = AttractorLayerUnnormed
95
+ elif bin_centers_type == "hybrid1":
96
+ SeedBinRegressorLayer = SeedBinRegressor
97
+ Attractor = AttractorLayerUnnormed
98
+ elif bin_centers_type == "hybrid2":
99
+ SeedBinRegressorLayer = SeedBinRegressorUnnormed
100
+ Attractor = AttractorLayer
101
+ else:
102
+ raise ValueError(
103
+ "bin_centers_type should be one of 'normed', 'softplus', 'hybrid1', 'hybrid2'")
104
+
105
+ self.seed_bin_regressor = SeedBinRegressorLayer(
106
+ btlnck_features, n_bins=n_bins, min_depth=min_depth, max_depth=max_depth)
107
+ self.seed_projector = Projector(btlnck_features, bin_embedding_dim)
108
+ self.projectors = nn.ModuleList([
109
+ Projector(num_out, bin_embedding_dim)
110
+ for num_out in num_out_features
111
+ ])
112
+ self.attractors = nn.ModuleList([
113
+ Attractor(bin_embedding_dim, n_bins, n_attractors=n_attractors[i], min_depth=min_depth, max_depth=max_depth,
114
+ alpha=attractor_alpha, gamma=attractor_gamma, kind=attractor_kind, attractor_type=attractor_type)
115
+ for i in range(len(num_out_features))
116
+ ])
117
+
118
+ last_in = N_MIDAS_OUT + 1 # +1 for relative depth
119
+
120
+ # use log binomial instead of softmax
121
+ self.conditional_log_binomial = ConditionalLogBinomial(
122
+ last_in, bin_embedding_dim, n_classes=n_bins, min_temp=min_temp, max_temp=max_temp)
123
+
124
+ def forward(self, x, return_final_centers=False, denorm=False, return_probs=False, **kwargs):
125
+ """
126
+ Args:
127
+ x (torch.Tensor): Input image tensor of shape (B, C, H, W)
128
+ return_final_centers (bool, optional): Whether to return the final bin centers. Defaults to False.
129
+ denorm (bool, optional): Whether to denormalize the input image. This reverses ImageNet normalization as midas normalization is different. Defaults to False.
130
+ return_probs (bool, optional): Whether to return the output probability distribution. Defaults to False.
131
+
132
+ Returns:
133
+ dict: Dictionary containing the following keys:
134
+ - rel_depth (torch.Tensor): Relative depth map of shape (B, H, W)
135
+ - metric_depth (torch.Tensor): Metric depth map of shape (B, 1, H, W)
136
+ - bin_centers (torch.Tensor): Bin centers of shape (B, n_bins). Present only if return_final_centers is True
137
+ - probs (torch.Tensor): Output probability distribution of shape (B, n_bins, H, W). Present only if return_probs is True
138
+
139
+ """
140
+ b, c, h, w = x.shape
141
+ # print("input shape ", x.shape)
142
+ self.orig_input_width = w
143
+ self.orig_input_height = h
144
+ rel_depth, out = self.core(x, denorm=denorm, return_rel_depth=True)
145
+ # print("output shapes", rel_depth.shape, out.shape)
146
+
147
+ outconv_activation = out[0]
148
+ btlnck = out[1]
149
+ x_blocks = out[2:]
150
+
151
+ x_d0 = self.conv2(btlnck)
152
+ x = x_d0
153
+ _, seed_b_centers = self.seed_bin_regressor(x)
154
+
155
+ if self.bin_centers_type == 'normed' or self.bin_centers_type == 'hybrid2':
156
+ b_prev = (seed_b_centers - self.min_depth) / \
157
+ (self.max_depth - self.min_depth)
158
+ else:
159
+ b_prev = seed_b_centers
160
+
161
+ prev_b_embedding = self.seed_projector(x)
162
+
163
+ # unroll this loop for better performance
164
+ for projector, attractor, x in zip(self.projectors, self.attractors, x_blocks):
165
+ b_embedding = projector(x)
166
+ b, b_centers = attractor(
167
+ b_embedding, b_prev, prev_b_embedding, interpolate=True)
168
+ b_prev = b.clone()
169
+ prev_b_embedding = b_embedding.clone()
170
+
171
+ last = outconv_activation
172
+
173
+ if self.inverse_midas:
174
+ # invert depth followed by normalization
175
+ rel_depth = 1.0 / (rel_depth + 1e-6)
176
+ rel_depth = (rel_depth - rel_depth.min()) / \
177
+ (rel_depth.max() - rel_depth.min())
178
+ # concat rel depth with last. First interpolate rel depth to last size
179
+ rel_cond = rel_depth.unsqueeze(1)
180
+ rel_cond = nn.functional.interpolate(
181
+ rel_cond, size=last.shape[2:], mode='bilinear', align_corners=True)
182
+ last = torch.cat([last, rel_cond], dim=1)
183
+
184
+ b_embedding = nn.functional.interpolate(
185
+ b_embedding, last.shape[-2:], mode='bilinear', align_corners=True)
186
+ x = self.conditional_log_binomial(last, b_embedding)
187
+
188
+ # Now depth value is Sum px * cx , where cx are bin_centers from the last bin tensor
189
+ # print(x.shape, b_centers.shape)
190
+ b_centers = nn.functional.interpolate(
191
+ b_centers, x.shape[-2:], mode='bilinear', align_corners=True)
192
+ out = torch.sum(x * b_centers, dim=1, keepdim=True)
193
+
194
+ # Structure output dict
195
+ output = dict(metric_depth=out)
196
+ if return_final_centers or return_probs:
197
+ output['bin_centers'] = b_centers
198
+
199
+ if return_probs:
200
+ output['probs'] = x
201
+
202
+ return output
203
+
204
+ def get_lr_params(self, lr):
205
+ """
206
+ Learning rate configuration for different layers of the model
207
+ Args:
208
+ lr (float) : Base learning rate
209
+ Returns:
210
+ list : list of parameters to optimize and their learning rates, in the format required by torch optimizers.
211
+ """
212
+ param_conf = []
213
+ if self.train_midas:
214
+ if self.encoder_lr_factor > 0:
215
+ param_conf.append({'params': self.core.get_enc_params_except_rel_pos(
216
+ ), 'lr': lr / self.encoder_lr_factor})
217
+
218
+ if self.pos_enc_lr_factor > 0:
219
+ param_conf.append(
220
+ {'params': self.core.get_rel_pos_params(), 'lr': lr / self.pos_enc_lr_factor})
221
+
222
+ midas_params = self.core.core.scratch.parameters()
223
+ midas_lr_factor = self.midas_lr_factor
224
+ param_conf.append(
225
+ {'params': midas_params, 'lr': lr / midas_lr_factor})
226
+
227
+ remaining_modules = []
228
+ for name, child in self.named_children():
229
+ if name != 'core':
230
+ remaining_modules.append(child)
231
+ remaining_params = itertools.chain(
232
+ *[child.parameters() for child in remaining_modules])
233
+
234
+ param_conf.append({'params': remaining_params, 'lr': lr})
235
+
236
+ return param_conf
237
+
238
+ @staticmethod
239
+ def build(midas_model_type="DPT_BEiT_L_384", pretrained_resource=None, use_pretrained_midas=False, train_midas=False, freeze_midas_bn=True, **kwargs):
240
+ core = MidasCore.build(midas_model_type=midas_model_type, use_pretrained_midas=use_pretrained_midas,
241
+ train_midas=train_midas, fetch_features=True, freeze_bn=freeze_midas_bn, **kwargs)
242
+ model = ZoeDepth(core, **kwargs)
243
+ if pretrained_resource:
244
+ assert isinstance(pretrained_resource, str), "pretrained_resource must be a string"
245
+ model = load_state_from_resource(model, pretrained_resource)
246
+ return model
247
+
248
+ @staticmethod
249
+ def build_from_config(config):
250
+ return ZoeDepth.build(**config)
zoedepth/models/zoedepth_nk/__init__.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # MIT License
2
+
3
+ # Copyright (c) 2022 Intelligent Systems Lab Org
4
+
5
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ # of this software and associated documentation files (the "Software"), to deal
7
+ # in the Software without restriction, including without limitation the rights
8
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ # copies of the Software, and to permit persons to whom the Software is
10
+ # furnished to do so, subject to the following conditions:
11
+
12
+ # The above copyright notice and this permission notice shall be included in all
13
+ # copies or substantial portions of the Software.
14
+
15
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ # SOFTWARE.
22
+
23
+ # File author: Shariq Farooq Bhat
24
+
25
+ from .zoedepth_nk_v1 import ZoeDepthNK
26
+
27
+ all_versions = {
28
+ "v1": ZoeDepthNK,
29
+ }
30
+
31
+ get_version = lambda v : all_versions[v]
zoedepth/models/zoedepth_nk/config_zoedepth_nk.json ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model": {
3
+ "name": "ZoeDepthNK",
4
+ "version_name": "v1",
5
+ "bin_conf" : [
6
+ {
7
+ "name": "nyu",
8
+ "n_bins": 64,
9
+ "min_depth": 1e-3,
10
+ "max_depth": 10.0
11
+ },
12
+ {
13
+ "name": "kitti",
14
+ "n_bins": 64,
15
+ "min_depth": 1e-3,
16
+ "max_depth": 80.0
17
+ }
18
+ ],
19
+ "bin_embedding_dim": 128,
20
+ "bin_centers_type": "softplus",
21
+ "n_attractors":[16, 8, 4, 1],
22
+ "attractor_alpha": 1000,
23
+ "attractor_gamma": 2,
24
+ "attractor_kind" : "mean",
25
+ "attractor_type" : "inv",
26
+ "min_temp": 0.0212,
27
+ "max_temp": 50.0,
28
+ "memory_efficient": true,
29
+ "midas_model_type" : "DPT_BEiT_L_384",
30
+ "img_size": [384, 512]
31
+ },
32
+
33
+ "train": {
34
+ "train_midas": true,
35
+ "use_pretrained_midas": true,
36
+ "trainer": "zoedepth_nk",
37
+ "epochs": 5,
38
+ "bs": 16,
39
+ "optim_kwargs": {"lr": 0.0002512, "wd": 0.01},
40
+ "sched_kwargs": {"div_factor": 1, "final_div_factor": 10000, "pct_start": 0.7, "three_phase":false, "cycle_momentum": true},
41
+ "same_lr": false,
42
+ "w_si": 1,
43
+ "w_domain": 100,
44
+ "avoid_boundary": false,
45
+ "random_crop": false,
46
+ "input_width": 640,
47
+ "input_height": 480,
48
+ "w_grad": 0,
49
+ "w_reg": 0,
50
+ "midas_lr_factor": 10,
51
+ "encoder_lr_factor":10,
52
+ "pos_enc_lr_factor":10
53
+ },
54
+
55
+ "infer": {
56
+ "train_midas": false,
57
+ "pretrained_resource": "url::https://github.com/isl-org/ZoeDepth/releases/download/v1.0/ZoeD_M12_NK.pt",
58
+ "use_pretrained_midas": false,
59
+ "force_keep_ar": true
60
+ },
61
+
62
+ "eval": {
63
+ "train_midas": false,
64
+ "pretrained_resource": "url::https://github.com/isl-org/ZoeDepth/releases/download/v1.0/ZoeD_M12_NK.pt",
65
+ "use_pretrained_midas": false
66
+ }
67
+ }
zoedepth/models/zoedepth_nk/zoedepth_nk_v1.py ADDED
@@ -0,0 +1,333 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # MIT License
2
+
3
+ # Copyright (c) 2022 Intelligent Systems Lab Org
4
+
5
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ # of this software and associated documentation files (the "Software"), to deal
7
+ # in the Software without restriction, including without limitation the rights
8
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ # copies of the Software, and to permit persons to whom the Software is
10
+ # furnished to do so, subject to the following conditions:
11
+
12
+ # The above copyright notice and this permission notice shall be included in all
13
+ # copies or substantial portions of the Software.
14
+
15
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ # SOFTWARE.
22
+
23
+ # File author: Shariq Farooq Bhat
24
+
25
+ import itertools
26
+
27
+ import torch
28
+ import torch.nn as nn
29
+
30
+ from zoedepth.models.depth_model import DepthModel
31
+ from zoedepth.models.base_models.midas import MidasCore
32
+ from zoedepth.models.layers.attractor import AttractorLayer, AttractorLayerUnnormed
33
+ from zoedepth.models.layers.dist_layers import ConditionalLogBinomial
34
+ from zoedepth.models.layers.localbins_layers import (Projector, SeedBinRegressor,
35
+ SeedBinRegressorUnnormed)
36
+ from zoedepth.models.layers.patch_transformer import PatchTransformerEncoder
37
+ from zoedepth.models.model_io import load_state_from_resource
38
+
39
+
40
+ class ZoeDepthNK(DepthModel):
41
+ def __init__(self, core, bin_conf, bin_centers_type="softplus", bin_embedding_dim=128,
42
+ n_attractors=[16, 8, 4, 1], attractor_alpha=300, attractor_gamma=2, attractor_kind='sum', attractor_type='exp',
43
+ min_temp=5, max_temp=50,
44
+ memory_efficient=False, train_midas=True,
45
+ is_midas_pretrained=True, midas_lr_factor=1, encoder_lr_factor=10, pos_enc_lr_factor=10, inverse_midas=False, **kwargs):
46
+ """ZoeDepthNK model. This is the version of ZoeDepth that has two metric heads and uses a learned router to route to experts.
47
+
48
+ Args:
49
+ core (models.base_models.midas.MidasCore): The base midas model that is used for extraction of "relative" features
50
+
51
+ bin_conf (List[dict]): A list of dictionaries that contain the bin configuration for each metric head. Each dictionary should contain the following keys:
52
+ "name" (str, typically same as the dataset name), "n_bins" (int), "min_depth" (float), "max_depth" (float)
53
+
54
+ The length of this list determines the number of metric heads.
55
+ bin_centers_type (str, optional): "normed" or "softplus". Activation type used for bin centers. For "normed" bin centers, linear normalization trick is applied. This results in bounded bin centers.
56
+ For "softplus", softplus activation is used and thus are unbounded. Defaults to "normed".
57
+ bin_embedding_dim (int, optional): bin embedding dimension. Defaults to 128.
58
+
59
+ n_attractors (List[int], optional): Number of bin attractors at decoder layers. Defaults to [16, 8, 4, 1].
60
+ attractor_alpha (int, optional): Proportional attractor strength. Refer to models.layers.attractor for more details. Defaults to 300.
61
+ attractor_gamma (int, optional): Exponential attractor strength. Refer to models.layers.attractor for more details. Defaults to 2.
62
+ attractor_kind (str, optional): Attraction aggregation "sum" or "mean". Defaults to 'sum'.
63
+ attractor_type (str, optional): Type of attractor to use; "inv" (Inverse attractor) or "exp" (Exponential attractor). Defaults to 'exp'.
64
+
65
+ min_temp (int, optional): Lower bound for temperature of output probability distribution. Defaults to 5.
66
+ max_temp (int, optional): Upper bound for temperature of output probability distribution. Defaults to 50.
67
+
68
+ memory_efficient (bool, optional): Whether to use memory efficient version of attractor layers. Memory efficient version is slower but is recommended incase of multiple metric heads in order save GPU memory. Defaults to False.
69
+
70
+ train_midas (bool, optional): Whether to train "core", the base midas model. Defaults to True.
71
+ is_midas_pretrained (bool, optional): Is "core" pretrained? Defaults to True.
72
+ midas_lr_factor (int, optional): Learning rate reduction factor for base midas model except its encoder and positional encodings. Defaults to 10.
73
+ encoder_lr_factor (int, optional): Learning rate reduction factor for the encoder in midas model. Defaults to 10.
74
+ pos_enc_lr_factor (int, optional): Learning rate reduction factor for positional encodings in the base midas model. Defaults to 10.
75
+
76
+ """
77
+
78
+ super().__init__()
79
+
80
+ self.core = core
81
+ self.bin_conf = bin_conf
82
+ self.min_temp = min_temp
83
+ self.max_temp = max_temp
84
+ self.memory_efficient = memory_efficient
85
+ self.train_midas = train_midas
86
+ self.is_midas_pretrained = is_midas_pretrained
87
+ self.midas_lr_factor = midas_lr_factor
88
+ self.encoder_lr_factor = encoder_lr_factor
89
+ self.pos_enc_lr_factor = pos_enc_lr_factor
90
+ self.inverse_midas = inverse_midas
91
+
92
+ N_MIDAS_OUT = 32
93
+ btlnck_features = self.core.output_channels[0]
94
+ num_out_features = self.core.output_channels[1:]
95
+ # self.scales = [16, 8, 4, 2] # spatial scale factors
96
+
97
+ self.conv2 = nn.Conv2d(
98
+ btlnck_features, btlnck_features, kernel_size=1, stride=1, padding=0)
99
+
100
+ # Transformer classifier on the bottleneck
101
+ self.patch_transformer = PatchTransformerEncoder(
102
+ btlnck_features, 1, 128, use_class_token=True)
103
+ self.mlp_classifier = nn.Sequential(
104
+ nn.Linear(128, 128),
105
+ nn.ReLU(),
106
+ nn.Linear(128, 2)
107
+ )
108
+
109
+ if bin_centers_type == "normed":
110
+ SeedBinRegressorLayer = SeedBinRegressor
111
+ Attractor = AttractorLayer
112
+ elif bin_centers_type == "softplus":
113
+ SeedBinRegressorLayer = SeedBinRegressorUnnormed
114
+ Attractor = AttractorLayerUnnormed
115
+ elif bin_centers_type == "hybrid1":
116
+ SeedBinRegressorLayer = SeedBinRegressor
117
+ Attractor = AttractorLayerUnnormed
118
+ elif bin_centers_type == "hybrid2":
119
+ SeedBinRegressorLayer = SeedBinRegressorUnnormed
120
+ Attractor = AttractorLayer
121
+ else:
122
+ raise ValueError(
123
+ "bin_centers_type should be one of 'normed', 'softplus', 'hybrid1', 'hybrid2'")
124
+ self.bin_centers_type = bin_centers_type
125
+ # We have bins for each bin conf.
126
+ # Create a map (ModuleDict) of 'name' -> seed_bin_regressor
127
+ self.seed_bin_regressors = nn.ModuleDict(
128
+ {conf['name']: SeedBinRegressorLayer(btlnck_features, conf["n_bins"], mlp_dim=bin_embedding_dim//2, min_depth=conf["min_depth"], max_depth=conf["max_depth"])
129
+ for conf in bin_conf}
130
+ )
131
+
132
+ self.seed_projector = Projector(
133
+ btlnck_features, bin_embedding_dim, mlp_dim=bin_embedding_dim//2)
134
+ self.projectors = nn.ModuleList([
135
+ Projector(num_out, bin_embedding_dim, mlp_dim=bin_embedding_dim//2)
136
+ for num_out in num_out_features
137
+ ])
138
+
139
+ # Create a map (ModuleDict) of 'name' -> attractors (ModuleList)
140
+ self.attractors = nn.ModuleDict(
141
+ {conf['name']: nn.ModuleList([
142
+ Attractor(bin_embedding_dim, n_attractors[i],
143
+ mlp_dim=bin_embedding_dim, alpha=attractor_alpha,
144
+ gamma=attractor_gamma, kind=attractor_kind,
145
+ attractor_type=attractor_type, memory_efficient=memory_efficient,
146
+ min_depth=conf["min_depth"], max_depth=conf["max_depth"])
147
+ for i in range(len(n_attractors))
148
+ ])
149
+ for conf in bin_conf}
150
+ )
151
+
152
+ last_in = N_MIDAS_OUT
153
+ # conditional log binomial for each bin conf
154
+ self.conditional_log_binomial = nn.ModuleDict(
155
+ {conf['name']: ConditionalLogBinomial(last_in, bin_embedding_dim, conf['n_bins'], bottleneck_factor=4, min_temp=self.min_temp, max_temp=self.max_temp)
156
+ for conf in bin_conf}
157
+ )
158
+
159
+ def forward(self, x, return_final_centers=False, denorm=False, return_probs=False, **kwargs):
160
+ """
161
+ Args:
162
+ x (torch.Tensor): Input image tensor of shape (B, C, H, W). Assumes all images are from the same domain.
163
+ return_final_centers (bool, optional): Whether to return the final centers of the attractors. Defaults to False.
164
+ denorm (bool, optional): Whether to denormalize the input image. Defaults to False.
165
+ return_probs (bool, optional): Whether to return the probabilities of the bins. Defaults to False.
166
+
167
+ Returns:
168
+ dict: Dictionary of outputs with keys:
169
+ - "rel_depth": Relative depth map of shape (B, 1, H, W)
170
+ - "metric_depth": Metric depth map of shape (B, 1, H, W)
171
+ - "domain_logits": Domain logits of shape (B, 2)
172
+ - "bin_centers": Bin centers of shape (B, N, H, W). Present only if return_final_centers is True
173
+ - "probs": Bin probabilities of shape (B, N, H, W). Present only if return_probs is True
174
+ """
175
+ b, c, h, w = x.shape
176
+ self.orig_input_width = w
177
+ self.orig_input_height = h
178
+ rel_depth, out = self.core(x, denorm=denorm, return_rel_depth=True)
179
+
180
+ outconv_activation = out[0]
181
+ btlnck = out[1]
182
+ x_blocks = out[2:]
183
+
184
+ x_d0 = self.conv2(btlnck)
185
+ x = x_d0
186
+
187
+ # Predict which path to take
188
+ embedding = self.patch_transformer(x)[0] # N, E
189
+ domain_logits = self.mlp_classifier(embedding) # N, 2
190
+ domain_vote = torch.softmax(domain_logits.sum(
191
+ dim=0, keepdim=True), dim=-1) # 1, 2
192
+
193
+ # Get the path
194
+ bin_conf_name = ["nyu", "kitti"][torch.argmax(
195
+ domain_vote, dim=-1).squeeze().item()]
196
+
197
+ try:
198
+ conf = [c for c in self.bin_conf if c.name == bin_conf_name][0]
199
+ except IndexError:
200
+ raise ValueError(
201
+ f"bin_conf_name {bin_conf_name} not found in bin_confs")
202
+
203
+ min_depth = conf['min_depth']
204
+ max_depth = conf['max_depth']
205
+
206
+ seed_bin_regressor = self.seed_bin_regressors[bin_conf_name]
207
+ _, seed_b_centers = seed_bin_regressor(x)
208
+ if self.bin_centers_type == 'normed' or self.bin_centers_type == 'hybrid2':
209
+ b_prev = (seed_b_centers - min_depth)/(max_depth - min_depth)
210
+ else:
211
+ b_prev = seed_b_centers
212
+ prev_b_embedding = self.seed_projector(x)
213
+
214
+ attractors = self.attractors[bin_conf_name]
215
+ for projector, attractor, x in zip(self.projectors, attractors, x_blocks):
216
+ b_embedding = projector(x)
217
+ b, b_centers = attractor(
218
+ b_embedding, b_prev, prev_b_embedding, interpolate=True)
219
+ b_prev = b
220
+ prev_b_embedding = b_embedding
221
+
222
+ last = outconv_activation
223
+
224
+ b_centers = nn.functional.interpolate(
225
+ b_centers, last.shape[-2:], mode='bilinear', align_corners=True)
226
+ b_embedding = nn.functional.interpolate(
227
+ b_embedding, last.shape[-2:], mode='bilinear', align_corners=True)
228
+
229
+ clb = self.conditional_log_binomial[bin_conf_name]
230
+ x = clb(last, b_embedding)
231
+
232
+ # Now depth value is Sum px * cx , where cx are bin_centers from the last bin tensor
233
+ # print(x.shape, b_centers.shape)
234
+ # b_centers = nn.functional.interpolate(b_centers, x.shape[-2:], mode='bilinear', align_corners=True)
235
+ out = torch.sum(x * b_centers, dim=1, keepdim=True)
236
+
237
+ output = dict(domain_logits=domain_logits, metric_depth=out)
238
+ if return_final_centers or return_probs:
239
+ output['bin_centers'] = b_centers
240
+
241
+ if return_probs:
242
+ output['probs'] = x
243
+ return output
244
+
245
+ def get_lr_params(self, lr):
246
+ """
247
+ Learning rate configuration for different layers of the model
248
+
249
+ Args:
250
+ lr (float) : Base learning rate
251
+ Returns:
252
+ list : list of parameters to optimize and their learning rates, in the format required by torch optimizers.
253
+ """
254
+ param_conf = []
255
+ if self.train_midas:
256
+ def get_rel_pos_params():
257
+ for name, p in self.core.core.pretrained.named_parameters():
258
+ if "relative_position" in name:
259
+ yield p
260
+
261
+ def get_enc_params_except_rel_pos():
262
+ for name, p in self.core.core.pretrained.named_parameters():
263
+ if "relative_position" not in name:
264
+ yield p
265
+
266
+ encoder_params = get_enc_params_except_rel_pos()
267
+ rel_pos_params = get_rel_pos_params()
268
+ midas_params = self.core.core.scratch.parameters()
269
+ midas_lr_factor = self.midas_lr_factor if self.is_midas_pretrained else 1.0
270
+ param_conf.extend([
271
+ {'params': encoder_params, 'lr': lr / self.encoder_lr_factor},
272
+ {'params': rel_pos_params, 'lr': lr / self.pos_enc_lr_factor},
273
+ {'params': midas_params, 'lr': lr / midas_lr_factor}
274
+ ])
275
+
276
+ remaining_modules = []
277
+ for name, child in self.named_children():
278
+ if name != 'core':
279
+ remaining_modules.append(child)
280
+ remaining_params = itertools.chain(
281
+ *[child.parameters() for child in remaining_modules])
282
+ param_conf.append({'params': remaining_params, 'lr': lr})
283
+ return param_conf
284
+
285
+ def get_conf_parameters(self, conf_name):
286
+ """
287
+ Returns parameters of all the ModuleDicts children that are exclusively used for the given bin configuration
288
+ """
289
+ params = []
290
+ for name, child in self.named_children():
291
+ if isinstance(child, nn.ModuleDict):
292
+ for bin_conf_name, module in child.items():
293
+ if bin_conf_name == conf_name:
294
+ params += list(module.parameters())
295
+ return params
296
+
297
+ def freeze_conf(self, conf_name):
298
+ """
299
+ Freezes all the parameters of all the ModuleDicts children that are exclusively used for the given bin configuration
300
+ """
301
+ for p in self.get_conf_parameters(conf_name):
302
+ p.requires_grad = False
303
+
304
+ def unfreeze_conf(self, conf_name):
305
+ """
306
+ Unfreezes all the parameters of all the ModuleDicts children that are exclusively used for the given bin configuration
307
+ """
308
+ for p in self.get_conf_parameters(conf_name):
309
+ p.requires_grad = True
310
+
311
+ def freeze_all_confs(self):
312
+ """
313
+ Freezes all the parameters of all the ModuleDicts children
314
+ """
315
+ for name, child in self.named_children():
316
+ if isinstance(child, nn.ModuleDict):
317
+ for bin_conf_name, module in child.items():
318
+ for p in module.parameters():
319
+ p.requires_grad = False
320
+
321
+ @staticmethod
322
+ def build(midas_model_type="DPT_BEiT_L_384", pretrained_resource=None, use_pretrained_midas=False, train_midas=False, freeze_midas_bn=True, **kwargs):
323
+ core = MidasCore.build(midas_model_type=midas_model_type, use_pretrained_midas=use_pretrained_midas,
324
+ train_midas=train_midas, fetch_features=True, freeze_bn=freeze_midas_bn, **kwargs)
325
+ model = ZoeDepthNK(core, **kwargs)
326
+ if pretrained_resource:
327
+ assert isinstance(pretrained_resource, str), "pretrained_resource must be a string"
328
+ model = load_state_from_resource(model, pretrained_resource)
329
+ return model
330
+
331
+ @staticmethod
332
+ def build_from_config(config):
333
+ return ZoeDepthNK.build(**config)
zoedepth/trainers/base_trainer.py ADDED
@@ -0,0 +1,326 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # MIT License
2
+
3
+ # Copyright (c) 2022 Intelligent Systems Lab Org
4
+
5
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ # of this software and associated documentation files (the "Software"), to deal
7
+ # in the Software without restriction, including without limitation the rights
8
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ # copies of the Software, and to permit persons to whom the Software is
10
+ # furnished to do so, subject to the following conditions:
11
+
12
+ # The above copyright notice and this permission notice shall be included in all
13
+ # copies or substantial portions of the Software.
14
+
15
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ # SOFTWARE.
22
+
23
+ # File author: Shariq Farooq Bhat
24
+
25
+ import os
26
+ import uuid
27
+ import warnings
28
+ from datetime import datetime as dt
29
+ from typing import Dict
30
+
31
+ import matplotlib.pyplot as plt
32
+ import numpy as np
33
+ import torch
34
+ import torch.distributed as dist
35
+ import torch.nn as nn
36
+ import torch.optim as optim
37
+ import wandb
38
+ from tqdm import tqdm
39
+
40
+ from zoedepth.utils.config import flatten
41
+ from zoedepth.utils.misc import RunningAverageDict, colorize, colors
42
+
43
+
44
+ def is_rank_zero(args):
45
+ return args.rank == 0
46
+
47
+
48
+ class BaseTrainer:
49
+ def __init__(self, config, model, train_loader, test_loader=None, device=None):
50
+ """ Base Trainer class for training a model."""
51
+
52
+ self.config = config
53
+ self.metric_criterion = "abs_rel"
54
+ if device is None:
55
+ device = torch.device(
56
+ 'cuda') if torch.cuda.is_available() else torch.device('cpu')
57
+ self.device = device
58
+ self.model = model
59
+ self.train_loader = train_loader
60
+ self.test_loader = test_loader
61
+ self.optimizer = self.init_optimizer()
62
+ self.scheduler = self.init_scheduler()
63
+
64
+ def resize_to_target(self, prediction, target):
65
+ if prediction.shape[2:] != target.shape[-2:]:
66
+ prediction = nn.functional.interpolate(
67
+ prediction, size=target.shape[-2:], mode="bilinear", align_corners=True
68
+ )
69
+ return prediction
70
+
71
+ def load_ckpt(self, checkpoint_dir="./checkpoints", ckpt_type="best"):
72
+ import glob
73
+ import os
74
+
75
+ from zoedepth.models.model_io import load_wts
76
+
77
+ if hasattr(self.config, "checkpoint"):
78
+ checkpoint = self.config.checkpoint
79
+ elif hasattr(self.config, "ckpt_pattern"):
80
+ pattern = self.config.ckpt_pattern
81
+ matches = glob.glob(os.path.join(
82
+ checkpoint_dir, f"*{pattern}*{ckpt_type}*"))
83
+ if not (len(matches) > 0):
84
+ raise ValueError(f"No matches found for the pattern {pattern}")
85
+ checkpoint = matches[0]
86
+ else:
87
+ return
88
+ model = load_wts(self.model, checkpoint)
89
+ # TODO : Resuming training is not properly supported in this repo. Implement loading / saving of optimizer and scheduler to support it.
90
+ print("Loaded weights from {0}".format(checkpoint))
91
+ warnings.warn(
92
+ "Resuming training is not properly supported in this repo. Implement loading / saving of optimizer and scheduler to support it.")
93
+ self.model = model
94
+
95
+ def init_optimizer(self):
96
+ m = self.model.module if self.config.multigpu else self.model
97
+
98
+ if self.config.same_lr:
99
+ print("Using same LR")
100
+ if hasattr(m, 'core'):
101
+ m.core.unfreeze()
102
+ params = self.model.parameters()
103
+ else:
104
+ print("Using diff LR")
105
+ if not hasattr(m, 'get_lr_params'):
106
+ raise NotImplementedError(
107
+ f"Model {m.__class__.__name__} does not implement get_lr_params. Please implement it or use the same LR for all parameters.")
108
+
109
+ params = m.get_lr_params(self.config.lr)
110
+
111
+ return optim.AdamW(params, lr=self.config.lr, weight_decay=self.config.wd)
112
+
113
+ def init_scheduler(self):
114
+ lrs = [l['lr'] for l in self.optimizer.param_groups]
115
+ return optim.lr_scheduler.OneCycleLR(self.optimizer, lrs, epochs=self.config.epochs, steps_per_epoch=len(self.train_loader),
116
+ cycle_momentum=self.config.cycle_momentum,
117
+ base_momentum=0.85, max_momentum=0.95, div_factor=self.config.div_factor, final_div_factor=self.config.final_div_factor, pct_start=self.config.pct_start, three_phase=self.config.three_phase)
118
+
119
+ def train_on_batch(self, batch, train_step):
120
+ raise NotImplementedError
121
+
122
+ def validate_on_batch(self, batch, val_step):
123
+ raise NotImplementedError
124
+
125
+ def raise_if_nan(self, losses):
126
+ for key, value in losses.items():
127
+ if torch.isnan(value):
128
+ raise ValueError(f"{key} is NaN, Stopping training")
129
+
130
+ @property
131
+ def iters_per_epoch(self):
132
+ return len(self.train_loader)
133
+
134
+ @property
135
+ def total_iters(self):
136
+ return self.config.epochs * self.iters_per_epoch
137
+
138
+ def should_early_stop(self):
139
+ if self.config.get('early_stop', False) and self.step > self.config.early_stop:
140
+ return True
141
+
142
+ def train(self):
143
+ print(f"Training {self.config.name}")
144
+ if self.config.uid is None:
145
+ self.config.uid = str(uuid.uuid4()).split('-')[-1]
146
+ run_id = f"{dt.now().strftime('%d-%h_%H-%M')}-{self.config.uid}"
147
+ self.config.run_id = run_id
148
+ self.config.experiment_id = f"{self.config.name}{self.config.version_name}_{run_id}"
149
+ self.should_write = ((not self.config.distributed)
150
+ or self.config.rank == 0)
151
+ self.should_log = self.should_write # and logging
152
+ if self.should_log:
153
+ tags = self.config.tags.split(
154
+ ',') if self.config.tags != '' else None
155
+ wandb.init(project=self.config.project, name=self.config.experiment_id, config=flatten(self.config), dir=self.config.root,
156
+ tags=tags, notes=self.config.notes, settings=wandb.Settings(start_method="fork"))
157
+
158
+ self.model.train()
159
+ self.step = 0
160
+ best_loss = np.inf
161
+ validate_every = int(self.config.validate_every * self.iters_per_epoch)
162
+
163
+
164
+ if self.config.prefetch:
165
+
166
+ for i, batch in tqdm(enumerate(self.train_loader), desc=f"Prefetching...",
167
+ total=self.iters_per_epoch) if is_rank_zero(self.config) else enumerate(self.train_loader):
168
+ pass
169
+
170
+ losses = {}
171
+ def stringify_losses(L): return "; ".join(map(
172
+ lambda kv: f"{colors.fg.purple}{kv[0]}{colors.reset}: {round(kv[1].item(),3):.4e}", L.items()))
173
+ for epoch in range(self.config.epochs):
174
+ if self.should_early_stop():
175
+ break
176
+
177
+ self.epoch = epoch
178
+ ################################# Train loop ##########################################################
179
+ if self.should_log:
180
+ wandb.log({"Epoch": epoch}, step=self.step)
181
+ pbar = tqdm(enumerate(self.train_loader), desc=f"Epoch: {epoch + 1}/{self.config.epochs}. Loop: Train",
182
+ total=self.iters_per_epoch) if is_rank_zero(self.config) else enumerate(self.train_loader)
183
+ for i, batch in pbar:
184
+ if self.should_early_stop():
185
+ print("Early stopping")
186
+ break
187
+ # print(f"Batch {self.step+1} on rank {self.config.rank}")
188
+ losses = self.train_on_batch(batch, i)
189
+ # print(f"trained batch {self.step+1} on rank {self.config.rank}")
190
+
191
+ self.raise_if_nan(losses)
192
+ if is_rank_zero(self.config) and self.config.print_losses:
193
+ pbar.set_description(
194
+ f"Epoch: {epoch + 1}/{self.config.epochs}. Loop: Train. Losses: {stringify_losses(losses)}")
195
+ self.scheduler.step()
196
+
197
+ if self.should_log and self.step % 50 == 0:
198
+ wandb.log({f"Train/{name}": loss.item()
199
+ for name, loss in losses.items()}, step=self.step)
200
+
201
+ self.step += 1
202
+
203
+ ########################################################################################################
204
+
205
+ if self.test_loader:
206
+ if (self.step % validate_every) == 0:
207
+ self.model.eval()
208
+ if self.should_write:
209
+ self.save_checkpoint(
210
+ f"{self.config.experiment_id}_latest.pt")
211
+
212
+ ################################# Validation loop ##################################################
213
+ # validate on the entire validation set in every process but save only from rank 0, I know, inefficient, but avoids divergence of processes
214
+ metrics, test_losses = self.validate()
215
+ # print("Validated: {}".format(metrics))
216
+ if self.should_log:
217
+ wandb.log(
218
+ {f"Test/{name}": tloss for name, tloss in test_losses.items()}, step=self.step)
219
+
220
+ wandb.log({f"Metrics/{k}": v for k,
221
+ v in metrics.items()}, step=self.step)
222
+
223
+ if (metrics[self.metric_criterion] < best_loss) and self.should_write:
224
+ self.save_checkpoint(
225
+ f"{self.config.experiment_id}_best.pt")
226
+ best_loss = metrics[self.metric_criterion]
227
+
228
+ self.model.train()
229
+
230
+ if self.config.distributed:
231
+ dist.barrier()
232
+ # print(f"Validated: {metrics} on device {self.config.rank}")
233
+
234
+ # print(f"Finished step {self.step} on device {self.config.rank}")
235
+ #################################################################################################
236
+
237
+ # Save / validate at the end
238
+ self.step += 1 # log as final point
239
+ self.model.eval()
240
+ self.save_checkpoint(f"{self.config.experiment_id}_latest.pt")
241
+ if self.test_loader:
242
+
243
+ ################################# Validation loop ##################################################
244
+ metrics, test_losses = self.validate()
245
+ # print("Validated: {}".format(metrics))
246
+ if self.should_log:
247
+ wandb.log({f"Test/{name}": tloss for name,
248
+ tloss in test_losses.items()}, step=self.step)
249
+ wandb.log({f"Metrics/{k}": v for k,
250
+ v in metrics.items()}, step=self.step)
251
+
252
+ if (metrics[self.metric_criterion] < best_loss) and self.should_write:
253
+ self.save_checkpoint(
254
+ f"{self.config.experiment_id}_best.pt")
255
+ best_loss = metrics[self.metric_criterion]
256
+
257
+ self.model.train()
258
+
259
+ def validate(self):
260
+ with torch.no_grad():
261
+ losses_avg = RunningAverageDict()
262
+ metrics_avg = RunningAverageDict()
263
+ for i, batch in tqdm(enumerate(self.test_loader), desc=f"Epoch: {self.epoch + 1}/{self.config.epochs}. Loop: Validation", total=len(self.test_loader), disable=not is_rank_zero(self.config)):
264
+ metrics, losses = self.validate_on_batch(batch, val_step=i)
265
+
266
+ if losses:
267
+ losses_avg.update(losses)
268
+ if metrics:
269
+ metrics_avg.update(metrics)
270
+
271
+ return metrics_avg.get_value(), losses_avg.get_value()
272
+
273
+ def save_checkpoint(self, filename):
274
+ if not self.should_write:
275
+ return
276
+ root = self.config.save_dir
277
+ if not os.path.isdir(root):
278
+ os.makedirs(root)
279
+
280
+ fpath = os.path.join(root, filename)
281
+ m = self.model.module if self.config.multigpu else self.model
282
+ torch.save(
283
+ {
284
+ "model": m.state_dict(),
285
+ "optimizer": None, # TODO : Change to self.optimizer.state_dict() if resume support is needed, currently None to reduce file size
286
+ "epoch": self.epoch
287
+ }, fpath)
288
+
289
+ def log_images(self, rgb: Dict[str, list] = {}, depth: Dict[str, list] = {}, scalar_field: Dict[str, list] = {}, prefix="", scalar_cmap="jet", min_depth=None, max_depth=None):
290
+ if not self.should_log:
291
+ return
292
+
293
+ if min_depth is None:
294
+ try:
295
+ min_depth = self.config.min_depth
296
+ max_depth = self.config.max_depth
297
+ except AttributeError:
298
+ min_depth = None
299
+ max_depth = None
300
+
301
+ depth = {k: colorize(v, vmin=min_depth, vmax=max_depth)
302
+ for k, v in depth.items()}
303
+ scalar_field = {k: colorize(
304
+ v, vmin=None, vmax=None, cmap=scalar_cmap) for k, v in scalar_field.items()}
305
+ images = {**rgb, **depth, **scalar_field}
306
+ wimages = {
307
+ prefix+"Predictions": [wandb.Image(v, caption=k) for k, v in images.items()]}
308
+ wandb.log(wimages, step=self.step)
309
+
310
+ def log_line_plot(self, data):
311
+ if not self.should_log:
312
+ return
313
+
314
+ plt.plot(data)
315
+ plt.ylabel("Scale factors")
316
+ wandb.log({"Scale factors": wandb.Image(plt)}, step=self.step)
317
+ plt.close()
318
+
319
+ def log_bar_plot(self, title, labels, values):
320
+ if not self.should_log:
321
+ return
322
+
323
+ data = [[label, val] for (label, val) in zip(labels, values)]
324
+ table = wandb.Table(data=data, columns=["label", "value"])
325
+ wandb.log({title: wandb.plot.bar(table, "label",
326
+ "value", title=title)}, step=self.step)
zoedepth/trainers/builder.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # MIT License
2
+
3
+ # Copyright (c) 2022 Intelligent Systems Lab Org
4
+
5
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ # of this software and associated documentation files (the "Software"), to deal
7
+ # in the Software without restriction, including without limitation the rights
8
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ # copies of the Software, and to permit persons to whom the Software is
10
+ # furnished to do so, subject to the following conditions:
11
+
12
+ # The above copyright notice and this permission notice shall be included in all
13
+ # copies or substantial portions of the Software.
14
+
15
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ # SOFTWARE.
22
+
23
+ # File author: Shariq Farooq Bhat
24
+
25
+ from importlib import import_module
26
+
27
+
28
+ def get_trainer(config):
29
+ """Builds and returns a trainer based on the config.
30
+
31
+ Args:
32
+ config (dict): the config dict (typically constructed using utils.config.get_config)
33
+ config.trainer (str): the name of the trainer to use. The module named "{config.trainer}_trainer" must exist in trainers root module
34
+
35
+ Raises:
36
+ ValueError: If the specified trainer does not exist under trainers/ folder
37
+
38
+ Returns:
39
+ Trainer (inherited from zoedepth.trainers.BaseTrainer): The Trainer object
40
+ """
41
+ assert "trainer" in config and config.trainer is not None and config.trainer != '', "Trainer not specified. Config: {0}".format(
42
+ config)
43
+ try:
44
+ Trainer = getattr(import_module(
45
+ f"zoedepth.trainers.{config.trainer}_trainer"), 'Trainer')
46
+ except ModuleNotFoundError as e:
47
+ raise ValueError(f"Trainer {config.trainer}_trainer not found.") from e
48
+ return Trainer
zoedepth/trainers/loss.py ADDED
@@ -0,0 +1,316 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # MIT License
2
+
3
+ # Copyright (c) 2022 Intelligent Systems Lab Org
4
+
5
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ # of this software and associated documentation files (the "Software"), to deal
7
+ # in the Software without restriction, including without limitation the rights
8
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ # copies of the Software, and to permit persons to whom the Software is
10
+ # furnished to do so, subject to the following conditions:
11
+
12
+ # The above copyright notice and this permission notice shall be included in all
13
+ # copies or substantial portions of the Software.
14
+
15
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ # SOFTWARE.
22
+
23
+ # File author: Shariq Farooq Bhat
24
+
25
+ import torch
26
+ import torch.nn as nn
27
+ import torch.nn.functional as F
28
+ import torch.cuda.amp as amp
29
+ import numpy as np
30
+
31
+
32
+ KEY_OUTPUT = 'metric_depth'
33
+
34
+
35
+ def extract_key(prediction, key):
36
+ if isinstance(prediction, dict):
37
+ return prediction[key]
38
+ return prediction
39
+
40
+
41
+ # Main loss function used for ZoeDepth. Copy/paste from AdaBins repo (https://github.com/shariqfarooq123/AdaBins/blob/0952d91e9e762be310bb4cd055cbfe2448c0ce20/loss.py#L7)
42
+ class SILogLoss(nn.Module):
43
+ """SILog loss (pixel-wise)"""
44
+ def __init__(self, beta=0.15):
45
+ super(SILogLoss, self).__init__()
46
+ self.name = 'SILog'
47
+ self.beta = beta
48
+
49
+ def forward(self, input, target, mask=None, interpolate=True, return_interpolated=False):
50
+ input = extract_key(input, KEY_OUTPUT)
51
+ if input.shape[-1] != target.shape[-1] and interpolate:
52
+ input = nn.functional.interpolate(
53
+ input, target.shape[-2:], mode='bilinear', align_corners=True)
54
+ intr_input = input
55
+ else:
56
+ intr_input = input
57
+
58
+ if target.ndim == 3:
59
+ target = target.unsqueeze(1)
60
+
61
+ if mask is not None:
62
+ if mask.ndim == 3:
63
+ mask = mask.unsqueeze(1)
64
+
65
+ input = input[mask]
66
+ target = target[mask]
67
+
68
+ with amp.autocast(enabled=False): # amp causes NaNs in this loss function
69
+ alpha = 1e-7
70
+ g = torch.log(input + alpha) - torch.log(target + alpha)
71
+
72
+ # n, c, h, w = g.shape
73
+ # norm = 1/(h*w)
74
+ # Dg = norm * torch.sum(g**2) - (0.85/(norm**2)) * (torch.sum(g))**2
75
+
76
+ Dg = torch.var(g) + self.beta * torch.pow(torch.mean(g), 2)
77
+
78
+ loss = 10 * torch.sqrt(Dg)
79
+
80
+ if torch.isnan(loss):
81
+ print("Nan SILog loss")
82
+ print("input:", input.shape)
83
+ print("target:", target.shape)
84
+ print("G", torch.sum(torch.isnan(g)))
85
+ print("Input min max", torch.min(input), torch.max(input))
86
+ print("Target min max", torch.min(target), torch.max(target))
87
+ print("Dg", torch.isnan(Dg))
88
+ print("loss", torch.isnan(loss))
89
+
90
+ if not return_interpolated:
91
+ return loss
92
+
93
+ return loss, intr_input
94
+
95
+
96
+ def grad(x):
97
+ # x.shape : n, c, h, w
98
+ diff_x = x[..., 1:, 1:] - x[..., 1:, :-1]
99
+ diff_y = x[..., 1:, 1:] - x[..., :-1, 1:]
100
+ mag = diff_x**2 + diff_y**2
101
+ # angle_ratio
102
+ angle = torch.atan(diff_y / (diff_x + 1e-10))
103
+ return mag, angle
104
+
105
+
106
+ def grad_mask(mask):
107
+ return mask[..., 1:, 1:] & mask[..., 1:, :-1] & mask[..., :-1, 1:]
108
+
109
+
110
+ class GradL1Loss(nn.Module):
111
+ """Gradient loss"""
112
+ def __init__(self):
113
+ super(GradL1Loss, self).__init__()
114
+ self.name = 'GradL1'
115
+
116
+ def forward(self, input, target, mask=None, interpolate=True, return_interpolated=False):
117
+ input = extract_key(input, KEY_OUTPUT)
118
+ if input.shape[-1] != target.shape[-1] and interpolate:
119
+ input = nn.functional.interpolate(
120
+ input, target.shape[-2:], mode='bilinear', align_corners=True)
121
+ intr_input = input
122
+ else:
123
+ intr_input = input
124
+
125
+ grad_gt = grad(target)
126
+ grad_pred = grad(input)
127
+ mask_g = grad_mask(mask)
128
+
129
+ loss = nn.functional.l1_loss(grad_pred[0][mask_g], grad_gt[0][mask_g])
130
+ loss = loss + \
131
+ nn.functional.l1_loss(grad_pred[1][mask_g], grad_gt[1][mask_g])
132
+ if not return_interpolated:
133
+ return loss
134
+ return loss, intr_input
135
+
136
+
137
+ class OrdinalRegressionLoss(object):
138
+
139
+ def __init__(self, ord_num, beta, discretization="SID"):
140
+ self.ord_num = ord_num
141
+ self.beta = beta
142
+ self.discretization = discretization
143
+
144
+ def _create_ord_label(self, gt):
145
+ N,one, H, W = gt.shape
146
+ # print("gt shape:", gt.shape)
147
+
148
+ ord_c0 = torch.ones(N, self.ord_num, H, W).to(gt.device)
149
+ if self.discretization == "SID":
150
+ label = self.ord_num * torch.log(gt) / np.log(self.beta)
151
+ else:
152
+ label = self.ord_num * (gt - 1.0) / (self.beta - 1.0)
153
+ label = label.long()
154
+ mask = torch.linspace(0, self.ord_num - 1, self.ord_num, requires_grad=False) \
155
+ .view(1, self.ord_num, 1, 1).to(gt.device)
156
+ mask = mask.repeat(N, 1, H, W).contiguous().long()
157
+ mask = (mask > label)
158
+ ord_c0[mask] = 0
159
+ ord_c1 = 1 - ord_c0
160
+ # implementation according to the paper.
161
+ # ord_label = torch.ones(N, self.ord_num * 2, H, W).to(gt.device)
162
+ # ord_label[:, 0::2, :, :] = ord_c0
163
+ # ord_label[:, 1::2, :, :] = ord_c1
164
+ # reimplementation for fast speed.
165
+ ord_label = torch.cat((ord_c0, ord_c1), dim=1)
166
+ return ord_label, mask
167
+
168
+ def __call__(self, prob, gt):
169
+ """
170
+ :param prob: ordinal regression probability, N x 2*Ord Num x H x W, torch.Tensor
171
+ :param gt: depth ground truth, NXHxW, torch.Tensor
172
+ :return: loss: loss value, torch.float
173
+ """
174
+ # N, C, H, W = prob.shape
175
+ valid_mask = gt > 0.
176
+ ord_label, mask = self._create_ord_label(gt)
177
+ # print("prob shape: {}, ord label shape: {}".format(prob.shape, ord_label.shape))
178
+ entropy = -prob * ord_label
179
+ loss = torch.sum(entropy, dim=1)[valid_mask.squeeze(1)]
180
+ return loss.mean()
181
+
182
+
183
+ class DiscreteNLLLoss(nn.Module):
184
+ """Cross entropy loss"""
185
+ def __init__(self, min_depth=1e-3, max_depth=10, depth_bins=64):
186
+ super(DiscreteNLLLoss, self).__init__()
187
+ self.name = 'CrossEntropy'
188
+ self.ignore_index = -(depth_bins + 1)
189
+ # self._loss_func = nn.NLLLoss(ignore_index=self.ignore_index)
190
+ self._loss_func = nn.CrossEntropyLoss(ignore_index=self.ignore_index)
191
+ self.min_depth = min_depth
192
+ self.max_depth = max_depth
193
+ self.depth_bins = depth_bins
194
+ self.alpha = 1
195
+ self.zeta = 1 - min_depth
196
+ self.beta = max_depth + self.zeta
197
+
198
+ def quantize_depth(self, depth):
199
+ # depth : N1HW
200
+ # output : NCHW
201
+
202
+ # Quantize depth log-uniformly on [1, self.beta] into self.depth_bins bins
203
+ depth = torch.log(depth / self.alpha) / np.log(self.beta / self.alpha)
204
+ depth = depth * (self.depth_bins - 1)
205
+ depth = torch.round(depth)
206
+ depth = depth.long()
207
+ return depth
208
+
209
+
210
+
211
+ def _dequantize_depth(self, depth):
212
+ """
213
+ Inverse of quantization
214
+ depth : NCHW -> N1HW
215
+ """
216
+ # Get the center of the bin
217
+
218
+
219
+
220
+
221
+ def forward(self, input, target, mask=None, interpolate=True, return_interpolated=False):
222
+ input = extract_key(input, KEY_OUTPUT)
223
+ # assert torch.all(input <= 0), "Input should be negative"
224
+
225
+ if input.shape[-1] != target.shape[-1] and interpolate:
226
+ input = nn.functional.interpolate(
227
+ input, target.shape[-2:], mode='bilinear', align_corners=True)
228
+ intr_input = input
229
+ else:
230
+ intr_input = input
231
+
232
+ # assert torch.all(input)<=1)
233
+ if target.ndim == 3:
234
+ target = target.unsqueeze(1)
235
+
236
+ target = self.quantize_depth(target)
237
+ if mask is not None:
238
+ if mask.ndim == 3:
239
+ mask = mask.unsqueeze(1)
240
+
241
+ # Set the mask to ignore_index
242
+ mask = mask.long()
243
+ input = input * mask + (1 - mask) * self.ignore_index
244
+ target = target * mask + (1 - mask) * self.ignore_index
245
+
246
+
247
+
248
+ input = input.flatten(2) # N, nbins, H*W
249
+ target = target.flatten(1) # N, H*W
250
+ loss = self._loss_func(input, target)
251
+
252
+ if not return_interpolated:
253
+ return loss
254
+ return loss, intr_input
255
+
256
+
257
+
258
+
259
+ def compute_scale_and_shift(prediction, target, mask):
260
+ # system matrix: A = [[a_00, a_01], [a_10, a_11]]
261
+ a_00 = torch.sum(mask * prediction * prediction, (1, 2))
262
+ a_01 = torch.sum(mask * prediction, (1, 2))
263
+ a_11 = torch.sum(mask, (1, 2))
264
+
265
+ # right hand side: b = [b_0, b_1]
266
+ b_0 = torch.sum(mask * prediction * target, (1, 2))
267
+ b_1 = torch.sum(mask * target, (1, 2))
268
+
269
+ # solution: x = A^-1 . b = [[a_11, -a_01], [-a_10, a_00]] / (a_00 * a_11 - a_01 * a_10) . b
270
+ x_0 = torch.zeros_like(b_0)
271
+ x_1 = torch.zeros_like(b_1)
272
+
273
+ det = a_00 * a_11 - a_01 * a_01
274
+ # A needs to be a positive definite matrix.
275
+ valid = det > 0
276
+
277
+ x_0[valid] = (a_11[valid] * b_0[valid] - a_01[valid] * b_1[valid]) / det[valid]
278
+ x_1[valid] = (-a_01[valid] * b_0[valid] + a_00[valid] * b_1[valid]) / det[valid]
279
+
280
+ return x_0, x_1
281
+ class ScaleAndShiftInvariantLoss(nn.Module):
282
+ def __init__(self):
283
+ super().__init__()
284
+ self.name = "SSILoss"
285
+
286
+ def forward(self, prediction, target, mask, interpolate=True, return_interpolated=False):
287
+
288
+ if prediction.shape[-1] != target.shape[-1] and interpolate:
289
+ prediction = nn.functional.interpolate(prediction, target.shape[-2:], mode='bilinear', align_corners=True)
290
+ intr_input = prediction
291
+ else:
292
+ intr_input = prediction
293
+
294
+
295
+ prediction, target, mask = prediction.squeeze(), target.squeeze(), mask.squeeze()
296
+ assert prediction.shape == target.shape, f"Shape mismatch: Expected same shape but got {prediction.shape} and {target.shape}."
297
+
298
+ scale, shift = compute_scale_and_shift(prediction, target, mask)
299
+
300
+ scaled_prediction = scale.view(-1, 1, 1) * prediction + shift.view(-1, 1, 1)
301
+
302
+ loss = nn.functional.l1_loss(scaled_prediction[mask], target[mask])
303
+ if not return_interpolated:
304
+ return loss
305
+ return loss, intr_input
306
+
307
+
308
+
309
+
310
+ if __name__ == '__main__':
311
+ # Tests for DiscreteNLLLoss
312
+ celoss = DiscreteNLLLoss()
313
+ print(celoss(torch.rand(4, 64, 26, 32)*10, torch.rand(4, 1, 26, 32)*10, ))
314
+
315
+ d = torch.Tensor([6.59, 3.8, 10.0])
316
+ print(celoss.dequantize_depth(celoss.quantize_depth(d)))
zoedepth/trainers/zoedepth_nk_trainer.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # MIT License
2
+
3
+ # Copyright (c) 2022 Intelligent Systems Lab Org
4
+
5
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ # of this software and associated documentation files (the "Software"), to deal
7
+ # in the Software without restriction, including without limitation the rights
8
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ # copies of the Software, and to permit persons to whom the Software is
10
+ # furnished to do so, subject to the following conditions:
11
+
12
+ # The above copyright notice and this permission notice shall be included in all
13
+ # copies or substantial portions of the Software.
14
+
15
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ # SOFTWARE.
22
+
23
+ # File author: Shariq Farooq Bhat
24
+
25
+ import torch
26
+ import torch.cuda.amp as amp
27
+ import torch.nn as nn
28
+
29
+ from zoedepth.trainers.loss import GradL1Loss, SILogLoss
30
+ from zoedepth.utils.config import DATASETS_CONFIG
31
+ from zoedepth.utils.misc import compute_metrics
32
+
33
+ from .base_trainer import BaseTrainer
34
+
35
+
36
+ class Trainer(BaseTrainer):
37
+ def __init__(self, config, model, train_loader, test_loader=None, device=None):
38
+ super().__init__(config, model, train_loader,
39
+ test_loader=test_loader, device=device)
40
+ self.device = device
41
+ self.silog_loss = SILogLoss()
42
+ self.grad_loss = GradL1Loss()
43
+ self.domain_classifier_loss = nn.CrossEntropyLoss()
44
+
45
+ self.scaler = amp.GradScaler(enabled=self.config.use_amp)
46
+
47
+ def train_on_batch(self, batch, train_step):
48
+ """
49
+ Expects a batch of images and depth as input
50
+ batch["image"].shape : batch_size, c, h, w
51
+ batch["depth"].shape : batch_size, 1, h, w
52
+
53
+ Assumes all images in a batch are from the same dataset
54
+ """
55
+
56
+ images, depths_gt = batch['image'].to(
57
+ self.device), batch['depth'].to(self.device)
58
+ # batch['dataset'] is a tensor strings all valued either 'nyu' or 'kitti'. labels nyu -> 0, kitti -> 1
59
+ dataset = batch['dataset'][0]
60
+ # Convert to 0s or 1s
61
+ domain_labels = torch.Tensor([dataset == 'kitti' for _ in range(
62
+ images.size(0))]).to(torch.long).to(self.device)
63
+
64
+ # m = self.model.module if self.config.multigpu else self.model
65
+
66
+ b, c, h, w = images.size()
67
+ mask = batch["mask"].to(self.device).to(torch.bool)
68
+
69
+ losses = {}
70
+
71
+ with amp.autocast(enabled=self.config.use_amp):
72
+ output = self.model(images)
73
+ pred_depths = output['metric_depth']
74
+ domain_logits = output['domain_logits']
75
+
76
+ l_si, pred = self.silog_loss(
77
+ pred_depths, depths_gt, mask=mask, interpolate=True, return_interpolated=True)
78
+ loss = self.config.w_si * l_si
79
+ losses[self.silog_loss.name] = l_si
80
+
81
+ if self.config.w_grad > 0:
82
+ l_grad = self.grad_loss(pred, depths_gt, mask=mask)
83
+ loss = loss + self.config.w_grad * l_grad
84
+ losses[self.grad_loss.name] = l_grad
85
+ else:
86
+ l_grad = torch.Tensor([0])
87
+
88
+ if self.config.w_domain > 0:
89
+ l_domain = self.domain_classifier_loss(
90
+ domain_logits, domain_labels)
91
+ loss = loss + self.config.w_domain * l_domain
92
+ losses["DomainLoss"] = l_domain
93
+ else:
94
+ l_domain = torch.Tensor([0.])
95
+
96
+ self.scaler.scale(loss).backward()
97
+
98
+ if self.config.clip_grad > 0:
99
+ self.scaler.unscale_(self.optimizer)
100
+ nn.utils.clip_grad_norm_(
101
+ self.model.parameters(), self.config.clip_grad)
102
+
103
+ self.scaler.step(self.optimizer)
104
+
105
+ if self.should_log and self.step > 1 and (self.step % int(self.config.log_images_every * self.iters_per_epoch)) == 0:
106
+ depths_gt[torch.logical_not(mask)] = -99
107
+ self.log_images(rgb={"Input": images[0, ...]}, depth={"GT": depths_gt[0], "PredictedMono": pred[0]}, prefix="Train",
108
+ min_depth=DATASETS_CONFIG[dataset]['min_depth'], max_depth=DATASETS_CONFIG[dataset]['max_depth'])
109
+
110
+ self.scaler.update()
111
+ self.optimizer.zero_grad(set_to_none=True)
112
+
113
+ return losses
114
+
115
+ def validate_on_batch(self, batch, val_step):
116
+ images = batch['image'].to(self.device)
117
+ depths_gt = batch['depth'].to(self.device)
118
+ dataset = batch['dataset'][0]
119
+ if 'has_valid_depth' in batch:
120
+ if not batch['has_valid_depth']:
121
+ return None, None
122
+
123
+ depths_gt = depths_gt.squeeze().unsqueeze(0).unsqueeze(0)
124
+ with amp.autocast(enabled=self.config.use_amp):
125
+ m = self.model.module if self.config.multigpu else self.model
126
+ pred_depths = m(images)["metric_depth"]
127
+ pred_depths = pred_depths.squeeze().unsqueeze(0).unsqueeze(0)
128
+
129
+ mask = torch.logical_and(
130
+ depths_gt > self.config.min_depth, depths_gt < self.config.max_depth)
131
+ with amp.autocast(enabled=self.config.use_amp):
132
+ l_depth = self.silog_loss(
133
+ pred_depths, depths_gt, mask=mask.to(torch.bool), interpolate=True)
134
+
135
+ metrics = compute_metrics(depths_gt, pred_depths, **self.config)
136
+ losses = {f"{self.silog_loss.name}": l_depth.item()}
137
+
138
+ if val_step == 1 and self.should_log:
139
+ depths_gt[torch.logical_not(mask)] = -99
140
+ self.log_images(rgb={"Input": images[0]}, depth={"GT": depths_gt[0], "PredictedMono": pred_depths[0]}, prefix="Test",
141
+ min_depth=DATASETS_CONFIG[dataset]['min_depth'], max_depth=DATASETS_CONFIG[dataset]['max_depth'])
142
+
143
+ return metrics, losses
zoedepth/trainers/zoedepth_trainer.py ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # MIT License
2
+
3
+ # Copyright (c) 2022 Intelligent Systems Lab Org
4
+
5
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ # of this software and associated documentation files (the "Software"), to deal
7
+ # in the Software without restriction, including without limitation the rights
8
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ # copies of the Software, and to permit persons to whom the Software is
10
+ # furnished to do so, subject to the following conditions:
11
+
12
+ # The above copyright notice and this permission notice shall be included in all
13
+ # copies or substantial portions of the Software.
14
+
15
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ # SOFTWARE.
22
+
23
+ # File author: Shariq Farooq Bhat
24
+
25
+ import torch
26
+ import torch.cuda.amp as amp
27
+ import torch.nn as nn
28
+
29
+ from zoedepth.trainers.loss import GradL1Loss, SILogLoss
30
+ from zoedepth.utils.config import DATASETS_CONFIG
31
+ from zoedepth.utils.misc import compute_metrics
32
+ from zoedepth.data.preprocess import get_black_border
33
+
34
+ from .base_trainer import BaseTrainer
35
+ from torchvision import transforms
36
+ from PIL import Image
37
+ import numpy as np
38
+
39
+ class Trainer(BaseTrainer):
40
+ def __init__(self, config, model, train_loader, test_loader=None, device=None):
41
+ super().__init__(config, model, train_loader,
42
+ test_loader=test_loader, device=device)
43
+ self.device = device
44
+ self.silog_loss = SILogLoss()
45
+ self.grad_loss = GradL1Loss()
46
+ self.scaler = amp.GradScaler(enabled=self.config.use_amp)
47
+
48
+ def train_on_batch(self, batch, train_step):
49
+ """
50
+ Expects a batch of images and depth as input
51
+ batch["image"].shape : batch_size, c, h, w
52
+ batch["depth"].shape : batch_size, 1, h, w
53
+ """
54
+
55
+ images, depths_gt = batch['image'].to(
56
+ self.device), batch['depth'].to(self.device)
57
+ dataset = batch['dataset'][0]
58
+
59
+ b, c, h, w = images.size()
60
+ mask = batch["mask"].to(self.device).to(torch.bool)
61
+
62
+ losses = {}
63
+
64
+ with amp.autocast(enabled=self.config.use_amp):
65
+
66
+ output = self.model(images)
67
+ pred_depths = output['metric_depth']
68
+
69
+ l_si, pred = self.silog_loss(
70
+ pred_depths, depths_gt, mask=mask, interpolate=True, return_interpolated=True)
71
+ loss = self.config.w_si * l_si
72
+ losses[self.silog_loss.name] = l_si
73
+
74
+ if self.config.w_grad > 0:
75
+ l_grad = self.grad_loss(pred, depths_gt, mask=mask)
76
+ loss = loss + self.config.w_grad * l_grad
77
+ losses[self.grad_loss.name] = l_grad
78
+ else:
79
+ l_grad = torch.Tensor([0])
80
+
81
+ self.scaler.scale(loss).backward()
82
+
83
+ if self.config.clip_grad > 0:
84
+ self.scaler.unscale_(self.optimizer)
85
+ nn.utils.clip_grad_norm_(
86
+ self.model.parameters(), self.config.clip_grad)
87
+
88
+ self.scaler.step(self.optimizer)
89
+
90
+ if self.should_log and (self.step % int(self.config.log_images_every * self.iters_per_epoch)) == 0:
91
+ # -99 is treated as invalid depth in the log_images function and is colored grey.
92
+ depths_gt[torch.logical_not(mask)] = -99
93
+
94
+ self.log_images(rgb={"Input": images[0, ...]}, depth={"GT": depths_gt[0], "PredictedMono": pred[0]}, prefix="Train",
95
+ min_depth=DATASETS_CONFIG[dataset]['min_depth'], max_depth=DATASETS_CONFIG[dataset]['max_depth'])
96
+
97
+ if self.config.get("log_rel", False):
98
+ self.log_images(
99
+ scalar_field={"RelPred": output["relative_depth"][0]}, prefix="TrainRel")
100
+
101
+ self.scaler.update()
102
+ self.optimizer.zero_grad()
103
+
104
+ return losses
105
+
106
+ @torch.no_grad()
107
+ def eval_infer(self, x):
108
+ with amp.autocast(enabled=self.config.use_amp):
109
+ m = self.model.module if self.config.multigpu else self.model
110
+ pred_depths = m(x)['metric_depth']
111
+ return pred_depths
112
+
113
+ @torch.no_grad()
114
+ def crop_aware_infer(self, x):
115
+ # if we are not avoiding the black border, we can just use the normal inference
116
+ if not self.config.get("avoid_boundary", False):
117
+ return self.eval_infer(x)
118
+
119
+ # otherwise, we need to crop the image to avoid the black border
120
+ # For now, this may be a bit slow due to converting to numpy and back
121
+ # We assume no normalization is done on the input image
122
+
123
+ # get the black border
124
+ assert x.shape[0] == 1, "Only batch size 1 is supported for now"
125
+ x_pil = transforms.ToPILImage()(x[0].cpu())
126
+ x_np = np.array(x_pil, dtype=np.uint8)
127
+ black_border_params = get_black_border(x_np)
128
+ top, bottom, left, right = black_border_params.top, black_border_params.bottom, black_border_params.left, black_border_params.right
129
+ x_np_cropped = x_np[top:bottom, left:right, :]
130
+ x_cropped = transforms.ToTensor()(Image.fromarray(x_np_cropped))
131
+
132
+ # run inference on the cropped image
133
+ pred_depths_cropped = self.eval_infer(x_cropped.unsqueeze(0).to(self.device))
134
+
135
+ # resize the prediction to x_np_cropped's size
136
+ pred_depths_cropped = nn.functional.interpolate(
137
+ pred_depths_cropped, size=(x_np_cropped.shape[0], x_np_cropped.shape[1]), mode="bilinear", align_corners=False)
138
+
139
+
140
+ # pad the prediction back to the original size
141
+ pred_depths = torch.zeros((1, 1, x_np.shape[0], x_np.shape[1]), device=pred_depths_cropped.device, dtype=pred_depths_cropped.dtype)
142
+ pred_depths[:, :, top:bottom, left:right] = pred_depths_cropped
143
+
144
+ return pred_depths
145
+
146
+
147
+
148
+ def validate_on_batch(self, batch, val_step):
149
+ images = batch['image'].to(self.device)
150
+ depths_gt = batch['depth'].to(self.device)
151
+ dataset = batch['dataset'][0]
152
+ mask = batch["mask"].to(self.device)
153
+ if 'has_valid_depth' in batch:
154
+ if not batch['has_valid_depth']:
155
+ return None, None
156
+
157
+ depths_gt = depths_gt.squeeze().unsqueeze(0).unsqueeze(0)
158
+ mask = mask.squeeze().unsqueeze(0).unsqueeze(0)
159
+ if dataset == 'nyu':
160
+ pred_depths = self.crop_aware_infer(images)
161
+ else:
162
+ pred_depths = self.eval_infer(images)
163
+ pred_depths = pred_depths.squeeze().unsqueeze(0).unsqueeze(0)
164
+
165
+ with amp.autocast(enabled=self.config.use_amp):
166
+ l_depth = self.silog_loss(
167
+ pred_depths, depths_gt, mask=mask.to(torch.bool), interpolate=True)
168
+
169
+ metrics = compute_metrics(depths_gt, pred_depths, **self.config)
170
+ losses = {f"{self.silog_loss.name}": l_depth.item()}
171
+
172
+ if val_step == 1 and self.should_log:
173
+ depths_gt[torch.logical_not(mask)] = -99
174
+ self.log_images(rgb={"Input": images[0]}, depth={"GT": depths_gt[0], "PredictedMono": pred_depths[0]}, prefix="Test",
175
+ min_depth=DATASETS_CONFIG[dataset]['min_depth'], max_depth=DATASETS_CONFIG[dataset]['max_depth'])
176
+
177
+ return metrics, losses
zoedepth/utils/__init__.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # MIT License
2
+
3
+ # Copyright (c) 2022 Intelligent Systems Lab Org
4
+
5
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ # of this software and associated documentation files (the "Software"), to deal
7
+ # in the Software without restriction, including without limitation the rights
8
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ # copies of the Software, and to permit persons to whom the Software is
10
+ # furnished to do so, subject to the following conditions:
11
+
12
+ # The above copyright notice and this permission notice shall be included in all
13
+ # copies or substantial portions of the Software.
14
+
15
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ # SOFTWARE.
22
+
23
+ # File author: Shariq Farooq Bhat
24
+
zoedepth/utils/arg_utils.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+ def infer_type(x): # hacky way to infer type from string args
4
+ if not isinstance(x, str):
5
+ return x
6
+
7
+ try:
8
+ x = int(x)
9
+ return x
10
+ except ValueError:
11
+ pass
12
+
13
+ try:
14
+ x = float(x)
15
+ return x
16
+ except ValueError:
17
+ pass
18
+
19
+ return x
20
+
21
+
22
+ def parse_unknown(unknown_args):
23
+ clean = []
24
+ for a in unknown_args:
25
+ if "=" in a:
26
+ k, v = a.split("=")
27
+ clean.extend([k, v])
28
+ else:
29
+ clean.append(a)
30
+
31
+ keys = clean[::2]
32
+ values = clean[1::2]
33
+ return {k.replace("--", ""): infer_type(v) for k, v in zip(keys, values)}
zoedepth/utils/config.py ADDED
@@ -0,0 +1,437 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # MIT License
2
+
3
+ # Copyright (c) 2022 Intelligent Systems Lab Org
4
+
5
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ # of this software and associated documentation files (the "Software"), to deal
7
+ # in the Software without restriction, including without limitation the rights
8
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ # copies of the Software, and to permit persons to whom the Software is
10
+ # furnished to do so, subject to the following conditions:
11
+
12
+ # The above copyright notice and this permission notice shall be included in all
13
+ # copies or substantial portions of the Software.
14
+
15
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ # SOFTWARE.
22
+
23
+ # File author: Shariq Farooq Bhat
24
+
25
+ import json
26
+ import os
27
+
28
+ from zoedepth.utils.easydict import EasyDict as edict
29
+
30
+ from zoedepth.utils.arg_utils import infer_type
31
+ import pathlib
32
+ import platform
33
+
34
+ ROOT = pathlib.Path(__file__).parent.parent.resolve()
35
+
36
+ HOME_DIR = os.path.expanduser("~")
37
+
38
+ COMMON_CONFIG = {
39
+ "save_dir": os.path.expanduser("~/shortcuts/monodepth3_checkpoints"),
40
+ "project": "ZoeDepth",
41
+ "tags": '',
42
+ "notes": "",
43
+ "gpu": None,
44
+ "root": ".",
45
+ "uid": None,
46
+ "print_losses": False
47
+ }
48
+
49
+ DATASETS_CONFIG = {
50
+ "kitti": {
51
+ "dataset": "kitti",
52
+ "min_depth": 0.001,
53
+ "max_depth": 80,
54
+ "data_path": os.path.join(HOME_DIR, "shortcuts/datasets/kitti/raw"),
55
+ "gt_path": os.path.join(HOME_DIR, "shortcuts/datasets/kitti/gts"),
56
+ "filenames_file": "./train_test_inputs/kitti_eigen_train_files_with_gt.txt",
57
+ "input_height": 352,
58
+ "input_width": 1216, # 704
59
+ "data_path_eval": os.path.join(HOME_DIR, "shortcuts/datasets/kitti/raw"),
60
+ "gt_path_eval": os.path.join(HOME_DIR, "shortcuts/datasets/kitti/gts"),
61
+ "filenames_file_eval": "./train_test_inputs/kitti_eigen_test_files_with_gt.txt",
62
+
63
+ "min_depth_eval": 1e-3,
64
+ "max_depth_eval": 80,
65
+
66
+ "do_random_rotate": True,
67
+ "degree": 1.0,
68
+ "do_kb_crop": True,
69
+ "garg_crop": True,
70
+ "eigen_crop": False,
71
+ "use_right": False
72
+ },
73
+ "kitti_test": {
74
+ "dataset": "kitti",
75
+ "min_depth": 0.001,
76
+ "max_depth": 80,
77
+ "data_path": os.path.join(HOME_DIR, "shortcuts/datasets/kitti/raw"),
78
+ "gt_path": os.path.join(HOME_DIR, "shortcuts/datasets/kitti/gts"),
79
+ "filenames_file": "./train_test_inputs/kitti_eigen_train_files_with_gt.txt",
80
+ "input_height": 352,
81
+ "input_width": 1216,
82
+ "data_path_eval": os.path.join(HOME_DIR, "shortcuts/datasets/kitti/raw"),
83
+ "gt_path_eval": os.path.join(HOME_DIR, "shortcuts/datasets/kitti/gts"),
84
+ "filenames_file_eval": "./train_test_inputs/kitti_eigen_test_files_with_gt.txt",
85
+
86
+ "min_depth_eval": 1e-3,
87
+ "max_depth_eval": 80,
88
+
89
+ "do_random_rotate": False,
90
+ "degree": 1.0,
91
+ "do_kb_crop": True,
92
+ "garg_crop": True,
93
+ "eigen_crop": False,
94
+ "use_right": False
95
+ },
96
+ "nyu": {
97
+ "dataset": "nyu",
98
+ "avoid_boundary": False,
99
+ "min_depth": 1e-3, # originally 0.1
100
+ "max_depth": 10,
101
+ "data_path": os.path.join(HOME_DIR, "shortcuts/datasets/nyu_depth_v2/sync/"),
102
+ "gt_path": os.path.join(HOME_DIR, "shortcuts/datasets/nyu_depth_v2/sync/"),
103
+ "filenames_file": "./train_test_inputs/nyudepthv2_train_files_with_gt.txt",
104
+ "input_height": 480,
105
+ "input_width": 640,
106
+ "data_path_eval": os.path.join(HOME_DIR, "shortcuts/datasets/nyu_depth_v2/official_splits/test/"),
107
+ "gt_path_eval": os.path.join(HOME_DIR, "shortcuts/datasets/nyu_depth_v2/official_splits/test/"),
108
+ "filenames_file_eval": "./train_test_inputs/nyudepthv2_test_files_with_gt.txt",
109
+ "min_depth_eval": 1e-3,
110
+ "max_depth_eval": 10,
111
+ "min_depth_diff": -10,
112
+ "max_depth_diff": 10,
113
+
114
+ "do_random_rotate": True,
115
+ "degree": 1.0,
116
+ "do_kb_crop": False,
117
+ "garg_crop": False,
118
+ "eigen_crop": True
119
+ },
120
+ "ibims": {
121
+ "dataset": "ibims",
122
+ "ibims_root": os.path.join(HOME_DIR, "shortcuts/datasets/ibims/ibims1_core_raw/"),
123
+ "eigen_crop": True,
124
+ "garg_crop": False,
125
+ "do_kb_crop": False,
126
+ "min_depth_eval": 0,
127
+ "max_depth_eval": 10,
128
+ "min_depth": 1e-3,
129
+ "max_depth": 10
130
+ },
131
+ "sunrgbd": {
132
+ "dataset": "sunrgbd",
133
+ "sunrgbd_root": os.path.join(HOME_DIR, "shortcuts/datasets/SUNRGBD/test/"),
134
+ "eigen_crop": True,
135
+ "garg_crop": False,
136
+ "do_kb_crop": False,
137
+ "min_depth_eval": 0,
138
+ "max_depth_eval": 8,
139
+ "min_depth": 1e-3,
140
+ "max_depth": 10
141
+ },
142
+ "diml_indoor": {
143
+ "dataset": "diml_indoor",
144
+ "diml_indoor_root": os.path.join(HOME_DIR, "shortcuts/datasets/diml_indoor_test/"),
145
+ "eigen_crop": True,
146
+ "garg_crop": False,
147
+ "do_kb_crop": False,
148
+ "min_depth_eval": 0,
149
+ "max_depth_eval": 10,
150
+ "min_depth": 1e-3,
151
+ "max_depth": 10
152
+ },
153
+ "diml_outdoor": {
154
+ "dataset": "diml_outdoor",
155
+ "diml_outdoor_root": os.path.join(HOME_DIR, "shortcuts/datasets/diml_outdoor_test/"),
156
+ "eigen_crop": False,
157
+ "garg_crop": True,
158
+ "do_kb_crop": False,
159
+ "min_depth_eval": 2,
160
+ "max_depth_eval": 80,
161
+ "min_depth": 1e-3,
162
+ "max_depth": 80
163
+ },
164
+ "diode_indoor": {
165
+ "dataset": "diode_indoor",
166
+ "diode_indoor_root": os.path.join(HOME_DIR, "shortcuts/datasets/diode_indoor/"),
167
+ "eigen_crop": True,
168
+ "garg_crop": False,
169
+ "do_kb_crop": False,
170
+ "min_depth_eval": 1e-3,
171
+ "max_depth_eval": 10,
172
+ "min_depth": 1e-3,
173
+ "max_depth": 10
174
+ },
175
+ "diode_outdoor": {
176
+ "dataset": "diode_outdoor",
177
+ "diode_outdoor_root": os.path.join(HOME_DIR, "shortcuts/datasets/diode_outdoor/"),
178
+ "eigen_crop": False,
179
+ "garg_crop": True,
180
+ "do_kb_crop": False,
181
+ "min_depth_eval": 1e-3,
182
+ "max_depth_eval": 80,
183
+ "min_depth": 1e-3,
184
+ "max_depth": 80
185
+ },
186
+ "hypersim_test": {
187
+ "dataset": "hypersim_test",
188
+ "hypersim_test_root": os.path.join(HOME_DIR, "shortcuts/datasets/hypersim_test/"),
189
+ "eigen_crop": True,
190
+ "garg_crop": False,
191
+ "do_kb_crop": False,
192
+ "min_depth_eval": 1e-3,
193
+ "max_depth_eval": 80,
194
+ "min_depth": 1e-3,
195
+ "max_depth": 10
196
+ },
197
+ "vkitti": {
198
+ "dataset": "vkitti",
199
+ "vkitti_root": os.path.join(HOME_DIR, "shortcuts/datasets/vkitti_test/"),
200
+ "eigen_crop": False,
201
+ "garg_crop": True,
202
+ "do_kb_crop": True,
203
+ "min_depth_eval": 1e-3,
204
+ "max_depth_eval": 80,
205
+ "min_depth": 1e-3,
206
+ "max_depth": 80
207
+ },
208
+ "vkitti2": {
209
+ "dataset": "vkitti2",
210
+ "vkitti2_root": os.path.join(HOME_DIR, "shortcuts/datasets/vkitti2/"),
211
+ "eigen_crop": False,
212
+ "garg_crop": True,
213
+ "do_kb_crop": True,
214
+ "min_depth_eval": 1e-3,
215
+ "max_depth_eval": 80,
216
+ "min_depth": 1e-3,
217
+ "max_depth": 80,
218
+ },
219
+ "ddad": {
220
+ "dataset": "ddad",
221
+ "ddad_root": os.path.join(HOME_DIR, "shortcuts/datasets/ddad/ddad_val/"),
222
+ "eigen_crop": False,
223
+ "garg_crop": True,
224
+ "do_kb_crop": True,
225
+ "min_depth_eval": 1e-3,
226
+ "max_depth_eval": 80,
227
+ "min_depth": 1e-3,
228
+ "max_depth": 80,
229
+ },
230
+ }
231
+
232
+ ALL_INDOOR = ["nyu", "ibims", "sunrgbd", "diode_indoor", "hypersim_test"]
233
+ ALL_OUTDOOR = ["kitti", "diml_outdoor", "diode_outdoor", "vkitti2", "ddad"]
234
+ ALL_EVAL_DATASETS = ALL_INDOOR + ALL_OUTDOOR
235
+
236
+ COMMON_TRAINING_CONFIG = {
237
+ "dataset": "nyu",
238
+ "distributed": True,
239
+ "workers": 16,
240
+ "clip_grad": 0.1,
241
+ "use_shared_dict": False,
242
+ "shared_dict": None,
243
+ "use_amp": False,
244
+
245
+ "aug": True,
246
+ "random_crop": False,
247
+ "random_translate": False,
248
+ "translate_prob": 0.2,
249
+ "max_translation": 100,
250
+
251
+ "validate_every": 0.25,
252
+ "log_images_every": 0.1,
253
+ "prefetch": False,
254
+ }
255
+
256
+
257
+ def flatten(config, except_keys=('bin_conf')):
258
+ def recurse(inp):
259
+ if isinstance(inp, dict):
260
+ for key, value in inp.items():
261
+ if key in except_keys:
262
+ yield (key, value)
263
+ if isinstance(value, dict):
264
+ yield from recurse(value)
265
+ else:
266
+ yield (key, value)
267
+
268
+ return dict(list(recurse(config)))
269
+
270
+
271
+ def split_combined_args(kwargs):
272
+ """Splits the arguments that are combined with '__' into multiple arguments.
273
+ Combined arguments should have equal number of keys and values.
274
+ Keys are separated by '__' and Values are separated with ';'.
275
+ For example, '__n_bins__lr=256;0.001'
276
+
277
+ Args:
278
+ kwargs (dict): key-value pairs of arguments where key-value is optionally combined according to the above format.
279
+
280
+ Returns:
281
+ dict: Parsed dict with the combined arguments split into individual key-value pairs.
282
+ """
283
+ new_kwargs = dict(kwargs)
284
+ for key, value in kwargs.items():
285
+ if key.startswith("__"):
286
+ keys = key.split("__")[1:]
287
+ values = value.split(";")
288
+ assert len(keys) == len(
289
+ values), f"Combined arguments should have equal number of keys and values. Keys are separated by '__' and Values are separated with ';'. For example, '__n_bins__lr=256;0.001. Given (keys,values) is ({keys}, {values})"
290
+ for k, v in zip(keys, values):
291
+ new_kwargs[k] = v
292
+ return new_kwargs
293
+
294
+
295
+ def parse_list(config, key, dtype=int):
296
+ """Parse a list of values for the key if the value is a string. The values are separated by a comma.
297
+ Modifies the config in place.
298
+ """
299
+ if key in config:
300
+ if isinstance(config[key], str):
301
+ config[key] = list(map(dtype, config[key].split(',')))
302
+ assert isinstance(config[key], list) and all([isinstance(e, dtype) for e in config[key]]
303
+ ), f"{key} should be a list of values dtype {dtype}. Given {config[key]} of type {type(config[key])} with values of type {[type(e) for e in config[key]]}."
304
+
305
+
306
+ def get_model_config(model_name, model_version=None):
307
+ """Find and parse the .json config file for the model.
308
+
309
+ Args:
310
+ model_name (str): name of the model. The config file should be named config_{model_name}[_{model_version}].json under the models/{model_name} directory.
311
+ model_version (str, optional): Specific config version. If specified config_{model_name}_{model_version}.json is searched for and used. Otherwise config_{model_name}.json is used. Defaults to None.
312
+
313
+ Returns:
314
+ easydict: the config dictionary for the model.
315
+ """
316
+ config_fname = f"config_{model_name}_{model_version}.json" if model_version is not None else f"config_{model_name}.json"
317
+ config_file = os.path.join(ROOT, "models", model_name, config_fname)
318
+ if not os.path.exists(config_file):
319
+ return None
320
+
321
+ with open(config_file, "r") as f:
322
+ config = edict(json.load(f))
323
+
324
+ # handle dictionary inheritance
325
+ # only training config is supported for inheritance
326
+ if "inherit" in config.train and config.train.inherit is not None:
327
+ inherit_config = get_model_config(config.train["inherit"]).train
328
+ for key, value in inherit_config.items():
329
+ if key not in config.train:
330
+ config.train[key] = value
331
+ return edict(config)
332
+
333
+
334
+ def update_model_config(config, mode, model_name, model_version=None, strict=False):
335
+ model_config = get_model_config(model_name, model_version)
336
+ if model_config is not None:
337
+ config = {**config, **
338
+ flatten({**model_config.model, **model_config[mode]})}
339
+ elif strict:
340
+ raise ValueError(f"Config file for model {model_name} not found.")
341
+ return config
342
+
343
+
344
+ def check_choices(name, value, choices):
345
+ # return # No checks in dev branch
346
+ if value not in choices:
347
+ raise ValueError(f"{name} {value} not in supported choices {choices}")
348
+
349
+
350
+ KEYS_TYPE_BOOL = ["use_amp", "distributed", "use_shared_dict", "same_lr", "aug", "three_phase",
351
+ "prefetch", "cycle_momentum"] # Casting is not necessary as their int casted values in config are 0 or 1
352
+
353
+
354
+ def get_config(model_name, mode='train', dataset=None, **overwrite_kwargs):
355
+ """Main entry point to get the config for the model.
356
+
357
+ Args:
358
+ model_name (str): name of the desired model.
359
+ mode (str, optional): "train" or "infer". Defaults to 'train'.
360
+ dataset (str, optional): If specified, the corresponding dataset configuration is loaded as well. Defaults to None.
361
+
362
+ Keyword Args: key-value pairs of arguments to overwrite the default config.
363
+
364
+ The order of precedence for overwriting the config is (Higher precedence first):
365
+ # 1. overwrite_kwargs
366
+ # 2. "config_version": Config file version if specified in overwrite_kwargs. The corresponding config loaded is config_{model_name}_{config_version}.json
367
+ # 3. "version_name": Default Model version specific config specified in overwrite_kwargs. The corresponding config loaded is config_{model_name}_{version_name}.json
368
+ # 4. common_config: Default config for all models specified in COMMON_CONFIG
369
+
370
+ Returns:
371
+ easydict: The config dictionary for the model.
372
+ """
373
+
374
+
375
+ check_choices("Model", model_name, ["zoedepth", "zoedepth_nk"])
376
+ check_choices("Mode", mode, ["train", "infer", "eval"])
377
+ if mode == "train":
378
+ check_choices("Dataset", dataset, ["nyu", "kitti", "mix", None])
379
+
380
+ config = flatten({**COMMON_CONFIG, **COMMON_TRAINING_CONFIG})
381
+ config = update_model_config(config, mode, model_name)
382
+
383
+ # update with model version specific config
384
+ version_name = overwrite_kwargs.get("version_name", config["version_name"])
385
+ config = update_model_config(config, mode, model_name, version_name)
386
+
387
+ # update with config version if specified
388
+ config_version = overwrite_kwargs.get("config_version", None)
389
+ if config_version is not None:
390
+ print("Overwriting config with config_version", config_version)
391
+ config = update_model_config(config, mode, model_name, config_version)
392
+
393
+ # update with overwrite_kwargs
394
+ # Combined args are useful for hyperparameter search
395
+ overwrite_kwargs = split_combined_args(overwrite_kwargs)
396
+ config = {**config, **overwrite_kwargs}
397
+
398
+ # Casting to bool # TODO: Not necessary. Remove and test
399
+ for key in KEYS_TYPE_BOOL:
400
+ if key in config:
401
+ config[key] = bool(config[key])
402
+
403
+ # Model specific post processing of config
404
+ parse_list(config, "n_attractors")
405
+
406
+ # adjust n_bins for each bin configuration if bin_conf is given and n_bins is passed in overwrite_kwargs
407
+ if 'bin_conf' in config and 'n_bins' in overwrite_kwargs:
408
+ bin_conf = config['bin_conf'] # list of dicts
409
+ n_bins = overwrite_kwargs['n_bins']
410
+ new_bin_conf = []
411
+ for conf in bin_conf:
412
+ conf['n_bins'] = n_bins
413
+ new_bin_conf.append(conf)
414
+ config['bin_conf'] = new_bin_conf
415
+
416
+ if mode == "train":
417
+ orig_dataset = dataset
418
+ if dataset == "mix":
419
+ dataset = 'nyu' # Use nyu as default for mix. Dataset config is changed accordingly while loading the dataloader
420
+ if dataset is not None:
421
+ config['project'] = f"MonoDepth3-{orig_dataset}" # Set project for wandb
422
+
423
+ if dataset is not None:
424
+ config['dataset'] = dataset
425
+ config = {**DATASETS_CONFIG[dataset], **config}
426
+
427
+
428
+ config['model'] = model_name
429
+ typed_config = {k: infer_type(v) for k, v in config.items()}
430
+ # add hostname to config
431
+ config['hostname'] = platform.node()
432
+ return edict(typed_config)
433
+
434
+
435
+ def change_dataset(config, new_dataset):
436
+ config.update(DATASETS_CONFIG[new_dataset])
437
+ return config
zoedepth/utils/easydict/__init__.py ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ EasyDict
3
+ Copy/pasted from https://github.com/makinacorpus/easydict
4
+ Original author: Mathieu Leplatre <mathieu.leplatre@makina-corpus.com>
5
+ """
6
+
7
+ class EasyDict(dict):
8
+ """
9
+ Get attributes
10
+
11
+ >>> d = EasyDict({'foo':3})
12
+ >>> d['foo']
13
+ 3
14
+ >>> d.foo
15
+ 3
16
+ >>> d.bar
17
+ Traceback (most recent call last):
18
+ ...
19
+ AttributeError: 'EasyDict' object has no attribute 'bar'
20
+
21
+ Works recursively
22
+
23
+ >>> d = EasyDict({'foo':3, 'bar':{'x':1, 'y':2}})
24
+ >>> isinstance(d.bar, dict)
25
+ True
26
+ >>> d.bar.x
27
+ 1
28
+
29
+ Bullet-proof
30
+
31
+ >>> EasyDict({})
32
+ {}
33
+ >>> EasyDict(d={})
34
+ {}
35
+ >>> EasyDict(None)
36
+ {}
37
+ >>> d = {'a': 1}
38
+ >>> EasyDict(**d)
39
+ {'a': 1}
40
+ >>> EasyDict((('a', 1), ('b', 2)))
41
+ {'a': 1, 'b': 2}
42
+
43
+ Set attributes
44
+
45
+ >>> d = EasyDict()
46
+ >>> d.foo = 3
47
+ >>> d.foo
48
+ 3
49
+ >>> d.bar = {'prop': 'value'}
50
+ >>> d.bar.prop
51
+ 'value'
52
+ >>> d
53
+ {'foo': 3, 'bar': {'prop': 'value'}}
54
+ >>> d.bar.prop = 'newer'
55
+ >>> d.bar.prop
56
+ 'newer'
57
+
58
+
59
+ Values extraction
60
+
61
+ >>> d = EasyDict({'foo':0, 'bar':[{'x':1, 'y':2}, {'x':3, 'y':4}]})
62
+ >>> isinstance(d.bar, list)
63
+ True
64
+ >>> from operator import attrgetter
65
+ >>> list(map(attrgetter('x'), d.bar))
66
+ [1, 3]
67
+ >>> list(map(attrgetter('y'), d.bar))
68
+ [2, 4]
69
+ >>> d = EasyDict()
70
+ >>> list(d.keys())
71
+ []
72
+ >>> d = EasyDict(foo=3, bar=dict(x=1, y=2))
73
+ >>> d.foo
74
+ 3
75
+ >>> d.bar.x
76
+ 1
77
+
78
+ Still like a dict though
79
+
80
+ >>> o = EasyDict({'clean':True})
81
+ >>> list(o.items())
82
+ [('clean', True)]
83
+
84
+ And like a class
85
+
86
+ >>> class Flower(EasyDict):
87
+ ... power = 1
88
+ ...
89
+ >>> f = Flower()
90
+ >>> f.power
91
+ 1
92
+ >>> f = Flower({'height': 12})
93
+ >>> f.height
94
+ 12
95
+ >>> f['power']
96
+ 1
97
+ >>> sorted(f.keys())
98
+ ['height', 'power']
99
+
100
+ update and pop items
101
+ >>> d = EasyDict(a=1, b='2')
102
+ >>> e = EasyDict(c=3.0, a=9.0)
103
+ >>> d.update(e)
104
+ >>> d.c
105
+ 3.0
106
+ >>> d['c']
107
+ 3.0
108
+ >>> d.get('c')
109
+ 3.0
110
+ >>> d.update(a=4, b=4)
111
+ >>> d.b
112
+ 4
113
+ >>> d.pop('a')
114
+ 4
115
+ >>> d.a
116
+ Traceback (most recent call last):
117
+ ...
118
+ AttributeError: 'EasyDict' object has no attribute 'a'
119
+ """
120
+ def __init__(self, d=None, **kwargs):
121
+ if d is None:
122
+ d = {}
123
+ else:
124
+ d = dict(d)
125
+ if kwargs:
126
+ d.update(**kwargs)
127
+ for k, v in d.items():
128
+ setattr(self, k, v)
129
+ # Class attributes
130
+ for k in self.__class__.__dict__.keys():
131
+ if not (k.startswith('__') and k.endswith('__')) and not k in ('update', 'pop'):
132
+ setattr(self, k, getattr(self, k))
133
+
134
+ def __setattr__(self, name, value):
135
+ if isinstance(value, (list, tuple)):
136
+ value = [self.__class__(x)
137
+ if isinstance(x, dict) else x for x in value]
138
+ elif isinstance(value, dict) and not isinstance(value, self.__class__):
139
+ value = self.__class__(value)
140
+ super(EasyDict, self).__setattr__(name, value)
141
+ super(EasyDict, self).__setitem__(name, value)
142
+
143
+ __setitem__ = __setattr__
144
+
145
+ def update(self, e=None, **f):
146
+ d = e or dict()
147
+ d.update(f)
148
+ for k in d:
149
+ setattr(self, k, d[k])
150
+
151
+ def pop(self, k, d=None):
152
+ delattr(self, k)
153
+ return super(EasyDict, self).pop(k, d)
154
+
155
+
156
+ if __name__ == "__main__":
157
+ import doctest
158
+ doctest.testmod()
zoedepth/utils/geometry.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # MIT License
2
+
3
+ # Copyright (c) 2022 Intelligent Systems Lab Org
4
+
5
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ # of this software and associated documentation files (the "Software"), to deal
7
+ # in the Software without restriction, including without limitation the rights
8
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ # copies of the Software, and to permit persons to whom the Software is
10
+ # furnished to do so, subject to the following conditions:
11
+
12
+ # The above copyright notice and this permission notice shall be included in all
13
+ # copies or substantial portions of the Software.
14
+
15
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ # SOFTWARE.
22
+
23
+ # File author: Shariq Farooq Bhat
24
+
25
+ import numpy as np
26
+
27
+ def get_intrinsics(H,W):
28
+ """
29
+ Intrinsics for a pinhole camera model.
30
+ Assume fov of 55 degrees and central principal point.
31
+ """
32
+ f = 0.5 * W / np.tan(0.5 * 55 * np.pi / 180.0)
33
+ cx = 0.5 * W
34
+ cy = 0.5 * H
35
+ return np.array([[f, 0, cx],
36
+ [0, f, cy],
37
+ [0, 0, 1]])
38
+
39
+ def depth_to_points(depth, R=None, t=None):
40
+
41
+ K = get_intrinsics(depth.shape[1], depth.shape[2])
42
+ Kinv = np.linalg.inv(K)
43
+ if R is None:
44
+ R = np.eye(3)
45
+ if t is None:
46
+ t = np.zeros(3)
47
+
48
+ # M converts from your coordinate to PyTorch3D's coordinate system
49
+ M = np.eye(3)
50
+ M[0, 0] = -1.0
51
+ M[1, 1] = -1.0
52
+
53
+ height, width = depth.shape[1:3]
54
+
55
+ x = np.arange(width)
56
+ y = np.arange(height)
57
+ coord = np.stack(np.meshgrid(x, y), -1)
58
+ coord = np.concatenate((coord, np.ones_like(coord)[:, :, [0]]), -1) # z=1
59
+ coord = coord.astype(np.float32)
60
+ # coord = torch.as_tensor(coord, dtype=torch.float32, device=device)
61
+ coord = coord[None] # bs, h, w, 3
62
+
63
+ D = depth[:, :, :, None, None]
64
+ # print(D.shape, Kinv[None, None, None, ...].shape, coord[:, :, :, :, None].shape )
65
+ pts3D_1 = D * Kinv[None, None, None, ...] @ coord[:, :, :, :, None]
66
+ # pts3D_1 live in your coordinate system. Convert them to Py3D's
67
+ pts3D_1 = M[None, None, None, ...] @ pts3D_1
68
+ # from reference to targe tviewpoint
69
+ pts3D_2 = R[None, None, None, ...] @ pts3D_1 + t[None, None, None, :, None]
70
+ # pts3D_2 = pts3D_1
71
+ # depth_2 = pts3D_2[:, :, :, 2, :] # b,1,h,w
72
+ return pts3D_2[:, :, :, :3, 0][0]
73
+
74
+
75
+ def create_triangles(h, w, mask=None):
76
+ """
77
+ Reference: https://github.com/google-research/google-research/blob/e96197de06613f1b027d20328e06d69829fa5a89/infinite_nature/render_utils.py#L68
78
+ Creates mesh triangle indices from a given pixel grid size.
79
+ This function is not and need not be differentiable as triangle indices are
80
+ fixed.
81
+ Args:
82
+ h: (int) denoting the height of the image.
83
+ w: (int) denoting the width of the image.
84
+ Returns:
85
+ triangles: 2D numpy array of indices (int) with shape (2(W-1)(H-1) x 3)
86
+ """
87
+ x, y = np.meshgrid(range(w - 1), range(h - 1))
88
+ tl = y * w + x
89
+ tr = y * w + x + 1
90
+ bl = (y + 1) * w + x
91
+ br = (y + 1) * w + x + 1
92
+ triangles = np.array([tl, bl, tr, br, tr, bl])
93
+ triangles = np.transpose(triangles, (1, 2, 0)).reshape(
94
+ ((w - 1) * (h - 1) * 2, 3))
95
+ if mask is not None:
96
+ mask = mask.reshape(-1)
97
+ triangles = triangles[mask[triangles].all(1)]
98
+ return triangles
zoedepth/utils/misc.py ADDED
@@ -0,0 +1,368 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # MIT License
2
+
3
+ # Copyright (c) 2022 Intelligent Systems Lab Org
4
+
5
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ # of this software and associated documentation files (the "Software"), to deal
7
+ # in the Software without restriction, including without limitation the rights
8
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ # copies of the Software, and to permit persons to whom the Software is
10
+ # furnished to do so, subject to the following conditions:
11
+
12
+ # The above copyright notice and this permission notice shall be included in all
13
+ # copies or substantial portions of the Software.
14
+
15
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ # SOFTWARE.
22
+
23
+ # File author: Shariq Farooq Bhat
24
+
25
+ """Miscellaneous utility functions."""
26
+
27
+ from scipy import ndimage
28
+
29
+ import base64
30
+ import math
31
+ import re
32
+ from io import BytesIO
33
+
34
+ import matplotlib
35
+ import matplotlib.cm
36
+ import numpy as np
37
+ import requests
38
+ import torch
39
+ import torch.distributed as dist
40
+ import torch.nn
41
+ import torch.nn as nn
42
+ import torch.utils.data.distributed
43
+ from PIL import Image
44
+ from torchvision.transforms import ToTensor
45
+
46
+
47
+ class RunningAverage:
48
+ def __init__(self):
49
+ self.avg = 0
50
+ self.count = 0
51
+
52
+ def append(self, value):
53
+ self.avg = (value + self.count * self.avg) / (self.count + 1)
54
+ self.count += 1
55
+
56
+ def get_value(self):
57
+ return self.avg
58
+
59
+
60
+ def denormalize(x):
61
+ """Reverses the imagenet normalization applied to the input.
62
+
63
+ Args:
64
+ x (torch.Tensor - shape(N,3,H,W)): input tensor
65
+
66
+ Returns:
67
+ torch.Tensor - shape(N,3,H,W): Denormalized input
68
+ """
69
+ mean = torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(x.device)
70
+ std = torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(x.device)
71
+ return x * std + mean
72
+
73
+
74
+ class RunningAverageDict:
75
+ """A dictionary of running averages."""
76
+ def __init__(self):
77
+ self._dict = None
78
+
79
+ def update(self, new_dict):
80
+ if new_dict is None:
81
+ return
82
+
83
+ if self._dict is None:
84
+ self._dict = dict()
85
+ for key, value in new_dict.items():
86
+ self._dict[key] = RunningAverage()
87
+
88
+ for key, value in new_dict.items():
89
+ self._dict[key].append(value)
90
+
91
+ def get_value(self):
92
+ if self._dict is None:
93
+ return None
94
+ return {key: value.get_value() for key, value in self._dict.items()}
95
+
96
+
97
+ def colorize(value, vmin=None, vmax=None, cmap='gray_r', invalid_val=-99, invalid_mask=None, background_color=(128, 128, 128, 255), gamma_corrected=False, value_transform=None):
98
+ """Converts a depth map to a color image.
99
+
100
+ Args:
101
+ value (torch.Tensor, numpy.ndarry): Input depth map. Shape: (H, W) or (1, H, W) or (1, 1, H, W). All singular dimensions are squeezed
102
+ vmin (float, optional): vmin-valued entries are mapped to start color of cmap. If None, value.min() is used. Defaults to None.
103
+ vmax (float, optional): vmax-valued entries are mapped to end color of cmap. If None, value.max() is used. Defaults to None.
104
+ cmap (str, optional): matplotlib colormap to use. Defaults to 'magma_r'.
105
+ invalid_val (int, optional): Specifies value of invalid pixels that should be colored as 'background_color'. Defaults to -99.
106
+ invalid_mask (numpy.ndarray, optional): Boolean mask for invalid regions. Defaults to None.
107
+ background_color (tuple[int], optional): 4-tuple RGB color to give to invalid pixels. Defaults to (128, 128, 128, 255).
108
+ gamma_corrected (bool, optional): Apply gamma correction to colored image. Defaults to False.
109
+ value_transform (Callable, optional): Apply transform function to valid pixels before coloring. Defaults to None.
110
+
111
+ Returns:
112
+ numpy.ndarray, dtype - uint8: Colored depth map. Shape: (H, W, 4)
113
+ """
114
+ if isinstance(value, torch.Tensor):
115
+ value = value.detach().cpu().numpy()
116
+
117
+ value = value.squeeze()
118
+ if invalid_mask is None:
119
+ invalid_mask = value == invalid_val
120
+ mask = np.logical_not(invalid_mask)
121
+
122
+ # normalize
123
+ vmin = np.percentile(value[mask],2) if vmin is None else vmin
124
+ vmax = np.percentile(value[mask],85) if vmax is None else vmax
125
+ if vmin != vmax:
126
+ value = (value - vmin) / (vmax - vmin) # vmin..vmax
127
+ else:
128
+ # Avoid 0-division
129
+ value = value * 0.
130
+
131
+ # squeeze last dim if it exists
132
+ # grey out the invalid values
133
+
134
+ value[invalid_mask] = np.nan
135
+ cmapper = matplotlib.cm.get_cmap(cmap)
136
+ if value_transform:
137
+ value = value_transform(value)
138
+ # value = value / value.max()
139
+ value = cmapper(value, bytes=True) # (nxmx4)
140
+
141
+ # img = value[:, :, :]
142
+ img = value[...]
143
+ img[invalid_mask] = background_color
144
+
145
+ # return img.transpose((2, 0, 1))
146
+ if gamma_corrected:
147
+ # gamma correction
148
+ img = img / 255
149
+ img = np.power(img, 2.2)
150
+ img = img * 255
151
+ img = img.astype(np.uint8)
152
+ return img
153
+
154
+
155
+ def count_parameters(model, include_all=False):
156
+ return sum(p.numel() for p in model.parameters() if p.requires_grad or include_all)
157
+
158
+
159
+ def compute_errors(gt, pred):
160
+ """Compute metrics for 'pred' compared to 'gt'
161
+
162
+ Args:
163
+ gt (numpy.ndarray): Ground truth values
164
+ pred (numpy.ndarray): Predicted values
165
+
166
+ gt.shape should be equal to pred.shape
167
+
168
+ Returns:
169
+ dict: Dictionary containing the following metrics:
170
+ 'a1': Delta1 accuracy: Fraction of pixels that are within a scale factor of 1.25
171
+ 'a2': Delta2 accuracy: Fraction of pixels that are within a scale factor of 1.25^2
172
+ 'a3': Delta3 accuracy: Fraction of pixels that are within a scale factor of 1.25^3
173
+ 'abs_rel': Absolute relative error
174
+ 'rmse': Root mean squared error
175
+ 'log_10': Absolute log10 error
176
+ 'sq_rel': Squared relative error
177
+ 'rmse_log': Root mean squared error on the log scale
178
+ 'silog': Scale invariant log error
179
+ """
180
+ thresh = np.maximum((gt / pred), (pred / gt))
181
+ a1 = (thresh < 1.25).mean()
182
+ a2 = (thresh < 1.25 ** 2).mean()
183
+ a3 = (thresh < 1.25 ** 3).mean()
184
+
185
+ abs_rel = np.mean(np.abs(gt - pred) / gt)
186
+ sq_rel = np.mean(((gt - pred) ** 2) / gt)
187
+
188
+ rmse = (gt - pred) ** 2
189
+ rmse = np.sqrt(rmse.mean())
190
+
191
+ rmse_log = (np.log(gt) - np.log(pred)) ** 2
192
+ rmse_log = np.sqrt(rmse_log.mean())
193
+
194
+ err = np.log(pred) - np.log(gt)
195
+ silog = np.sqrt(np.mean(err ** 2) - np.mean(err) ** 2) * 100
196
+
197
+ log_10 = (np.abs(np.log10(gt) - np.log10(pred))).mean()
198
+ return dict(a1=a1, a2=a2, a3=a3, abs_rel=abs_rel, rmse=rmse, log_10=log_10, rmse_log=rmse_log,
199
+ silog=silog, sq_rel=sq_rel)
200
+
201
+
202
+ def compute_metrics(gt, pred, interpolate=True, garg_crop=False, eigen_crop=True, dataset='nyu', min_depth_eval=0.1, max_depth_eval=10, **kwargs):
203
+ """Compute metrics of predicted depth maps. Applies cropping and masking as necessary or specified via arguments. Refer to compute_errors for more details on metrics.
204
+ """
205
+ if 'config' in kwargs:
206
+ config = kwargs['config']
207
+ garg_crop = config.garg_crop
208
+ eigen_crop = config.eigen_crop
209
+ min_depth_eval = config.min_depth_eval
210
+ max_depth_eval = config.max_depth_eval
211
+
212
+ if gt.shape[-2:] != pred.shape[-2:] and interpolate:
213
+ pred = nn.functional.interpolate(
214
+ pred, gt.shape[-2:], mode='bilinear', align_corners=True)
215
+
216
+ pred = pred.squeeze().cpu().numpy()
217
+ pred[pred < min_depth_eval] = min_depth_eval
218
+ pred[pred > max_depth_eval] = max_depth_eval
219
+ pred[np.isinf(pred)] = max_depth_eval
220
+ pred[np.isnan(pred)] = min_depth_eval
221
+
222
+ gt_depth = gt.squeeze().cpu().numpy()
223
+ valid_mask = np.logical_and(
224
+ gt_depth > min_depth_eval, gt_depth < max_depth_eval)
225
+
226
+ if garg_crop or eigen_crop:
227
+ gt_height, gt_width = gt_depth.shape
228
+ eval_mask = np.zeros(valid_mask.shape)
229
+
230
+ if garg_crop:
231
+ eval_mask[int(0.40810811 * gt_height):int(0.99189189 * gt_height),
232
+ int(0.03594771 * gt_width):int(0.96405229 * gt_width)] = 1
233
+
234
+ elif eigen_crop:
235
+ # print("-"*10, " EIGEN CROP ", "-"*10)
236
+ if dataset == 'kitti':
237
+ eval_mask[int(0.3324324 * gt_height):int(0.91351351 * gt_height),
238
+ int(0.0359477 * gt_width):int(0.96405229 * gt_width)] = 1
239
+ else:
240
+ # assert gt_depth.shape == (480, 640), "Error: Eigen crop is currently only valid for (480, 640) images"
241
+ eval_mask[45:471, 41:601] = 1
242
+ else:
243
+ eval_mask = np.ones(valid_mask.shape)
244
+ valid_mask = np.logical_and(valid_mask, eval_mask)
245
+ return compute_errors(gt_depth[valid_mask], pred[valid_mask])
246
+
247
+
248
+ #################################### Model uilts ################################################
249
+
250
+
251
+ def parallelize(config, model, find_unused_parameters=True):
252
+
253
+ if config.gpu is not None:
254
+ torch.cuda.set_device(config.gpu)
255
+ model = model.cuda(config.gpu)
256
+
257
+ config.multigpu = False
258
+ if config.distributed:
259
+ # Use DDP
260
+ config.multigpu = True
261
+ config.rank = config.rank * config.ngpus_per_node + config.gpu
262
+ dist.init_process_group(backend=config.dist_backend, init_method=config.dist_url,
263
+ world_size=config.world_size, rank=config.rank)
264
+ config.batch_size = int(config.batch_size / config.ngpus_per_node)
265
+ # config.batch_size = 8
266
+ config.workers = int(
267
+ (config.num_workers + config.ngpus_per_node - 1) / config.ngpus_per_node)
268
+ print("Device", config.gpu, "Rank", config.rank, "batch size",
269
+ config.batch_size, "Workers", config.workers)
270
+ torch.cuda.set_device(config.gpu)
271
+ model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
272
+ model = model.cuda(config.gpu)
273
+ model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[config.gpu], output_device=config.gpu,
274
+ find_unused_parameters=find_unused_parameters)
275
+
276
+ elif config.gpu is None:
277
+ # Use DP
278
+ config.multigpu = True
279
+ model = model.cuda()
280
+ model = torch.nn.DataParallel(model)
281
+
282
+ return model
283
+
284
+
285
+ #################################################################################################
286
+
287
+
288
+ #####################################################################################################
289
+
290
+
291
+ class colors:
292
+ '''Colors class:
293
+ Reset all colors with colors.reset
294
+ Two subclasses fg for foreground and bg for background.
295
+ Use as colors.subclass.colorname.
296
+ i.e. colors.fg.red or colors.bg.green
297
+ Also, the generic bold, disable, underline, reverse, strikethrough,
298
+ and invisible work with the main class
299
+ i.e. colors.bold
300
+ '''
301
+ reset = '\033[0m'
302
+ bold = '\033[01m'
303
+ disable = '\033[02m'
304
+ underline = '\033[04m'
305
+ reverse = '\033[07m'
306
+ strikethrough = '\033[09m'
307
+ invisible = '\033[08m'
308
+
309
+ class fg:
310
+ black = '\033[30m'
311
+ red = '\033[31m'
312
+ green = '\033[32m'
313
+ orange = '\033[33m'
314
+ blue = '\033[34m'
315
+ purple = '\033[35m'
316
+ cyan = '\033[36m'
317
+ lightgrey = '\033[37m'
318
+ darkgrey = '\033[90m'
319
+ lightred = '\033[91m'
320
+ lightgreen = '\033[92m'
321
+ yellow = '\033[93m'
322
+ lightblue = '\033[94m'
323
+ pink = '\033[95m'
324
+ lightcyan = '\033[96m'
325
+
326
+ class bg:
327
+ black = '\033[40m'
328
+ red = '\033[41m'
329
+ green = '\033[42m'
330
+ orange = '\033[43m'
331
+ blue = '\033[44m'
332
+ purple = '\033[45m'
333
+ cyan = '\033[46m'
334
+ lightgrey = '\033[47m'
335
+
336
+
337
+ def printc(text, color):
338
+ print(f"{color}{text}{colors.reset}")
339
+
340
+ ############################################
341
+
342
+ def get_image_from_url(url):
343
+ response = requests.get(url)
344
+ img = Image.open(BytesIO(response.content)).convert("RGB")
345
+ return img
346
+
347
+ def url_to_torch(url, size=(384, 384)):
348
+ img = get_image_from_url(url)
349
+ img = img.resize(size, Image.ANTIALIAS)
350
+ img = torch.from_numpy(np.asarray(img)).float()
351
+ img = img.permute(2, 0, 1)
352
+ img.div_(255)
353
+ return img
354
+
355
+ def pil_to_batched_tensor(img):
356
+ return ToTensor()(img).unsqueeze(0)
357
+
358
+ def save_raw_16bit(depth, fpath="raw.png"):
359
+ if isinstance(depth, torch.Tensor):
360
+ depth = depth.squeeze().cpu().numpy()
361
+
362
+ assert isinstance(depth, np.ndarray), "Depth must be a torch tensor or numpy array"
363
+ assert depth.ndim == 2, "Depth must be 2D"
364
+ depth = depth * 256 # scale for 16-bit png
365
+ depth = depth.astype(np.uint16)
366
+ depth = Image.fromarray(depth)
367
+ depth.save(fpath)
368
+ print("Saved raw depth to", fpath)