toshas commited on
Commit
77e5011
·
1 Parent(s): c7adb54

remove unused files

Browse files
marigold_depth_estimation_lcm.py DELETED
@@ -1,710 +0,0 @@
1
- # Copyright 2024 Bingxin Ke, Anton Obukhov, ETH Zurich and The HuggingFace Team. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # --------------------------------------------------------------------------
15
- # If you find this code useful, we kindly ask you to cite our paper in your work.
16
- # Please find bibtex at: https://github.com/prs-eth/Marigold#-citation
17
- # More information about the method can be found at https://marigoldmonodepth.github.io
18
- # --------------------------------------------------------------------------
19
-
20
-
21
- import math
22
- from typing import Dict, Union, Tuple
23
-
24
- import matplotlib
25
- import numpy as np
26
- import torch
27
- from PIL import Image
28
- from scipy.optimize import minimize
29
- from torch.utils.data import DataLoader, TensorDataset
30
- from tqdm.auto import tqdm
31
- from transformers import CLIPTextModel, CLIPTokenizer
32
-
33
- from diffusers import (
34
- AutoencoderKL,
35
- DDIMScheduler,
36
- DiffusionPipeline,
37
- UNet2DConditionModel,
38
- )
39
- from diffusers.utils import BaseOutput, check_min_version
40
-
41
-
42
- # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
43
- check_min_version("0.27.0.dev0")
44
-
45
-
46
- class MarigoldDepthConsistencyOutput(BaseOutput):
47
- """
48
- Output class for Marigold monocular depth prediction pipeline.
49
-
50
- Args:
51
- depth_np (`np.ndarray`):
52
- Predicted depth map, with depth values in the range of [0, 1].
53
- depth_colored (`None` or `PIL.Image.Image`):
54
- Colorized depth map, with the shape of [3, H, W] and values in [0, 1].
55
- depth_latent (`torch.Tensor`):
56
- Depth map's latent, with the shape of [4, h, w].
57
- uncertainty (`None` or `np.ndarray`):
58
- Uncalibrated uncertainty(MAD, median absolute deviation) coming from ensembling.
59
- """
60
-
61
- depth_np: np.ndarray
62
- depth_colored: Union[None, Image.Image]
63
- depth_latent: torch.Tensor
64
- uncertainty: Union[None, np.ndarray]
65
-
66
-
67
- class MarigoldDepthConsistencyPipeline(DiffusionPipeline):
68
- """
69
- Pipeline for monocular depth estimation using Marigold: https://marigoldmonodepth.github.io.
70
-
71
- This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
72
- library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
73
-
74
- Args:
75
- unet (`UNet2DConditionModel`):
76
- Conditional U-Net to denoise the depth latent, conditioned on image latent.
77
- vae (`AutoencoderKL`):
78
- Variational Auto-Encoder (VAE) Model to encode and decode images and depth maps
79
- to and from latent representations.
80
- scheduler (`DDIMScheduler`):
81
- A scheduler to be used in combination with `unet` to denoise the encoded image latents.
82
- text_encoder (`CLIPTextModel`):
83
- Text-encoder, for empty text embedding.
84
- tokenizer (`CLIPTokenizer`):
85
- CLIP tokenizer.
86
- """
87
-
88
- rgb_latent_scale_factor = 0.18215
89
- depth_latent_scale_factor = 0.18215
90
-
91
- def __init__(
92
- self,
93
- unet: UNet2DConditionModel,
94
- vae: AutoencoderKL,
95
- scheduler: DDIMScheduler,
96
- text_encoder: CLIPTextModel,
97
- tokenizer: CLIPTokenizer,
98
- ):
99
- super().__init__()
100
-
101
- self.register_modules(
102
- unet=unet,
103
- vae=vae,
104
- scheduler=scheduler,
105
- text_encoder=text_encoder,
106
- tokenizer=tokenizer,
107
- )
108
-
109
- self.empty_text_embed = None
110
-
111
- @torch.no_grad()
112
- def __call__(
113
- self,
114
- input_image: Image,
115
- denoising_steps: int = 1,
116
- ensemble_size: int = 1,
117
- processing_res: int = 768,
118
- match_input_res: bool = True,
119
- batch_size: int = 0,
120
- depth_latent_init: torch.Tensor = None,
121
- depth_latent_init_strength: float = 0.1,
122
- return_depth_latent: bool = False,
123
- seed: int = None,
124
- color_map: str = "Spectral",
125
- show_progress_bar: bool = True,
126
- ensemble_kwargs: Dict = None,
127
- ) -> MarigoldDepthConsistencyOutput:
128
- """
129
- Function invoked when calling the pipeline.
130
-
131
- Args:
132
- input_image (`Image`):
133
- Input RGB (or gray-scale) image.
134
- processing_res (`int`, *optional*, defaults to `768`):
135
- Maximum resolution of processing.
136
- If set to 0: will not resize at all.
137
- match_input_res (`bool`, *optional*, defaults to `True`):
138
- Resize depth prediction to match input resolution.
139
- Only valid if `limit_input_res` is not None.
140
- denoising_steps (`int`, *optional*, defaults to `1`):
141
- Number of diffusion denoising steps (consistency) during inference.
142
- ensemble_size (`int`, *optional*, defaults to `1`):
143
- Number of predictions to be ensembled.
144
- batch_size (`int`, *optional*, defaults to `0`):
145
- Inference batch size, no bigger than `num_ensemble`.
146
- If set to 0, the script will automatically decide the proper batch size.
147
- depth_latent_init (`torch.Tensor`, *optional*, defaults to `None`):
148
- Initial depth map latent for better temporal consistency.
149
- depth_latent_init_strength (`float`, *optional*, defaults to `0.1`)
150
- Degree of initial depth latent influence, must be between 0 and 1.
151
- return_depth_latent (`bool`, defaults to False)
152
- Whether to return the depth latent.
153
- seed (`int`, *optional*, defaults to `None`)
154
- Reproducibility seed.
155
- show_progress_bar (`bool`, *optional*, defaults to `True`):
156
- Display a progress bar of diffusion denoising.
157
- color_map (`str`, *optional*, defaults to `"Spectral"`, pass `None` to skip colorized depth map generation):
158
- Colormap used to colorize the depth map.
159
- ensemble_kwargs (`dict`, *optional*, defaults to `None`):
160
- Arguments for detailed ensembling settings.
161
- Returns:
162
- `MarigoldDepthConsistencyOutput`: Output class for Marigold monocular depth prediction pipeline, including:
163
- - **depth_np** (`np.ndarray`) Predicted depth map, with depth values in the range of [0, 1]
164
- - **depth_colored** (`None` or `PIL.Image.Image`) Colorized depth map, with the shape of [3, H, W] and
165
- values in [0, 1]. None if `color_map` is `None`
166
- - **depth_latent** (`torch.Tensor`) Predicted depth map latent
167
- - **uncertainty** (`None` or `np.ndarray`) Uncalibrated uncertainty(MAD, median absolute deviation)
168
- coming from ensembling. None if `ensemble_size = 1`
169
- """
170
-
171
- device = self.device
172
- input_size = input_image.size
173
-
174
- if not match_input_res:
175
- assert (
176
- processing_res is not None
177
- ), "Value error: `resize_output_back` is only valid with "
178
- assert processing_res >= 0, "Value error: `processing_res` must be non-negative"
179
- assert (
180
- 1 <= denoising_steps <= 10
181
- ), "Value error: This model degrades with large number of steps"
182
- assert ensemble_size >= 1
183
-
184
- # ----------------- Image Preprocess -----------------
185
- # Resize image
186
- if processing_res > 0:
187
- input_image = self.resize_max_res(
188
- input_image, max_edge_resolution=processing_res
189
- )
190
- # Convert the image to RGB, to 1.remove the alpha channel 2.convert B&W to 3-channel
191
- input_image = input_image.convert("RGB")
192
- image = np.asarray(input_image)
193
-
194
- # Normalize rgb values
195
- rgb = np.transpose(image, (2, 0, 1)) # [H, W, rgb] -> [rgb, H, W]
196
- rgb_norm = rgb / 255.0 * 2.0 - 1.0 # [0, 255] -> [-1, 1]
197
- rgb_norm = torch.from_numpy(rgb_norm).to(self.dtype)
198
- rgb_norm = rgb_norm.to(device)
199
- assert rgb_norm.min() >= -1.0 and rgb_norm.max() <= 1.0
200
-
201
- # ----------------- Predicting depth -----------------
202
- # Batch repeated input image
203
- duplicated_rgb = torch.stack([rgb_norm] * ensemble_size)
204
- batch_dataset = TensorDataset(duplicated_rgb)
205
- if batch_size > 0:
206
- _bs = batch_size
207
- else:
208
- _bs = self._find_batch_size(
209
- ensemble_size=ensemble_size,
210
- input_res=max(duplicated_rgb.shape[-2:]),
211
- dtype=self.dtype,
212
- )
213
-
214
- batch_loader = DataLoader(batch_dataset, batch_size=_bs, shuffle=False)
215
-
216
- # Predict depth maps (batched)
217
- depth_pred_ls = []
218
- if show_progress_bar:
219
- iterable = tqdm(
220
- batch_loader, desc=" " * 2 + "Inference batches", leave=False
221
- )
222
- else:
223
- iterable = batch_loader
224
- depth_latent = None
225
- for batch in iterable:
226
- (batched_img,) = batch
227
- depth_pred_raw, depth_latent = self.single_infer(
228
- rgb_in=batched_img,
229
- num_inference_steps=denoising_steps,
230
- depth_latent_init=depth_latent_init,
231
- depth_latent_init_strength=depth_latent_init_strength,
232
- seed=seed,
233
- show_pbar=show_progress_bar,
234
- )
235
- depth_pred_ls.append(depth_pred_raw.detach())
236
- depth_preds = torch.concat(depth_pred_ls, dim=0).squeeze()
237
- torch.cuda.empty_cache() # clear vram cache for ensembling
238
-
239
- # ----------------- Test-time ensembling -----------------
240
- if ensemble_size > 1:
241
- depth_pred, pred_uncert = self.ensemble_depths(
242
- depth_preds, **(ensemble_kwargs or {})
243
- )
244
- else:
245
- depth_pred = depth_preds
246
- pred_uncert = None
247
-
248
- # ----------------- Post processing -----------------
249
- # Scale prediction to [0, 1]
250
- min_d = torch.min(depth_pred)
251
- max_d = torch.max(depth_pred)
252
- depth_pred = (depth_pred - min_d) / (max_d - min_d)
253
- if return_depth_latent:
254
- if ensemble_size > 1:
255
- depth_latent = self._encode_depth(2 * depth_pred - 1)
256
- else:
257
- depth_latent = None
258
-
259
- # Convert to numpy
260
- depth_pred = depth_pred.cpu().numpy().astype(np.float32)
261
-
262
- # Resize back to original resolution
263
- if match_input_res:
264
- pred_img = Image.fromarray(depth_pred)
265
- pred_img = pred_img.resize(input_size)
266
- depth_pred = np.asarray(pred_img)
267
-
268
- # Clip output range
269
- depth_pred = depth_pred.clip(0, 1)
270
-
271
- # Colorize
272
- if color_map is not None:
273
- depth_colored = self.colorize_depth_maps(
274
- depth_pred, 0, 1, cmap=color_map
275
- ).squeeze() # [3, H, W], value in (0, 1)
276
- depth_colored = (depth_colored * 255).astype(np.uint8)
277
- depth_colored_hwc = self.chw2hwc(depth_colored)
278
- depth_colored_img = Image.fromarray(depth_colored_hwc)
279
- else:
280
- depth_colored_img = None
281
- return MarigoldDepthConsistencyOutput(
282
- depth_np=depth_pred,
283
- depth_colored=depth_colored_img,
284
- depth_latent=depth_latent,
285
- uncertainty=pred_uncert,
286
- )
287
-
288
- def _encode_empty_text(self):
289
- """
290
- Encode text embedding for empty prompt.
291
- """
292
- prompt = ""
293
- text_inputs = self.tokenizer(
294
- prompt,
295
- padding="do_not_pad",
296
- max_length=self.tokenizer.model_max_length,
297
- truncation=True,
298
- return_tensors="pt",
299
- )
300
- text_input_ids = text_inputs.input_ids.to(self.text_encoder.device)
301
- self.empty_text_embed = self.text_encoder(text_input_ids)[0].to(self.dtype)
302
-
303
- @torch.no_grad()
304
- def single_infer(
305
- self,
306
- rgb_in: torch.Tensor,
307
- num_inference_steps: int,
308
- depth_latent_init: torch.Tensor,
309
- depth_latent_init_strength: float,
310
- seed: int,
311
- show_pbar: bool,
312
- ) -> Tuple[torch.Tensor, torch.Tensor]:
313
- """
314
- Perform an individual depth prediction without ensembling.
315
-
316
- Args:
317
- rgb_in (`torch.Tensor`):
318
- Input RGB image.
319
- num_inference_steps (`int`):
320
- Number of diffusion denoisign steps (DDIM) during inference.
321
- depth_latent_init (`torch.Tensor`, `optional`):
322
- Initial depth latent
323
- depth_latent_init_strength (`float`, `optional`):
324
- Degree of initial depth latent influence, must be between 0 and 1
325
- seed (`int`, *optional*, defaults to `None`)
326
- Reproducibility seed.
327
- show_pbar (`bool`):
328
- Display a progress bar of diffusion denoising.
329
- Returns:
330
- `torch.Tensor`: Predicted depth map.
331
- """
332
- device = rgb_in.device
333
-
334
- # Set timesteps
335
- self.scheduler.set_timesteps(num_inference_steps, device=device)
336
- timesteps = self.scheduler.timesteps # [T]
337
-
338
- # Encode image
339
- rgb_latent = self._encode_rgb(rgb_in)
340
-
341
- # Initial depth map (noise)
342
- if seed is None:
343
- rng = None
344
- else:
345
- rng = torch.Generator(device=device)
346
- rng.manual_seed(seed)
347
- depth_latent = torch.randn(
348
- rgb_latent.shape, device=device, dtype=self.dtype, generator=rng
349
- ) # [B, 4, h, w]
350
-
351
- if depth_latent_init is not None:
352
- assert 0.0 <= depth_latent_init_strength <= 1.0
353
- assert (
354
- depth_latent_init.dim() == 4
355
- and depth_latent.dim() == 4
356
- and depth_latent_init.shape[0] == 1
357
- )
358
- if depth_latent.shape[0] != 1:
359
- depth_latent_init = depth_latent_init.repeat(
360
- depth_latent.shape[0], 1, 1, 1
361
- )
362
- depth_latent *= 1.0 - depth_latent_init_strength
363
- depth_latent = depth_latent + depth_latent_init * depth_latent_init_strength
364
-
365
- # Batched empty text embedding
366
- if self.empty_text_embed is None:
367
- self._encode_empty_text()
368
- batch_empty_text_embed = self.empty_text_embed.repeat(
369
- (rgb_latent.shape[0], 1, 1)
370
- ) # [B, 2, 1024]
371
-
372
- # Denoising loop
373
- if show_pbar:
374
- iterable = tqdm(
375
- enumerate(timesteps),
376
- total=len(timesteps),
377
- leave=False,
378
- desc=" " * 4 + "Diffusion denoising",
379
- )
380
- else:
381
- iterable = enumerate(timesteps)
382
-
383
- for i, t in iterable:
384
- unet_input = torch.cat(
385
- [rgb_latent, depth_latent], dim=1
386
- ) # this order is important
387
-
388
- # predict the noise residual
389
- noise_pred = self.unet(
390
- unet_input, t, encoder_hidden_states=batch_empty_text_embed
391
- ).sample # [B, 4, h, w]
392
-
393
- # compute the previous noisy sample x_t -> x_t-1
394
- depth_latent = self.scheduler.step(
395
- noise_pred, t, depth_latent, generator=rng
396
- ).prev_sample
397
-
398
- depth = self._decode_depth(depth_latent)
399
-
400
- # clip prediction
401
- depth = torch.clip(depth, -1.0, 1.0)
402
- # shift to [0, 1]
403
- depth = (depth + 1.0) / 2.0
404
-
405
- return depth, depth_latent
406
-
407
- def _encode_depth(self, depth_in: torch.Tensor) -> torch.Tensor:
408
- """
409
- Encode depth image into latent.
410
-
411
- Args:
412
- depth_in (`torch.Tensor`):
413
- Input Depth image to be encoded.
414
-
415
- Returns:
416
- `torch.Tensor`: Depth latent.
417
- """
418
- # encode
419
- dims = depth_in.squeeze().shape
420
- h = self.vae.encoder(depth_in.reshape(1, 1, *dims).repeat(1, 3, 1, 1))
421
- moments = self.vae.quant_conv(h)
422
- mean, _ = torch.chunk(moments, 2, dim=1)
423
- depth_latent = mean * self.depth_latent_scale_factor
424
- return depth_latent
425
-
426
- def _encode_rgb(self, rgb_in: torch.Tensor) -> torch.Tensor:
427
- """
428
- Encode RGB image into latent.
429
-
430
- Args:
431
- rgb_in (`torch.Tensor`):
432
- Input RGB image to be encoded.
433
-
434
- Returns:
435
- `torch.Tensor`: Image latent.
436
- """
437
- # encode
438
- h = self.vae.encoder(rgb_in)
439
- moments = self.vae.quant_conv(h)
440
- mean, logvar = torch.chunk(moments, 2, dim=1)
441
- # scale latent
442
- rgb_latent = mean * self.rgb_latent_scale_factor
443
- return rgb_latent
444
-
445
- def _decode_depth(self, depth_latent: torch.Tensor) -> torch.Tensor:
446
- """
447
- Decode depth latent into depth map.
448
-
449
- Args:
450
- depth_latent (`torch.Tensor`):
451
- Depth latent to be decoded.
452
-
453
- Returns:
454
- `torch.Tensor`: Decoded depth map.
455
- """
456
- # scale latent
457
- depth_latent = depth_latent / self.depth_latent_scale_factor
458
- # decode
459
- z = self.vae.post_quant_conv(depth_latent)
460
- stacked = self.vae.decoder(z)
461
- # mean of output channels
462
- depth_mean = stacked.mean(dim=1, keepdim=True)
463
- return depth_mean
464
-
465
- @staticmethod
466
- def resize_max_res(img: Image.Image, max_edge_resolution: int) -> Image.Image:
467
- """
468
- Resize image to limit maximum edge length while keeping aspect ratio.
469
-
470
- Args:
471
- img (`Image.Image`):
472
- Image to be resized.
473
- max_edge_resolution (`int`):
474
- Maximum edge length (pixel).
475
-
476
- Returns:
477
- `Image.Image`: Resized image.
478
- """
479
- original_width, original_height = img.size
480
- downscale_factor = min(
481
- max_edge_resolution / original_width, max_edge_resolution / original_height
482
- )
483
-
484
- new_width = int(original_width * downscale_factor)
485
- new_height = int(original_height * downscale_factor)
486
-
487
- resized_img = img.resize((new_width, new_height))
488
- return resized_img
489
-
490
- @staticmethod
491
- def colorize_depth_maps(
492
- depth_map, min_depth, max_depth, cmap="Spectral", valid_mask=None
493
- ):
494
- """
495
- Colorize depth maps.
496
- """
497
- assert len(depth_map.shape) >= 2, "Invalid dimension"
498
-
499
- if isinstance(depth_map, torch.Tensor):
500
- depth = depth_map.detach().squeeze().numpy()
501
- elif isinstance(depth_map, np.ndarray):
502
- depth = depth_map.copy().squeeze()
503
- # reshape to [ (B,) H, W ]
504
- if depth.ndim < 3:
505
- depth = depth[np.newaxis, :, :]
506
-
507
- # colorize
508
- cm = matplotlib.colormaps[cmap]
509
- depth = ((depth - min_depth) / (max_depth - min_depth)).clip(0, 1)
510
- img_colored_np = cm(depth, bytes=False)[:, :, :, 0:3] # value from 0 to 1
511
- img_colored_np = np.rollaxis(img_colored_np, 3, 1)
512
-
513
- if valid_mask is not None:
514
- if isinstance(depth_map, torch.Tensor):
515
- valid_mask = valid_mask.detach().numpy()
516
- valid_mask = valid_mask.squeeze() # [H, W] or [B, H, W]
517
- if valid_mask.ndim < 3:
518
- valid_mask = valid_mask[np.newaxis, np.newaxis, :, :]
519
- else:
520
- valid_mask = valid_mask[:, np.newaxis, :, :]
521
- valid_mask = np.repeat(valid_mask, 3, axis=1)
522
- img_colored_np[~valid_mask] = 0
523
-
524
- if isinstance(depth_map, torch.Tensor):
525
- img_colored = torch.from_numpy(img_colored_np).float()
526
- elif isinstance(depth_map, np.ndarray):
527
- img_colored = img_colored_np
528
-
529
- return img_colored
530
-
531
- @staticmethod
532
- def chw2hwc(chw):
533
- assert 3 == len(chw.shape)
534
- if isinstance(chw, torch.Tensor):
535
- hwc = torch.permute(chw, (1, 2, 0))
536
- elif isinstance(chw, np.ndarray):
537
- hwc = np.moveaxis(chw, 0, -1)
538
- return hwc
539
-
540
- @staticmethod
541
- def _find_batch_size(ensemble_size: int, input_res: int, dtype: torch.dtype) -> int:
542
- """
543
- Automatically search for suitable operating batch size.
544
-
545
- Args:
546
- ensemble_size (`int`):
547
- Number of predictions to be ensembled.
548
- input_res (`int`):
549
- Operating resolution of the input image.
550
-
551
- Returns:
552
- `int`: Operating batch size.
553
- """
554
- # Search table for suggested max. inference batch size
555
- bs_search_table = [
556
- # tested on A100-PCIE-80GB
557
- {"res": 768, "total_vram": 79, "bs": 35, "dtype": torch.float32},
558
- {"res": 1024, "total_vram": 79, "bs": 20, "dtype": torch.float32},
559
- # tested on A100-PCIE-40GB
560
- {"res": 768, "total_vram": 39, "bs": 15, "dtype": torch.float32},
561
- {"res": 1024, "total_vram": 39, "bs": 8, "dtype": torch.float32},
562
- {"res": 768, "total_vram": 39, "bs": 30, "dtype": torch.float16},
563
- {"res": 1024, "total_vram": 39, "bs": 15, "dtype": torch.float16},
564
- # tested on RTX3090, RTX4090
565
- {"res": 512, "total_vram": 23, "bs": 20, "dtype": torch.float32},
566
- {"res": 768, "total_vram": 23, "bs": 7, "dtype": torch.float32},
567
- {"res": 1024, "total_vram": 23, "bs": 3, "dtype": torch.float32},
568
- {"res": 512, "total_vram": 23, "bs": 40, "dtype": torch.float16},
569
- {"res": 768, "total_vram": 23, "bs": 18, "dtype": torch.float16},
570
- {"res": 1024, "total_vram": 23, "bs": 10, "dtype": torch.float16},
571
- # tested on GTX1080Ti
572
- {"res": 512, "total_vram": 10, "bs": 5, "dtype": torch.float32},
573
- {"res": 768, "total_vram": 10, "bs": 2, "dtype": torch.float32},
574
- {"res": 512, "total_vram": 10, "bs": 10, "dtype": torch.float16},
575
- {"res": 768, "total_vram": 10, "bs": 5, "dtype": torch.float16},
576
- {"res": 1024, "total_vram": 10, "bs": 3, "dtype": torch.float16},
577
- ]
578
-
579
- if not torch.cuda.is_available():
580
- return 1
581
-
582
- total_vram = torch.cuda.mem_get_info()[1] / 1024.0**3
583
- filtered_bs_search_table = [s for s in bs_search_table if s["dtype"] == dtype]
584
- for settings in sorted(
585
- filtered_bs_search_table,
586
- key=lambda k: (k["res"], -k["total_vram"]),
587
- ):
588
- if input_res <= settings["res"] and total_vram >= settings["total_vram"]:
589
- bs = settings["bs"]
590
- if bs > ensemble_size:
591
- bs = ensemble_size
592
- elif bs > math.ceil(ensemble_size / 2) and bs < ensemble_size:
593
- bs = math.ceil(ensemble_size / 2)
594
- return bs
595
-
596
- return 1
597
-
598
- @staticmethod
599
- def ensemble_depths(
600
- input_images: torch.Tensor,
601
- regularizer_strength: float = 0.02,
602
- max_iter: int = 2,
603
- tol: float = 1e-3,
604
- reduction: str = "median",
605
- max_res: int = None,
606
- ):
607
- """
608
- To ensemble multiple affine-invariant depth images (up to scale and shift),
609
- by aligning estimating the scale and shift
610
- """
611
-
612
- def inter_distances(tensors: torch.Tensor):
613
- """
614
- To calculate the distance between each two depth maps.
615
- """
616
- distances = []
617
- for i, j in torch.combinations(torch.arange(tensors.shape[0])):
618
- arr1 = tensors[i : i + 1]
619
- arr2 = tensors[j : j + 1]
620
- distances.append(arr1 - arr2)
621
- dist = torch.concatenate(distances, dim=0)
622
- return dist
623
-
624
- device = input_images.device
625
- dtype = input_images.dtype
626
- np_dtype = np.float32
627
-
628
- original_input = input_images.clone()
629
- n_img = input_images.shape[0]
630
- ori_shape = input_images.shape
631
-
632
- if max_res is not None:
633
- scale_factor = torch.min(max_res / torch.tensor(ori_shape[-2:]))
634
- if scale_factor < 1:
635
- downscaler = torch.nn.Upsample(
636
- scale_factor=scale_factor, mode="nearest"
637
- )
638
- input_images = downscaler(torch.from_numpy(input_images)).numpy()
639
-
640
- # init guess
641
- _min = np.min(input_images.reshape((n_img, -1)).cpu().numpy(), axis=1)
642
- _max = np.max(input_images.reshape((n_img, -1)).cpu().numpy(), axis=1)
643
- s_init = 1.0 / (_max - _min).reshape((-1, 1, 1))
644
- t_init = (-1 * s_init.flatten() * _min.flatten()).reshape((-1, 1, 1))
645
- x = np.concatenate([s_init, t_init]).reshape(-1).astype(np_dtype)
646
-
647
- input_images = input_images.to(device)
648
-
649
- # objective function
650
- def closure(x):
651
- l = len(x)
652
- s = x[: int(l / 2)]
653
- t = x[int(l / 2) :]
654
- s = torch.from_numpy(s).to(dtype=dtype).to(device)
655
- t = torch.from_numpy(t).to(dtype=dtype).to(device)
656
-
657
- transformed_arrays = input_images * s.view((-1, 1, 1)) + t.view((-1, 1, 1))
658
- dists = inter_distances(transformed_arrays)
659
- sqrt_dist = torch.sqrt(torch.mean(dists**2))
660
-
661
- if "mean" == reduction:
662
- pred = torch.mean(transformed_arrays, dim=0)
663
- elif "median" == reduction:
664
- pred = torch.median(transformed_arrays, dim=0).values
665
- else:
666
- raise ValueError
667
-
668
- near_err = torch.sqrt((0 - torch.min(pred)) ** 2)
669
- far_err = torch.sqrt((1 - torch.max(pred)) ** 2)
670
-
671
- err = sqrt_dist + (near_err + far_err) * regularizer_strength
672
- err = err.detach().cpu().numpy().astype(np_dtype)
673
- return err
674
-
675
- res = minimize(
676
- closure,
677
- x,
678
- method="BFGS",
679
- tol=tol,
680
- options={"maxiter": max_iter, "disp": False},
681
- )
682
- x = res.x
683
- l = len(x)
684
- s = x[: int(l / 2)]
685
- t = x[int(l / 2) :]
686
-
687
- # Prediction
688
- s = torch.from_numpy(s).to(dtype=dtype).to(device)
689
- t = torch.from_numpy(t).to(dtype=dtype).to(device)
690
- transformed_arrays = original_input * s.view(-1, 1, 1) + t.view(-1, 1, 1)
691
- if "mean" == reduction:
692
- aligned_images = torch.mean(transformed_arrays, dim=0)
693
- std = torch.std(transformed_arrays, dim=0)
694
- uncertainty = std
695
- elif "median" == reduction:
696
- aligned_images = torch.median(transformed_arrays, dim=0).values
697
- # MAD (median absolute deviation) as uncertainty indicator
698
- abs_dev = torch.abs(transformed_arrays - aligned_images)
699
- mad = torch.median(abs_dev, dim=0).values
700
- uncertainty = mad
701
- else:
702
- raise ValueError(f"Unknown reduction method: {reduction}")
703
-
704
- # Scale and shift to [0, 1]
705
- _min = torch.min(aligned_images)
706
- _max = torch.max(aligned_images)
707
- aligned_images = (aligned_images - _min) / (_max - _min)
708
- uncertainty /= _max - _min
709
-
710
- return aligned_images, uncertainty
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
marigold_logo_square.jpg DELETED

Git LFS Details

  • SHA256: bd5f1e527678fc913aee17ab69831551cfdb2934f673e9e97a7f011103b63c9e
  • Pointer size: 130 Bytes
  • Size of remote file: 76 kB