selmee commited on
Commit
cbcd8d9
1 Parent(s): 28e7ea2

Upload 30 files

Browse files
src/depth_pro.egg-info/PKG-INFO ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Metadata-Version: 2.1
2
+ Name: depth_pro
3
+ Version: 0.1
4
+ Summary: Inference/Network/Model code for Apple Depth Pro monocular depth estimation.
5
+ Project-URL: Homepage, https://github.com/apple/ml-depth-pro
6
+ Project-URL: Repository, https://github.com/apple/ml-depth-pro
7
+ Description-Content-Type: text/markdown
8
+ License-File: LICENSE
9
+ Requires-Dist: torch
10
+ Requires-Dist: torchvision
11
+ Requires-Dist: timm
12
+ Requires-Dist: numpy<2
13
+ Requires-Dist: pillow_heif
14
+ Requires-Dist: matplotlib
15
+
16
+ ## Depth Pro: Sharp Monocular Metric Depth in Less Than a Second
17
+
18
+ This software project accompanies the research paper:
19
+ **Depth Pro: Sharp Monocular Metric Depth in Less Than a Second**,
20
+ *Aleksei Bochkovskii, Amaël Delaunoy, Hugo Germain, Marcel Santos, Yichao Zhou, Stephan R. Richter, and Vladlen Koltun*.
21
+
22
+ ![](data/depth-pro-teaser.jpg)
23
+
24
+ We present a foundation model for zero-shot metric monocular depth estimation. Our model, Depth Pro, synthesizes high-resolution depth maps with unparalleled sharpness and high-frequency details. The predictions are metric, with absolute scale, without relying on the availability of metadata such as camera intrinsics. And the model is fast, producing a 2.25-megapixel depth map in 0.3 seconds on a standard GPU. These characteristics are enabled by a number of technical contributions, including an efficient multi-scale vision transformer for dense prediction, a training protocol that combines real and synthetic datasets to achieve high metric accuracy alongside fine boundary tracing, dedicated evaluation metrics for boundary accuracy in estimated depth maps, and state-of-the-art focal length estimation from a single image.
25
+
26
+
27
+ The model in this repository is a reference implementation, which has been re-trained. Its performance is close to the model reported in the paper but does not match it exactly.
28
+
29
+ ## Getting Started
30
+
31
+ We recommend setting up a virtual environment. Using e.g. miniconda, the `depth_pro` package can be installed via:
32
+
33
+ ```bash
34
+ conda create -n depth-pro -y python=3.9
35
+ conda activate depth-pro
36
+
37
+ pip install -e .
38
+ ```
39
+
40
+ To download pretrained checkpoints follow the code snippet below:
41
+ ```bash
42
+ source get_pretrained_models.sh # Files will be downloaded to `checkpoints` directory.
43
+ ```
44
+
45
+ ### Running from commandline
46
+
47
+ We provide a helper script to directly run the model on a single image:
48
+ ```bash
49
+ # Run prediction on a single image:
50
+ depth-pro-run -i ./data/example.jpg
51
+ # Run `depth-pro-run -h` for available options.
52
+ ```
53
+
54
+ ### Running from python
55
+
56
+ ```python
57
+ from PIL import Image
58
+ import depth_pro
59
+
60
+ # Load model and preprocessing transform
61
+ model, transform = depth_pro.create_model_and_transforms()
62
+ model.eval()
63
+
64
+ # Load and preprocess an image.
65
+ image, _, f_px = depth_pro.load_rgb(image_path)
66
+ image = transform(image)
67
+
68
+ # Run inference.
69
+ prediction = model.infer(image, f_px=f_px)
70
+ depth = prediction["depth"] # Depth in [m].
71
+ focallength_px = prediction["focallength_px"] # Focal length in pixels.
72
+ ```
73
+
74
+
75
+ ### Evaluation (boundary metrics)
76
+
77
+ Our boundary metrics can be found under `eval/boundary_metrics.py` and used as follows:
78
+
79
+ ```python
80
+ # for a depth-based dataset
81
+ boundary_f1 = SI_boundary_F1(predicted_depth, target_depth)
82
+
83
+ # for a mask-based dataset (image matting / segmentation)
84
+ boundary_recall = SI_boundary_Recall(predicted_depth, target_mask)
85
+ ```
86
+
87
+
88
+ ## Citation
89
+
90
+ If you find our work useful, please cite the following paper:
91
+
92
+ ```bibtex
93
+ @article{Bochkovskii2024:arxiv,
94
+ author = {Aleksei Bochkovskii and Ama\"{e}l Delaunoy and Hugo Germain and Marcel Santos and
95
+ Yichao Zhou and Stephan R. Richter and Vladlen Koltun}
96
+ title = {Depth Pro: Sharp Monocular Metric Depth in Less Than a Second},
97
+ journal = {arXiv},
98
+ year = {2024},
99
+ }
100
+ ```
101
+
102
+ ## License
103
+ This sample code is released under the [LICENSE](LICENSE) terms.
104
+
105
+ The model weights are released under the [LICENSE](LICENSE) terms.
106
+
107
+ ## Acknowledgements
108
+
109
+ Our codebase is built using multiple opensource contributions, please see [Acknowledgements](ACKNOWLEDGEMENTS.md) for more details.
110
+
111
+ Please check the paper for a complete list of references and datasets used in this work.
src/depth_pro.egg-info/SOURCES.txt ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ACKNOWLEDGEMENTS.md
2
+ CODE_OF_CONDUCT.md
3
+ CONTRIBUTING.md
4
+ LICENSE
5
+ README.md
6
+ get_pretrained_models.sh
7
+ pyproject.toml
8
+ data/depth-pro-teaser.jpg
9
+ data/example.jpg
10
+ src/depth_pro/__init__.py
11
+ src/depth_pro/depth_pro.py
12
+ src/depth_pro/utils.py
13
+ src/depth_pro.egg-info/PKG-INFO
14
+ src/depth_pro.egg-info/SOURCES.txt
15
+ src/depth_pro.egg-info/dependency_links.txt
16
+ src/depth_pro.egg-info/entry_points.txt
17
+ src/depth_pro.egg-info/requires.txt
18
+ src/depth_pro.egg-info/top_level.txt
19
+ src/depth_pro/cli/__init__.py
20
+ src/depth_pro/cli/run.py
21
+ src/depth_pro/eval/boundary_metrics.py
22
+ src/depth_pro/eval/dis5k_sample_list.txt
23
+ src/depth_pro/network/__init__.py
24
+ src/depth_pro/network/decoder.py
25
+ src/depth_pro/network/encoder.py
26
+ src/depth_pro/network/fov.py
27
+ src/depth_pro/network/vit.py
28
+ src/depth_pro/network/vit_factory.py
src/depth_pro.egg-info/dependency_links.txt ADDED
@@ -0,0 +1 @@
 
 
1
+
src/depth_pro.egg-info/entry_points.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ [console_scripts]
2
+ depth-pro-run = depth_pro.cli:run_main
src/depth_pro.egg-info/requires.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ torch
2
+ torchvision
3
+ timm
4
+ numpy<2
5
+ pillow_heif
6
+ matplotlib
src/depth_pro.egg-info/top_level.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ depth_pro
src/depth_pro/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # Copyright (C) 2024 Apple Inc. All Rights Reserved.
2
+ """Depth Pro package."""
3
+
4
+ from .depth_pro import create_model_and_transforms # noqa
5
+ from .utils import load_rgb # noqa
src/depth_pro/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (286 Bytes). View file
 
src/depth_pro/__pycache__/depth_pro.cpython-39.pyc ADDED
Binary file (7.82 kB). View file
 
src/depth_pro/__pycache__/utils.cpython-39.pyc ADDED
Binary file (3.25 kB). View file
 
src/depth_pro/cli/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ # Copyright (C) 2024 Apple Inc. All Rights Reserved.
2
+ """Depth Pro CLI and tools."""
3
+
4
+ from .run import main as run_main # noqa
src/depth_pro/cli/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (239 Bytes). View file
 
src/depth_pro/cli/__pycache__/run.cpython-39.pyc ADDED
Binary file (3.39 kB). View file
 
src/depth_pro/cli/run.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """Sample script to run DepthPro.
3
+
4
+ Copyright (C) 2024 Apple Inc. All Rights Reserved.
5
+ """
6
+
7
+
8
+ import argparse
9
+ import logging
10
+ from pathlib import Path
11
+
12
+ import numpy as np
13
+ import PIL.Image
14
+ import torch
15
+ from matplotlib import pyplot as plt
16
+ from tqdm import tqdm
17
+
18
+ from depth_pro import create_model_and_transforms, load_rgb
19
+
20
+ LOGGER = logging.getLogger(__name__)
21
+
22
+
23
+ def get_torch_device() -> torch.device:
24
+ """Get the Torch device."""
25
+ device = torch.device("cpu")
26
+ if torch.cuda.is_available():
27
+ device = torch.device("cuda:0")
28
+ elif torch.backends.mps.is_available():
29
+ device = torch.device("mps")
30
+ return device
31
+
32
+
33
+ def run(args):
34
+ """Run Depth Pro on a sample image."""
35
+ if args.verbose:
36
+ logging.basicConfig(level=logging.INFO)
37
+
38
+ # Load model.
39
+ model, transform = create_model_and_transforms(
40
+ device=get_torch_device(),
41
+ precision=torch.half,
42
+ )
43
+ model.eval()
44
+
45
+ image_paths = [args.image_path]
46
+ if args.image_path.is_dir():
47
+ image_paths = args.image_path.glob("**/*")
48
+ relative_path = args.image_path
49
+ else:
50
+ relative_path = args.image_path.parent
51
+
52
+ if not args.skip_display:
53
+ plt.ion()
54
+ fig = plt.figure()
55
+ ax_rgb = fig.add_subplot(121)
56
+ ax_disp = fig.add_subplot(122)
57
+
58
+ for image_path in tqdm(image_paths):
59
+ # Load image and focal length from exif info (if found.).
60
+ try:
61
+ LOGGER.info(f"Loading image {image_path} ...")
62
+ image, _, f_px = load_rgb(image_path)
63
+ except Exception as e:
64
+ LOGGER.error(str(e))
65
+ continue
66
+ # Run prediction. If `f_px` is provided, it is used to estimate the final metric depth,
67
+ # otherwise the model estimates `f_px` to compute the depth metricness.
68
+ prediction = model.infer(transform(image), f_px=f_px)
69
+
70
+ # Extract the depth and focal length.
71
+ depth = prediction["depth"].detach().cpu().numpy().squeeze()
72
+ if f_px is not None:
73
+ LOGGER.debug(f"Focal length (from exif): {f_px:0.2f}")
74
+ elif prediction["focallength_px"] is not None:
75
+ focallength_px = prediction["focallength_px"].detach().cpu().item()
76
+ LOGGER.info(f"Estimated focal length: {focallength_px}")
77
+
78
+ # Save Depth as npz file.
79
+ if args.output_path is not None:
80
+ output_file = (
81
+ args.output_path
82
+ / image_path.relative_to(relative_path).parent
83
+ / image_path.stem
84
+ )
85
+ LOGGER.info(f"Saving depth map to: {str(output_file)}")
86
+ output_file.parent.mkdir(parents=True, exist_ok=True)
87
+ np.savez_compressed(output_file, depth=depth)
88
+
89
+ # Save as color-mapped "turbo" jpg image.
90
+ cmap = plt.get_cmap("turbo_r")
91
+ normalized_depth = (depth - depth.min()) / (
92
+ depth.max() - depth.min()
93
+ )
94
+ color_depth = (cmap(normalized_depth)[..., :3] * 255).astype(
95
+ np.uint8
96
+ )
97
+ color_map_output_file = str(output_file) + ".jpg"
98
+ LOGGER.info(f"Saving color-mapped depth to: : {color_map_output_file}")
99
+ PIL.Image.fromarray(color_depth).save(
100
+ color_map_output_file, format="JPEG", quality=90
101
+ )
102
+
103
+ # Display the image and estimated depth map.
104
+ if not args.skip_display:
105
+ ax_rgb.imshow(image)
106
+ ax_disp.imshow(depth, cmap="turbo_r")
107
+ fig.canvas.draw()
108
+ fig.canvas.flush_events()
109
+
110
+ LOGGER.info("Done predicting depth!")
111
+ if not args.skip_display:
112
+ plt.show(block=True)
113
+
114
+
115
+ def main():
116
+ """Run DepthPro inference example."""
117
+ parser = argparse.ArgumentParser(
118
+ description="Inference scripts of DepthPro with PyTorch models."
119
+ )
120
+ parser.add_argument(
121
+ "-i",
122
+ "--image-path",
123
+ type=Path,
124
+ default="./data/example.jpg",
125
+ help="Path to input image.",
126
+ )
127
+ parser.add_argument(
128
+ "-o",
129
+ "--output-path",
130
+ type=Path,
131
+ help="Path to store output files.",
132
+ )
133
+ parser.add_argument(
134
+ "--skip-display",
135
+ action="store_true",
136
+ help="Skip matplotlib display.",
137
+ )
138
+ parser.add_argument(
139
+ "-v",
140
+ "--verbose",
141
+ action="store_true",
142
+ help="Show verbose output."
143
+ )
144
+
145
+ run(parser.parse_args())
146
+
147
+
148
+ if __name__ == "__main__":
149
+ main()
src/depth_pro/depth_pro.py ADDED
@@ -0,0 +1,300 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2024 Apple Inc. All Rights Reserved.
2
+ # Depth Pro: Sharp Monocular Metric Depth in Less Than a Second
3
+
4
+
5
+ from __future__ import annotations
6
+
7
+ from dataclasses import dataclass
8
+ from typing import Mapping, Optional, Tuple, Union
9
+
10
+ import torch
11
+ from torch import nn
12
+ from torchvision.transforms import (
13
+ Compose,
14
+ ConvertImageDtype,
15
+ Lambda,
16
+ Normalize,
17
+ ToTensor,
18
+ )
19
+
20
+ from .network.decoder import MultiresConvDecoder
21
+ from .network.encoder import DepthProEncoder
22
+ from .network.fov import FOVNetwork
23
+ from .network.vit_factory import VIT_CONFIG_DICT, ViTPreset, create_vit
24
+
25
+
26
+ @dataclass
27
+ class DepthProConfig:
28
+ """Configuration for DepthPro."""
29
+
30
+ patch_encoder_preset: ViTPreset
31
+ image_encoder_preset: ViTPreset
32
+ decoder_features: int
33
+
34
+ checkpoint_uri: Optional[str] = None
35
+ fov_encoder_preset: Optional[ViTPreset] = None
36
+ use_fov_head: bool = True
37
+
38
+
39
+ DEFAULT_MONODEPTH_CONFIG_DICT = DepthProConfig(
40
+ patch_encoder_preset="dinov2l16_384",
41
+ image_encoder_preset="dinov2l16_384",
42
+ checkpoint_uri="./checkpoints/depth_pro.pt",
43
+ decoder_features=256,
44
+ use_fov_head=True,
45
+ fov_encoder_preset="dinov2l16_384",
46
+ )
47
+
48
+
49
+ def create_backbone_model(
50
+ preset: ViTPreset
51
+ ) -> Tuple[nn.Module, ViTPreset]:
52
+ """Create and load a backbone model given a config.
53
+
54
+ Args:
55
+ ----
56
+ preset: A backbone preset to load pre-defind configs.
57
+
58
+ Returns:
59
+ -------
60
+ A Torch module and the associated config.
61
+
62
+ """
63
+ if preset in VIT_CONFIG_DICT:
64
+ config = VIT_CONFIG_DICT[preset]
65
+ model = create_vit(preset=preset, use_pretrained=False)
66
+ else:
67
+ raise KeyError(f"Preset {preset} not found.")
68
+
69
+ return model, config
70
+
71
+
72
+ def create_model_and_transforms(
73
+ config: DepthProConfig = DEFAULT_MONODEPTH_CONFIG_DICT,
74
+ device: torch.device = torch.device("cpu"),
75
+ precision: torch.dtype = torch.float32,
76
+ ) -> Tuple[DepthPro, Compose]:
77
+ """Create a DepthPro model and load weights from `config.checkpoint_uri`.
78
+
79
+ Args:
80
+ ----
81
+ config: The configuration for the DPT model architecture.
82
+ device: The optional Torch device to load the model onto, default runs on "cpu".
83
+ precision: The optional precision used for the model, default is FP32.
84
+
85
+ Returns:
86
+ -------
87
+ The Torch DepthPro model and associated Transform.
88
+
89
+ """
90
+ patch_encoder, patch_encoder_config = create_backbone_model(
91
+ preset=config.patch_encoder_preset
92
+ )
93
+ image_encoder, _ = create_backbone_model(
94
+ preset=config.image_encoder_preset
95
+ )
96
+
97
+ fov_encoder = None
98
+ if config.use_fov_head and config.fov_encoder_preset is not None:
99
+ fov_encoder, _ = create_backbone_model(preset=config.fov_encoder_preset)
100
+
101
+ dims_encoder = patch_encoder_config.encoder_feature_dims
102
+ hook_block_ids = patch_encoder_config.encoder_feature_layer_ids
103
+ encoder = DepthProEncoder(
104
+ dims_encoder=dims_encoder,
105
+ patch_encoder=patch_encoder,
106
+ image_encoder=image_encoder,
107
+ hook_block_ids=hook_block_ids,
108
+ decoder_features=config.decoder_features,
109
+ )
110
+ decoder = MultiresConvDecoder(
111
+ dims_encoder=[config.decoder_features] + list(encoder.dims_encoder),
112
+ dim_decoder=config.decoder_features,
113
+ )
114
+ model = DepthPro(
115
+ encoder=encoder,
116
+ decoder=decoder,
117
+ last_dims=(32, 1),
118
+ use_fov_head=config.use_fov_head,
119
+ fov_encoder=fov_encoder,
120
+ ).to(device)
121
+
122
+ if precision == torch.half:
123
+ model.half()
124
+
125
+ transform = Compose(
126
+ [
127
+ ToTensor(),
128
+ Lambda(lambda x: x.to(device)),
129
+ Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
130
+ ConvertImageDtype(precision),
131
+ ]
132
+ )
133
+
134
+ if config.checkpoint_uri is not None:
135
+ state_dict = torch.load(config.checkpoint_uri, map_location="cpu")
136
+ missing_keys, unexpected_keys = model.load_state_dict(
137
+ state_dict=state_dict, strict=True
138
+ )
139
+
140
+ if len(unexpected_keys) != 0:
141
+ raise KeyError(
142
+ f"Found unexpected keys when loading monodepth: {unexpected_keys}"
143
+ )
144
+
145
+ # fc_norm is only for the classification head,
146
+ # which we would not use. We only use the encoding.
147
+ missing_keys = [key for key in missing_keys if "fc_norm" not in key]
148
+ if len(missing_keys) != 0:
149
+ raise KeyError(f"Keys are missing when loading monodepth: {missing_keys}")
150
+
151
+ return model, transform
152
+
153
+
154
+ class DepthPro(nn.Module):
155
+ """DepthPro network."""
156
+
157
+ def __init__(
158
+ self,
159
+ encoder: DepthProEncoder,
160
+ decoder: MultiresConvDecoder,
161
+ last_dims: tuple[int, int],
162
+ use_fov_head: bool = True,
163
+ fov_encoder: Optional[nn.Module] = None,
164
+ ):
165
+ """Initialize DepthPro.
166
+
167
+ Args:
168
+ ----
169
+ encoder: The DepthProEncoder backbone.
170
+ decoder: The MultiresConvDecoder decoder.
171
+ last_dims: The dimension for the last convolution layers.
172
+ use_fov_head: Whether to use the field-of-view head.
173
+ fov_encoder: A separate encoder for the field of view.
174
+
175
+ """
176
+ super().__init__()
177
+
178
+ self.encoder = encoder
179
+ self.decoder = decoder
180
+
181
+ dim_decoder = decoder.dim_decoder
182
+ self.head = nn.Sequential(
183
+ nn.Conv2d(
184
+ dim_decoder, dim_decoder // 2, kernel_size=3, stride=1, padding=1
185
+ ),
186
+ nn.ConvTranspose2d(
187
+ in_channels=dim_decoder // 2,
188
+ out_channels=dim_decoder // 2,
189
+ kernel_size=2,
190
+ stride=2,
191
+ padding=0,
192
+ bias=True,
193
+ ),
194
+ nn.Conv2d(
195
+ dim_decoder // 2,
196
+ last_dims[0],
197
+ kernel_size=3,
198
+ stride=1,
199
+ padding=1,
200
+ ),
201
+ nn.ReLU(True),
202
+ nn.Conv2d(last_dims[0], last_dims[1], kernel_size=1, stride=1, padding=0),
203
+ nn.ReLU(),
204
+ )
205
+
206
+ # Set the final convoultion layer's bias to be 0.
207
+ self.head[4].bias.data.fill_(0)
208
+
209
+ # Set the FOV estimation head.
210
+ if use_fov_head:
211
+ self.fov = FOVNetwork(num_features=dim_decoder, fov_encoder=fov_encoder)
212
+
213
+ @property
214
+ def img_size(self) -> int:
215
+ """Return the internal image size of the network."""
216
+ return self.encoder.img_size
217
+
218
+ def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
219
+ """Decode by projection and fusion of multi-resolution encodings.
220
+
221
+ Args:
222
+ ----
223
+ x (torch.Tensor): Input image.
224
+
225
+ Returns:
226
+ -------
227
+ The canonical inverse depth map [m] and the optional estimated field of view [deg].
228
+
229
+ """
230
+ _, _, H, W = x.shape
231
+ print("Width:", W)
232
+ print("Height:", H)
233
+ assert H == self.img_size and W == self.img_size
234
+
235
+ encodings = self.encoder(x)
236
+ features, features_0 = self.decoder(encodings)
237
+ canonical_inverse_depth = self.head(features)
238
+
239
+ fov_deg = None
240
+ if hasattr(self, "fov"):
241
+ fov_deg = self.fov.forward(x, features_0.detach())
242
+
243
+ return canonical_inverse_depth, fov_deg
244
+
245
+ @torch.no_grad()
246
+ def infer(
247
+ self,
248
+ x: torch.Tensor,
249
+ f_px: Optional[Union[float, torch.Tensor]] = None,
250
+ interpolation_mode="bilinear",
251
+ ) -> Mapping[str, torch.Tensor]:
252
+ """Infer depth and fov for a given image.
253
+
254
+ If the image is not at network resolution, it is resized to 1536x1536 and
255
+ the estimated depth is resized to the original image resolution.
256
+ Note: if the focal length is given, the estimated value is ignored and the provided
257
+ focal length is use to generate the metric depth values.
258
+
259
+ Args:
260
+ ----
261
+ x (torch.Tensor): Input image
262
+ f_px (torch.Tensor): Optional focal length in pixels corresponding to `x`.
263
+ interpolation_mode (str): Interpolation function for downsampling/upsampling.
264
+
265
+ Returns:
266
+ -------
267
+ Tensor dictionary (torch.Tensor): depth [m], focallength [pixels].
268
+
269
+ """
270
+ if len(x.shape) == 3:
271
+ x = x.unsqueeze(0)
272
+ _, _, H, W = x.shape
273
+ resize = H != self.img_size or W != self.img_size
274
+
275
+ if resize:
276
+ x = nn.functional.interpolate(
277
+ x,
278
+ size=(self.img_size, self.img_size),
279
+ mode=interpolation_mode,
280
+ align_corners=False,
281
+ )
282
+
283
+ canonical_inverse_depth, fov_deg = self.forward(x)
284
+ if f_px is None:
285
+ f_px = 0.5 * W / torch.tan(0.5 * torch.deg2rad(fov_deg.to(torch.float)))
286
+
287
+ inverse_depth = canonical_inverse_depth * (W / f_px)
288
+ f_px = f_px.squeeze()
289
+
290
+ if resize:
291
+ inverse_depth = nn.functional.interpolate(
292
+ inverse_depth, size=(H, W), mode=interpolation_mode, align_corners=False
293
+ )
294
+
295
+ depth = 1.0 / torch.clamp(inverse_depth, min=1e-4, max=1e4)
296
+
297
+ return {
298
+ "depth": depth.squeeze(),
299
+ "focallength_px": f_px,
300
+ }
src/depth_pro/eval/boundary_metrics.py ADDED
@@ -0,0 +1,332 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Tuple
2
+
3
+ import numpy as np
4
+
5
+
6
+ def connected_component(r: np.ndarray, c: np.ndarray) -> List[List[int]]:
7
+ """Find connected components in the given row and column indices.
8
+
9
+ Args:
10
+ ----
11
+ r (np.ndarray): Row indices.
12
+ c (np.ndarray): Column indices.
13
+
14
+ Yields:
15
+ ------
16
+ List[int]: Indices of connected components.
17
+
18
+ """
19
+ indices = [0]
20
+ for i in range(1, r.size):
21
+ if r[i] == r[indices[-1]] and c[i] == c[indices[-1]] + 1:
22
+ indices.append(i)
23
+ else:
24
+ yield indices
25
+ indices = [i]
26
+ yield indices
27
+
28
+
29
+ def nms_horizontal(ratio: np.ndarray, threshold: float) -> np.ndarray:
30
+ """Apply Non-Maximum Suppression (NMS) horizontally on the given ratio matrix.
31
+
32
+ Args:
33
+ ----
34
+ ratio (np.ndarray): Input ratio matrix.
35
+ threshold (float): Threshold for NMS.
36
+
37
+ Returns:
38
+ -------
39
+ np.ndarray: Binary mask after applying NMS.
40
+
41
+ """
42
+ mask = np.zeros_like(ratio, dtype=bool)
43
+ r, c = np.nonzero(ratio > threshold)
44
+ if len(r) == 0:
45
+ return mask
46
+ for ids in connected_component(r, c):
47
+ values = [ratio[r[i], c[i]] for i in ids]
48
+ mi = np.argmax(values)
49
+ mask[r[ids[mi]], c[ids[mi]]] = True
50
+ return mask
51
+
52
+
53
+ def nms_vertical(ratio: np.ndarray, threshold: float) -> np.ndarray:
54
+ """Apply Non-Maximum Suppression (NMS) vertically on the given ratio matrix.
55
+
56
+ Args:
57
+ ----
58
+ ratio (np.ndarray): Input ratio matrix.
59
+ threshold (float): Threshold for NMS.
60
+
61
+ Returns:
62
+ -------
63
+ np.ndarray: Binary mask after applying NMS.
64
+
65
+ """
66
+ return np.transpose(nms_horizontal(np.transpose(ratio), threshold))
67
+
68
+
69
+ def fgbg_depth(
70
+ d: np.ndarray, t: float
71
+ ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
72
+ """Find foreground-background relations between neighboring pixels.
73
+
74
+ Args:
75
+ ----
76
+ d (np.ndarray): Depth matrix.
77
+ t (float): Threshold for comparison.
78
+
79
+ Returns:
80
+ -------
81
+ Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: Four matrices indicating
82
+ left, top, right, and bottom foreground-background relations.
83
+
84
+ """
85
+ right_is_big_enough = (d[..., :, 1:] / d[..., :, :-1]) > t
86
+ left_is_big_enough = (d[..., :, :-1] / d[..., :, 1:]) > t
87
+ bottom_is_big_enough = (d[..., 1:, :] / d[..., :-1, :]) > t
88
+ top_is_big_enough = (d[..., :-1, :] / d[..., 1:, :]) > t
89
+ return (
90
+ left_is_big_enough,
91
+ top_is_big_enough,
92
+ right_is_big_enough,
93
+ bottom_is_big_enough,
94
+ )
95
+
96
+
97
+ def fgbg_depth_thinned(
98
+ d: np.ndarray, t: float
99
+ ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
100
+ """Find foreground-background relations between neighboring pixels with Non-Maximum Suppression.
101
+
102
+ Args:
103
+ ----
104
+ d (np.ndarray): Depth matrix.
105
+ t (float): Threshold for NMS.
106
+
107
+ Returns:
108
+ -------
109
+ Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: Four matrices indicating
110
+ left, top, right, and bottom foreground-background relations with NMS applied.
111
+
112
+ """
113
+ right_is_big_enough = nms_horizontal(d[..., :, 1:] / d[..., :, :-1], t)
114
+ left_is_big_enough = nms_horizontal(d[..., :, :-1] / d[..., :, 1:], t)
115
+ bottom_is_big_enough = nms_vertical(d[..., 1:, :] / d[..., :-1, :], t)
116
+ top_is_big_enough = nms_vertical(d[..., :-1, :] / d[..., 1:, :], t)
117
+ return (
118
+ left_is_big_enough,
119
+ top_is_big_enough,
120
+ right_is_big_enough,
121
+ bottom_is_big_enough,
122
+ )
123
+
124
+
125
+ def fgbg_binary_mask(
126
+ d: np.ndarray,
127
+ ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
128
+ """Find foreground-background relations between neighboring pixels in binary masks.
129
+
130
+ Args:
131
+ ----
132
+ d (np.ndarray): Binary depth matrix.
133
+
134
+ Returns:
135
+ -------
136
+ Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: Four matrices indicating
137
+ left, top, right, and bottom foreground-background relations in binary masks.
138
+
139
+ """
140
+ assert d.dtype == bool
141
+ right_is_big_enough = d[..., :, 1:] & ~d[..., :, :-1]
142
+ left_is_big_enough = d[..., :, :-1] & ~d[..., :, 1:]
143
+ bottom_is_big_enough = d[..., 1:, :] & ~d[..., :-1, :]
144
+ top_is_big_enough = d[..., :-1, :] & ~d[..., 1:, :]
145
+ return (
146
+ left_is_big_enough,
147
+ top_is_big_enough,
148
+ right_is_big_enough,
149
+ bottom_is_big_enough,
150
+ )
151
+
152
+
153
+ def edge_recall_matting(pr: np.ndarray, gt: np.ndarray, t: float) -> float:
154
+ """Calculate edge recall for image matting.
155
+
156
+ Args:
157
+ ----
158
+ pr (np.ndarray): Predicted depth matrix.
159
+ gt (np.ndarray): Ground truth binary mask.
160
+ t (float): Threshold for NMS.
161
+
162
+ Returns:
163
+ -------
164
+ float: Edge recall value.
165
+
166
+ """
167
+ assert gt.dtype == bool
168
+ ap, bp, cp, dp = fgbg_depth_thinned(pr, t)
169
+ ag, bg, cg, dg = fgbg_binary_mask(gt)
170
+ return 0.25 * (
171
+ np.count_nonzero(ap & ag) / max(np.count_nonzero(ag), 1)
172
+ + np.count_nonzero(bp & bg) / max(np.count_nonzero(bg), 1)
173
+ + np.count_nonzero(cp & cg) / max(np.count_nonzero(cg), 1)
174
+ + np.count_nonzero(dp & dg) / max(np.count_nonzero(dg), 1)
175
+ )
176
+
177
+
178
+ def boundary_f1(
179
+ pr: np.ndarray,
180
+ gt: np.ndarray,
181
+ t: float,
182
+ return_p: bool = False,
183
+ return_r: bool = False,
184
+ ) -> float:
185
+ """Calculate Boundary F1 score.
186
+
187
+ Args:
188
+ ----
189
+ pr (np.ndarray): Predicted depth matrix.
190
+ gt (np.ndarray): Ground truth depth matrix.
191
+ t (float): Threshold for comparison.
192
+ return_p (bool, optional): If True, return precision. Defaults to False.
193
+ return_r (bool, optional): If True, return recall. Defaults to False.
194
+
195
+ Returns:
196
+ -------
197
+ float: Boundary F1 score, or precision, or recall depending on the flags.
198
+
199
+ """
200
+ ap, bp, cp, dp = fgbg_depth(pr, t)
201
+ ag, bg, cg, dg = fgbg_depth(gt, t)
202
+
203
+ r = 0.25 * (
204
+ np.count_nonzero(ap & ag) / max(np.count_nonzero(ag), 1)
205
+ + np.count_nonzero(bp & bg) / max(np.count_nonzero(bg), 1)
206
+ + np.count_nonzero(cp & cg) / max(np.count_nonzero(cg), 1)
207
+ + np.count_nonzero(dp & dg) / max(np.count_nonzero(dg), 1)
208
+ )
209
+ p = 0.25 * (
210
+ np.count_nonzero(ap & ag) / max(np.count_nonzero(ap), 1)
211
+ + np.count_nonzero(bp & bg) / max(np.count_nonzero(bp), 1)
212
+ + np.count_nonzero(cp & cg) / max(np.count_nonzero(cp), 1)
213
+ + np.count_nonzero(dp & dg) / max(np.count_nonzero(dp), 1)
214
+ )
215
+ if r + p == 0:
216
+ return 0.0
217
+ if return_p:
218
+ return p
219
+ if return_r:
220
+ return r
221
+ return 2 * (r * p) / (r + p)
222
+
223
+
224
+ def get_thresholds_and_weights(
225
+ t_min: float, t_max: float, N: int
226
+ ) -> Tuple[np.ndarray, np.ndarray]:
227
+ """Generate thresholds and weights for the given range.
228
+
229
+ Args:
230
+ ----
231
+ t_min (float): Minimum threshold.
232
+ t_max (float): Maximum threshold.
233
+ N (int): Number of thresholds.
234
+
235
+ Returns:
236
+ -------
237
+ Tuple[np.ndarray, np.ndarray]: Array of thresholds and corresponding weights.
238
+
239
+ """
240
+ thresholds = np.linspace(t_min, t_max, N)
241
+ weights = thresholds / thresholds.sum()
242
+ return thresholds, weights
243
+
244
+
245
+ def invert_depth(depth: np.ndarray, eps: float = 1e-6) -> np.ndarray:
246
+ """Inverts a depth map with numerical stability.
247
+
248
+ Args:
249
+ ----
250
+ depth (np.ndarray): Depth map to be inverted.
251
+ eps (float): Minimum value to avoid division by zero (default is 1e-6).
252
+
253
+ Returns:
254
+ -------
255
+ np.ndarray: Inverted depth map.
256
+
257
+ """
258
+ inverse_depth = 1.0 / depth.clip(min=eps)
259
+ return inverse_depth
260
+
261
+
262
+ def SI_boundary_F1(
263
+ predicted_depth: np.ndarray,
264
+ target_depth: np.ndarray,
265
+ t_min: float = 1.05,
266
+ t_max: float = 1.25,
267
+ N: int = 10,
268
+ ) -> float:
269
+ """Calculate Scale-Invariant Boundary F1 Score for depth-based ground-truth.
270
+
271
+ Args:
272
+ ----
273
+ predicted_depth (np.ndarray): Predicted depth matrix.
274
+ target_depth (np.ndarray): Ground truth depth matrix.
275
+ t_min (float, optional): Minimum threshold. Defaults to 1.05.
276
+ t_max (float, optional): Maximum threshold. Defaults to 1.25.
277
+ N (int, optional): Number of thresholds. Defaults to 10.
278
+
279
+ Returns:
280
+ -------
281
+ float: Scale-Invariant Boundary F1 Score.
282
+
283
+ """
284
+ assert predicted_depth.ndim == target_depth.ndim == 2
285
+ thresholds, weights = get_thresholds_and_weights(t_min, t_max, N)
286
+ f1_scores = np.array(
287
+ [
288
+ boundary_f1(invert_depth(predicted_depth), invert_depth(target_depth), t)
289
+ for t in thresholds
290
+ ]
291
+ )
292
+ return np.sum(f1_scores * weights)
293
+
294
+
295
+ def SI_boundary_Recall(
296
+ predicted_depth: np.ndarray,
297
+ target_mask: np.ndarray,
298
+ t_min: float = 1.05,
299
+ t_max: float = 1.25,
300
+ N: int = 10,
301
+ alpha_threshold: float = 0.1,
302
+ ) -> float:
303
+ """Calculate Scale-Invariant Boundary Recall Score for mask-based ground-truth.
304
+
305
+ Args:
306
+ ----
307
+ predicted_depth (np.ndarray): Predicted depth matrix.
308
+ target_mask (np.ndarray): Ground truth binary mask.
309
+ t_min (float, optional): Minimum threshold. Defaults to 1.05.
310
+ t_max (float, optional): Maximum threshold. Defaults to 1.25.
311
+ N (int, optional): Number of thresholds. Defaults to 10.
312
+ alpha_threshold (float, optional): Threshold for alpha masking. Defaults to 0.1.
313
+
314
+ Returns:
315
+ -------
316
+ float: Scale-Invariant Boundary Recall Score.
317
+
318
+ """
319
+ assert predicted_depth.ndim == target_mask.ndim == 2
320
+ thresholds, weights = get_thresholds_and_weights(t_min, t_max, N)
321
+ thresholded_target = target_mask > alpha_threshold
322
+
323
+ recall_scores = np.array(
324
+ [
325
+ edge_recall_matting(
326
+ invert_depth(predicted_depth), thresholded_target, t=float(t)
327
+ )
328
+ for t in thresholds
329
+ ]
330
+ )
331
+ weighted_recall = np.sum(recall_scores * weights)
332
+ return weighted_recall
src/depth_pro/eval/dis5k_sample_list.txt ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ DIS5K/DIS-TE1/im/12#Graphics#4#TrafficSign#8245751856_821be14f86_o.jpg
2
+ DIS5K/DIS-TE1/im/13#Insect#4#Butterfly#16023994688_7ff8cdccb1_o.jpg
3
+ DIS5K/DIS-TE1/im/14#Kitchenware#4#Kitchenware#IMG_20210520_205538.jpg
4
+ DIS5K/DIS-TE1/im/14#Kitchenware#8#SweetStand#4848284981_fc90f54b50_o.jpg
5
+ DIS5K/DIS-TE1/im/17#Non-motor Vehicle#4#Cart#15012855035_d10b57014f_o.jpg
6
+ DIS5K/DIS-TE1/im/2#Aircraft#5#Kite#13104545564_5afceec9bd_o.jpg
7
+ DIS5K/DIS-TE1/im/20#Sports#10#Skateboarding#8472763540_bb2390e928_o.jpg
8
+ DIS5K/DIS-TE1/im/21#Tool#14#Sword#32473146960_dcc6b77848_o.jpg
9
+ DIS5K/DIS-TE1/im/21#Tool#15#Tapeline#9680492386_2d2020f282_o.jpg
10
+ DIS5K/DIS-TE1/im/21#Tool#4#Flag#507752845_ef852100f0_o.jpg
11
+ DIS5K/DIS-TE1/im/21#Tool#6#Key#11966089533_3becd78b44_o.jpg
12
+ DIS5K/DIS-TE1/im/21#Tool#8#Scale#31946428472_d28def471b_o.jpg
13
+ DIS5K/DIS-TE1/im/22#Weapon#4#Rifle#8472656430_3eb908b211_o.jpg
14
+ DIS5K/DIS-TE1/im/8#Electronics#3#Earphone#1177468301_641df8c267_o.jpg
15
+ DIS5K/DIS-TE1/im/8#Electronics#9#MusicPlayer#2235782872_7d47847bb4_o.jpg
16
+ DIS5K/DIS-TE2/im/11#Furniture#13#Ladder#3878434417_2ed740586e_o.jpg
17
+ DIS5K/DIS-TE2/im/13#Insect#1#Ant#27047700955_3b3a1271f8_o.jpg
18
+ DIS5K/DIS-TE2/im/13#Insect#11#Spider#5567179191_38d1f65589_o.jpg
19
+ DIS5K/DIS-TE2/im/13#Insect#8#Locust#5237933769_e6687c05e4_o.jpg
20
+ DIS5K/DIS-TE2/im/14#Kitchenware#2#DishRack#70838854_40cf689da7_o.jpg
21
+ DIS5K/DIS-TE2/im/14#Kitchenware#8#SweetStand#8467929412_fef7f4275d_o.jpg
22
+ DIS5K/DIS-TE2/im/16#Music Instrument#2#Harp#28058219806_28e05ff24a_o.jpg
23
+ DIS5K/DIS-TE2/im/17#Non-motor Vehicle#1#BabyCarriage#29794777180_2e1695a0cf_o.jpg
24
+ DIS5K/DIS-TE2/im/19#Ship#3#Sailboat#22442908623_5977e3becf_o.jpg
25
+ DIS5K/DIS-TE2/im/2#Aircraft#5#Kite#44654358051_1400e71cc4_o.jpg
26
+ DIS5K/DIS-TE2/im/21#Tool#11#Stand#IMG_20210520_205442.jpg
27
+ DIS5K/DIS-TE2/im/21#Tool#17#Tripod#9318977876_34615ec9a0_o.jpg
28
+ DIS5K/DIS-TE2/im/5#Artifact#3#Handcraft#50860882577_8482143b1b_o.jpg
29
+ DIS5K/DIS-TE2/im/8#Electronics#10#Robot#3093360210_fee54dc5c5_o.jpg
30
+ DIS5K/DIS-TE2/im/8#Electronics#6#Microphone#47411477652_6da66cbc10_o.jpg
31
+ DIS5K/DIS-TE3/im/14#Kitchenware#4#Kitchenware#2451122898_ef883175dd_o.jpg
32
+ DIS5K/DIS-TE3/im/15#Machine#4#SewingMachine#9311164128_97ba1d3947_o.jpg
33
+ DIS5K/DIS-TE3/im/16#Music Instrument#2#Harp#7670920550_59e992fd7b_o.jpg
34
+ DIS5K/DIS-TE3/im/17#Non-motor Vehicle#1#BabyCarriage#8389984877_1fddf8715c_o.jpg
35
+ DIS5K/DIS-TE3/im/17#Non-motor Vehicle#3#Carriage#5947122724_98e0fc3d1f_o.jpg
36
+ DIS5K/DIS-TE3/im/2#Aircraft#2#Balloon#2487168092_641505883f_o.jpg
37
+ DIS5K/DIS-TE3/im/2#Aircraft#4#Helicopter#8401177591_06c71c8df2_o.jpg
38
+ DIS5K/DIS-TE3/im/20#Sports#1#Archery#12520003103_faa43ea3e0_o.jpg
39
+ DIS5K/DIS-TE3/im/21#Tool#11#Stand#IMG_20210709_221507.jpg
40
+ DIS5K/DIS-TE3/im/21#Tool#2#Clip#5656649687_63d0c6696d_o.jpg
41
+ DIS5K/DIS-TE3/im/21#Tool#6#Key#12878459244_6387a140ea_o.jpg
42
+ DIS5K/DIS-TE3/im/3#Aquatic#1#Lobster#109214461_f52b4b6093_o.jpg
43
+ DIS5K/DIS-TE3/im/4#Architecture#19#Windmill#20195851863_2627117e0e_o.jpg
44
+ DIS5K/DIS-TE3/im/5#Artifact#2#Cage#5821476369_ea23927487_o.jpg
45
+ DIS5K/DIS-TE3/im/8#Electronics#7#MobileHolder#49732997896_7f53c290b5_o.jpg
46
+ DIS5K/DIS-TE4/im/13#Insect#6#Centipede#15302179708_a267850881_o.jpg
47
+ DIS5K/DIS-TE4/im/17#Non-motor Vehicle#11#Tricycle#5771069105_a3aef6f665_o.jpg
48
+ DIS5K/DIS-TE4/im/17#Non-motor Vehicle#2#Bicycle#4245936196_fdf812dcb7_o.jpg
49
+ DIS5K/DIS-TE4/im/17#Non-motor Vehicle#9#ShoppingCart#4674052920_a5b7a2b236_o.jpg
50
+ DIS5K/DIS-TE4/im/18#Plant#1#Bonsai#3539420884_ca8973e2c0_o.jpg
51
+ DIS5K/DIS-TE4/im/2#Aircraft#6#Parachute#33590416634_9d6f2325e7_o.jpg
52
+ DIS5K/DIS-TE4/im/20#Sports#1#Archery#46924476515_0be1caa684_o.jpg
53
+ DIS5K/DIS-TE4/im/20#Sports#8#Racket#19337607166_dd1985fb59_o.jpg
54
+ DIS5K/DIS-TE4/im/21#Tool#6#Key#3193329588_839b0c74ce_o.jpg
55
+ DIS5K/DIS-TE4/im/5#Artifact#2#Cage#5821886526_0573ba2d0d_o.jpg
56
+ DIS5K/DIS-TE4/im/5#Artifact#3#Handcraft#50105138282_3c1d02c968_o.jpg
57
+ DIS5K/DIS-TE4/im/8#Electronics#1#Antenna#4305034305_874f21a701_o.jpg
58
+ DIS5K/DIS-TR/im/1#Accessories#1#Bag#15554964549_3105e51b6f_o.jpg
59
+ DIS5K/DIS-TR/im/1#Accessories#1#Bag#41104261980_098a6c4a56_o.jpg
60
+ DIS5K/DIS-TR/im/1#Accessories#2#Clothes#2284764037_871b2e8ca4_o.jpg
61
+ DIS5K/DIS-TR/im/1#Accessories#3#Eyeglasses#1824643784_70d0134156_o.jpg
62
+ DIS5K/DIS-TR/im/1#Accessories#3#Eyeglasses#3590020230_37b09a29b3_o.jpg
63
+ DIS5K/DIS-TR/im/1#Accessories#3#Eyeglasses#4809652879_4da8a69f3b_o.jpg
64
+ DIS5K/DIS-TR/im/1#Accessories#3#Eyeglasses#792204934_f9b28f99b4_o.jpg
65
+ DIS5K/DIS-TR/im/1#Accessories#5#Jewelry#13909132974_c4750c5fb7_o.jpg
66
+ DIS5K/DIS-TR/im/1#Accessories#7#Shoe#2483391615_9199ece8d6_o.jpg
67
+ DIS5K/DIS-TR/im/1#Accessories#8#Watch#4343266960_f6633b029b_o.jpg
68
+ DIS5K/DIS-TR/im/10#Frame#2#BicycleFrame#17897573_42964dd104_o.jpg
69
+ DIS5K/DIS-TR/im/10#Frame#5#Rack#15898634812_64807069ff_o.jpg
70
+ DIS5K/DIS-TR/im/10#Frame#5#Rack#23928546819_c184cb0b60_o.jpg
71
+ DIS5K/DIS-TR/im/11#Furniture#19#Shower#6189119596_77bcfe80ee_o.jpg
72
+ DIS5K/DIS-TR/im/11#Furniture#2#Bench#3263647075_9306e280b5_o.jpg
73
+ DIS5K/DIS-TR/im/11#Furniture#5#CoatHanger#12774091054_cd5ff520ef_o.jpg
74
+ DIS5K/DIS-TR/im/11#Furniture#6#DentalChair#13878156865_d0439dcb32_o.jpg
75
+ DIS5K/DIS-TR/im/11#Furniture#9#Easel#5861024714_2070cd480c_o.jpg
76
+ DIS5K/DIS-TR/im/12#Graphics#4#TrafficSign#40621867334_f3c32ec189_o.jpg
77
+ DIS5K/DIS-TR/im/13#Insect#1#Ant#3295038190_db5dd0d4f4_o.jpg
78
+ DIS5K/DIS-TR/im/13#Insect#10#Mosquito#24341339_a88a1dad4c_o.jpg
79
+ DIS5K/DIS-TR/im/13#Insect#11#Spider#27171518270_63b78069ff_o.jpg
80
+ DIS5K/DIS-TR/im/13#Insect#11#Spider#49925050281_fa727c154e_o.jpg
81
+ DIS5K/DIS-TR/im/13#Insect#2#Beatle#279616486_2f1e64f591_o.jpg
82
+ DIS5K/DIS-TR/im/13#Insect#3#Bee#43892067695_82cf3e536b_o.jpg
83
+ DIS5K/DIS-TR/im/13#Insect#6#Centipede#20874281788_3e15c90a1c_o.jpg
84
+ DIS5K/DIS-TR/im/13#Insect#7#Dragonfly#14106671120_1b824d77e4_o.jpg
85
+ DIS5K/DIS-TR/im/13#Insect#8#Locust#21637491048_676ef7c9f7_o.jpg
86
+ DIS5K/DIS-TR/im/13#Insect#9#Mantis#1381120202_9dff6987b2_o.jpg
87
+ DIS5K/DIS-TR/im/14#Kitchenware#1#Cup#12812517473_327d6474b8_o.jpg
88
+ DIS5K/DIS-TR/im/14#Kitchenware#10#WineGlass#6402491641_389275d4d1_o.jpg
89
+ DIS5K/DIS-TR/im/14#Kitchenware#3#Hydrovalve#3129932040_8c05825004_o.jpg
90
+ DIS5K/DIS-TR/im/14#Kitchenware#4#Kitchenware#2881934780_87d5218ebb_o.jpg
91
+ DIS5K/DIS-TR/im/14#Kitchenware#4#Kitchenware#IMG_20210520_205527.jpg
92
+ DIS5K/DIS-TR/im/14#Kitchenware#6#Spoon#32989113501_b69eccf0df_o.jpg
93
+ DIS5K/DIS-TR/im/14#Kitchenware#8#SweetStand#2867322189_c56d1e0b87_o.jpg
94
+ DIS5K/DIS-TR/im/15#Machine#1#Gear#19217846720_f5f2807475_o.jpg
95
+ DIS5K/DIS-TR/im/15#Machine#2#Machine#1620160659_9571b7a7ab_o.jpg
96
+ DIS5K/DIS-TR/im/16#Music Instrument#2#Harp#6012801603_1a6e2c16a6_o.jpg
97
+ DIS5K/DIS-TR/im/16#Music Instrument#5#Trombone#8683292118_d223c17ccb_o.jpg
98
+ DIS5K/DIS-TR/im/16#Music Instrument#6#Trumpet#8393262740_b8c216142c_o.jpg
99
+ DIS5K/DIS-TR/im/16#Music Instrument#8#Violin#1511267391_40e4949d68_o.jpg
100
+ DIS5K/DIS-TR/im/17#Non-motor Vehicle#1#BabyCarriage#6989512997_38b3dbc88b_o.jpg
101
+ DIS5K/DIS-TR/im/17#Non-motor Vehicle#12#Wheel#14627183228_b2d68cf501_o.jpg
102
+ DIS5K/DIS-TR/im/17#Non-motor Vehicle#12#Wheel#2932226475_1b2403e549_o.jpg
103
+ DIS5K/DIS-TR/im/17#Non-motor Vehicle#12#Wheel#5420155648_86459905b8_o.jpg
104
+ DIS5K/DIS-TR/im/17#Non-motor Vehicle#2#Bicycle#IMG_20210513_134904.jpg
105
+ DIS5K/DIS-TR/im/17#Non-motor Vehicle#3#Carriage#3311962551_6f211b7bd6_o.jpg
106
+ DIS5K/DIS-TR/im/17#Non-motor Vehicle#4#Cart#2609732026_baf7fff3a1_o.jpg
107
+ DIS5K/DIS-TR/im/17#Non-motor Vehicle#5#Handcart#5821282211_201cefeaf2_o.jpg
108
+ DIS5K/DIS-TR/im/17#Non-motor Vehicle#7#Mower#5779003232_3bb3ae531a_o.jpg
109
+ DIS5K/DIS-TR/im/17#Non-motor Vehicle#9#ShoppingCart#10051622843_ace07e32b8_o.jpg
110
+ DIS5K/DIS-TR/im/17#Non-motor Vehicle#9#ShoppingCart#8075259294_f23e243849_o.jpg
111
+ DIS5K/DIS-TR/im/18#Plant#2#Tree#44800999741_e377e16dbb_o.jpg
112
+ DIS5K/DIS-TR/im/2#Aircraft#1#Airplane#2631761913_3ac67d0223_o.jpg
113
+ DIS5K/DIS-TR/im/2#Aircraft#1#Airplane#37707911566_e908a261b6_o.jpg
114
+ DIS5K/DIS-TR/im/2#Aircraft#3#HangGlider#2557220131_b8506920c5_o.jpg
115
+ DIS5K/DIS-TR/im/2#Aircraft#4#Helicopter#6215659280_5dbd9b4546_o.jpg
116
+ DIS5K/DIS-TR/im/2#Aircraft#6#Parachute#20185790493_e56fcaf8c6_o.jpg
117
+ DIS5K/DIS-TR/im/20#Sports#1#Archery#3871269982_ae4c59a7eb_o.jpg
118
+ DIS5K/DIS-TR/im/20#Sports#9#RockClimbing#9662433268_51299bc50e_o.jpg
119
+ DIS5K/DIS-TR/im/21#Tool#14#Sword#26258479365_2950d7fa37_o.jpg
120
+ DIS5K/DIS-TR/im/21#Tool#15#Tapeline#15505703447_e0fdeaa5a6_o.jpg
121
+ DIS5K/DIS-TR/im/21#Tool#4#Flag#26678602024_9b665742de_o.jpg
122
+ DIS5K/DIS-TR/im/21#Tool#4#Flag#5774823110_d603ce3cc8_o.jpg
123
+ DIS5K/DIS-TR/im/21#Tool#5#Hook#6867989814_dba18d673c_o.jpg
124
+ DIS5K/DIS-TR/im/22#Weapon#4#Rifle#4451713125_cd91719189_o.jpg
125
+ DIS5K/DIS-TR/im/3#Aquatic#2#Seadragon#4910944581_913139b238_o.jpg
126
+ DIS5K/DIS-TR/im/4#Architecture#12#Scaffold#3661448960_8aff24cc4d_o.jpg
127
+ DIS5K/DIS-TR/im/4#Architecture#13#Sculpture#6385318715_9a88d4eba7_o.jpg
128
+ DIS5K/DIS-TR/im/4#Architecture#17#Well#5011603479_75cf42808a_o.jpg
129
+ DIS5K/DIS-TR/im/5#Artifact#2#Cage#4892828841_7f1bc05682_o.jpg
130
+ DIS5K/DIS-TR/im/5#Artifact#3#Handcraft#15404211628_9e9ff2ce2e_o.jpg
131
+ DIS5K/DIS-TR/im/5#Artifact#3#Handcraft#3200169865_7c84cfcccf_o.jpg
132
+ DIS5K/DIS-TR/im/5#Artifact#3#Handcraft#5859295071_c217e7c22f_o.jpg
133
+ DIS5K/DIS-TR/im/6#Automobile#10#SteeringWheel#17200338026_f1e2122d8e_o.jpg
134
+ DIS5K/DIS-TR/im/6#Automobile#3#Car#3780893425_1a7d275e09_o.jpg
135
+ DIS5K/DIS-TR/im/6#Automobile#5#Crane#15282506502_1b1132a7c3_o.jpg
136
+ DIS5K/DIS-TR/im/7#Electrical#1#Cable#16767791875_8e6df41752_o.jpg
137
+ DIS5K/DIS-TR/im/7#Electrical#1#Cable#3291433361_38747324c4_o.jpg
138
+ DIS5K/DIS-TR/im/7#Electrical#1#Cable#4195104238_12a754c61a_o.jpg
139
+ DIS5K/DIS-TR/im/7#Electrical#1#Cable#49645415132_61e5664ecf_o.jpg
140
+ DIS5K/DIS-TR/im/7#Electrical#1#Cable#IMG_20210521_232406.jpg
141
+ DIS5K/DIS-TR/im/7#Electrical#10#UtilityPole#3298312021_92f431e3e9_o.jpg
142
+ DIS5K/DIS-TR/im/7#Electrical#10#UtilityPole#47950134773_fbfff63f4e_o.jpg
143
+ DIS5K/DIS-TR/im/7#Electrical#11#VacuumCleaner#5448403677_6a29e21881_o.jpg
144
+ DIS5K/DIS-TR/im/7#Electrical#2#CeilingLamp#611568868_680ed5d39f_o.jpg
145
+ DIS5K/DIS-TR/im/7#Electrical#3#Fan#3391683115_990525a693_o.jpg
146
+ DIS5K/DIS-TR/im/7#Electrical#6#StreetLamp#150049122_0692266618_o.jpg
147
+ DIS5K/DIS-TR/im/7#Electrical#9#TransmissionTower#31433908671_7e7e277dfe_o.jpg
148
+ DIS5K/DIS-TR/im/8#Electronics#1#Antenna#8727884873_e0622ee5c4_o.jpg
149
+ DIS5K/DIS-TR/im/8#Electronics#2#Camcorder#4172690390_7e5f280ace_o.jpg
150
+ DIS5K/DIS-TR/im/8#Electronics#3#Earphone#413984555_f290febdf5_o.jpg
151
+ DIS5K/DIS-TR/im/8#Electronics#5#Headset#30574225373_3717ed9fa4_o.jpg
152
+ DIS5K/DIS-TR/im/8#Electronics#6#Microphone#538006482_4aae4f5bd6_o.jpg
153
+ DIS5K/DIS-TR/im/8#Electronics#9#MusicPlayer#1306012480_2ea80d2afd_o.jpg
154
+ DIS5K/DIS-TR/im/9#Entertainment#1#GymEquipment#33071754135_8f3195cbd1_o.jpg
155
+ DIS5K/DIS-TR/im/9#Entertainment#2#KidsPlayground#2305807849_be53d724ea_o.jpg
156
+ DIS5K/DIS-TR/im/9#Entertainment#2#KidsPlayground#3862040422_5bbf903204_o.jpg
157
+ DIS5K/DIS-TR/im/9#Entertainment#3#OutdoorFitnessEquipment#10814507005_3dacaa28b3_o.jpg
158
+ DIS5K/DIS-TR/im/9#Entertainment#4#FerrisWheel#81640293_4b0ee62040_o.jpg
159
+ DIS5K/DIS-TR/im/9#Entertainment#5#Swing#49867339188_08073f4b76_o.jpg
160
+ DIS5K/DIS-VD/im/1#Accessories#1#Bag#6815402415_e01c1a41e6_o.jpg
161
+ DIS5K/DIS-VD/im/1#Accessories#5#Jewelry#2744070193_1486582e8d_o.jpg
162
+ DIS5K/DIS-VD/im/10#Frame#1#BasketballHoop#IMG_20210521_232650.jpg
163
+ DIS5K/DIS-VD/im/10#Frame#5#Rack#6156611713_49ebf12b1e_o.jpg
164
+ DIS5K/DIS-VD/im/11#Furniture#11#Handrail#3276641240_1b84b5af85_o.jpg
165
+ DIS5K/DIS-VD/im/11#Furniture#13#Ladder#33423266_5391cf47e9_o.jpg
166
+ DIS5K/DIS-VD/im/11#Furniture#17#Table#3725111755_4fc101e7ab_o.jpg
167
+ DIS5K/DIS-VD/im/11#Furniture#2#Bench#35556410400_7235b58070_o.jpg
168
+ DIS5K/DIS-VD/im/11#Furniture#4#Chair#3301769985_e49de6739f_o.jpg
169
+ DIS5K/DIS-VD/im/11#Furniture#6#DentalChair#23811071619_2a95c3a688_o.jpg
170
+ DIS5K/DIS-VD/im/11#Furniture#9#Easel#8322807354_df6d56542e_o.jpg
171
+ DIS5K/DIS-VD/im/13#Insect#10#Mosquito#12391674863_0cdf430d3f_o.jpg
172
+ DIS5K/DIS-VD/im/13#Insect#7#Dragonfly#14693028899_344ea118f2_o.jpg
173
+ DIS5K/DIS-VD/im/14#Kitchenware#10#WineGlass#4450148455_8f460f541a_o.jpg
174
+ DIS5K/DIS-VD/im/14#Kitchenware#3#Hydrovalve#IMG_20210520_203410.jpg
175
+ DIS5K/DIS-VD/im/15#Machine#3#PlowHarrow#34521712846_df4babb024_o.jpg
176
+ DIS5K/DIS-VD/im/16#Music Instrument#5#Trombone#6222242743_e7189405cd_o.jpg
177
+ DIS5K/DIS-VD/im/17#Non-motor Vehicle#12#Wheel#25677578797_ea47e1d9e8_o.jpg
178
+ DIS5K/DIS-VD/im/17#Non-motor Vehicle#2#Bicycle#5153474856_21560b081b_o.jpg
179
+ DIS5K/DIS-VD/im/17#Non-motor Vehicle#7#Mower#16992510572_8a6ff27398_o.jpg
180
+ DIS5K/DIS-VD/im/19#Ship#2#Canoe#40571458163_7faf8b73d9_o.jpg
181
+ DIS5K/DIS-VD/im/2#Aircraft#1#Airplane#4270588164_66a619e834_o.jpg
182
+ DIS5K/DIS-VD/im/2#Aircraft#4#Helicopter#86789665_650b94b2ee_o.jpg
183
+ DIS5K/DIS-VD/im/20#Sports#14#Wakesurfing#5589577652_5061c168d2_o.jpg
184
+ DIS5K/DIS-VD/im/21#Tool#10#Spade#37018312543_63b21b0784_o.jpg
185
+ DIS5K/DIS-VD/im/21#Tool#14#Sword#24789047250_42df9bf422_o.jpg
186
+ DIS5K/DIS-VD/im/21#Tool#18#Umbrella#IMG_20210513_140445.jpg
187
+ DIS5K/DIS-VD/im/21#Tool#6#Key#43939732715_5a6e28b518_o.jpg
188
+ DIS5K/DIS-VD/im/22#Weapon#1#Cannon#12758066705_90b54295e7_o.jpg
189
+ DIS5K/DIS-VD/im/22#Weapon#4#Rifle#8019368790_fb6dc469a7_o.jpg
190
+ DIS5K/DIS-VD/im/3#Aquatic#5#Shrimp#2582833427_7a99e7356e_o.jpg
191
+ DIS5K/DIS-VD/im/4#Architecture#12#Scaffold#1013402687_590750354e_o.jpg
192
+ DIS5K/DIS-VD/im/4#Architecture#13#Sculpture#17176841759_272a3ed6e3_o.jpg
193
+ DIS5K/DIS-VD/im/4#Architecture#14#Stair#15079108505_0d11281624_o.jpg
194
+ DIS5K/DIS-VD/im/4#Architecture#19#Windmill#2928111082_ceb3051c04_o.jpg
195
+ DIS5K/DIS-VD/im/4#Architecture#3#Crack#3551574032_17dd106d31_o.jpg
196
+ DIS5K/DIS-VD/im/4#Architecture#5#GasStation#4564307581_c3069bdc62_o.jpg
197
+ DIS5K/DIS-VD/im/4#Architecture#8#ObservationTower#2704526950_d4f0ddc807_o.jpg
198
+ DIS5K/DIS-VD/im/5#Artifact#3#Handcraft#10873642323_1bafce3aa5_o.jpg
199
+ DIS5K/DIS-VD/im/6#Automobile#11#Tractor#8594504006_0c2c557d85_o.jpg
200
+ DIS5K/DIS-VD/im/8#Electronics#3#Earphone#8106454803_1178d867cc_o.jpg
src/depth_pro/network/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ # Copyright (C) 2024 Apple Inc. All Rights Reserved.
2
+ """Depth Pro network blocks."""
src/depth_pro/network/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (197 Bytes). View file
 
src/depth_pro/network/__pycache__/decoder.cpython-39.pyc ADDED
Binary file (5.17 kB). View file
 
src/depth_pro/network/__pycache__/encoder.cpython-39.pyc ADDED
Binary file (7.28 kB). View file
 
src/depth_pro/network/__pycache__/fov.cpython-39.pyc ADDED
Binary file (2.08 kB). View file
 
src/depth_pro/network/__pycache__/vit.cpython-39.pyc ADDED
Binary file (2.81 kB). View file
 
src/depth_pro/network/__pycache__/vit_factory.cpython-39.pyc ADDED
Binary file (2.94 kB). View file
 
src/depth_pro/network/decoder.py ADDED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Copyright (C) 2024 Apple Inc. All Rights Reserved.
2
+
3
+ Dense Prediction Transformer Decoder architecture.
4
+
5
+ Implements a variant of Vision Transformers for Dense Prediction, https://arxiv.org/abs/2103.13413
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ from typing import Iterable
11
+
12
+ import torch
13
+ from torch import nn
14
+
15
+
16
+ class MultiresConvDecoder(nn.Module):
17
+ """Decoder for multi-resolution encodings."""
18
+
19
+ def __init__(
20
+ self,
21
+ dims_encoder: Iterable[int],
22
+ dim_decoder: int,
23
+ ):
24
+ """Initialize multiresolution convolutional decoder.
25
+
26
+ Args:
27
+ ----
28
+ dims_encoder: Expected dims at each level from the encoder.
29
+ dim_decoder: Dim of decoder features.
30
+
31
+ """
32
+ super().__init__()
33
+ self.dims_encoder = list(dims_encoder)
34
+ self.dim_decoder = dim_decoder
35
+ self.dim_out = dim_decoder
36
+
37
+ num_encoders = len(self.dims_encoder)
38
+
39
+ # At the highest resolution, i.e. level 0, we apply projection w/ 1x1 convolution
40
+ # when the dimensions mismatch. Otherwise we do not do anything, which is
41
+ # the default behavior of monodepth.
42
+ conv0 = (
43
+ nn.Conv2d(self.dims_encoder[0], dim_decoder, kernel_size=1, bias=False)
44
+ if self.dims_encoder[0] != dim_decoder
45
+ else nn.Identity()
46
+ )
47
+
48
+ convs = [conv0]
49
+ for i in range(1, num_encoders):
50
+ convs.append(
51
+ nn.Conv2d(
52
+ self.dims_encoder[i],
53
+ dim_decoder,
54
+ kernel_size=3,
55
+ stride=1,
56
+ padding=1,
57
+ bias=False,
58
+ )
59
+ )
60
+
61
+ self.convs = nn.ModuleList(convs)
62
+
63
+ fusions = []
64
+ for i in range(num_encoders):
65
+ fusions.append(
66
+ FeatureFusionBlock2d(
67
+ num_features=dim_decoder,
68
+ deconv=(i != 0),
69
+ batch_norm=False,
70
+ )
71
+ )
72
+ self.fusions = nn.ModuleList(fusions)
73
+
74
+ def forward(self, encodings: torch.Tensor) -> torch.Tensor:
75
+ """Decode the multi-resolution encodings."""
76
+ num_levels = len(encodings)
77
+ num_encoders = len(self.dims_encoder)
78
+
79
+ if num_levels != num_encoders:
80
+ raise ValueError(
81
+ f"Got encoder output levels={num_levels}, expected levels={num_encoders+1}."
82
+ )
83
+
84
+ # Project features of different encoder dims to the same decoder dim.
85
+ # Fuse features from the lowest resolution (num_levels-1)
86
+ # to the highest (0).
87
+ features = self.convs[-1](encodings[-1])
88
+ lowres_features = features
89
+ features = self.fusions[-1](features)
90
+ for i in range(num_levels - 2, -1, -1):
91
+ features_i = self.convs[i](encodings[i])
92
+ features = self.fusions[i](features, features_i)
93
+ return features, lowres_features
94
+
95
+
96
+ class ResidualBlock(nn.Module):
97
+ """Generic implementation of residual blocks.
98
+
99
+ This implements a generic residual block from
100
+ He et al. - Identity Mappings in Deep Residual Networks (2016),
101
+ https://arxiv.org/abs/1603.05027
102
+ which can be further customized via factory functions.
103
+ """
104
+
105
+ def __init__(self, residual: nn.Module, shortcut: nn.Module | None = None) -> None:
106
+ """Initialize ResidualBlock."""
107
+ super().__init__()
108
+ self.residual = residual
109
+ self.shortcut = shortcut
110
+
111
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
112
+ """Apply residual block."""
113
+ delta_x = self.residual(x)
114
+
115
+ if self.shortcut is not None:
116
+ x = self.shortcut(x)
117
+
118
+ return x + delta_x
119
+
120
+
121
+ class FeatureFusionBlock2d(nn.Module):
122
+ """Feature fusion for DPT."""
123
+
124
+ def __init__(
125
+ self,
126
+ num_features: int,
127
+ deconv: bool = False,
128
+ batch_norm: bool = False,
129
+ ):
130
+ """Initialize feature fusion block.
131
+
132
+ Args:
133
+ ----
134
+ num_features: Input and output dimensions.
135
+ deconv: Whether to use deconv before the final output conv.
136
+ batch_norm: Whether to use batch normalization in resnet blocks.
137
+
138
+ """
139
+ super().__init__()
140
+
141
+ self.resnet1 = self._residual_block(num_features, batch_norm)
142
+ self.resnet2 = self._residual_block(num_features, batch_norm)
143
+
144
+ self.use_deconv = deconv
145
+ if deconv:
146
+ self.deconv = nn.ConvTranspose2d(
147
+ in_channels=num_features,
148
+ out_channels=num_features,
149
+ kernel_size=2,
150
+ stride=2,
151
+ padding=0,
152
+ bias=False,
153
+ )
154
+
155
+ self.out_conv = nn.Conv2d(
156
+ num_features,
157
+ num_features,
158
+ kernel_size=1,
159
+ stride=1,
160
+ padding=0,
161
+ bias=True,
162
+ )
163
+
164
+ self.skip_add = nn.quantized.FloatFunctional()
165
+
166
+ def forward(self, x0: torch.Tensor, x1: torch.Tensor | None = None) -> torch.Tensor:
167
+ """Process and fuse input features."""
168
+ x = x0
169
+
170
+ if x1 is not None:
171
+ res = self.resnet1(x1)
172
+ x = self.skip_add.add(x, res)
173
+
174
+ x = self.resnet2(x)
175
+
176
+ if self.use_deconv:
177
+ x = self.deconv(x)
178
+ x = self.out_conv(x)
179
+
180
+ return x
181
+
182
+ @staticmethod
183
+ def _residual_block(num_features: int, batch_norm: bool):
184
+ """Create a residual block."""
185
+
186
+ def _create_block(dim: int, batch_norm: bool) -> list[nn.Module]:
187
+ layers = [
188
+ nn.ReLU(False),
189
+ nn.Conv2d(
190
+ num_features,
191
+ num_features,
192
+ kernel_size=3,
193
+ stride=1,
194
+ padding=1,
195
+ bias=not batch_norm,
196
+ ),
197
+ ]
198
+ if batch_norm:
199
+ layers.append(nn.BatchNorm2d(dim))
200
+ return layers
201
+
202
+ residual = nn.Sequential(
203
+ *_create_block(dim=num_features, batch_norm=batch_norm),
204
+ *_create_block(dim=num_features, batch_norm=batch_norm),
205
+ )
206
+ return ResidualBlock(residual)
src/depth_pro/network/encoder.py ADDED
@@ -0,0 +1,332 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2024 Apple Inc. All Rights Reserved.
2
+ # DepthProEncoder combining patch and image encoders.
3
+
4
+ from __future__ import annotations
5
+
6
+ import math
7
+ from typing import Iterable, Optional
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+
13
+
14
+ class DepthProEncoder(nn.Module):
15
+ """DepthPro Encoder.
16
+
17
+ An encoder aimed at creating multi-resolution encodings from Vision Transformers.
18
+ """
19
+
20
+ def __init__(
21
+ self,
22
+ dims_encoder: Iterable[int],
23
+ patch_encoder: nn.Module,
24
+ image_encoder: nn.Module,
25
+ hook_block_ids: Iterable[int],
26
+ decoder_features: int,
27
+ ):
28
+ """Initialize DepthProEncoder.
29
+
30
+ The framework
31
+ 1. creates an image pyramid,
32
+ 2. generates overlapping patches with a sliding window at each pyramid level,
33
+ 3. creates batched encodings via vision transformer backbones,
34
+ 4. produces multi-resolution encodings.
35
+
36
+ Args:
37
+ ----
38
+ img_size: Backbone image resolution.
39
+ dims_encoder: Dimensions of the encoder at different layers.
40
+ patch_encoder: Backbone used for patches.
41
+ image_encoder: Backbone used for global image encoder.
42
+ hook_block_ids: Hooks to obtain intermediate features for the patch encoder model.
43
+ decoder_features: Number of feature output in the decoder.
44
+
45
+ """
46
+ super().__init__()
47
+
48
+ self.dims_encoder = list(dims_encoder)
49
+ self.patch_encoder = patch_encoder
50
+ self.image_encoder = image_encoder
51
+ self.hook_block_ids = list(hook_block_ids)
52
+
53
+ patch_encoder_embed_dim = patch_encoder.embed_dim
54
+ image_encoder_embed_dim = image_encoder.embed_dim
55
+
56
+ self.out_size = int(
57
+ patch_encoder.patch_embed.img_size[0] // patch_encoder.patch_embed.patch_size[0]
58
+ )
59
+
60
+ def _create_project_upsample_block(
61
+ dim_in: int,
62
+ dim_out: int,
63
+ upsample_layers: int,
64
+ dim_int: Optional[int] = None,
65
+ ) -> nn.Module:
66
+ if dim_int is None:
67
+ dim_int = dim_out
68
+ # Projection.
69
+ blocks = [
70
+ nn.Conv2d(
71
+ in_channels=dim_in,
72
+ out_channels=dim_int,
73
+ kernel_size=1,
74
+ stride=1,
75
+ padding=0,
76
+ bias=False,
77
+ )
78
+ ]
79
+
80
+ # Upsampling.
81
+ blocks += [
82
+ nn.ConvTranspose2d(
83
+ in_channels=dim_int if i == 0 else dim_out,
84
+ out_channels=dim_out,
85
+ kernel_size=2,
86
+ stride=2,
87
+ padding=0,
88
+ bias=False,
89
+ )
90
+ for i in range(upsample_layers)
91
+ ]
92
+
93
+ return nn.Sequential(*blocks)
94
+
95
+ self.upsample_latent0 = _create_project_upsample_block(
96
+ dim_in=patch_encoder_embed_dim,
97
+ dim_int=self.dims_encoder[0],
98
+ dim_out=decoder_features,
99
+ upsample_layers=3,
100
+ )
101
+ self.upsample_latent1 = _create_project_upsample_block(
102
+ dim_in=patch_encoder_embed_dim, dim_out=self.dims_encoder[0], upsample_layers=2
103
+ )
104
+
105
+ self.upsample0 = _create_project_upsample_block(
106
+ dim_in=patch_encoder_embed_dim, dim_out=self.dims_encoder[1], upsample_layers=1
107
+ )
108
+ self.upsample1 = _create_project_upsample_block(
109
+ dim_in=patch_encoder_embed_dim, dim_out=self.dims_encoder[2], upsample_layers=1
110
+ )
111
+ self.upsample2 = _create_project_upsample_block(
112
+ dim_in=patch_encoder_embed_dim, dim_out=self.dims_encoder[3], upsample_layers=1
113
+ )
114
+
115
+ self.upsample_lowres = nn.ConvTranspose2d(
116
+ in_channels=image_encoder_embed_dim,
117
+ out_channels=self.dims_encoder[3],
118
+ kernel_size=2,
119
+ stride=2,
120
+ padding=0,
121
+ bias=True,
122
+ )
123
+ self.fuse_lowres = nn.Conv2d(
124
+ in_channels=(self.dims_encoder[3] + self.dims_encoder[3]),
125
+ out_channels=self.dims_encoder[3],
126
+ kernel_size=1,
127
+ stride=1,
128
+ padding=0,
129
+ bias=True,
130
+ )
131
+
132
+ # Obtain intermediate outputs of the blocks.
133
+ self.patch_encoder.blocks[self.hook_block_ids[0]].register_forward_hook(
134
+ self._hook0
135
+ )
136
+ self.patch_encoder.blocks[self.hook_block_ids[1]].register_forward_hook(
137
+ self._hook1
138
+ )
139
+
140
+ def _hook0(self, model, input, output):
141
+ self.backbone_highres_hook0 = output
142
+
143
+ def _hook1(self, model, input, output):
144
+ self.backbone_highres_hook1 = output
145
+
146
+ @property
147
+ def img_size(self) -> int:
148
+ """Return the full image size of the SPN network."""
149
+ return self.patch_encoder.patch_embed.img_size[0] * 4
150
+
151
+ def _create_pyramid(
152
+ self, x: torch.Tensor
153
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
154
+ """Create a 3-level image pyramid."""
155
+ # Original resolution: 1536 by default.
156
+ x0 = x
157
+
158
+ # Middle resolution: 768 by default.
159
+ x1 = F.interpolate(
160
+ x, size=None, scale_factor=0.5, mode="bilinear", align_corners=False
161
+ )
162
+
163
+ # Low resolution: 384 by default, corresponding to the backbone resolution.
164
+ x2 = F.interpolate(
165
+ x, size=None, scale_factor=0.25, mode="bilinear", align_corners=False
166
+ )
167
+
168
+ return x0, x1, x2
169
+
170
+ def split(self, x: torch.Tensor, overlap_ratio: float = 0.25) -> torch.Tensor:
171
+ """Split the input into small patches with sliding window."""
172
+ patch_size = 384
173
+ patch_stride = int(patch_size * (1 - overlap_ratio))
174
+
175
+ image_size = x.shape[-1]
176
+ steps = int(math.ceil((image_size - patch_size) / patch_stride)) + 1
177
+
178
+ x_patch_list = []
179
+ for j in range(steps):
180
+ j0 = j * patch_stride
181
+ j1 = j0 + patch_size
182
+
183
+ for i in range(steps):
184
+ i0 = i * patch_stride
185
+ i1 = i0 + patch_size
186
+ x_patch_list.append(x[..., j0:j1, i0:i1])
187
+
188
+ return torch.cat(x_patch_list, dim=0)
189
+
190
+ def merge(self, x: torch.Tensor, batch_size: int, padding: int = 3) -> torch.Tensor:
191
+ """Merge the patched input into a image with sliding window."""
192
+ steps = int(math.sqrt(x.shape[0] // batch_size))
193
+
194
+ idx = 0
195
+
196
+ output_list = []
197
+ for j in range(steps):
198
+ output_row_list = []
199
+ for i in range(steps):
200
+ output = x[batch_size * idx : batch_size * (idx + 1)]
201
+
202
+ if j != 0:
203
+ output = output[..., padding:, :]
204
+ if i != 0:
205
+ output = output[..., :, padding:]
206
+ if j != steps - 1:
207
+ output = output[..., :-padding, :]
208
+ if i != steps - 1:
209
+ output = output[..., :, :-padding]
210
+
211
+ output_row_list.append(output)
212
+ idx += 1
213
+
214
+ output_row = torch.cat(output_row_list, dim=-1)
215
+ output_list.append(output_row)
216
+ output = torch.cat(output_list, dim=-2)
217
+ return output
218
+
219
+ def reshape_feature(
220
+ self, embeddings: torch.Tensor, width, height, cls_token_offset=1
221
+ ):
222
+ """Discard class token and reshape 1D feature map to a 2D grid."""
223
+ b, hw, c = embeddings.shape
224
+
225
+ # Remove class token.
226
+ if cls_token_offset > 0:
227
+ embeddings = embeddings[:, cls_token_offset:, :]
228
+
229
+ # Shape: (batch, height, width, dim) -> (batch, dim, height, width)
230
+ embeddings = embeddings.reshape(b, height, width, c).permute(0, 3, 1, 2)
231
+ return embeddings
232
+
233
+ def forward(self, x: torch.Tensor) -> list[torch.Tensor]:
234
+ """Encode input at multiple resolutions.
235
+
236
+ Args:
237
+ ----
238
+ x (torch.Tensor): Input image.
239
+
240
+ Returns:
241
+ -------
242
+ Multi resolution encoded features.
243
+
244
+ """
245
+ batch_size = x.shape[0]
246
+
247
+ # Step 0: create a 3-level image pyramid.
248
+ x0, x1, x2 = self._create_pyramid(x)
249
+
250
+ # Step 1: split to create batched overlapped mini-images at the backbone (BeiT/ViT/Dino)
251
+ # resolution.
252
+ # 5x5 @ 384x384 at the highest resolution (1536x1536).
253
+ x0_patches = self.split(x0, overlap_ratio=0.25)
254
+ # 3x3 @ 384x384 at the middle resolution (768x768).
255
+ x1_patches = self.split(x1, overlap_ratio=0.5)
256
+ # 1x1 # 384x384 at the lowest resolution (384x384).
257
+ x2_patches = x2
258
+
259
+ # Concatenate all the sliding window patches and form a batch of size (35=5x5+3x3+1x1).
260
+ x_pyramid_patches = torch.cat(
261
+ (x0_patches, x1_patches, x2_patches),
262
+ dim=0,
263
+ )
264
+
265
+ # Step 2: Run the backbone (BeiT) model and get the result of large batch size.
266
+ x_pyramid_encodings = self.patch_encoder(x_pyramid_patches)
267
+ x_pyramid_encodings = self.reshape_feature(
268
+ x_pyramid_encodings, self.out_size, self.out_size
269
+ )
270
+
271
+ # Step 3: merging.
272
+ # Merge highres latent encoding.
273
+ x_latent0_encodings = self.reshape_feature(
274
+ self.backbone_highres_hook0,
275
+ self.out_size,
276
+ self.out_size,
277
+ )
278
+ x_latent0_features = self.merge(
279
+ x_latent0_encodings[: batch_size * 5 * 5], batch_size=batch_size, padding=3
280
+ )
281
+
282
+ x_latent1_encodings = self.reshape_feature(
283
+ self.backbone_highres_hook1,
284
+ self.out_size,
285
+ self.out_size,
286
+ )
287
+ x_latent1_features = self.merge(
288
+ x_latent1_encodings[: batch_size * 5 * 5], batch_size=batch_size, padding=3
289
+ )
290
+
291
+ # Split the 35 batch size from pyramid encoding back into 5x5+3x3+1x1.
292
+ x0_encodings, x1_encodings, x2_encodings = torch.split(
293
+ x_pyramid_encodings,
294
+ [len(x0_patches), len(x1_patches), len(x2_patches)],
295
+ dim=0,
296
+ )
297
+
298
+ # 96x96 feature maps by merging 5x5 @ 24x24 patches with overlaps.
299
+ x0_features = self.merge(x0_encodings, batch_size=batch_size, padding=3)
300
+
301
+ # 48x84 feature maps by merging 3x3 @ 24x24 patches with overlaps.
302
+ x1_features = self.merge(x1_encodings, batch_size=batch_size, padding=6)
303
+
304
+ # 24x24 feature maps.
305
+ x2_features = x2_encodings
306
+
307
+ # Apply the image encoder model.
308
+ x_global_features = self.image_encoder(x2_patches)
309
+ x_global_features = self.reshape_feature(
310
+ x_global_features, self.out_size, self.out_size
311
+ )
312
+
313
+ # Upsample feature maps.
314
+ x_latent0_features = self.upsample_latent0(x_latent0_features)
315
+ x_latent1_features = self.upsample_latent1(x_latent1_features)
316
+
317
+ x0_features = self.upsample0(x0_features)
318
+ x1_features = self.upsample1(x1_features)
319
+ x2_features = self.upsample2(x2_features)
320
+
321
+ x_global_features = self.upsample_lowres(x_global_features)
322
+ x_global_features = self.fuse_lowres(
323
+ torch.cat((x2_features, x_global_features), dim=1)
324
+ )
325
+
326
+ return [
327
+ x_latent0_features,
328
+ x_latent1_features,
329
+ x0_features,
330
+ x1_features,
331
+ x_global_features,
332
+ ]
src/depth_pro/network/fov.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2024 Apple Inc. All Rights Reserved.
2
+ # Field of View network architecture.
3
+
4
+ from typing import Optional
5
+
6
+ import torch
7
+ from torch import nn
8
+ from torch.nn import functional as F
9
+
10
+
11
+ class FOVNetwork(nn.Module):
12
+ """Field of View estimation network."""
13
+
14
+ def __init__(
15
+ self,
16
+ num_features: int,
17
+ fov_encoder: Optional[nn.Module] = None,
18
+ ):
19
+ """Initialize the Field of View estimation block.
20
+
21
+ Args:
22
+ ----
23
+ num_features: Number of features used.
24
+ fov_encoder: Optional encoder to bring additional network capacity.
25
+
26
+ """
27
+ super().__init__()
28
+
29
+ # Create FOV head.
30
+ fov_head0 = [
31
+ nn.Conv2d(
32
+ num_features, num_features // 2, kernel_size=3, stride=2, padding=1
33
+ ), # 128 x 24 x 24
34
+ nn.ReLU(True),
35
+ ]
36
+ fov_head = [
37
+ nn.Conv2d(
38
+ num_features // 2, num_features // 4, kernel_size=3, stride=2, padding=1
39
+ ), # 64 x 12 x 12
40
+ nn.ReLU(True),
41
+ nn.Conv2d(
42
+ num_features // 4, num_features // 8, kernel_size=3, stride=2, padding=1
43
+ ), # 32 x 6 x 6
44
+ nn.ReLU(True),
45
+ nn.Conv2d(num_features // 8, 1, kernel_size=6, stride=1, padding=0),
46
+ ]
47
+ if fov_encoder is not None:
48
+ self.encoder = nn.Sequential(
49
+ fov_encoder, nn.Linear(fov_encoder.embed_dim, num_features // 2)
50
+ )
51
+ self.downsample = nn.Sequential(*fov_head0)
52
+ else:
53
+ fov_head = fov_head0 + fov_head
54
+ self.head = nn.Sequential(*fov_head)
55
+
56
+ def forward(self, x: torch.Tensor, lowres_feature: torch.Tensor) -> torch.Tensor:
57
+ """Forward the fov network.
58
+
59
+ Args:
60
+ ----
61
+ x (torch.Tensor): Input image.
62
+ lowres_feature (torch.Tensor): Low resolution feature.
63
+
64
+ Returns:
65
+ -------
66
+ The field of view tensor.
67
+
68
+ """
69
+ if hasattr(self, "encoder"):
70
+ x = F.interpolate(
71
+ x,
72
+ size=None,
73
+ scale_factor=0.25,
74
+ mode="bilinear",
75
+ align_corners=False,
76
+ )
77
+ x = self.encoder(x)[:, 1:].permute(0, 2, 1)
78
+ lowres_feature = self.downsample(lowres_feature)
79
+ x = x.reshape_as(lowres_feature) + lowres_feature
80
+ else:
81
+ x = lowres_feature
82
+ return self.head(x)
src/depth_pro/network/vit.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2024 Apple Inc. All Rights Reserved.
2
+
3
+
4
+ try:
5
+ from timm.layers import resample_abs_pos_embed
6
+ except ImportError as err:
7
+ print("ImportError: {0}".format(err))
8
+ import torch
9
+ import torch.nn as nn
10
+ from torch.utils.checkpoint import checkpoint
11
+
12
+
13
+ def make_vit_b16_backbone(
14
+ model,
15
+ encoder_feature_dims,
16
+ encoder_feature_layer_ids,
17
+ vit_features,
18
+ start_index=1,
19
+ use_grad_checkpointing=False,
20
+ ) -> nn.Module:
21
+ """Make a ViTb16 backbone for the DPT model."""
22
+ if use_grad_checkpointing:
23
+ model.set_grad_checkpointing()
24
+
25
+ vit_model = nn.Module()
26
+ vit_model.hooks = encoder_feature_layer_ids
27
+ vit_model.model = model
28
+ vit_model.features = encoder_feature_dims
29
+ vit_model.vit_features = vit_features
30
+ vit_model.model.start_index = start_index
31
+ vit_model.model.patch_size = vit_model.model.patch_embed.patch_size
32
+ vit_model.model.is_vit = True
33
+ vit_model.model.forward = vit_model.model.forward_features
34
+
35
+ return vit_model
36
+
37
+
38
+ def forward_features_eva_fixed(self, x):
39
+ """Encode features."""
40
+ x = self.patch_embed(x)
41
+ x, rot_pos_embed = self._pos_embed(x)
42
+ for blk in self.blocks:
43
+ if self.grad_checkpointing:
44
+ x = checkpoint(blk, x, rot_pos_embed)
45
+ else:
46
+ x = blk(x, rot_pos_embed)
47
+ x = self.norm(x)
48
+ return x
49
+
50
+
51
+ def resize_vit(model: nn.Module, img_size) -> nn.Module:
52
+ """Resample the ViT module to the given size."""
53
+ patch_size = model.patch_embed.patch_size
54
+ model.patch_embed.img_size = img_size
55
+ grid_size = tuple([s // p for s, p in zip(img_size, patch_size)])
56
+ model.patch_embed.grid_size = grid_size
57
+
58
+ pos_embed = resample_abs_pos_embed(
59
+ model.pos_embed,
60
+ grid_size, # img_size
61
+ num_prefix_tokens=(
62
+ 0 if getattr(model, "no_embed_class", False) else model.num_prefix_tokens
63
+ ),
64
+ )
65
+ model.pos_embed = torch.nn.Parameter(pos_embed)
66
+
67
+ return model
68
+
69
+
70
+ def resize_patch_embed(model: nn.Module, new_patch_size=(16, 16)) -> nn.Module:
71
+ """Resample the ViT patch size to the given one."""
72
+ # interpolate patch embedding
73
+ if hasattr(model, "patch_embed"):
74
+ old_patch_size = model.patch_embed.patch_size
75
+
76
+ if (
77
+ new_patch_size[0] != old_patch_size[0]
78
+ or new_patch_size[1] != old_patch_size[1]
79
+ ):
80
+ patch_embed_proj = model.patch_embed.proj.weight
81
+ patch_embed_proj_bias = model.patch_embed.proj.bias
82
+ use_bias = True if patch_embed_proj_bias is not None else False
83
+ _, _, h, w = patch_embed_proj.shape
84
+
85
+ new_patch_embed_proj = torch.nn.functional.interpolate(
86
+ patch_embed_proj,
87
+ size=[new_patch_size[0], new_patch_size[1]],
88
+ mode="bicubic",
89
+ align_corners=False,
90
+ )
91
+ new_patch_embed_proj = (
92
+ new_patch_embed_proj * (h / new_patch_size[0]) * (w / new_patch_size[1])
93
+ )
94
+
95
+ model.patch_embed.proj = nn.Conv2d(
96
+ in_channels=model.patch_embed.proj.in_channels,
97
+ out_channels=model.patch_embed.proj.out_channels,
98
+ kernel_size=new_patch_size,
99
+ stride=new_patch_size,
100
+ bias=use_bias,
101
+ )
102
+
103
+ if use_bias:
104
+ model.patch_embed.proj.bias = patch_embed_proj_bias
105
+
106
+ model.patch_embed.proj.weight = torch.nn.Parameter(new_patch_embed_proj)
107
+
108
+ model.patch_size = new_patch_size
109
+ model.patch_embed.patch_size = new_patch_size
110
+ model.patch_embed.img_size = (
111
+ int(
112
+ model.patch_embed.img_size[0]
113
+ * new_patch_size[0]
114
+ / old_patch_size[0]
115
+ ),
116
+ int(
117
+ model.patch_embed.img_size[1]
118
+ * new_patch_size[1]
119
+ / old_patch_size[1]
120
+ ),
121
+ )
122
+
123
+ return model
src/depth_pro/network/vit_factory.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2024 Apple Inc. All Rights Reserved.
2
+ # Factory functions to build and load ViT models.
3
+
4
+
5
+ from __future__ import annotations
6
+
7
+ import logging
8
+ import types
9
+ from dataclasses import dataclass
10
+ from typing import Dict, List, Literal, Optional
11
+
12
+ import timm
13
+ import torch
14
+ import torch.nn as nn
15
+
16
+ from .vit import (
17
+ forward_features_eva_fixed,
18
+ make_vit_b16_backbone,
19
+ resize_patch_embed,
20
+ resize_vit,
21
+ )
22
+
23
+ LOGGER = logging.getLogger(__name__)
24
+
25
+
26
+ ViTPreset = Literal[
27
+ "dinov2l16_384",
28
+ ]
29
+
30
+
31
+ @dataclass
32
+ class ViTConfig:
33
+ """Configuration for ViT."""
34
+
35
+ in_chans: int
36
+ embed_dim: int
37
+
38
+ img_size: int = 384
39
+ patch_size: int = 16
40
+
41
+ # In case we need to rescale the backbone when loading from timm.
42
+ timm_preset: Optional[str] = None
43
+ timm_img_size: int = 384
44
+ timm_patch_size: int = 16
45
+
46
+ # The following 2 parameters are only used by DPT. See dpt_factory.py.
47
+ encoder_feature_layer_ids: List[int] = None
48
+ """The layers in the Beit/ViT used to constructs encoder features for DPT."""
49
+ encoder_feature_dims: List[int] = None
50
+ """The dimension of features of encoder layers from Beit/ViT features for DPT."""
51
+
52
+
53
+ VIT_CONFIG_DICT: Dict[ViTPreset, ViTConfig] = {
54
+ "dinov2l16_384": ViTConfig(
55
+ in_chans=3,
56
+ embed_dim=1024,
57
+ encoder_feature_layer_ids=[5, 11, 17, 23],
58
+ encoder_feature_dims=[256, 512, 1024, 1024],
59
+ img_size=384,
60
+ patch_size=16,
61
+ timm_preset="vit_large_patch14_dinov2",
62
+ timm_img_size=518,
63
+ timm_patch_size=14,
64
+ ),
65
+ }
66
+
67
+
68
+ def create_vit(
69
+ preset: ViTPreset,
70
+ use_pretrained: bool = False,
71
+ checkpoint_uri: str | None = None,
72
+ use_grad_checkpointing: bool = False,
73
+ ) -> nn.Module:
74
+ """Create and load a VIT backbone module.
75
+
76
+ Args:
77
+ ----
78
+ preset: The VIT preset to load the pre-defined config.
79
+ use_pretrained: Load pretrained weights if True, default is False.
80
+ checkpoint_uri: Checkpoint to load the wights from.
81
+ use_grad_checkpointing: Use grandient checkpointing.
82
+
83
+ Returns:
84
+ -------
85
+ A Torch ViT backbone module.
86
+
87
+ """
88
+ config = VIT_CONFIG_DICT[preset]
89
+
90
+ img_size = (config.img_size, config.img_size)
91
+ patch_size = (config.patch_size, config.patch_size)
92
+
93
+ if "eva02" in preset:
94
+ model = timm.create_model(config.timm_preset, pretrained=use_pretrained)
95
+ model.forward_features = types.MethodType(forward_features_eva_fixed, model)
96
+ else:
97
+ model = timm.create_model(
98
+ config.timm_preset, pretrained=use_pretrained, dynamic_img_size=True
99
+ )
100
+ model = make_vit_b16_backbone(
101
+ model,
102
+ encoder_feature_dims=config.encoder_feature_dims,
103
+ encoder_feature_layer_ids=config.encoder_feature_layer_ids,
104
+ vit_features=config.embed_dim,
105
+ use_grad_checkpointing=use_grad_checkpointing,
106
+ )
107
+ if config.patch_size != config.timm_patch_size:
108
+ model.model = resize_patch_embed(model.model, new_patch_size=patch_size)
109
+ if config.img_size != config.timm_img_size:
110
+ model.model = resize_vit(model.model, img_size=img_size)
111
+
112
+ if checkpoint_uri is not None:
113
+ state_dict = torch.load(checkpoint_uri, map_location="cpu")
114
+ missing_keys, unexpected_keys = model.load_state_dict(
115
+ state_dict=state_dict, strict=False
116
+ )
117
+
118
+ if len(unexpected_keys) != 0:
119
+ raise KeyError(f"Found unexpected keys when loading vit: {unexpected_keys}")
120
+ if len(missing_keys) != 0:
121
+ raise KeyError(f"Keys are missing when loading vit: {missing_keys}")
122
+
123
+ LOGGER.info(model)
124
+ return model.model
src/depth_pro/utils.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2024 Apple Inc. All Rights Reserved.
2
+
3
+ import logging
4
+ from pathlib import Path
5
+ from typing import Any, Dict, List, Tuple, Union
6
+
7
+ import numpy as np
8
+ import pillow_heif
9
+ from PIL import ExifTags, Image, TiffTags
10
+ from pillow_heif import register_heif_opener
11
+
12
+ register_heif_opener()
13
+ LOGGER = logging.getLogger(__name__)
14
+
15
+
16
+ def extract_exif(img_pil: Image) -> Dict[str, Any]:
17
+ """Return exif information as a dictionary.
18
+
19
+ Args:
20
+ ----
21
+ img_pil: A Pillow image.
22
+
23
+ Returns:
24
+ -------
25
+ A dictionary with extracted EXIF information.
26
+
27
+ """
28
+ # Get full exif description from get_ifd(0x8769):
29
+ # cf https://pillow.readthedocs.io/en/stable/releasenotes/8.2.0.html#image-getexif-exif-and-gps-ifd
30
+ img_exif = img_pil.getexif().get_ifd(0x8769)
31
+ exif_dict = {ExifTags.TAGS[k]: v for k, v in img_exif.items() if k in ExifTags.TAGS}
32
+
33
+ tiff_tags = img_pil.getexif()
34
+ tiff_dict = {
35
+ TiffTags.TAGS_V2[k].name: v
36
+ for k, v in tiff_tags.items()
37
+ if k in TiffTags.TAGS_V2
38
+ }
39
+ return {**exif_dict, **tiff_dict}
40
+
41
+
42
+ def fpx_from_f35(width: float, height: float, f_mm: float = 50) -> float:
43
+ """Convert a focal length given in mm (35mm film equivalent) to pixels."""
44
+ return f_mm * np.sqrt(width**2.0 + height**2.0) / np.sqrt(36**2 + 24**2)
45
+
46
+
47
+ def load_rgb(
48
+ path: Union[Path, str], auto_rotate: bool = True, remove_alpha: bool = True
49
+ ) -> Tuple[np.ndarray, List[bytes], float]:
50
+ """Load an RGB image.
51
+
52
+ Args:
53
+ ----
54
+ path: The url to the image to load.
55
+ auto_rotate: Rotate the image based on the EXIF data, default is True.
56
+ remove_alpha: Remove the alpha channel, default is True.
57
+
58
+ Returns:
59
+ -------
60
+ img: The image loaded as a numpy array.
61
+ icc_profile: The color profile of the image.
62
+ f_px: The optional focal length in pixels, extracting from the exif data.
63
+
64
+ """
65
+ LOGGER.debug(f"Loading image {path} ...")
66
+
67
+ path = Path(path)
68
+ if path.suffix.lower() in [".heic"]:
69
+ heif_file = pillow_heif.open_heif(path, convert_hdr_to_8bit=True)
70
+ img_pil = heif_file.to_pillow()
71
+ else:
72
+ img_pil = Image.open(path)
73
+
74
+ img_exif = extract_exif(img_pil)
75
+ icc_profile = img_pil.info.get("icc_profile", None)
76
+
77
+ # Rotate the image.
78
+ if auto_rotate:
79
+ exif_orientation = img_exif.get("Orientation", 1)
80
+ if exif_orientation == 3:
81
+ img_pil = img_pil.transpose(Image.ROTATE_180)
82
+ elif exif_orientation == 6:
83
+ img_pil = img_pil.transpose(Image.ROTATE_270)
84
+ elif exif_orientation == 8:
85
+ img_pil = img_pil.transpose(Image.ROTATE_90)
86
+ elif exif_orientation != 1:
87
+ LOGGER.warning(f"Ignoring image orientation {exif_orientation}.")
88
+
89
+ img = np.array(img_pil)
90
+ # Convert to RGB if single channel.
91
+ if img.ndim < 3 or img.shape[2] == 1:
92
+ img = np.dstack((img, img, img))
93
+
94
+ if remove_alpha:
95
+ img = img[:, :, :3]
96
+
97
+ LOGGER.debug(f"\tHxW: {img.shape[0]}x{img.shape[1]}")
98
+
99
+ # Extract the focal length from exif data.
100
+ f_35mm = img_exif.get(
101
+ "FocalLengthIn35mmFilm",
102
+ img_exif.get(
103
+ "FocalLenIn35mmFilm", img_exif.get("FocalLengthIn35mmFormat", None)
104
+ ),
105
+ )
106
+ if f_35mm is not None and f_35mm > 0:
107
+ LOGGER.debug(f"\tfocal length @ 35mm film: {f_35mm}mm")
108
+ f_px = fpx_from_f35(img.shape[1], img.shape[0], f_35mm)
109
+ else:
110
+ f_px = None
111
+
112
+ return img, icc_profile, f_px