zino36 commited on
Commit
23f7322
1 Parent(s): 43b45af

Upload 4 files

Browse files
Files changed (4) hide show
  1. demo.py +330 -0
  2. misc.py +17 -0
  3. model.py +68 -0
  4. sparse_ga.py +1039 -0
demo.py ADDED
@@ -0,0 +1,330 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Copyright (C) 2024-present Naver Corporation. All rights reserved.
3
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
4
+ #
5
+ # --------------------------------------------------------
6
+ # sparse gradio demo functions
7
+ # --------------------------------------------------------
8
+ import math
9
+ import gradio
10
+ import os
11
+ import numpy as np
12
+ import functools
13
+ import trimesh
14
+ import copy
15
+ from scipy.spatial.transform import Rotation
16
+ import tempfile
17
+ import shutil
18
+
19
+ from mast3r.cloud_opt.sparse_ga import sparse_global_alignment
20
+ from mast3r.cloud_opt.tsdf_optimizer import TSDFPostProcess
21
+
22
+ import mast3r.utils.path_to_dust3r # noqa
23
+ from dust3r.image_pairs import make_pairs
24
+ from dust3r.utils.image import load_images
25
+ from dust3r.utils.device import to_numpy
26
+ from dust3r.viz import add_scene_cam, CAM_COLORS, OPENGL, pts3d_to_trimesh, cat_meshes
27
+ from dust3r.demo import get_args_parser as dust3r_get_args_parser
28
+
29
+ import matplotlib.pyplot as pl
30
+
31
+
32
+ class SparseGAState():
33
+ def __init__(self, sparse_ga, should_delete=False, cache_dir=None, outfile_name=None):
34
+ self.sparse_ga = sparse_ga
35
+ self.cache_dir = cache_dir
36
+ self.outfile_name = outfile_name
37
+ self.should_delete = should_delete
38
+
39
+ def __del__(self):
40
+ if not self.should_delete:
41
+ return
42
+ if self.cache_dir is not None and os.path.isdir(self.cache_dir):
43
+ shutil.rmtree(self.cache_dir)
44
+ self.cache_dir = None
45
+ if self.outfile_name is not None and os.path.isfile(self.outfile_name):
46
+ os.remove(self.outfile_name)
47
+ self.outfile_name = None
48
+
49
+
50
+ def get_args_parser():
51
+ parser = dust3r_get_args_parser()
52
+ parser.add_argument('--share', action='store_true')
53
+ parser.add_argument('--gradio_delete_cache', default=None, type=int,
54
+ help='age/frequency at which gradio removes the file. If >0, matching cache is purged')
55
+
56
+ actions = parser._actions
57
+ for action in actions:
58
+ if action.dest == 'model_name':
59
+ action.choices = ["MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric"]
60
+ # change defaults
61
+ parser.prog = 'mast3r demo'
62
+ return parser
63
+
64
+
65
+ def _convert_scene_output_to_glb(outfile, imgs, pts3d, mask, focals, cams2world, cam_size=0.05,
66
+ cam_color=None, as_pointcloud=False,
67
+ transparent_cams=False, silent=False):
68
+ assert len(pts3d) == len(mask) <= len(imgs) <= len(cams2world) == len(focals)
69
+ pts3d = to_numpy(pts3d)
70
+ imgs = to_numpy(imgs)
71
+ focals = to_numpy(focals)
72
+ cams2world = to_numpy(cams2world)
73
+
74
+ scene = trimesh.Scene()
75
+
76
+ # full pointcloud
77
+ if as_pointcloud:
78
+ pts = np.concatenate([p[m.ravel()] for p, m in zip(pts3d, mask)]).reshape(-1, 3)
79
+ col = np.concatenate([p[m] for p, m in zip(imgs, mask)]).reshape(-1, 3)
80
+ valid_msk = np.isfinite(pts.sum(axis=1))
81
+ pct = trimesh.PointCloud(pts[valid_msk], colors=col[valid_msk])
82
+ scene.add_geometry(pct)
83
+ else:
84
+ meshes = []
85
+ for i in range(len(imgs)):
86
+ pts3d_i = pts3d[i].reshape(imgs[i].shape)
87
+ msk_i = mask[i] & np.isfinite(pts3d_i.sum(axis=-1))
88
+ meshes.append(pts3d_to_trimesh(imgs[i], pts3d_i, msk_i))
89
+ mesh = trimesh.Trimesh(**cat_meshes(meshes))
90
+ scene.add_geometry(mesh)
91
+
92
+ # add each camera
93
+ for i, pose_c2w in enumerate(cams2world):
94
+ if isinstance(cam_color, list):
95
+ camera_edge_color = cam_color[i]
96
+ else:
97
+ camera_edge_color = cam_color or CAM_COLORS[i % len(CAM_COLORS)]
98
+ add_scene_cam(scene, pose_c2w, camera_edge_color,
99
+ None if transparent_cams else imgs[i], focals[i],
100
+ imsize=imgs[i].shape[1::-1], screen_width=cam_size)
101
+
102
+ rot = np.eye(4)
103
+ rot[:3, :3] = Rotation.from_euler('y', np.deg2rad(180)).as_matrix()
104
+ scene.apply_transform(np.linalg.inv(cams2world[0] @ OPENGL @ rot))
105
+ if not silent:
106
+ print('(exporting 3D scene to', outfile, ')')
107
+ scene.export(file_obj=outfile)
108
+ return outfile
109
+
110
+
111
+ def get_3D_model_from_scene(silent, scene_state, min_conf_thr=2, as_pointcloud=False, mask_sky=False,
112
+ clean_depth=False, transparent_cams=False, cam_size=0.05, TSDF_thresh=0):
113
+ """
114
+ extract 3D_model (glb file) from a reconstructed scene
115
+ """
116
+ if scene_state is None:
117
+ return None
118
+ outfile = scene_state.outfile_name
119
+ if outfile is None:
120
+ return None
121
+
122
+ # get optimized values from scene
123
+ scene = scene_state.sparse_ga
124
+ rgbimg = scene.imgs
125
+ focals = scene.get_focals().cpu()
126
+ cams2world = scene.get_im_poses().cpu()
127
+
128
+ # 3D pointcloud from depthmap, poses and intrinsics
129
+ if TSDF_thresh > 0:
130
+ tsdf = TSDFPostProcess(scene, TSDF_thresh=TSDF_thresh)
131
+ pts3d, _, confs = to_numpy(tsdf.get_dense_pts3d(clean_depth=clean_depth))
132
+ else:
133
+ pts3d, _, confs = to_numpy(scene.get_dense_pts3d(clean_depth=clean_depth))
134
+ msk = to_numpy([c > min_conf_thr for c in confs])
135
+ return _convert_scene_output_to_glb(outfile, rgbimg, pts3d, msk, focals, cams2world, as_pointcloud=as_pointcloud,
136
+ transparent_cams=transparent_cams, cam_size=cam_size, silent=silent)
137
+
138
+
139
+ def get_reconstructed_scene(outdir, gradio_delete_cache, model, device, silent, image_size, current_scene_state,
140
+ filelist, optim_level, lr1, niter1, lr2, niter2, min_conf_thr, matching_conf_thr,
141
+ as_pointcloud, mask_sky, clean_depth, transparent_cams, cam_size, scenegraph_type, winsize,
142
+ win_cyclic, refid, TSDF_thresh, shared_intrinsics, **kw):
143
+ """
144
+ from a list of images, run mast3r inference, sparse global aligner.
145
+ then run get_3D_model_from_scene
146
+ """
147
+ imgs = load_images(filelist, size=image_size, verbose=not silent)
148
+ if len(imgs) == 1:
149
+ imgs = [imgs[0], copy.deepcopy(imgs[0])]
150
+ imgs[1]['idx'] = 1
151
+ filelist = [filelist[0], filelist[0] + '_2']
152
+
153
+ scene_graph_params = [scenegraph_type]
154
+ if scenegraph_type in ["swin", "logwin"]:
155
+ scene_graph_params.append(str(winsize))
156
+ elif scenegraph_type == "oneref":
157
+ scene_graph_params.append(str(refid))
158
+ if scenegraph_type in ["swin", "logwin"] and not win_cyclic:
159
+ scene_graph_params.append('noncyclic')
160
+ scene_graph = '-'.join(scene_graph_params)
161
+ pairs = make_pairs(imgs, scene_graph=scene_graph, prefilter=None, symmetrize=True)
162
+ if optim_level == 'coarse':
163
+ niter2 = 0
164
+ # Sparse GA (forward mast3r -> matching -> 3D optim -> 2D refinement -> triangulation)
165
+ if current_scene_state is not None and \
166
+ not current_scene_state.should_delete and \
167
+ current_scene_state.cache_dir is not None:
168
+ cache_dir = current_scene_state.cache_dir
169
+ elif gradio_delete_cache:
170
+ cache_dir = tempfile.mkdtemp(suffix='_cache', dir=outdir)
171
+ else:
172
+ cache_dir = os.path.join(outdir, 'cache')
173
+ scene = sparse_global_alignment(filelist, pairs, cache_dir,
174
+ model, lr1=lr1, niter1=niter1, lr2=lr2, niter2=niter2, device=device,
175
+ opt_depth='depth' in optim_level, shared_intrinsics=shared_intrinsics,
176
+ matching_conf_thr=matching_conf_thr, **kw)
177
+ if current_scene_state is not None and \
178
+ not current_scene_state.should_delete and \
179
+ current_scene_state.outfile_name is not None:
180
+ outfile_name = current_scene_state.outfile_name
181
+ else:
182
+ outfile_name = tempfile.mktemp(suffix='_scene.glb', dir=outdir)
183
+
184
+ scene_state = SparseGAState(scene, gradio_delete_cache, cache_dir, outfile_name)
185
+ outfile = get_3D_model_from_scene(silent, scene_state, min_conf_thr, as_pointcloud, mask_sky,
186
+ clean_depth, transparent_cams, cam_size, TSDF_thresh)
187
+ return scene_state, outfile
188
+
189
+
190
+ def set_scenegraph_options(inputfiles, win_cyclic, refid, scenegraph_type):
191
+ num_files = len(inputfiles) if inputfiles is not None else 1
192
+ show_win_controls = scenegraph_type in ["swin", "logwin"]
193
+ show_winsize = scenegraph_type in ["swin", "logwin"]
194
+ show_cyclic = scenegraph_type in ["swin", "logwin"]
195
+ max_winsize, min_winsize = 1, 1
196
+ if scenegraph_type == "swin":
197
+ if win_cyclic:
198
+ max_winsize = max(1, math.ceil((num_files - 1) / 2))
199
+ else:
200
+ max_winsize = num_files - 1
201
+ elif scenegraph_type == "logwin":
202
+ if win_cyclic:
203
+ half_size = math.ceil((num_files - 1) / 2)
204
+ max_winsize = max(1, math.ceil(math.log(half_size, 2)))
205
+ else:
206
+ max_winsize = max(1, math.ceil(math.log(num_files, 2)))
207
+ winsize = gradio.Slider(label="Scene Graph: Window Size", value=max_winsize,
208
+ minimum=min_winsize, maximum=max_winsize, step=1, visible=show_winsize)
209
+ win_cyclic = gradio.Checkbox(value=win_cyclic, label="Cyclic sequence", visible=show_cyclic)
210
+ win_col = gradio.Column(visible=show_win_controls)
211
+ refid = gradio.Slider(label="Scene Graph: Id", value=0, minimum=0,
212
+ maximum=num_files - 1, step=1, visible=scenegraph_type == 'oneref')
213
+ return win_col, winsize, win_cyclic, refid
214
+
215
+
216
+ def main_demo(tmpdirname, model, device, image_size, server_name, server_port, silent=False,
217
+ share=False, gradio_delete_cache=False):
218
+ if not silent:
219
+ print('Outputing stuff in', tmpdirname)
220
+
221
+ recon_fun = functools.partial(get_reconstructed_scene, tmpdirname, gradio_delete_cache, model, device,
222
+ silent, image_size)
223
+ model_from_scene_fun = functools.partial(get_3D_model_from_scene, silent)
224
+
225
+ def get_context(delete_cache):
226
+ css = """.gradio-container {margin: 0 !important; min-width: 100%};"""
227
+ title = "MASt3R Demo"
228
+ if delete_cache:
229
+ return gradio.Blocks(css=css, title=title, delete_cache=(delete_cache, delete_cache))
230
+ else:
231
+ return gradio.Blocks(css=css, title="MASt3R Demo") # for compatibility with older versions
232
+
233
+ with get_context(gradio_delete_cache) as demo:
234
+ # scene state is save so that you can change conf_thr, cam_size... without rerunning the inference
235
+ scene = gradio.State(None)
236
+ gradio.HTML('<h2 style="text-align: center;">MASt3R Demo</h2>')
237
+ with gradio.Column():
238
+ inputfiles = gradio.File(file_count="multiple")
239
+ with gradio.Row():
240
+ with gradio.Column():
241
+ with gradio.Row():
242
+ lr1 = gradio.Slider(label="Coarse LR", value=0.07, minimum=0.01, maximum=0.2, step=0.01)
243
+ niter1 = gradio.Number(value=500, precision=0, minimum=0, maximum=10_000,
244
+ label="num_iterations", info="For coarse alignment!")
245
+ lr2 = gradio.Slider(label="Fine LR", value=0.014, minimum=0.005, maximum=0.05, step=0.001)
246
+ niter2 = gradio.Number(value=200, precision=0, minimum=0, maximum=100_000,
247
+ label="num_iterations", info="For refinement!")
248
+ optim_level = gradio.Dropdown(["coarse", "refine", "refine+depth"],
249
+ value='refine', label="OptLevel",
250
+ info="Optimization level")
251
+ with gradio.Row():
252
+ matching_conf_thr = gradio.Slider(label="Matching Confidence Thr", value=5.,
253
+ minimum=0., maximum=30., step=0.1,
254
+ info="Before Fallback to Regr3D!")
255
+ shared_intrinsics = gradio.Checkbox(value=False, label="Shared intrinsics",
256
+ info="Only optimize one set of intrinsics for all views")
257
+ scenegraph_type = gradio.Dropdown([("complete: all possible image pairs", "complete"),
258
+ ("swin: sliding window", "swin"),
259
+ ("logwin: sliding window with long range", "logwin"),
260
+ ("oneref: match one image with all", "oneref")],
261
+ value='complete', label="Scenegraph",
262
+ info="Define how to make pairs",
263
+ interactive=True)
264
+ with gradio.Column(visible=False) as win_col:
265
+ winsize = gradio.Slider(label="Scene Graph: Window Size", value=1,
266
+ minimum=1, maximum=1, step=1)
267
+ win_cyclic = gradio.Checkbox(value=False, label="Cyclic sequence")
268
+ refid = gradio.Slider(label="Scene Graph: Id", value=0,
269
+ minimum=0, maximum=0, step=1, visible=False)
270
+ run_btn = gradio.Button("Run")
271
+
272
+ with gradio.Row():
273
+ # adjust the confidence threshold
274
+ min_conf_thr = gradio.Slider(label="min_conf_thr", value=1.5, minimum=0.0, maximum=10, step=0.1)
275
+ # adjust the camera size in the output pointcloud
276
+ cam_size = gradio.Slider(label="cam_size", value=0.2, minimum=0.001, maximum=1.0, step=0.001)
277
+ TSDF_thresh = gradio.Slider(label="TSDF Threshold", value=0., minimum=0., maximum=1., step=0.01)
278
+ with gradio.Row():
279
+ as_pointcloud = gradio.Checkbox(value=True, label="As pointcloud")
280
+ # two post process implemented
281
+ mask_sky = gradio.Checkbox(value=False, label="Mask sky")
282
+ clean_depth = gradio.Checkbox(value=True, label="Clean-up depthmaps")
283
+ transparent_cams = gradio.Checkbox(value=False, label="Transparent cameras")
284
+
285
+ outmodel = gradio.Model3D()
286
+
287
+ # events
288
+ scenegraph_type.change(set_scenegraph_options,
289
+ inputs=[inputfiles, win_cyclic, refid, scenegraph_type],
290
+ outputs=[win_col, winsize, win_cyclic, refid])
291
+ inputfiles.change(set_scenegraph_options,
292
+ inputs=[inputfiles, win_cyclic, refid, scenegraph_type],
293
+ outputs=[win_col, winsize, win_cyclic, refid])
294
+ win_cyclic.change(set_scenegraph_options,
295
+ inputs=[inputfiles, win_cyclic, refid, scenegraph_type],
296
+ outputs=[win_col, winsize, win_cyclic, refid])
297
+ run_btn.click(fn=recon_fun,
298
+ inputs=[scene, inputfiles, optim_level, lr1, niter1, lr2, niter2, min_conf_thr, matching_conf_thr,
299
+ as_pointcloud, mask_sky, clean_depth, transparent_cams, cam_size,
300
+ scenegraph_type, winsize, win_cyclic, refid, TSDF_thresh, shared_intrinsics],
301
+ outputs=[scene, outmodel])
302
+ min_conf_thr.release(fn=model_from_scene_fun,
303
+ inputs=[scene, min_conf_thr, as_pointcloud, mask_sky,
304
+ clean_depth, transparent_cams, cam_size, TSDF_thresh],
305
+ outputs=outmodel)
306
+ cam_size.change(fn=model_from_scene_fun,
307
+ inputs=[scene, min_conf_thr, as_pointcloud, mask_sky,
308
+ clean_depth, transparent_cams, cam_size, TSDF_thresh],
309
+ outputs=outmodel)
310
+ TSDF_thresh.change(fn=model_from_scene_fun,
311
+ inputs=[scene, min_conf_thr, as_pointcloud, mask_sky,
312
+ clean_depth, transparent_cams, cam_size, TSDF_thresh],
313
+ outputs=outmodel)
314
+ as_pointcloud.change(fn=model_from_scene_fun,
315
+ inputs=[scene, min_conf_thr, as_pointcloud, mask_sky,
316
+ clean_depth, transparent_cams, cam_size, TSDF_thresh],
317
+ outputs=outmodel)
318
+ mask_sky.change(fn=model_from_scene_fun,
319
+ inputs=[scene, min_conf_thr, as_pointcloud, mask_sky,
320
+ clean_depth, transparent_cams, cam_size, TSDF_thresh],
321
+ outputs=outmodel)
322
+ clean_depth.change(fn=model_from_scene_fun,
323
+ inputs=[scene, min_conf_thr, as_pointcloud, mask_sky,
324
+ clean_depth, transparent_cams, cam_size, TSDF_thresh],
325
+ outputs=outmodel)
326
+ transparent_cams.change(model_from_scene_fun,
327
+ inputs=[scene, min_conf_thr, as_pointcloud, mask_sky,
328
+ clean_depth, transparent_cams, cam_size, TSDF_thresh],
329
+ outputs=outmodel)
330
+ demo.launch(share=share, server_name=server_name, server_port=server_port)
misc.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+ #
4
+ # --------------------------------------------------------
5
+ # utilitary functions for MASt3R
6
+ # --------------------------------------------------------
7
+ import os
8
+ import hashlib
9
+
10
+
11
+ def mkdir_for(f):
12
+ os.makedirs(os.path.dirname(f), exist_ok=True)
13
+ return f
14
+
15
+
16
+ def hash_md5(s):
17
+ return hashlib.md5(s.encode('utf-8')).hexdigest()
model.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+ #
4
+ # --------------------------------------------------------
5
+ # MASt3R model class
6
+ # --------------------------------------------------------
7
+ import torch
8
+ import torch.nn.functional as F
9
+ import os
10
+
11
+ from mast3r.catmlp_dpt_head import mast3r_head_factory
12
+
13
+ import mast3r.utils.path_to_dust3r # noqa
14
+ from dust3r.model import AsymmetricCroCo3DStereo # noqa
15
+ from dust3r.utils.misc import transpose_to_landscape # noqa
16
+
17
+
18
+ inf = float('inf')
19
+
20
+
21
+ def load_model(model_path, device, verbose=True):
22
+ if verbose:
23
+ print('... loading model from', model_path)
24
+ ckpt = torch.load(model_path, map_location='cpu')
25
+ args = ckpt['args'].model.replace("ManyAR_PatchEmbed", "PatchEmbedDust3R")
26
+ if 'landscape_only' not in args:
27
+ args = args[:-1] + ', landscape_only=False)'
28
+ else:
29
+ args = args.replace(" ", "").replace('landscape_only=True', 'landscape_only=False')
30
+ assert "landscape_only=False" in args
31
+ if verbose:
32
+ print(f"instantiating : {args}")
33
+ net = eval(args)
34
+ s = net.load_state_dict(ckpt['model'], strict=False)
35
+ if verbose:
36
+ print(s)
37
+ return net.to(device)
38
+
39
+
40
+ class AsymmetricMASt3R(AsymmetricCroCo3DStereo):
41
+ def __init__(self, desc_mode=('norm'), two_confs=False, desc_conf_mode=None, **kwargs):
42
+ self.desc_mode = desc_mode
43
+ self.two_confs = two_confs
44
+ self.desc_conf_mode = desc_conf_mode
45
+ super().__init__(**kwargs)
46
+
47
+ @classmethod
48
+ def from_pretrained(cls, pretrained_model_name_or_path, **kw):
49
+ if os.path.isfile(pretrained_model_name_or_path):
50
+ return load_model(pretrained_model_name_or_path, device='cpu')
51
+ else:
52
+ return super(AsymmetricMASt3R, cls).from_pretrained(pretrained_model_name_or_path, **kw)
53
+
54
+ def set_downstream_head(self, output_mode, head_type, landscape_only, depth_mode, conf_mode, patch_size, img_size, **kw):
55
+ assert img_size[0] % patch_size == 0 and img_size[
56
+ 1] % patch_size == 0, f'{img_size=} must be multiple of {patch_size=}'
57
+ self.output_mode = output_mode
58
+ self.head_type = head_type
59
+ self.depth_mode = depth_mode
60
+ self.conf_mode = conf_mode
61
+ if self.desc_conf_mode is None:
62
+ self.desc_conf_mode = conf_mode
63
+ # allocate heads
64
+ self.downstream_head1 = mast3r_head_factory(head_type, output_mode, self, has_conf=bool(conf_mode))
65
+ self.downstream_head2 = mast3r_head_factory(head_type, output_mode, self, has_conf=bool(conf_mode))
66
+ # magic wrapper
67
+ self.head1 = transpose_to_landscape(self.downstream_head1, activate=landscape_only)
68
+ self.head2 = transpose_to_landscape(self.downstream_head2, activate=landscape_only)
sparse_ga.py ADDED
@@ -0,0 +1,1039 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+ #
4
+ # --------------------------------------------------------
5
+ # MASt3R Sparse Global Alignement
6
+ # --------------------------------------------------------
7
+ from tqdm import tqdm
8
+ import roma
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+ import numpy as np
13
+ import os
14
+ from collections import namedtuple
15
+ from functools import lru_cache
16
+ from scipy import sparse as sp
17
+
18
+ from mast3r.utils.misc import mkdir_for, hash_md5
19
+ from mast3r.cloud_opt.utils.losses import gamma_loss
20
+ from mast3r.cloud_opt.utils.schedules import linear_schedule, cosine_schedule
21
+ from mast3r.fast_nn import fast_reciprocal_NNs, merge_corres
22
+
23
+ import mast3r.utils.path_to_dust3r # noqa
24
+ from dust3r.utils.geometry import inv, geotrf # noqa
25
+ from dust3r.utils.device import to_cpu, to_numpy, todevice # noqa
26
+ from dust3r.post_process import estimate_focal_knowing_depth # noqa
27
+ from dust3r.optim_factory import adjust_learning_rate_by_lr # noqa
28
+ from dust3r.cloud_opt.base_opt import clean_pointcloud
29
+ from dust3r.viz import SceneViz
30
+
31
+
32
+ class SparseGA():
33
+ def __init__(self, img_paths, pairs_in, res_fine, anchors, canonical_paths=None):
34
+ def fetch_img(im):
35
+ def torgb(x): return (x[0].permute(1, 2, 0).numpy() * .5 + .5).clip(min=0., max=1.)
36
+ for im1, im2 in pairs_in:
37
+ if im1['instance'] == im:
38
+ return torgb(im1['img'])
39
+ if im2['instance'] == im:
40
+ return torgb(im2['img'])
41
+ self.canonical_paths = canonical_paths
42
+ self.img_paths = img_paths
43
+ self.imgs = [fetch_img(img) for img in img_paths]
44
+ self.intrinsics = res_fine['intrinsics']
45
+ self.cam2w = res_fine['cam2w']
46
+ self.depthmaps = res_fine['depthmaps']
47
+ self.pts3d = res_fine['pts3d']
48
+ self.pts3d_colors = []
49
+ self.working_device = self.cam2w.device
50
+ for i in range(len(self.imgs)):
51
+ im = self.imgs[i]
52
+ x, y = anchors[i][0][..., :2].detach().cpu().numpy().T
53
+ self.pts3d_colors.append(im[y, x])
54
+ assert self.pts3d_colors[-1].shape == self.pts3d[i].shape
55
+ self.n_imgs = len(self.imgs)
56
+
57
+ def get_focals(self):
58
+ return torch.tensor([ff[0, 0] for ff in self.intrinsics]).to(self.working_device)
59
+
60
+ def get_principal_points(self):
61
+ return torch.stack([ff[:2, -1] for ff in self.intrinsics]).to(self.working_device)
62
+
63
+ def get_im_poses(self):
64
+ return self.cam2w
65
+
66
+ def get_sparse_pts3d(self):
67
+ return self.pts3d
68
+
69
+ def get_dense_pts3d(self, clean_depth=True, subsample=8):
70
+ assert self.canonical_paths, 'cache_path is required for dense 3d points'
71
+ device = self.cam2w.device
72
+ confs = []
73
+ base_focals = []
74
+ anchors = {}
75
+ for i, canon_path in enumerate(self.canonical_paths):
76
+ (canon, canon2, conf), focal = torch.load(canon_path, map_location=device)
77
+ confs.append(conf)
78
+ base_focals.append(focal)
79
+
80
+ H, W = conf.shape
81
+ pixels = torch.from_numpy(np.mgrid[:W, :H].T.reshape(-1, 2)).float().to(device)
82
+ idxs, offsets = anchor_depth_offsets(canon2, {i: (pixels, None)}, subsample=subsample)
83
+ anchors[i] = (pixels, idxs[i], offsets[i])
84
+
85
+ # densify sparse depthmaps
86
+ pts3d, depthmaps = make_pts3d(anchors, self.intrinsics, self.cam2w, [
87
+ d.ravel() for d in self.depthmaps], base_focals=base_focals, ret_depth=True)
88
+
89
+ if clean_depth:
90
+ confs = clean_pointcloud(confs, self.intrinsics, inv(self.cam2w), depthmaps, pts3d)
91
+
92
+ return pts3d, depthmaps, confs
93
+
94
+ def get_pts3d_colors(self):
95
+ return self.pts3d_colors
96
+
97
+ def get_depthmaps(self):
98
+ return self.depthmaps
99
+
100
+ def get_masks(self):
101
+ return [slice(None, None) for _ in range(len(self.imgs))]
102
+
103
+ def show(self, show_cams=True):
104
+ pts3d, _, confs = self.get_dense_pts3d()
105
+ show_reconstruction(self.imgs, self.intrinsics if show_cams else None, self.cam2w,
106
+ [p.clip(min=-50, max=50) for p in pts3d],
107
+ masks=[c > 1 for c in confs])
108
+
109
+
110
+ def convert_dust3r_pairs_naming(imgs, pairs_in):
111
+ for pair_id in range(len(pairs_in)):
112
+ for i in range(2):
113
+ pairs_in[pair_id][i]['instance'] = imgs[pairs_in[pair_id][i]['idx']]
114
+ return pairs_in
115
+
116
+
117
+ def sparse_global_alignment(imgs, pairs_in, cache_path, model, subsample=8, desc_conf='desc_conf',
118
+ device='cuda', dtype=torch.float32, shared_intrinsics=False, **kw):
119
+ """ Sparse alignment with MASt3R
120
+ imgs: list of image paths
121
+ cache_path: path where to dump temporary files (str)
122
+
123
+ lr1, niter1: learning rate and #iterations for coarse global alignment (3D matching)
124
+ lr2, niter2: learning rate and #iterations for refinement (2D reproj error)
125
+
126
+ lora_depth: smart dimensionality reduction with depthmaps
127
+ """
128
+ # Convert pair naming convention from dust3r to mast3r
129
+ pairs_in = convert_dust3r_pairs_naming(imgs, pairs_in)
130
+ # forward pass
131
+ pairs, cache_path = forward_mast3r(pairs_in, model,
132
+ cache_path=cache_path, subsample=subsample,
133
+ desc_conf=desc_conf, device=device)
134
+
135
+ # extract canonical pointmaps
136
+ tmp_pairs, pairwise_scores, canonical_views, canonical_paths, preds_21 = \
137
+ prepare_canonical_data(imgs, pairs, subsample, cache_path=cache_path, mode='avg-angle', device=device)
138
+
139
+ # compute minimal spanning tree
140
+ mst = compute_min_spanning_tree(pairwise_scores)
141
+
142
+ # remove all edges not in the spanning tree?
143
+ # min_spanning_tree = {(imgs[i],imgs[j]) for i,j in mst[1]}
144
+ # tmp_pairs = {(a,b):v for (a,b),v in tmp_pairs.items() if {(a,b),(b,a)} & min_spanning_tree}
145
+
146
+ # smartly combine all useful data
147
+ imsizes, pps, base_focals, core_depth, anchors, corres, corres2d, preds_21 = \
148
+ condense_data(imgs, tmp_pairs, canonical_views, preds_21, dtype)
149
+
150
+ imgs, res_coarse, res_fine = sparse_scene_optimizer(
151
+ imgs, subsample, imsizes, pps, base_focals, core_depth, anchors, corres, corres2d, preds_21, canonical_paths, mst,
152
+ shared_intrinsics=shared_intrinsics, cache_path=cache_path, device=device, dtype=dtype, **kw)
153
+
154
+ return SparseGA(imgs, pairs_in, res_fine or res_coarse, anchors, canonical_paths)
155
+
156
+
157
+ def sparse_scene_optimizer(imgs, subsample, imsizes, pps, base_focals, core_depth, anchors, corres, corres2d,
158
+ preds_21, canonical_paths, mst, cache_path,
159
+ lr1=0.2, niter1=500, loss1=gamma_loss(1.1),
160
+ lr2=0.02, niter2=500, loss2=gamma_loss(0.4),
161
+ lossd=gamma_loss(1.1),
162
+ opt_pp=True, opt_depth=True,
163
+ schedule=cosine_schedule, depth_mode='add', exp_depth=False,
164
+ lora_depth=False, # dict(k=96, gamma=15, min_norm=.5),
165
+ shared_intrinsics=False,
166
+ init={}, device='cuda', dtype=torch.float32,
167
+ matching_conf_thr=5., loss_dust3r_w=0.01,
168
+ verbose=True, dbg=()):
169
+
170
+ # extrinsic parameters
171
+ vec0001 = torch.tensor((0, 0, 0, 1), dtype=dtype, device=device)
172
+ quats = [nn.Parameter(vec0001.clone()) for _ in range(len(imgs))]
173
+ trans = [nn.Parameter(torch.zeros(3, device=device, dtype=dtype)) for _ in range(len(imgs))]
174
+
175
+ # intialize
176
+ ones = torch.ones((len(imgs), 1), device=device, dtype=dtype)
177
+ median_depths = torch.ones(len(imgs), device=device, dtype=dtype)
178
+ for img in imgs:
179
+ idx = imgs.index(img)
180
+ init_values = init.setdefault(img, {})
181
+ if verbose and init_values:
182
+ print(f' >> initializing img=...{img[-25:]} [{idx}] for {set(init_values)}')
183
+
184
+ K = init_values.get('intrinsics')
185
+ if K is not None:
186
+ K = K.detach()
187
+ focal = K[:2, :2].diag().mean()
188
+ pp = K[:2, 2]
189
+ base_focals[idx] = focal
190
+ pps[idx] = pp
191
+ pps[idx] /= imsizes[idx] # default principal_point would be (0.5, 0.5)
192
+
193
+ depth = init_values.get('depthmap')
194
+ if depth is not None:
195
+ core_depth[idx] = depth.detach()
196
+
197
+ median_depths[idx] = med_depth = core_depth[idx].median()
198
+ core_depth[idx] /= med_depth
199
+
200
+ cam2w = init_values.get('cam2w')
201
+ if cam2w is not None:
202
+ rot = cam2w[:3, :3].detach()
203
+ cam_center = cam2w[:3, 3].detach()
204
+ quats[idx].data[:] = roma.rotmat_to_unitquat(rot)
205
+ trans_offset = med_depth * torch.cat((imsizes[idx] / base_focals[idx] * (0.5 - pps[idx]), ones[:1, 0]))
206
+ trans[idx].data[:] = cam_center + rot @ trans_offset
207
+ del rot
208
+ assert False, 'inverse kinematic chain not yet implemented'
209
+
210
+ # intrinsics parameters
211
+ if shared_intrinsics:
212
+ # Optimize a single set of intrinsics for all cameras. Use averages as init.
213
+ confs = torch.stack([torch.load(pth)[0][2].mean() for pth in canonical_paths]).to(pps)
214
+ weighting = confs / confs.sum()
215
+ pp = nn.Parameter((weighting @ pps).to(dtype))
216
+ pps = [pp for _ in range(len(imgs))]
217
+ focal_m = weighting @ base_focals
218
+ log_focal = nn.Parameter(focal_m.view(1).log().to(dtype))
219
+ log_focals = [log_focal for _ in range(len(imgs))]
220
+ else:
221
+ pps = [nn.Parameter(pp.to(dtype)) for pp in pps]
222
+ log_focals = [nn.Parameter(f.view(1).log().to(dtype)) for f in base_focals]
223
+
224
+ diags = imsizes.float().norm(dim=1)
225
+ min_focals = 0.25 * diags # diag = 1.2~1.4*max(W,H) => beta >= 1/(2*1.2*tan(fov/2)) ~= 0.26
226
+ max_focals = 10 * diags
227
+
228
+ assert len(mst[1]) == len(pps) - 1
229
+
230
+ def make_K_cam_depth(log_focals, pps, trans, quats, log_sizes, core_depth):
231
+ # make intrinsics
232
+ focals = torch.cat(log_focals).exp().clip(min=min_focals, max=max_focals)
233
+ pps = torch.stack(pps)
234
+ K = torch.eye(3, dtype=dtype, device=device)[None].expand(len(imgs), 3, 3).clone()
235
+ K[:, 0, 0] = K[:, 1, 1] = focals
236
+ K[:, 0:2, 2] = pps * imsizes
237
+ if trans is None:
238
+ return K
239
+
240
+ # security! optimization is always trying to crush the scale down
241
+ sizes = torch.cat(log_sizes).exp()
242
+ global_scaling = 1 / sizes.min()
243
+
244
+ # compute distance of camera to focal plane
245
+ # tan(fov) = W/2 / focal
246
+ z_cameras = sizes * median_depths * focals / base_focals
247
+
248
+ # make extrinsic
249
+ rel_cam2cam = torch.eye(4, dtype=dtype, device=device)[None].expand(len(imgs), 4, 4).clone()
250
+ rel_cam2cam[:, :3, :3] = roma.unitquat_to_rotmat(F.normalize(torch.stack(quats), dim=1))
251
+ rel_cam2cam[:, :3, 3] = torch.stack(trans)
252
+
253
+ # camera are defined as a kinematic chain
254
+ tmp_cam2w = [None] * len(K)
255
+ tmp_cam2w[mst[0]] = rel_cam2cam[mst[0]]
256
+ for i, j in mst[1]:
257
+ # i is the cam_i_to_world reference, j is the relative pose = cam_j_to_cam_i
258
+ tmp_cam2w[j] = tmp_cam2w[i] @ rel_cam2cam[j]
259
+ tmp_cam2w = torch.stack(tmp_cam2w)
260
+
261
+ # smart reparameterizaton of cameras
262
+ trans_offset = z_cameras.unsqueeze(1) * torch.cat((imsizes / focals.unsqueeze(1) * (0.5 - pps), ones), dim=-1)
263
+ new_trans = global_scaling * (tmp_cam2w[:, :3, 3:4] - tmp_cam2w[:, :3, :3] @ trans_offset.unsqueeze(-1))
264
+ cam2w = torch.cat((torch.cat((tmp_cam2w[:, :3, :3], new_trans), dim=2),
265
+ vec0001.view(1, 1, 4).expand(len(K), 1, 4)), dim=1)
266
+
267
+ depthmaps = []
268
+ for i in range(len(imgs)):
269
+ core_depth_img = core_depth[i]
270
+ if exp_depth:
271
+ core_depth_img = core_depth_img.exp()
272
+ if lora_depth: # compute core_depth as a low-rank decomposition of 3d points
273
+ core_depth_img = lora_depth_proj[i] @ core_depth_img
274
+ if depth_mode == 'add':
275
+ core_depth_img = z_cameras[i] + (core_depth_img - 1) * (median_depths[i] * sizes[i])
276
+ elif depth_mode == 'mul':
277
+ core_depth_img = z_cameras[i] * core_depth_img
278
+ else:
279
+ raise ValueError(f'Bad {depth_mode=}')
280
+ depthmaps.append(global_scaling * core_depth_img)
281
+
282
+ return K, (inv(cam2w), cam2w), depthmaps
283
+
284
+ K = make_K_cam_depth(log_focals, pps, None, None, None, None)
285
+
286
+ if shared_intrinsics:
287
+ print('init focal (shared) = ', to_numpy(K[0, 0, 0]).round(2))
288
+ else:
289
+ print('init focals =', to_numpy(K[:, 0, 0]))
290
+
291
+ # spectral low-rank projection of depthmaps
292
+ if lora_depth:
293
+ core_depth, lora_depth_proj = spectral_projection_of_depthmaps(
294
+ imgs, K, core_depth, subsample, cache_path=cache_path, **lora_depth)
295
+ if exp_depth:
296
+ core_depth = [d.clip(min=1e-4).log() for d in core_depth]
297
+ core_depth = [nn.Parameter(d.ravel().to(dtype)) for d in core_depth]
298
+ log_sizes = [nn.Parameter(torch.zeros(1, dtype=dtype, device=device)) for _ in range(len(imgs))]
299
+
300
+ # Fetch img slices
301
+ _, confs_sum, imgs_slices = corres
302
+
303
+ # Define which pairs are fine to use with matching
304
+ def matching_check(x): return x.max() > matching_conf_thr
305
+ is_matching_ok = {}
306
+ for s in imgs_slices:
307
+ is_matching_ok[s.img1, s.img2] = matching_check(s.confs)
308
+
309
+ # Prepare slices and corres for losses
310
+ dust3r_slices = [s for s in imgs_slices if not is_matching_ok[s.img1, s.img2]]
311
+ loss3d_slices = [s for s in imgs_slices if is_matching_ok[s.img1, s.img2]]
312
+ cleaned_corres2d = []
313
+ for cci, (img1, pix1, confs, confsum, imgs_slices) in enumerate(corres2d):
314
+ cf_sum = 0
315
+ pix1_filtered = []
316
+ confs_filtered = []
317
+ curstep = 0
318
+ cleaned_slices = []
319
+ for img2, slice2 in imgs_slices:
320
+ if is_matching_ok[img1, img2]:
321
+ tslice = slice(curstep, curstep + slice2.stop - slice2.start, slice2.step)
322
+ pix1_filtered.append(pix1[tslice])
323
+ confs_filtered.append(confs[tslice])
324
+ cleaned_slices.append((img2, slice2))
325
+ curstep += slice2.stop - slice2.start
326
+ if pix1_filtered != []:
327
+ pix1_filtered = torch.cat(pix1_filtered)
328
+ confs_filtered = torch.cat(confs_filtered)
329
+ cf_sum = confs_filtered.sum()
330
+ cleaned_corres2d.append((img1, pix1_filtered, confs_filtered, cf_sum, cleaned_slices))
331
+
332
+ def loss_dust3r(cam2w, pts3d, pix_loss):
333
+ # In the case no correspondence could be established, fallback to DUSt3R GA regression loss formulation (sparsified)
334
+ loss = 0.
335
+ cf_sum = 0.
336
+ for s in dust3r_slices:
337
+ if init[imgs[s.img1]].get('freeze') and init[imgs[s.img2]].get('freeze'):
338
+ continue
339
+ # fallback to dust3r regression
340
+ tgt_pts, tgt_confs = preds_21[imgs[s.img2]][imgs[s.img1]]
341
+ tgt_pts = geotrf(cam2w[s.img2], tgt_pts)
342
+ cf_sum += tgt_confs.sum()
343
+ loss += tgt_confs @ pix_loss(pts3d[s.img1], tgt_pts)
344
+ return loss / cf_sum if cf_sum != 0. else 0.
345
+
346
+ def loss_3d(K, w2cam, pts3d, pix_loss):
347
+ # For each correspondence, we have two 3D points (one for each image of the pair).
348
+ # For each 3D point, we have 2 reproj errors
349
+ if any(v.get('freeze') for v in init.values()):
350
+ pts3d_1 = []
351
+ pts3d_2 = []
352
+ confs = []
353
+ for s in loss3d_slices:
354
+ if init[imgs[s.img1]].get('freeze') and init[imgs[s.img2]].get('freeze'):
355
+ continue
356
+ pts3d_1.append(pts3d[s.img1][s.slice1])
357
+ pts3d_2.append(pts3d[s.img2][s.slice2])
358
+ confs.append(s.confs)
359
+ else:
360
+ pts3d_1 = [pts3d[s.img1][s.slice1] for s in loss3d_slices]
361
+ pts3d_2 = [pts3d[s.img2][s.slice2] for s in loss3d_slices]
362
+ confs = [s.confs for s in loss3d_slices]
363
+
364
+ if pts3d_1 != []:
365
+ confs = torch.cat(confs)
366
+ pts3d_1 = torch.cat(pts3d_1)
367
+ pts3d_2 = torch.cat(pts3d_2)
368
+ loss = confs @ pix_loss(pts3d_1, pts3d_2)
369
+ cf_sum = confs.sum()
370
+ else:
371
+ loss = 0.
372
+ cf_sum = 1.
373
+
374
+ return loss / cf_sum
375
+
376
+ def loss_2d(K, w2cam, pts3d, pix_loss):
377
+ # For each correspondence, we have two 3D points (one for each image of the pair).
378
+ # For each 3D point, we have 2 reproj errors
379
+ proj_matrix = K @ w2cam[:, :3]
380
+ loss = npix = 0
381
+ for img1, pix1_filtered, confs_filtered, cf_sum, cleaned_slices in cleaned_corres2d:
382
+ if init[imgs[img1]].get('freeze', 0) >= 1:
383
+ continue # no need
384
+ pts3d_in_img1 = [pts3d[img2][slice2] for img2, slice2 in cleaned_slices]
385
+ if pts3d_in_img1 != []:
386
+ pts3d_in_img1 = torch.cat(pts3d_in_img1)
387
+ loss += confs_filtered @ pix_loss(pix1_filtered, reproj2d(proj_matrix[img1], pts3d_in_img1))
388
+ npix += confs_filtered.sum()
389
+
390
+ return loss / npix if npix != 0 else 0.
391
+
392
+ def optimize_loop(loss_func, lr_base, niter, pix_loss, lr_end=0):
393
+ # create optimizer
394
+ params = pps + log_focals + quats + trans + log_sizes + core_depth
395
+ optimizer = torch.optim.Adam(params, lr=1, weight_decay=0, betas=(0.9, 0.9))
396
+ ploss = pix_loss if 'meta' in repr(pix_loss) else (lambda a: pix_loss)
397
+
398
+ with tqdm(total=niter) as bar:
399
+ for iter in range(niter or 1):
400
+ K, (w2cam, cam2w), depthmaps = make_K_cam_depth(log_focals, pps, trans, quats, log_sizes, core_depth)
401
+ pts3d = make_pts3d(anchors, K, cam2w, depthmaps, base_focals=base_focals)
402
+ if niter == 0:
403
+ break
404
+
405
+ alpha = (iter / niter)
406
+ lr = schedule(alpha, lr_base, lr_end)
407
+ adjust_learning_rate_by_lr(optimizer, lr)
408
+ pix_loss = ploss(1 - alpha)
409
+ optimizer.zero_grad()
410
+ loss = loss_func(K, w2cam, pts3d, pix_loss) + loss_dust3r_w * loss_dust3r(cam2w, pts3d, lossd)
411
+ loss.backward()
412
+ optimizer.step()
413
+
414
+ # make sure the pose remains well optimizable
415
+ for i in range(len(imgs)):
416
+ quats[i].data[:] /= quats[i].data.norm()
417
+
418
+ loss = float(loss)
419
+ if loss != loss:
420
+ break # NaN loss
421
+ bar.set_postfix_str(f'{lr=:.4f}, {loss=:.3f}')
422
+ bar.update(1)
423
+
424
+ if niter:
425
+ print(f'>> final loss = {loss}')
426
+ return dict(intrinsics=K.detach(), cam2w=cam2w.detach(),
427
+ depthmaps=[d.detach() for d in depthmaps], pts3d=[p.detach() for p in pts3d])
428
+
429
+ # at start, don't optimize 3d points
430
+ for i, img in enumerate(imgs):
431
+ trainable = not (init[img].get('freeze'))
432
+ pps[i].requires_grad_(False)
433
+ log_focals[i].requires_grad_(False)
434
+ quats[i].requires_grad_(trainable)
435
+ trans[i].requires_grad_(trainable)
436
+ log_sizes[i].requires_grad_(trainable)
437
+ core_depth[i].requires_grad_(False)
438
+
439
+ res_coarse = optimize_loop(loss_3d, lr_base=lr1, niter=niter1, pix_loss=loss1)
440
+
441
+ res_fine = None
442
+ if niter2:
443
+ # now we can optimize 3d points
444
+ for i, img in enumerate(imgs):
445
+ if init[img].get('freeze', 0) >= 1:
446
+ continue
447
+ pps[i].requires_grad_(bool(opt_pp))
448
+ log_focals[i].requires_grad_(True)
449
+ core_depth[i].requires_grad_(opt_depth)
450
+
451
+ # refinement with 2d reproj
452
+ res_fine = optimize_loop(loss_2d, lr_base=lr2, niter=niter2, pix_loss=loss2)
453
+
454
+ K = make_K_cam_depth(log_focals, pps, None, None, None, None)
455
+ if shared_intrinsics:
456
+ print('Final focal (shared) = ', to_numpy(K[0, 0, 0]).round(2))
457
+ else:
458
+ print('Final focals =', to_numpy(K[:, 0, 0]))
459
+
460
+ return imgs, res_coarse, res_fine
461
+
462
+
463
+ @lru_cache
464
+ def mask110(device, dtype):
465
+ return torch.tensor((1, 1, 0), device=device, dtype=dtype)
466
+
467
+
468
+ def proj3d(inv_K, pixels, z):
469
+ if pixels.shape[-1] == 2:
470
+ pixels = torch.cat((pixels, torch.ones_like(pixels[..., :1])), dim=-1)
471
+ return z.unsqueeze(-1) * (pixels * inv_K.diag() + inv_K[:, 2] * mask110(z.device, z.dtype))
472
+
473
+
474
+ def make_pts3d(anchors, K, cam2w, depthmaps, base_focals=None, ret_depth=False):
475
+ focals = K[:, 0, 0]
476
+ invK = inv(K)
477
+ all_pts3d = []
478
+ depth_out = []
479
+
480
+ for img, (pixels, idxs, offsets) in anchors.items():
481
+ # from depthmaps to 3d points
482
+ if base_focals is None:
483
+ pass
484
+ else:
485
+ # compensate for focal
486
+ # depth + depth * (offset - 1) * base_focal / focal
487
+ # = depth * (1 + (offset - 1) * (base_focal / focal))
488
+ offsets = 1 + (offsets - 1) * (base_focals[img] / focals[img])
489
+
490
+ pts3d = proj3d(invK[img], pixels, depthmaps[img][idxs] * offsets)
491
+ if ret_depth:
492
+ depth_out.append(pts3d[..., 2]) # before camera rotation
493
+
494
+ # rotate to world coordinate
495
+ pts3d = geotrf(cam2w[img], pts3d)
496
+ all_pts3d.append(pts3d)
497
+
498
+ if ret_depth:
499
+ return all_pts3d, depth_out
500
+ return all_pts3d
501
+
502
+
503
+ def make_dense_pts3d(intrinsics, cam2w, depthmaps, canonical_paths, subsample, device='cuda'):
504
+ base_focals = []
505
+ anchors = {}
506
+ confs = []
507
+ for i, canon_path in enumerate(canonical_paths):
508
+ (canon, canon2, conf), focal = torch.load(canon_path, map_location=device)
509
+ confs.append(conf)
510
+ base_focals.append(focal)
511
+ H, W = conf.shape
512
+ pixels = torch.from_numpy(np.mgrid[:W, :H].T.reshape(-1, 2)).float().to(device)
513
+ idxs, offsets = anchor_depth_offsets(canon2, {i: (pixels, None)}, subsample=subsample)
514
+ anchors[i] = (pixels, idxs[i], offsets[i])
515
+
516
+ # densify sparse depthmaps
517
+ pts3d, depthmaps_out = make_pts3d(anchors, intrinsics, cam2w, [
518
+ d.ravel() for d in depthmaps], base_focals=base_focals, ret_depth=True)
519
+
520
+ return pts3d, depthmaps_out, confs
521
+
522
+
523
+ @torch.no_grad()
524
+ def forward_mast3r(pairs, model, cache_path, desc_conf='desc_conf',
525
+ device='cuda', subsample=8, **matching_kw):
526
+ res_paths = {}
527
+
528
+ for img1, img2 in tqdm(pairs):
529
+ idx1 = hash_md5(img1['instance'])
530
+ idx2 = hash_md5(img2['instance'])
531
+
532
+ path1 = cache_path + f'/forward/{idx1}/{idx2}.pth'
533
+ path2 = cache_path + f'/forward/{idx2}/{idx1}.pth'
534
+ path_corres = cache_path + f'/corres_conf={desc_conf}_{subsample=}/{idx1}-{idx2}.pth'
535
+ path_corres2 = cache_path + f'/corres_conf={desc_conf}_{subsample=}/{idx2}-{idx1}.pth'
536
+
537
+ if os.path.isfile(path_corres2) and not os.path.isfile(path_corres):
538
+ score, (xy1, xy2, confs) = torch.load(path_corres2)
539
+ torch.save((score, (xy2, xy1, confs)), path_corres)
540
+
541
+ if not all(os.path.isfile(p) for p in (path1, path2, path_corres)):
542
+ if model is None:
543
+ continue
544
+ res = symmetric_inference(model, img1, img2, device=device)
545
+ X11, X21, X22, X12 = [r['pts3d'][0] for r in res]
546
+ C11, C21, C22, C12 = [r['conf'][0] for r in res]
547
+ descs = [r['desc'][0] for r in res]
548
+ qonfs = [r[desc_conf][0] for r in res]
549
+
550
+ # save
551
+ torch.save(to_cpu((X11, C11, X21, C21)), mkdir_for(path1))
552
+ torch.save(to_cpu((X22, C22, X12, C12)), mkdir_for(path2))
553
+
554
+ # perform reciprocal matching
555
+ corres = extract_correspondences(descs, qonfs, device=device, subsample=subsample)
556
+
557
+ conf_score = (C11.mean() * C12.mean() * C21.mean() * C22.mean()).sqrt().sqrt()
558
+ matching_score = (float(conf_score), float(corres[2].sum()), len(corres[2]))
559
+ if cache_path is not None:
560
+ torch.save((matching_score, corres), mkdir_for(path_corres))
561
+
562
+ res_paths[img1['instance'], img2['instance']] = (path1, path2), path_corres
563
+
564
+ del model
565
+ torch.cuda.empty_cache()
566
+
567
+ return res_paths, cache_path
568
+
569
+
570
+ def symmetric_inference(model, img1, img2, device):
571
+ shape1 = torch.from_numpy(img1['true_shape']).to(device, non_blocking=True)
572
+ shape2 = torch.from_numpy(img2['true_shape']).to(device, non_blocking=True)
573
+ img1 = img1['img'].to(device, non_blocking=True)
574
+ img2 = img2['img'].to(device, non_blocking=True)
575
+
576
+ # compute encoder only once
577
+ feat1, feat2, pos1, pos2 = model._encode_image_pairs(img1, img2, shape1, shape2)
578
+
579
+ def decoder(feat1, feat2, pos1, pos2, shape1, shape2):
580
+ dec1, dec2 = model._decoder(feat1, pos1, feat2, pos2)
581
+ with torch.cuda.amp.autocast(enabled=False):
582
+ res1 = model._downstream_head(1, [tok.float() for tok in dec1], shape1)
583
+ res2 = model._downstream_head(2, [tok.float() for tok in dec2], shape2)
584
+ return res1, res2
585
+
586
+ # decoder 1-2
587
+ res11, res21 = decoder(feat1, feat2, pos1, pos2, shape1, shape2)
588
+ # decoder 2-1
589
+ res22, res12 = decoder(feat2, feat1, pos2, pos1, shape2, shape1)
590
+
591
+ return (res11, res21, res22, res12)
592
+
593
+
594
+ def extract_correspondences(feats, qonfs, subsample=8, device=None, ptmap_key='pred_desc'):
595
+ feat11, feat21, feat22, feat12 = feats
596
+ qonf11, qonf21, qonf22, qonf12 = qonfs
597
+ assert feat11.shape[:2] == feat12.shape[:2] == qonf11.shape == qonf12.shape
598
+ assert feat21.shape[:2] == feat22.shape[:2] == qonf21.shape == qonf22.shape
599
+
600
+ if '3d' in ptmap_key:
601
+ opt = dict(device='cpu', workers=32)
602
+ else:
603
+ opt = dict(device=device, dist='dot', block_size=2**13)
604
+
605
+ # matching the two pairs
606
+ idx1 = []
607
+ idx2 = []
608
+ qonf1 = []
609
+ qonf2 = []
610
+ # TODO add non symmetric / pixel_tol options
611
+ for A, B, QA, QB in [(feat11, feat21, qonf11.cpu(), qonf21.cpu()),
612
+ (feat12, feat22, qonf12.cpu(), qonf22.cpu())]:
613
+ nn1to2 = fast_reciprocal_NNs(A, B, subsample_or_initxy1=subsample, ret_xy=False, **opt)
614
+ nn2to1 = fast_reciprocal_NNs(B, A, subsample_or_initxy1=subsample, ret_xy=False, **opt)
615
+
616
+ idx1.append(np.r_[nn1to2[0], nn2to1[1]])
617
+ idx2.append(np.r_[nn1to2[1], nn2to1[0]])
618
+ qonf1.append(QA.ravel()[idx1[-1]])
619
+ qonf2.append(QB.ravel()[idx2[-1]])
620
+
621
+ # merge corres from opposite pairs
622
+ H1, W1 = feat11.shape[:2]
623
+ H2, W2 = feat22.shape[:2]
624
+ cat = np.concatenate
625
+
626
+ xy1, xy2, idx = merge_corres(cat(idx1), cat(idx2), (H1, W1), (H2, W2), ret_xy=True, ret_index=True)
627
+ corres = (xy1.copy(), xy2.copy(), np.sqrt(cat(qonf1)[idx] * cat(qonf2)[idx]))
628
+
629
+ return todevice(corres, device)
630
+
631
+
632
+ @torch.no_grad()
633
+ def prepare_canonical_data(imgs, tmp_pairs, subsample, order_imgs=False, min_conf_thr=0,
634
+ cache_path=None, device='cuda', **kw):
635
+ canonical_views = {}
636
+ pairwise_scores = torch.zeros((len(imgs), len(imgs)), device=device)
637
+ canonical_paths = []
638
+ preds_21 = {}
639
+
640
+ for img in tqdm(imgs):
641
+ if cache_path:
642
+ cache = os.path.join(cache_path, 'canon_views', hash_md5(img) + f'_{subsample=}_{kw=}.pth')
643
+ canonical_paths.append(cache)
644
+ try:
645
+ (canon, canon2, cconf), focal = torch.load(cache, map_location=device)
646
+ except IOError:
647
+ # cache does not exist yet, we create it!
648
+ canon = focal = None
649
+
650
+ # collect all pred1
651
+ n_pairs = sum((img in pair) for pair in tmp_pairs)
652
+
653
+ ptmaps11 = None
654
+ pixels = {}
655
+ n = 0
656
+ for (img1, img2), ((path1, path2), path_corres) in tmp_pairs.items():
657
+ score = None
658
+ if img == img1:
659
+ X, C, X2, C2 = torch.load(path1, map_location=device)
660
+ score, (xy1, xy2, confs) = load_corres(path_corres, device, min_conf_thr)
661
+ pixels[img2] = xy1, confs
662
+ if img not in preds_21:
663
+ preds_21[img] = {}
664
+ # Subsample preds_21
665
+ preds_21[img][img2] = X2[::subsample, ::subsample].reshape(-1, 3), C2[::subsample, ::subsample].ravel()
666
+
667
+ if img == img2:
668
+ X, C, X2, C2 = torch.load(path2, map_location=device)
669
+ score, (xy1, xy2, confs) = load_corres(path_corres, device, min_conf_thr)
670
+ pixels[img1] = xy2, confs
671
+ if img not in preds_21:
672
+ preds_21[img] = {}
673
+ preds_21[img][img1] = X2[::subsample, ::subsample].reshape(-1, 3), C2[::subsample, ::subsample].ravel()
674
+
675
+ if score is not None:
676
+ i, j = imgs.index(img1), imgs.index(img2)
677
+ # score = score[0]
678
+ # score = np.log1p(score[2])
679
+ score = score[2]
680
+ pairwise_scores[i, j] = score
681
+ pairwise_scores[j, i] = score
682
+
683
+ if canon is not None:
684
+ continue
685
+ if ptmaps11 is None:
686
+ H, W = C.shape
687
+ ptmaps11 = torch.empty((n_pairs, H, W, 3), device=device)
688
+ confs11 = torch.empty((n_pairs, H, W), device=device)
689
+
690
+ ptmaps11[n] = X
691
+ confs11[n] = C
692
+ n += 1
693
+
694
+ if canon is None:
695
+ canon, canon2, cconf = canonical_view(ptmaps11, confs11, subsample, **kw)
696
+ del ptmaps11
697
+ del confs11
698
+
699
+ # compute focals
700
+ H, W = canon.shape[:2]
701
+ pp = torch.tensor([W / 2, H / 2], device=device)
702
+ if focal is None:
703
+ focal = estimate_focal_knowing_depth(canon[None], pp, focal_mode='weiszfeld', min_focal=0.5, max_focal=3.5)
704
+ if cache:
705
+ torch.save(to_cpu(((canon, canon2, cconf), focal)), mkdir_for(cache))
706
+
707
+ # extract depth offsets with correspondences
708
+ core_depth = canon[subsample // 2::subsample, subsample // 2::subsample, 2]
709
+ idxs, offsets = anchor_depth_offsets(canon2, pixels, subsample=subsample)
710
+
711
+ canonical_views[img] = (pp, (H, W), focal.view(1), core_depth, pixels, idxs, offsets)
712
+
713
+ return tmp_pairs, pairwise_scores, canonical_views, canonical_paths, preds_21
714
+
715
+
716
+ def load_corres(path_corres, device, min_conf_thr):
717
+ score, (xy1, xy2, confs) = torch.load(path_corres, map_location=device)
718
+ valid = confs > min_conf_thr if min_conf_thr else slice(None)
719
+ # valid = (xy1 > 0).all(dim=1) & (xy2 > 0).all(dim=1) & (xy1 < 512).all(dim=1) & (xy2 < 512).all(dim=1)
720
+ # print(f'keeping {valid.sum()} / {len(valid)} correspondences')
721
+ return score, (xy1[valid], xy2[valid], confs[valid])
722
+
723
+
724
+ PairOfSlices = namedtuple(
725
+ 'ImgPair', 'img1, slice1, pix1, anchor_idxs1, img2, slice2, pix2, anchor_idxs2, confs, confs_sum')
726
+
727
+
728
+ def condense_data(imgs, tmp_paths, canonical_views, preds_21, dtype=torch.float32):
729
+ # aggregate all data properly
730
+ set_imgs = set(imgs)
731
+
732
+ principal_points = []
733
+ shapes = []
734
+ focals = []
735
+ core_depth = []
736
+ img_anchors = {}
737
+ tmp_pixels = {}
738
+
739
+ for idx1, img1 in enumerate(imgs):
740
+ # load stuff
741
+ pp, shape, focal, anchors, pixels_confs, idxs, offsets = canonical_views[img1]
742
+
743
+ principal_points.append(pp)
744
+ shapes.append(shape)
745
+ focals.append(focal)
746
+ core_depth.append(anchors)
747
+
748
+ img_uv1 = []
749
+ img_idxs = []
750
+ img_offs = []
751
+ cur_n = [0]
752
+
753
+ for img2, (pixels, match_confs) in pixels_confs.items():
754
+ if img2 not in set_imgs:
755
+ continue
756
+ assert len(pixels) == len(idxs[img2]) == len(offsets[img2])
757
+ img_uv1.append(torch.cat((pixels, torch.ones_like(pixels[:, :1])), dim=-1))
758
+ img_idxs.append(idxs[img2])
759
+ img_offs.append(offsets[img2])
760
+ cur_n.append(cur_n[-1] + len(pixels))
761
+ # store the position of 3d points
762
+ tmp_pixels[img1, img2] = pixels.to(dtype), match_confs.to(dtype), slice(*cur_n[-2:])
763
+ img_anchors[idx1] = (torch.cat(img_uv1), torch.cat(img_idxs), torch.cat(img_offs))
764
+
765
+ all_confs = []
766
+ imgs_slices = []
767
+ corres2d = {img: [] for img in range(len(imgs))}
768
+
769
+ for img1, img2 in tmp_paths:
770
+ try:
771
+ pix1, confs1, slice1 = tmp_pixels[img1, img2]
772
+ pix2, confs2, slice2 = tmp_pixels[img2, img1]
773
+ except KeyError:
774
+ continue
775
+ img1 = imgs.index(img1)
776
+ img2 = imgs.index(img2)
777
+ confs = (confs1 * confs2).sqrt()
778
+
779
+ # prepare for loss_3d
780
+ all_confs.append(confs)
781
+ anchor_idxs1 = canonical_views[imgs[img1]][5][imgs[img2]]
782
+ anchor_idxs2 = canonical_views[imgs[img2]][5][imgs[img1]]
783
+ imgs_slices.append(PairOfSlices(img1, slice1, pix1, anchor_idxs1,
784
+ img2, slice2, pix2, anchor_idxs2,
785
+ confs, float(confs.sum())))
786
+
787
+ # prepare for loss_2d
788
+ corres2d[img1].append((pix1, confs, img2, slice2))
789
+ corres2d[img2].append((pix2, confs, img1, slice1))
790
+
791
+ all_confs = torch.cat(all_confs)
792
+ corres = (all_confs, float(all_confs.sum()), imgs_slices)
793
+
794
+ def aggreg_matches(img1, list_matches):
795
+ pix1, confs, img2, slice2 = zip(*list_matches)
796
+ all_pix1 = torch.cat(pix1).to(dtype)
797
+ all_confs = torch.cat(confs).to(dtype)
798
+ return img1, all_pix1, all_confs, float(all_confs.sum()), [(j, sl2) for j, sl2 in zip(img2, slice2)]
799
+ corres2d = [aggreg_matches(img, m) for img, m in corres2d.items()]
800
+
801
+ imsizes = torch.tensor([(W, H) for H, W in shapes], device=pp.device) # (W,H)
802
+ principal_points = torch.stack(principal_points)
803
+ focals = torch.cat(focals)
804
+
805
+ # Subsample preds_21
806
+ subsamp_preds_21 = {}
807
+ for imk, imv in preds_21.items():
808
+ subsamp_preds_21[imk] = {}
809
+ for im2k, (pred, conf) in preds_21[imk].items():
810
+ idxs = img_anchors[imgs.index(im2k)][1]
811
+ subsamp_preds_21[imk][im2k] = (pred[idxs], conf[idxs]) # anchors subsample
812
+
813
+ return imsizes, principal_points, focals, core_depth, img_anchors, corres, corres2d, subsamp_preds_21
814
+
815
+
816
+ def canonical_view(ptmaps11, confs11, subsample, mode='avg-angle'):
817
+ assert len(ptmaps11) == len(confs11) > 0, 'not a single view1 for img={i}'
818
+
819
+ # canonical pointmap is just a weighted average
820
+ confs11 = confs11.unsqueeze(-1) - 0.999
821
+ canon = (confs11 * ptmaps11).sum(0) / confs11.sum(0)
822
+
823
+ canon_depth = ptmaps11[..., 2].unsqueeze(1)
824
+ S = slice(subsample // 2, None, subsample)
825
+ center_depth = canon_depth[:, :, S, S]
826
+ center_depth = torch.clip(center_depth, min=torch.finfo(center_depth.dtype).eps)
827
+
828
+ stacked_depth = F.pixel_unshuffle(canon_depth, subsample)
829
+ stacked_confs = F.pixel_unshuffle(confs11[:, None, :, :, 0], subsample)
830
+
831
+ if mode == 'avg-reldepth':
832
+ rel_depth = stacked_depth / center_depth
833
+ stacked_canon = (stacked_confs * rel_depth).sum(dim=0) / stacked_confs.sum(dim=0)
834
+ canon2 = F.pixel_shuffle(stacked_canon.unsqueeze(0), subsample).squeeze()
835
+
836
+ elif mode == 'avg-angle':
837
+ xy = ptmaps11[..., 0:2].permute(0, 3, 1, 2)
838
+ stacked_xy = F.pixel_unshuffle(xy, subsample)
839
+ B, _, H, W = stacked_xy.shape
840
+ stacked_radius = (stacked_xy.view(B, 2, -1, H, W) - xy[:, :, None, S, S]).norm(dim=1)
841
+ stacked_radius.clip_(min=1e-8)
842
+
843
+ stacked_angle = torch.arctan((stacked_depth - center_depth) / stacked_radius)
844
+ avg_angle = (stacked_confs * stacked_angle).sum(dim=0) / stacked_confs.sum(dim=0)
845
+
846
+ # back to depth
847
+ stacked_depth = stacked_radius.mean(dim=0) * torch.tan(avg_angle)
848
+
849
+ canon2 = F.pixel_shuffle((1 + stacked_depth / canon[S, S, 2]).unsqueeze(0), subsample).squeeze()
850
+ else:
851
+ raise ValueError(f'bad {mode=}')
852
+
853
+ confs = (confs11.square().sum(dim=0) / confs11.sum(dim=0)).squeeze()
854
+ return canon, canon2, confs
855
+
856
+
857
+ def anchor_depth_offsets(canon_depth, pixels, subsample=8):
858
+ device = canon_depth.device
859
+
860
+ # create a 2D grid of anchor 3D points
861
+ H1, W1 = canon_depth.shape
862
+ yx = np.mgrid[subsample // 2:H1:subsample, subsample // 2:W1:subsample]
863
+ H2, W2 = yx.shape[1:]
864
+ cy, cx = yx.reshape(2, -1)
865
+ core_depth = canon_depth[cy, cx]
866
+ assert (core_depth > 0).all()
867
+
868
+ # slave 3d points (attached to core 3d points)
869
+ core_idxs = {} # core_idxs[img2] = {corr_idx:core_idx}
870
+ core_offs = {} # core_offs[img2] = {corr_idx:3d_offset}
871
+
872
+ for img2, (xy1, _confs) in pixels.items():
873
+ px, py = xy1.long().T
874
+
875
+ # find nearest anchor == block quantization
876
+ core_idx = (py // subsample) * W2 + (px // subsample)
877
+ core_idxs[img2] = core_idx.to(device)
878
+
879
+ # compute relative depth offsets w.r.t. anchors
880
+ ref_z = core_depth[core_idx]
881
+ pts_z = canon_depth[py, px]
882
+ offset = pts_z / ref_z
883
+ core_offs[img2] = offset.detach().to(device)
884
+
885
+ return core_idxs, core_offs
886
+
887
+
888
+ def spectral_clustering(graph, k=None, normalized_cuts=False):
889
+ graph.fill_diagonal_(0)
890
+
891
+ # graph laplacian
892
+ degrees = graph.sum(dim=-1)
893
+ laplacian = torch.diag(degrees) - graph
894
+ if normalized_cuts:
895
+ i_inv = torch.diag(degrees.sqrt().reciprocal())
896
+ laplacian = i_inv @ laplacian @ i_inv
897
+
898
+ # compute eigenvectors!
899
+ eigval, eigvec = torch.linalg.eigh(laplacian)
900
+ return eigval[:k], eigvec[:, :k]
901
+
902
+
903
+ def sim_func(p1, p2, gamma):
904
+ diff = (p1 - p2).norm(dim=-1)
905
+ avg_depth = (p1[:, :, 2] + p2[:, :, 2])
906
+ rel_distance = diff / avg_depth
907
+ sim = torch.exp(-gamma * rel_distance.square())
908
+ return sim
909
+
910
+
911
+ def backproj(K, depthmap, subsample):
912
+ H, W = depthmap.shape
913
+ uv = np.mgrid[subsample // 2:subsample * W:subsample, subsample // 2:subsample * H:subsample].T.reshape(H, W, 2)
914
+ xyz = depthmap.unsqueeze(-1) * geotrf(inv(K), todevice(uv, K.device), ncol=3)
915
+ return xyz
916
+
917
+
918
+ def spectral_projection_depth(K, depthmap, subsample, k=64, cache_path='',
919
+ normalized_cuts=True, gamma=7, min_norm=5):
920
+ try:
921
+ if cache_path:
922
+ cache_path = cache_path + f'_{k=}_norm={normalized_cuts}_{gamma=}.pth'
923
+ lora_proj = torch.load(cache_path, map_location=K.device)
924
+
925
+ except IOError:
926
+ # reconstruct 3d points in camera coordinates
927
+ xyz = backproj(K, depthmap, subsample)
928
+
929
+ # compute all distances
930
+ xyz = xyz.reshape(-1, 3)
931
+ graph = sim_func(xyz[:, None], xyz[None, :], gamma=gamma)
932
+ _, lora_proj = spectral_clustering(graph, k, normalized_cuts=normalized_cuts)
933
+
934
+ if cache_path:
935
+ torch.save(lora_proj.cpu(), mkdir_for(cache_path))
936
+
937
+ lora_proj, coeffs = lora_encode_normed(lora_proj, depthmap.ravel(), min_norm=min_norm)
938
+
939
+ # depthmap ~= lora_proj @ coeffs
940
+ return coeffs, lora_proj
941
+
942
+
943
+ def lora_encode_normed(lora_proj, x, min_norm, global_norm=False):
944
+ # encode the pointmap
945
+ coeffs = torch.linalg.pinv(lora_proj) @ x
946
+
947
+ # rectify the norm of basis vector to be ~ equal
948
+ if coeffs.ndim == 1:
949
+ coeffs = coeffs[:, None]
950
+ if global_norm:
951
+ lora_proj *= coeffs[1:].norm() * min_norm / coeffs.shape[1]
952
+ elif min_norm:
953
+ lora_proj *= coeffs.norm(dim=1).clip(min=min_norm)
954
+ # can have rounding errors here!
955
+ coeffs = (torch.linalg.pinv(lora_proj.double()) @ x.double()).float()
956
+
957
+ return lora_proj.detach(), coeffs.detach()
958
+
959
+
960
+ @torch.no_grad()
961
+ def spectral_projection_of_depthmaps(imgs, intrinsics, depthmaps, subsample, cache_path=None, **kw):
962
+ # recover 3d points
963
+ core_depth = []
964
+ lora_proj = []
965
+
966
+ for i, img in enumerate(tqdm(imgs)):
967
+ cache = os.path.join(cache_path, 'lora_depth', hash_md5(img)) if cache_path else None
968
+ depth, proj = spectral_projection_depth(intrinsics[i], depthmaps[i], subsample,
969
+ cache_path=cache, **kw)
970
+ core_depth.append(depth)
971
+ lora_proj.append(proj)
972
+
973
+ return core_depth, lora_proj
974
+
975
+
976
+ def reproj2d(Trf, pts3d):
977
+ res = (pts3d @ Trf[:3, :3].transpose(-1, -2)) + Trf[:3, 3]
978
+ clipped_z = res[:, 2:3].clip(min=1e-3) # make sure we don't have nans!
979
+ uv = res[:, 0:2] / clipped_z
980
+ return uv.clip(min=-1000, max=2000)
981
+
982
+
983
+ def bfs(tree, start_node):
984
+ order, predecessors = sp.csgraph.breadth_first_order(tree, start_node, directed=False)
985
+ ranks = np.arange(len(order))
986
+ ranks[order] = ranks.copy()
987
+ return ranks, predecessors
988
+
989
+
990
+ def compute_min_spanning_tree(pws):
991
+ sparse_graph = sp.dok_array(pws.shape)
992
+ for i, j in pws.nonzero().cpu().tolist():
993
+ sparse_graph[i, j] = -float(pws[i, j])
994
+ msp = sp.csgraph.minimum_spanning_tree(sparse_graph)
995
+
996
+ # now reorder the oriented edges, starting from the central point
997
+ ranks1, _ = bfs(msp, 0)
998
+ ranks2, _ = bfs(msp, ranks1.argmax())
999
+ ranks1, _ = bfs(msp, ranks2.argmax())
1000
+ # this is the point farther from any leaf
1001
+ root = np.minimum(ranks1, ranks2).argmax()
1002
+
1003
+ # find the ordered list of edges that describe the tree
1004
+ order, predecessors = sp.csgraph.breadth_first_order(msp, root, directed=False)
1005
+ order = order[1:] # root not do not have a predecessor
1006
+ edges = [(predecessors[i], i) for i in order]
1007
+
1008
+ return root, edges
1009
+
1010
+
1011
+ def show_reconstruction(shapes_or_imgs, K, cam2w, pts3d, gt_cam2w=None, gt_K=None, cam_size=None, masks=None, **kw):
1012
+ viz = SceneViz()
1013
+
1014
+ cc = cam2w[:, :3, 3]
1015
+ cs = cam_size or float(torch.cdist(cc, cc).fill_diagonal_(np.inf).min(dim=0).values.median())
1016
+ colors = 64 + np.random.randint(255 - 64, size=(len(cam2w), 3))
1017
+
1018
+ if isinstance(shapes_or_imgs, np.ndarray) and shapes_or_imgs.ndim == 2:
1019
+ cam_kws = dict(imsizes=shapes_or_imgs[:, ::-1], cam_size=cs)
1020
+ else:
1021
+ imgs = shapes_or_imgs
1022
+ cam_kws = dict(images=imgs, cam_size=cs)
1023
+ if K is not None:
1024
+ viz.add_cameras(to_numpy(cam2w), to_numpy(K), colors=colors, **cam_kws)
1025
+
1026
+ if gt_cam2w is not None:
1027
+ if gt_K is None:
1028
+ gt_K = K
1029
+ viz.add_cameras(to_numpy(gt_cam2w), to_numpy(gt_K), colors=colors, marker='o', **cam_kws)
1030
+
1031
+ if pts3d is not None:
1032
+ for i, p in enumerate(pts3d):
1033
+ if not len(p):
1034
+ continue
1035
+ if masks is None:
1036
+ viz.add_pointcloud(to_numpy(p), color=tuple(colors[i].tolist()))
1037
+ else:
1038
+ viz.add_pointcloud(to_numpy(p), mask=masks[i], color=imgs[i])
1039
+ viz.show(**kw)