zino36 commited on
Commit
20a4a01
1 Parent(s): fc5081c

Upload 7 files

Browse files
Files changed (7) hide show
  1. demo-2.py +283 -0
  2. demo.py +29 -312
  3. device.py +76 -0
  4. image.py +126 -0
  5. image_pairs.py +104 -0
  6. path_to_dust3r.py +19 -0
  7. viz.py +381 -0
demo-2.py ADDED
@@ -0,0 +1,283 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+ #
4
+ # --------------------------------------------------------
5
+ # gradio demo
6
+ # --------------------------------------------------------
7
+ import argparse
8
+ import math
9
+ import builtins
10
+ import datetime
11
+ import gradio
12
+ import os
13
+ import torch
14
+ import numpy as np
15
+ import functools
16
+ import trimesh
17
+ import copy
18
+ from scipy.spatial.transform import Rotation
19
+
20
+ from dust3r.inference import inference
21
+ from dust3r.image_pairs import make_pairs
22
+ from dust3r.utils.image import load_images, rgb
23
+ from dust3r.utils.device import to_numpy
24
+ from dust3r.viz import add_scene_cam, CAM_COLORS, OPENGL, pts3d_to_trimesh, cat_meshes
25
+ from dust3r.cloud_opt import global_aligner, GlobalAlignerMode
26
+
27
+ import matplotlib.pyplot as pl
28
+
29
+
30
+ def get_args_parser():
31
+ parser = argparse.ArgumentParser()
32
+ parser_url = parser.add_mutually_exclusive_group()
33
+ parser_url.add_argument("--local_network", action='store_true', default=False,
34
+ help="make app accessible on local network: address will be set to 0.0.0.0")
35
+ parser_url.add_argument("--server_name", type=str, default=None, help="server url, default is 127.0.0.1")
36
+ parser.add_argument("--image_size", type=int, default=512, choices=[512, 224], help="image size")
37
+ parser.add_argument("--server_port", type=int, help=("will start gradio app on this port (if available). "
38
+ "If None, will search for an available port starting at 7860."),
39
+ default=None)
40
+ parser_weights = parser.add_mutually_exclusive_group(required=True)
41
+ parser_weights.add_argument("--weights", type=str, help="path to the model weights", default=None)
42
+ parser_weights.add_argument("--model_name", type=str, help="name of the model weights",
43
+ choices=["DUSt3R_ViTLarge_BaseDecoder_512_dpt",
44
+ "DUSt3R_ViTLarge_BaseDecoder_512_linear",
45
+ "DUSt3R_ViTLarge_BaseDecoder_224_linear"])
46
+ parser.add_argument("--device", type=str, default='cuda', help="pytorch device")
47
+ parser.add_argument("--tmp_dir", type=str, default=None, help="value for tempfile.tempdir")
48
+ parser.add_argument("--silent", action='store_true', default=False,
49
+ help="silence logs")
50
+ return parser
51
+
52
+
53
+ def set_print_with_timestamp(time_format="%Y-%m-%d %H:%M:%S"):
54
+ builtin_print = builtins.print
55
+
56
+ def print_with_timestamp(*args, **kwargs):
57
+ now = datetime.datetime.now()
58
+ formatted_date_time = now.strftime(time_format)
59
+
60
+ builtin_print(f'[{formatted_date_time}] ', end='') # print with time stamp
61
+ builtin_print(*args, **kwargs)
62
+
63
+ builtins.print = print_with_timestamp
64
+
65
+
66
+ def _convert_scene_output_to_glb(outdir, imgs, pts3d, mask, focals, cams2world, cam_size=0.05,
67
+ cam_color=None, as_pointcloud=False,
68
+ transparent_cams=False, silent=False):
69
+ assert len(pts3d) == len(mask) <= len(imgs) <= len(cams2world) == len(focals)
70
+ pts3d = to_numpy(pts3d)
71
+ imgs = to_numpy(imgs)
72
+ focals = to_numpy(focals)
73
+ cams2world = to_numpy(cams2world)
74
+
75
+ scene = trimesh.Scene()
76
+
77
+ # full pointcloud
78
+ if as_pointcloud:
79
+ pts = np.concatenate([p[m] for p, m in zip(pts3d, mask)])
80
+ col = np.concatenate([p[m] for p, m in zip(imgs, mask)])
81
+ pct = trimesh.PointCloud(pts.reshape(-1, 3), colors=col.reshape(-1, 3))
82
+ scene.add_geometry(pct)
83
+ else:
84
+ meshes = []
85
+ for i in range(len(imgs)):
86
+ meshes.append(pts3d_to_trimesh(imgs[i], pts3d[i], mask[i]))
87
+ mesh = trimesh.Trimesh(**cat_meshes(meshes))
88
+ scene.add_geometry(mesh)
89
+
90
+ # add each camera
91
+ for i, pose_c2w in enumerate(cams2world):
92
+ if isinstance(cam_color, list):
93
+ camera_edge_color = cam_color[i]
94
+ else:
95
+ camera_edge_color = cam_color or CAM_COLORS[i % len(CAM_COLORS)]
96
+ add_scene_cam(scene, pose_c2w, camera_edge_color,
97
+ None if transparent_cams else imgs[i], focals[i],
98
+ imsize=imgs[i].shape[1::-1], screen_width=cam_size)
99
+
100
+ rot = np.eye(4)
101
+ rot[:3, :3] = Rotation.from_euler('y', np.deg2rad(180)).as_matrix()
102
+ scene.apply_transform(np.linalg.inv(cams2world[0] @ OPENGL @ rot))
103
+ outfile = os.path.join(outdir, 'scene.glb')
104
+ if not silent:
105
+ print('(exporting 3D scene to', outfile, ')')
106
+ scene.export(file_obj=outfile)
107
+ return outfile
108
+
109
+
110
+ def get_3D_model_from_scene(outdir, silent, scene, min_conf_thr=3, as_pointcloud=False, mask_sky=False,
111
+ clean_depth=False, transparent_cams=False, cam_size=0.05):
112
+ """
113
+ extract 3D_model (glb file) from a reconstructed scene
114
+ """
115
+ if scene is None:
116
+ return None
117
+ # post processes
118
+ if clean_depth:
119
+ scene = scene.clean_pointcloud()
120
+ if mask_sky:
121
+ scene = scene.mask_sky()
122
+
123
+ # get optimized values from scene
124
+ rgbimg = scene.imgs
125
+ focals = scene.get_focals().cpu()
126
+ cams2world = scene.get_im_poses().cpu()
127
+ # 3D pointcloud from depthmap, poses and intrinsics
128
+ pts3d = to_numpy(scene.get_pts3d())
129
+ scene.min_conf_thr = float(scene.conf_trf(torch.tensor(min_conf_thr)))
130
+ msk = to_numpy(scene.get_masks())
131
+ return _convert_scene_output_to_glb(outdir, rgbimg, pts3d, msk, focals, cams2world, as_pointcloud=as_pointcloud,
132
+ transparent_cams=transparent_cams, cam_size=cam_size, silent=silent)
133
+
134
+
135
+ def get_reconstructed_scene(outdir, model, device, silent, image_size, filelist, schedule, niter, min_conf_thr,
136
+ as_pointcloud, mask_sky, clean_depth, transparent_cams, cam_size,
137
+ scenegraph_type, winsize, refid):
138
+ """
139
+ from a list of images, run dust3r inference, global aligner.
140
+ then run get_3D_model_from_scene
141
+ """
142
+ imgs = load_images(filelist, size=image_size, verbose=not silent)
143
+ if len(imgs) == 1:
144
+ imgs = [imgs[0], copy.deepcopy(imgs[0])]
145
+ imgs[1]['idx'] = 1
146
+ if scenegraph_type == "swin":
147
+ scenegraph_type = scenegraph_type + "-" + str(winsize)
148
+ elif scenegraph_type == "oneref":
149
+ scenegraph_type = scenegraph_type + "-" + str(refid)
150
+
151
+ pairs = make_pairs(imgs, scene_graph=scenegraph_type, prefilter=None, symmetrize=True)
152
+ output = inference(pairs, model, device, batch_size=1, verbose=not silent)
153
+
154
+ mode = GlobalAlignerMode.PointCloudOptimizer if len(imgs) > 2 else GlobalAlignerMode.PairViewer
155
+ scene = global_aligner(output, device=device, mode=mode, verbose=not silent)
156
+ lr = 0.01
157
+
158
+ if mode == GlobalAlignerMode.PointCloudOptimizer:
159
+ loss = scene.compute_global_alignment(init='mst', niter=niter, schedule=schedule, lr=lr)
160
+
161
+ outfile = get_3D_model_from_scene(outdir, silent, scene, min_conf_thr, as_pointcloud, mask_sky,
162
+ clean_depth, transparent_cams, cam_size)
163
+
164
+ # also return rgb, depth and confidence imgs
165
+ # depth is normalized with the max value for all images
166
+ # we apply the jet colormap on the confidence maps
167
+ rgbimg = scene.imgs
168
+ depths = to_numpy(scene.get_depthmaps())
169
+ confs = to_numpy([c for c in scene.im_conf])
170
+ cmap = pl.get_cmap('jet')
171
+ depths_max = max([d.max() for d in depths])
172
+ depths = [d / depths_max for d in depths]
173
+ confs_max = max([d.max() for d in confs])
174
+ confs = [cmap(d / confs_max) for d in confs]
175
+
176
+ imgs = []
177
+ for i in range(len(rgbimg)):
178
+ imgs.append(rgbimg[i])
179
+ imgs.append(rgb(depths[i]))
180
+ imgs.append(rgb(confs[i]))
181
+
182
+ return scene, outfile, imgs
183
+
184
+
185
+ def set_scenegraph_options(inputfiles, winsize, refid, scenegraph_type):
186
+ num_files = len(inputfiles) if inputfiles is not None else 1
187
+ max_winsize = max(1, math.ceil((num_files - 1) / 2))
188
+ if scenegraph_type == "swin":
189
+ winsize = gradio.Slider(label="Scene Graph: Window Size", value=max_winsize,
190
+ minimum=1, maximum=max_winsize, step=1, visible=True)
191
+ refid = gradio.Slider(label="Scene Graph: Id", value=0, minimum=0,
192
+ maximum=num_files - 1, step=1, visible=False)
193
+ elif scenegraph_type == "oneref":
194
+ winsize = gradio.Slider(label="Scene Graph: Window Size", value=max_winsize,
195
+ minimum=1, maximum=max_winsize, step=1, visible=False)
196
+ refid = gradio.Slider(label="Scene Graph: Id", value=0, minimum=0,
197
+ maximum=num_files - 1, step=1, visible=True)
198
+ else:
199
+ winsize = gradio.Slider(label="Scene Graph: Window Size", value=max_winsize,
200
+ minimum=1, maximum=max_winsize, step=1, visible=False)
201
+ refid = gradio.Slider(label="Scene Graph: Id", value=0, minimum=0,
202
+ maximum=num_files - 1, step=1, visible=False)
203
+ return winsize, refid
204
+
205
+
206
+ def main_demo(tmpdirname, model, device, image_size, server_name, server_port, silent=False):
207
+ recon_fun = functools.partial(get_reconstructed_scene, tmpdirname, model, device, silent, image_size)
208
+ model_from_scene_fun = functools.partial(get_3D_model_from_scene, tmpdirname, silent)
209
+ with gradio.Blocks(css=""".gradio-container {margin: 0 !important; min-width: 100%};""", title="DUSt3R Demo") as demo:
210
+ # scene state is save so that you can change conf_thr, cam_size... without rerunning the inference
211
+ scene = gradio.State(None)
212
+ gradio.HTML('<h2 style="text-align: center;">DUSt3R Demo</h2>')
213
+ with gradio.Column():
214
+ inputfiles = gradio.File(file_count="multiple")
215
+ with gradio.Row():
216
+ schedule = gradio.Dropdown(["linear", "cosine"],
217
+ value='linear', label="schedule", info="For global alignment!")
218
+ niter = gradio.Number(value=300, precision=0, minimum=0, maximum=5000,
219
+ label="num_iterations", info="For global alignment!")
220
+ scenegraph_type = gradio.Dropdown([("complete: all possible image pairs", "complete"),
221
+ ("swin: sliding window", "swin"),
222
+ ("oneref: match one image with all", "oneref")],
223
+ value='complete', label="Scenegraph",
224
+ info="Define how to make pairs",
225
+ interactive=True)
226
+ winsize = gradio.Slider(label="Scene Graph: Window Size", value=1,
227
+ minimum=1, maximum=1, step=1, visible=False)
228
+ refid = gradio.Slider(label="Scene Graph: Id", value=0, minimum=0, maximum=0, step=1, visible=False)
229
+
230
+ run_btn = gradio.Button("Run")
231
+
232
+ with gradio.Row():
233
+ # adjust the confidence threshold
234
+ min_conf_thr = gradio.Slider(label="min_conf_thr", value=3.0, minimum=1.0, maximum=20, step=0.1)
235
+ # adjust the camera size in the output pointcloud
236
+ cam_size = gradio.Slider(label="cam_size", value=0.05, minimum=0.001, maximum=0.1, step=0.001)
237
+ with gradio.Row():
238
+ as_pointcloud = gradio.Checkbox(value=False, label="As pointcloud")
239
+ # two post process implemented
240
+ mask_sky = gradio.Checkbox(value=False, label="Mask sky")
241
+ clean_depth = gradio.Checkbox(value=True, label="Clean-up depthmaps")
242
+ transparent_cams = gradio.Checkbox(value=False, label="Transparent cameras")
243
+
244
+ outmodel = gradio.Model3D()
245
+ outgallery = gradio.Gallery(label='rgb,depth,confidence', columns=3, height="100%")
246
+
247
+ # events
248
+ scenegraph_type.change(set_scenegraph_options,
249
+ inputs=[inputfiles, winsize, refid, scenegraph_type],
250
+ outputs=[winsize, refid])
251
+ inputfiles.change(set_scenegraph_options,
252
+ inputs=[inputfiles, winsize, refid, scenegraph_type],
253
+ outputs=[winsize, refid])
254
+ run_btn.click(fn=recon_fun,
255
+ inputs=[inputfiles, schedule, niter, min_conf_thr, as_pointcloud,
256
+ mask_sky, clean_depth, transparent_cams, cam_size,
257
+ scenegraph_type, winsize, refid],
258
+ outputs=[scene, outmodel, outgallery])
259
+ min_conf_thr.release(fn=model_from_scene_fun,
260
+ inputs=[scene, min_conf_thr, as_pointcloud, mask_sky,
261
+ clean_depth, transparent_cams, cam_size],
262
+ outputs=outmodel)
263
+ cam_size.change(fn=model_from_scene_fun,
264
+ inputs=[scene, min_conf_thr, as_pointcloud, mask_sky,
265
+ clean_depth, transparent_cams, cam_size],
266
+ outputs=outmodel)
267
+ as_pointcloud.change(fn=model_from_scene_fun,
268
+ inputs=[scene, min_conf_thr, as_pointcloud, mask_sky,
269
+ clean_depth, transparent_cams, cam_size],
270
+ outputs=outmodel)
271
+ mask_sky.change(fn=model_from_scene_fun,
272
+ inputs=[scene, min_conf_thr, as_pointcloud, mask_sky,
273
+ clean_depth, transparent_cams, cam_size],
274
+ outputs=outmodel)
275
+ clean_depth.change(fn=model_from_scene_fun,
276
+ inputs=[scene, min_conf_thr, as_pointcloud, mask_sky,
277
+ clean_depth, transparent_cams, cam_size],
278
+ outputs=outmodel)
279
+ transparent_cams.change(model_from_scene_fun,
280
+ inputs=[scene, min_conf_thr, as_pointcloud, mask_sky,
281
+ clean_depth, transparent_cams, cam_size],
282
+ outputs=outmodel)
283
+ demo.launch(share=False, server_name=server_name, server_port=server_port)
demo.py CHANGED
@@ -3,328 +3,45 @@
3
  # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
4
  #
5
  # --------------------------------------------------------
6
- # sparse gradio demo functions
7
  # --------------------------------------------------------
8
- import math
9
- import gradio
10
  import os
11
- import numpy as np
12
- import functools
13
- import trimesh
14
- import copy
15
- from scipy.spatial.transform import Rotation
16
  import tempfile
17
- import shutil
18
 
19
- from sparse_ga import sparse_global_alignment
20
- from tsdf_optimizer import TSDFPostProcess
21
 
22
- import path_to_dust3r # noqa
23
- from image_pairs import make_pairs
24
- from image import load_images
25
- from device import to_numpy
26
- from viz import add_scene_cam, CAM_COLORS, OPENGL, pts3d_to_trimesh, cat_meshes
27
- from demo import get_args_parser as dust3r_get_args_parser
28
 
29
  import matplotlib.pyplot as pl
 
30
 
 
31
 
32
- class SparseGAState():
33
- def __init__(self, sparse_ga, should_delete=False, cache_dir=None, outfile_name=None):
34
- self.sparse_ga = sparse_ga
35
- self.cache_dir = cache_dir
36
- self.outfile_name = outfile_name
37
- self.should_delete = should_delete
38
 
39
- def __del__(self):
40
- if not self.should_delete:
41
- return
42
- if self.cache_dir is not None and os.path.isdir(self.cache_dir):
43
- shutil.rmtree(self.cache_dir)
44
- self.cache_dir = None
45
- if self.outfile_name is not None and os.path.isfile(self.outfile_name):
46
- os.remove(self.outfile_name)
47
- self.outfile_name = None
48
-
49
-
50
- def get_args_parser():
51
- parser = dust3r_get_args_parser()
52
- parser.add_argument('--share', action='store_true')
53
- parser.add_argument('--gradio_delete_cache', default=None, type=int,
54
- help='age/frequency at which gradio removes the file. If >0, matching cache is purged')
55
-
56
- actions = parser._actions
57
- for action in actions:
58
- if action.dest == 'model_name':
59
- action.choices = ["MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric"]
60
- # change defaults
61
- parser.prog = 'mast3r demo'
62
- return parser
63
-
64
-
65
- def _convert_scene_output_to_glb(outfile, imgs, pts3d, mask, focals, cams2world, cam_size=0.05,
66
- cam_color=None, as_pointcloud=False,
67
- transparent_cams=False, silent=False):
68
- assert len(pts3d) == len(mask) <= len(imgs) <= len(cams2world) == len(focals)
69
- pts3d = to_numpy(pts3d)
70
- imgs = to_numpy(imgs)
71
- focals = to_numpy(focals)
72
- cams2world = to_numpy(cams2world)
73
-
74
- scene = trimesh.Scene()
75
-
76
- # full pointcloud
77
- if as_pointcloud:
78
- pts = np.concatenate([p[m.ravel()] for p, m in zip(pts3d, mask)]).reshape(-1, 3)
79
- col = np.concatenate([p[m] for p, m in zip(imgs, mask)]).reshape(-1, 3)
80
- valid_msk = np.isfinite(pts.sum(axis=1))
81
- pct = trimesh.PointCloud(pts[valid_msk], colors=col[valid_msk])
82
- scene.add_geometry(pct)
83
- else:
84
- meshes = []
85
- for i in range(len(imgs)):
86
- pts3d_i = pts3d[i].reshape(imgs[i].shape)
87
- msk_i = mask[i] & np.isfinite(pts3d_i.sum(axis=-1))
88
- meshes.append(pts3d_to_trimesh(imgs[i], pts3d_i, msk_i))
89
- mesh = trimesh.Trimesh(**cat_meshes(meshes))
90
- scene.add_geometry(mesh)
91
-
92
- # add each camera
93
- for i, pose_c2w in enumerate(cams2world):
94
- if isinstance(cam_color, list):
95
- camera_edge_color = cam_color[i]
96
- else:
97
- camera_edge_color = cam_color or CAM_COLORS[i % len(CAM_COLORS)]
98
- add_scene_cam(scene, pose_c2w, camera_edge_color,
99
- None if transparent_cams else imgs[i], focals[i],
100
- imsize=imgs[i].shape[1::-1], screen_width=cam_size)
101
-
102
- rot = np.eye(4)
103
- rot[:3, :3] = Rotation.from_euler('y', np.deg2rad(180)).as_matrix()
104
- scene.apply_transform(np.linalg.inv(cams2world[0] @ OPENGL @ rot))
105
- if not silent:
106
- print('(exporting 3D scene to', outfile, ')')
107
- scene.export(file_obj=outfile)
108
- return outfile
109
-
110
-
111
- def get_3D_model_from_scene(silent, scene_state, min_conf_thr=2, as_pointcloud=False, mask_sky=False,
112
- clean_depth=False, transparent_cams=False, cam_size=0.05, TSDF_thresh=0):
113
- """
114
- extract 3D_model (glb file) from a reconstructed scene
115
- """
116
- if scene_state is None:
117
- return None
118
- outfile = scene_state.outfile_name
119
- if outfile is None:
120
- return None
121
-
122
- # get optimized values from scene
123
- scene = scene_state.sparse_ga
124
- rgbimg = scene.imgs
125
- focals = scene.get_focals().cpu()
126
- cams2world = scene.get_im_poses().cpu()
127
-
128
- # 3D pointcloud from depthmap, poses and intrinsics
129
- if TSDF_thresh > 0:
130
- tsdf = TSDFPostProcess(scene, TSDF_thresh=TSDF_thresh)
131
- pts3d, _, confs = to_numpy(tsdf.get_dense_pts3d(clean_depth=clean_depth))
132
  else:
133
- pts3d, _, confs = to_numpy(scene.get_dense_pts3d(clean_depth=clean_depth))
134
- msk = to_numpy([c > min_conf_thr for c in confs])
135
- return _convert_scene_output_to_glb(outfile, rgbimg, pts3d, msk, focals, cams2world, as_pointcloud=as_pointcloud,
136
- transparent_cams=transparent_cams, cam_size=cam_size, silent=silent)
137
 
138
-
139
- def get_reconstructed_scene(outdir, gradio_delete_cache, model, device, silent, image_size, current_scene_state,
140
- filelist, optim_level, lr1, niter1, lr2, niter2, min_conf_thr, matching_conf_thr,
141
- as_pointcloud, mask_sky, clean_depth, transparent_cams, cam_size, scenegraph_type, winsize,
142
- win_cyclic, refid, TSDF_thresh, shared_intrinsics, **kw):
143
- """
144
- from a list of images, run mast3r inference, sparse global aligner.
145
- then run get_3D_model_from_scene
146
- """
147
- imgs = load_images(filelist, size=image_size, verbose=not silent)
148
- if len(imgs) == 1:
149
- imgs = [imgs[0], copy.deepcopy(imgs[0])]
150
- imgs[1]['idx'] = 1
151
- filelist = [filelist[0], filelist[0] + '_2']
152
-
153
- scene_graph_params = [scenegraph_type]
154
- if scenegraph_type in ["swin", "logwin"]:
155
- scene_graph_params.append(str(winsize))
156
- elif scenegraph_type == "oneref":
157
- scene_graph_params.append(str(refid))
158
- if scenegraph_type in ["swin", "logwin"] and not win_cyclic:
159
- scene_graph_params.append('noncyclic')
160
- scene_graph = '-'.join(scene_graph_params)
161
- pairs = make_pairs(imgs, scene_graph=scene_graph, prefilter=None, symmetrize=True)
162
- if optim_level == 'coarse':
163
- niter2 = 0
164
- # Sparse GA (forward mast3r -> matching -> 3D optim -> 2D refinement -> triangulation)
165
- if current_scene_state is not None and \
166
- not current_scene_state.should_delete and \
167
- current_scene_state.cache_dir is not None:
168
- cache_dir = current_scene_state.cache_dir
169
- elif gradio_delete_cache:
170
- cache_dir = tempfile.mkdtemp(suffix='_cache', dir=outdir)
171
  else:
172
- cache_dir = os.path.join(outdir, 'cache')
173
- scene = sparse_global_alignment(filelist, pairs, cache_dir,
174
- model, lr1=lr1, niter1=niter1, lr2=lr2, niter2=niter2, device=device,
175
- opt_depth='depth' in optim_level, shared_intrinsics=shared_intrinsics,
176
- matching_conf_thr=matching_conf_thr, **kw)
177
- if current_scene_state is not None and \
178
- not current_scene_state.should_delete and \
179
- current_scene_state.outfile_name is not None:
180
- outfile_name = current_scene_state.outfile_name
181
- else:
182
- outfile_name = tempfile.mktemp(suffix='_scene.glb', dir=outdir)
183
-
184
- scene_state = SparseGAState(scene, gradio_delete_cache, cache_dir, outfile_name)
185
- outfile = get_3D_model_from_scene(silent, scene_state, min_conf_thr, as_pointcloud, mask_sky,
186
- clean_depth, transparent_cams, cam_size, TSDF_thresh)
187
- return scene_state, outfile
188
-
189
-
190
- def set_scenegraph_options(inputfiles, win_cyclic, refid, scenegraph_type):
191
- num_files = len(inputfiles) if inputfiles is not None else 1
192
- show_win_controls = scenegraph_type in ["swin", "logwin"]
193
- show_winsize = scenegraph_type in ["swin", "logwin"]
194
- show_cyclic = scenegraph_type in ["swin", "logwin"]
195
- max_winsize, min_winsize = 1, 1
196
- if scenegraph_type == "swin":
197
- if win_cyclic:
198
- max_winsize = max(1, math.ceil((num_files - 1) / 2))
199
- else:
200
- max_winsize = num_files - 1
201
- elif scenegraph_type == "logwin":
202
- if win_cyclic:
203
- half_size = math.ceil((num_files - 1) / 2)
204
- max_winsize = max(1, math.ceil(math.log(half_size, 2)))
205
- else:
206
- max_winsize = max(1, math.ceil(math.log(num_files, 2)))
207
- winsize = gradio.Slider(label="Scene Graph: Window Size", value=max_winsize,
208
- minimum=min_winsize, maximum=max_winsize, step=1, visible=show_winsize)
209
- win_cyclic = gradio.Checkbox(value=win_cyclic, label="Cyclic sequence", visible=show_cyclic)
210
- win_col = gradio.Column(visible=show_win_controls)
211
- refid = gradio.Slider(label="Scene Graph: Id", value=0, minimum=0,
212
- maximum=num_files - 1, step=1, visible=scenegraph_type == 'oneref')
213
- return win_col, winsize, win_cyclic, refid
214
-
215
-
216
- def main_demo(tmpdirname, model, device, image_size, server_name, server_port, silent=False,
217
- share=False, gradio_delete_cache=False):
218
- if not silent:
219
- print('Outputing stuff in', tmpdirname)
220
-
221
- recon_fun = functools.partial(get_reconstructed_scene, tmpdirname, gradio_delete_cache, model, device,
222
- silent, image_size)
223
- model_from_scene_fun = functools.partial(get_3D_model_from_scene, silent)
224
-
225
- def get_context(delete_cache):
226
- css = """.gradio-container {margin: 0 !important; min-width: 100%};"""
227
- title = "MASt3R Demo"
228
- if delete_cache:
229
- return gradio.Blocks(css=css, title=title, delete_cache=(delete_cache, delete_cache))
230
- else:
231
- return gradio.Blocks(css=css, title="MASt3R Demo") # for compatibility with older versions
232
-
233
- with get_context(gradio_delete_cache) as demo:
234
- # scene state is save so that you can change conf_thr, cam_size... without rerunning the inference
235
- scene = gradio.State(None)
236
- gradio.HTML('<h2 style="text-align: center;">MASt3R Demo</h2>')
237
- with gradio.Column():
238
- inputfiles = gradio.File(file_count="multiple")
239
- with gradio.Row():
240
- with gradio.Column():
241
- with gradio.Row():
242
- lr1 = gradio.Slider(label="Coarse LR", value=0.07, minimum=0.01, maximum=0.2, step=0.01)
243
- niter1 = gradio.Number(value=500, precision=0, minimum=0, maximum=10_000,
244
- label="num_iterations", info="For coarse alignment!")
245
- lr2 = gradio.Slider(label="Fine LR", value=0.014, minimum=0.005, maximum=0.05, step=0.001)
246
- niter2 = gradio.Number(value=200, precision=0, minimum=0, maximum=100_000,
247
- label="num_iterations", info="For refinement!")
248
- optim_level = gradio.Dropdown(["coarse", "refine", "refine+depth"],
249
- value='refine', label="OptLevel",
250
- info="Optimization level")
251
- with gradio.Row():
252
- matching_conf_thr = gradio.Slider(label="Matching Confidence Thr", value=5.,
253
- minimum=0., maximum=30., step=0.1,
254
- info="Before Fallback to Regr3D!")
255
- shared_intrinsics = gradio.Checkbox(value=False, label="Shared intrinsics",
256
- info="Only optimize one set of intrinsics for all views")
257
- scenegraph_type = gradio.Dropdown([("complete: all possible image pairs", "complete"),
258
- ("swin: sliding window", "swin"),
259
- ("logwin: sliding window with long range", "logwin"),
260
- ("oneref: match one image with all", "oneref")],
261
- value='complete', label="Scenegraph",
262
- info="Define how to make pairs",
263
- interactive=True)
264
- with gradio.Column(visible=False) as win_col:
265
- winsize = gradio.Slider(label="Scene Graph: Window Size", value=1,
266
- minimum=1, maximum=1, step=1)
267
- win_cyclic = gradio.Checkbox(value=False, label="Cyclic sequence")
268
- refid = gradio.Slider(label="Scene Graph: Id", value=0,
269
- minimum=0, maximum=0, step=1, visible=False)
270
- run_btn = gradio.Button("Run")
271
-
272
- with gradio.Row():
273
- # adjust the confidence threshold
274
- min_conf_thr = gradio.Slider(label="min_conf_thr", value=1.5, minimum=0.0, maximum=10, step=0.1)
275
- # adjust the camera size in the output pointcloud
276
- cam_size = gradio.Slider(label="cam_size", value=0.2, minimum=0.001, maximum=1.0, step=0.001)
277
- TSDF_thresh = gradio.Slider(label="TSDF Threshold", value=0., minimum=0., maximum=1., step=0.01)
278
- with gradio.Row():
279
- as_pointcloud = gradio.Checkbox(value=True, label="As pointcloud")
280
- # two post process implemented
281
- mask_sky = gradio.Checkbox(value=False, label="Mask sky")
282
- clean_depth = gradio.Checkbox(value=True, label="Clean-up depthmaps")
283
- transparent_cams = gradio.Checkbox(value=False, label="Transparent cameras")
284
-
285
- outmodel = gradio.Model3D()
286
-
287
- # events
288
- scenegraph_type.change(set_scenegraph_options,
289
- inputs=[inputfiles, win_cyclic, refid, scenegraph_type],
290
- outputs=[win_col, winsize, win_cyclic, refid])
291
- inputfiles.change(set_scenegraph_options,
292
- inputs=[inputfiles, win_cyclic, refid, scenegraph_type],
293
- outputs=[win_col, winsize, win_cyclic, refid])
294
- win_cyclic.change(set_scenegraph_options,
295
- inputs=[inputfiles, win_cyclic, refid, scenegraph_type],
296
- outputs=[win_col, winsize, win_cyclic, refid])
297
- run_btn.click(fn=recon_fun,
298
- inputs=[scene, inputfiles, optim_level, lr1, niter1, lr2, niter2, min_conf_thr, matching_conf_thr,
299
- as_pointcloud, mask_sky, clean_depth, transparent_cams, cam_size,
300
- scenegraph_type, winsize, win_cyclic, refid, TSDF_thresh, shared_intrinsics],
301
- outputs=[scene, outmodel])
302
- min_conf_thr.release(fn=model_from_scene_fun,
303
- inputs=[scene, min_conf_thr, as_pointcloud, mask_sky,
304
- clean_depth, transparent_cams, cam_size, TSDF_thresh],
305
- outputs=outmodel)
306
- cam_size.change(fn=model_from_scene_fun,
307
- inputs=[scene, min_conf_thr, as_pointcloud, mask_sky,
308
- clean_depth, transparent_cams, cam_size, TSDF_thresh],
309
- outputs=outmodel)
310
- TSDF_thresh.change(fn=model_from_scene_fun,
311
- inputs=[scene, min_conf_thr, as_pointcloud, mask_sky,
312
- clean_depth, transparent_cams, cam_size, TSDF_thresh],
313
- outputs=outmodel)
314
- as_pointcloud.change(fn=model_from_scene_fun,
315
- inputs=[scene, min_conf_thr, as_pointcloud, mask_sky,
316
- clean_depth, transparent_cams, cam_size, TSDF_thresh],
317
- outputs=outmodel)
318
- mask_sky.change(fn=model_from_scene_fun,
319
- inputs=[scene, min_conf_thr, as_pointcloud, mask_sky,
320
- clean_depth, transparent_cams, cam_size, TSDF_thresh],
321
- outputs=outmodel)
322
- clean_depth.change(fn=model_from_scene_fun,
323
- inputs=[scene, min_conf_thr, as_pointcloud, mask_sky,
324
- clean_depth, transparent_cams, cam_size, TSDF_thresh],
325
- outputs=outmodel)
326
- transparent_cams.change(model_from_scene_fun,
327
- inputs=[scene, min_conf_thr, as_pointcloud, mask_sky,
328
- clean_depth, transparent_cams, cam_size, TSDF_thresh],
329
- outputs=outmodel)
330
- demo.launch(share=share, server_name=server_name, server_port=server_port)
 
3
  # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
4
  #
5
  # --------------------------------------------------------
6
+ # gradio demo executable
7
  # --------------------------------------------------------
 
 
8
  import os
9
+ import torch
 
 
 
 
10
  import tempfile
11
+ from contextlib import nullcontext
12
 
13
+ from mast3r.demo import get_args_parser, main_demo
 
14
 
15
+ from mast3r.model import AsymmetricMASt3R
16
+ from mast3r.utils.misc import hash_md5
 
 
 
 
17
 
18
  import matplotlib.pyplot as pl
19
+ pl.ion()
20
 
21
+ torch.backends.cuda.matmul.allow_tf32 = True # for gpu >= Ampere and pytorch >= 1.12
22
 
23
+ if __name__ == '__main__':
24
+ parser = get_args_parser()
25
+ args = parser.parse_args()
 
 
 
26
 
27
+ if args.server_name is not None:
28
+ server_name = args.server_name
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  else:
30
+ server_name = '0.0.0.0' if args.local_network else '127.0.0.1'
 
 
 
31
 
32
+ if args.weights is not None:
33
+ weights_path = args.weights
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  else:
35
+ weights_path = "naver/" + args.model_name
36
+
37
+ model = AsymmetricMASt3R.from_pretrained(weights_path).to(args.device)
38
+ chkpt_tag = hash_md5(weights_path)
39
+
40
+ def get_context(tmp_dir):
41
+ return tempfile.TemporaryDirectory(suffix='_mast3r_gradio_demo') if tmp_dir is None \
42
+ else nullcontext(tmp_dir)
43
+ with get_context(args.tmp_dir) as tmpdirname:
44
+ cache_path = os.path.join(tmpdirname, chkpt_tag)
45
+ os.makedirs(cache_path, exist_ok=True)
46
+ main_demo(cache_path, model, args.device, args.image_size, server_name, args.server_port, silent=args.silent,
47
+ share=args.share, gradio_delete_cache=args.gradio_delete_cache)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
device.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+ #
4
+ # --------------------------------------------------------
5
+ # utilitary functions for DUSt3R
6
+ # --------------------------------------------------------
7
+ import numpy as np
8
+ import torch
9
+
10
+
11
+ def todevice(batch, device, callback=None, non_blocking=False):
12
+ ''' Transfer some variables to another device (i.e. GPU, CPU:torch, CPU:numpy).
13
+
14
+ batch: list, tuple, dict of tensors or other things
15
+ device: pytorch device or 'numpy'
16
+ callback: function that would be called on every sub-elements.
17
+ '''
18
+ if callback:
19
+ batch = callback(batch)
20
+
21
+ if isinstance(batch, dict):
22
+ return {k: todevice(v, device) for k, v in batch.items()}
23
+
24
+ if isinstance(batch, (tuple, list)):
25
+ return type(batch)(todevice(x, device) for x in batch)
26
+
27
+ x = batch
28
+ if device == 'numpy':
29
+ if isinstance(x, torch.Tensor):
30
+ x = x.detach().cpu().numpy()
31
+ elif x is not None:
32
+ if isinstance(x, np.ndarray):
33
+ x = torch.from_numpy(x)
34
+ if torch.is_tensor(x):
35
+ x = x.to(device, non_blocking=non_blocking)
36
+ return x
37
+
38
+
39
+ to_device = todevice # alias
40
+
41
+
42
+ def to_numpy(x): return todevice(x, 'numpy')
43
+ def to_cpu(x): return todevice(x, 'cpu')
44
+ def to_cuda(x): return todevice(x, 'cuda')
45
+
46
+
47
+ def collate_with_cat(whatever, lists=False):
48
+ if isinstance(whatever, dict):
49
+ return {k: collate_with_cat(vals, lists=lists) for k, vals in whatever.items()}
50
+
51
+ elif isinstance(whatever, (tuple, list)):
52
+ if len(whatever) == 0:
53
+ return whatever
54
+ elem = whatever[0]
55
+ T = type(whatever)
56
+
57
+ if elem is None:
58
+ return None
59
+ if isinstance(elem, (bool, float, int, str)):
60
+ return whatever
61
+ if isinstance(elem, tuple):
62
+ return T(collate_with_cat(x, lists=lists) for x in zip(*whatever))
63
+ if isinstance(elem, dict):
64
+ return {k: collate_with_cat([e[k] for e in whatever], lists=lists) for k in elem}
65
+
66
+ if isinstance(elem, torch.Tensor):
67
+ return listify(whatever) if lists else torch.cat(whatever)
68
+ if isinstance(elem, np.ndarray):
69
+ return listify(whatever) if lists else torch.cat([torch.from_numpy(x) for x in whatever])
70
+
71
+ # otherwise, we just chain lists
72
+ return sum(whatever, T())
73
+
74
+
75
+ def listify(elems):
76
+ return [x for e in elems for x in e]
image.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+ #
4
+ # --------------------------------------------------------
5
+ # utilitary functions about images (loading/converting...)
6
+ # --------------------------------------------------------
7
+ import os
8
+ import torch
9
+ import numpy as np
10
+ import PIL.Image
11
+ from PIL.ImageOps import exif_transpose
12
+ import torchvision.transforms as tvf
13
+ os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1"
14
+ import cv2 # noqa
15
+
16
+ try:
17
+ from pillow_heif import register_heif_opener # noqa
18
+ register_heif_opener()
19
+ heif_support_enabled = True
20
+ except ImportError:
21
+ heif_support_enabled = False
22
+
23
+ ImgNorm = tvf.Compose([tvf.ToTensor(), tvf.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
24
+
25
+
26
+ def img_to_arr( img ):
27
+ if isinstance(img, str):
28
+ img = imread_cv2(img)
29
+ return img
30
+
31
+ def imread_cv2(path, options=cv2.IMREAD_COLOR):
32
+ """ Open an image or a depthmap with opencv-python.
33
+ """
34
+ if path.endswith(('.exr', 'EXR')):
35
+ options = cv2.IMREAD_ANYDEPTH
36
+ img = cv2.imread(path, options)
37
+ if img is None:
38
+ raise IOError(f'Could not load image={path} with {options=}')
39
+ if img.ndim == 3:
40
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
41
+ return img
42
+
43
+
44
+ def rgb(ftensor, true_shape=None):
45
+ if isinstance(ftensor, list):
46
+ return [rgb(x, true_shape=true_shape) for x in ftensor]
47
+ if isinstance(ftensor, torch.Tensor):
48
+ ftensor = ftensor.detach().cpu().numpy() # H,W,3
49
+ if ftensor.ndim == 3 and ftensor.shape[0] == 3:
50
+ ftensor = ftensor.transpose(1, 2, 0)
51
+ elif ftensor.ndim == 4 and ftensor.shape[1] == 3:
52
+ ftensor = ftensor.transpose(0, 2, 3, 1)
53
+ if true_shape is not None:
54
+ H, W = true_shape
55
+ ftensor = ftensor[:H, :W]
56
+ if ftensor.dtype == np.uint8:
57
+ img = np.float32(ftensor) / 255
58
+ else:
59
+ img = (ftensor * 0.5) + 0.5
60
+ return img.clip(min=0, max=1)
61
+
62
+
63
+ def _resize_pil_image(img, long_edge_size):
64
+ S = max(img.size)
65
+ if S > long_edge_size:
66
+ interp = PIL.Image.LANCZOS
67
+ elif S <= long_edge_size:
68
+ interp = PIL.Image.BICUBIC
69
+ new_size = tuple(int(round(x*long_edge_size/S)) for x in img.size)
70
+ return img.resize(new_size, interp)
71
+
72
+
73
+ def load_images(folder_or_list, size, square_ok=False, verbose=True):
74
+ """ open and convert all images in a list or folder to proper input format for DUSt3R
75
+ """
76
+ if isinstance(folder_or_list, str):
77
+ if verbose:
78
+ print(f'>> Loading images from {folder_or_list}')
79
+ root, folder_content = folder_or_list, sorted(os.listdir(folder_or_list))
80
+
81
+ elif isinstance(folder_or_list, list):
82
+ if verbose:
83
+ print(f'>> Loading a list of {len(folder_or_list)} images')
84
+ root, folder_content = '', folder_or_list
85
+
86
+ else:
87
+ raise ValueError(f'bad {folder_or_list=} ({type(folder_or_list)})')
88
+
89
+ supported_images_extensions = ['.jpg', '.jpeg', '.png']
90
+ if heif_support_enabled:
91
+ supported_images_extensions += ['.heic', '.heif']
92
+ supported_images_extensions = tuple(supported_images_extensions)
93
+
94
+ imgs = []
95
+ for path in folder_content:
96
+ if not path.lower().endswith(supported_images_extensions):
97
+ continue
98
+ img = exif_transpose(PIL.Image.open(os.path.join(root, path))).convert('RGB')
99
+ W1, H1 = img.size
100
+ if size == 224:
101
+ # resize short side to 224 (then crop)
102
+ img = _resize_pil_image(img, round(size * max(W1/H1, H1/W1)))
103
+ else:
104
+ # resize long side to 512
105
+ img = _resize_pil_image(img, size)
106
+ W, H = img.size
107
+ cx, cy = W//2, H//2
108
+ if size == 224:
109
+ half = min(cx, cy)
110
+ img = img.crop((cx-half, cy-half, cx+half, cy+half))
111
+ else:
112
+ halfw, halfh = ((2*cx)//16)*8, ((2*cy)//16)*8
113
+ if not (square_ok) and W == H:
114
+ halfh = 3*halfw/4
115
+ img = img.crop((cx-halfw, cy-halfh, cx+halfw, cy+halfh))
116
+
117
+ W2, H2 = img.size
118
+ if verbose:
119
+ print(f' - adding {path} with resolution {W1}x{H1} --> {W2}x{H2}')
120
+ imgs.append(dict(img=ImgNorm(img)[None], true_shape=np.int32(
121
+ [img.size[::-1]]), idx=len(imgs), instance=str(len(imgs))))
122
+
123
+ assert imgs, 'no images foud at '+root
124
+ if verbose:
125
+ print(f' (Found {len(imgs)} images)')
126
+ return imgs
image_pairs.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+ #
4
+ # --------------------------------------------------------
5
+ # utilities needed to load image pairs
6
+ # --------------------------------------------------------
7
+ import numpy as np
8
+ import torch
9
+
10
+
11
+ def make_pairs(imgs, scene_graph='complete', prefilter=None, symmetrize=True):
12
+ pairs = []
13
+ if scene_graph == 'complete': # complete graph
14
+ for i in range(len(imgs)):
15
+ for j in range(i):
16
+ pairs.append((imgs[i], imgs[j]))
17
+ elif scene_graph.startswith('swin'):
18
+ iscyclic = not scene_graph.endswith('noncyclic')
19
+ try:
20
+ winsize = int(scene_graph.split('-')[1])
21
+ except Exception as e:
22
+ winsize = 3
23
+ pairsid = set()
24
+ for i in range(len(imgs)):
25
+ for j in range(1, winsize + 1):
26
+ idx = (i + j)
27
+ if iscyclic:
28
+ idx = idx % len(imgs) # explicit loop closure
29
+ if idx >= len(imgs):
30
+ continue
31
+ pairsid.add((i, idx) if i < idx else (idx, i))
32
+ for i, j in pairsid:
33
+ pairs.append((imgs[i], imgs[j]))
34
+ elif scene_graph.startswith('logwin'):
35
+ iscyclic = not scene_graph.endswith('noncyclic')
36
+ try:
37
+ winsize = int(scene_graph.split('-')[1])
38
+ except Exception as e:
39
+ winsize = 3
40
+ offsets = [2**i for i in range(winsize)]
41
+ pairsid = set()
42
+ for i in range(len(imgs)):
43
+ ixs_l = [i - off for off in offsets]
44
+ ixs_r = [i + off for off in offsets]
45
+ for j in ixs_l + ixs_r:
46
+ if iscyclic:
47
+ j = j % len(imgs) # Explicit loop closure
48
+ if j < 0 or j >= len(imgs) or j == i:
49
+ continue
50
+ pairsid.add((i, j) if i < j else (j, i))
51
+ for i, j in pairsid:
52
+ pairs.append((imgs[i], imgs[j]))
53
+ elif scene_graph.startswith('oneref'):
54
+ refid = int(scene_graph.split('-')[1]) if '-' in scene_graph else 0
55
+ for j in range(len(imgs)):
56
+ if j != refid:
57
+ pairs.append((imgs[refid], imgs[j]))
58
+ if symmetrize:
59
+ pairs += [(img2, img1) for img1, img2 in pairs]
60
+
61
+ # now, remove edges
62
+ if isinstance(prefilter, str) and prefilter.startswith('seq'):
63
+ pairs = filter_pairs_seq(pairs, int(prefilter[3:]))
64
+
65
+ if isinstance(prefilter, str) and prefilter.startswith('cyc'):
66
+ pairs = filter_pairs_seq(pairs, int(prefilter[3:]), cyclic=True)
67
+
68
+ return pairs
69
+
70
+
71
+ def sel(x, kept):
72
+ if isinstance(x, dict):
73
+ return {k: sel(v, kept) for k, v in x.items()}
74
+ if isinstance(x, (torch.Tensor, np.ndarray)):
75
+ return x[kept]
76
+ if isinstance(x, (tuple, list)):
77
+ return type(x)([x[k] for k in kept])
78
+
79
+
80
+ def _filter_edges_seq(edges, seq_dis_thr, cyclic=False):
81
+ # number of images
82
+ n = max(max(e) for e in edges) + 1
83
+
84
+ kept = []
85
+ for e, (i, j) in enumerate(edges):
86
+ dis = abs(i - j)
87
+ if cyclic:
88
+ dis = min(dis, abs(i + n - j), abs(i - n - j))
89
+ if dis <= seq_dis_thr:
90
+ kept.append(e)
91
+ return kept
92
+
93
+
94
+ def filter_pairs_seq(pairs, seq_dis_thr, cyclic=False):
95
+ edges = [(img1['idx'], img2['idx']) for img1, img2 in pairs]
96
+ kept = _filter_edges_seq(edges, seq_dis_thr, cyclic=cyclic)
97
+ return [pairs[i] for i in kept]
98
+
99
+
100
+ def filter_edges_seq(view1, view2, pred1, pred2, seq_dis_thr, cyclic=False):
101
+ edges = [(int(i), int(j)) for i, j in zip(view1['idx'], view2['idx'])]
102
+ kept = _filter_edges_seq(edges, seq_dis_thr, cyclic=cyclic)
103
+ print(f'>> Filtering edges more than {seq_dis_thr} frames apart: kept {len(kept)}/{len(edges)} edges')
104
+ return sel(view1, kept), sel(view2, kept), sel(pred1, kept), sel(pred2, kept)
path_to_dust3r.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+ #
4
+ # --------------------------------------------------------
5
+ # dust3r submodule import
6
+ # --------------------------------------------------------
7
+
8
+ import sys
9
+ import os.path as path
10
+ HERE_PATH = path.normpath(path.dirname(__file__))
11
+ DUSt3R_REPO_PATH = path.normpath(path.join(HERE_PATH, '../../dust3r'))
12
+ DUSt3R_LIB_PATH = path.join(DUSt3R_REPO_PATH, 'dust3r')
13
+ # check the presence of models directory in repo to be sure its cloned
14
+ if path.isdir(DUSt3R_LIB_PATH):
15
+ # workaround for sibling import
16
+ sys.path.insert(0, DUSt3R_REPO_PATH)
17
+ else:
18
+ raise ImportError(f"dust3r is not initialized, could not find: {DUSt3R_LIB_PATH}.\n "
19
+ "Did you forget to run 'git submodule update --init --recursive' ?")
viz.py ADDED
@@ -0,0 +1,381 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+ #
4
+ # --------------------------------------------------------
5
+ # Visualization utilities using trimesh
6
+ # --------------------------------------------------------
7
+ import PIL.Image
8
+ import numpy as np
9
+ from scipy.spatial.transform import Rotation
10
+ import torch
11
+
12
+ from dust3r.utils.geometry import geotrf, get_med_dist_between_poses, depthmap_to_absolute_camera_coordinates
13
+ from dust3r.utils.device import to_numpy
14
+ from dust3r.utils.image import rgb, img_to_arr
15
+
16
+ try:
17
+ import trimesh
18
+ except ImportError:
19
+ print('/!\\ module trimesh is not installed, cannot visualize results /!\\')
20
+
21
+
22
+
23
+ def cat_3d(vecs):
24
+ if isinstance(vecs, (np.ndarray, torch.Tensor)):
25
+ vecs = [vecs]
26
+ return np.concatenate([p.reshape(-1, 3) for p in to_numpy(vecs)])
27
+
28
+
29
+ def show_raw_pointcloud(pts3d, colors, point_size=2):
30
+ scene = trimesh.Scene()
31
+
32
+ pct = trimesh.PointCloud(cat_3d(pts3d), colors=cat_3d(colors))
33
+ scene.add_geometry(pct)
34
+
35
+ scene.show(line_settings={'point_size': point_size})
36
+
37
+
38
+ def pts3d_to_trimesh(img, pts3d, valid=None):
39
+ H, W, THREE = img.shape
40
+ assert THREE == 3
41
+ assert img.shape == pts3d.shape
42
+
43
+ vertices = pts3d.reshape(-1, 3)
44
+
45
+ # make squares: each pixel == 2 triangles
46
+ idx = np.arange(len(vertices)).reshape(H, W)
47
+ idx1 = idx[:-1, :-1].ravel() # top-left corner
48
+ idx2 = idx[:-1, +1:].ravel() # right-left corner
49
+ idx3 = idx[+1:, :-1].ravel() # bottom-left corner
50
+ idx4 = idx[+1:, +1:].ravel() # bottom-right corner
51
+ faces = np.concatenate((
52
+ np.c_[idx1, idx2, idx3],
53
+ np.c_[idx3, idx2, idx1], # same triangle, but backward (cheap solution to cancel face culling)
54
+ np.c_[idx2, idx3, idx4],
55
+ np.c_[idx4, idx3, idx2], # same triangle, but backward (cheap solution to cancel face culling)
56
+ ), axis=0)
57
+
58
+ # prepare triangle colors
59
+ face_colors = np.concatenate((
60
+ img[:-1, :-1].reshape(-1, 3),
61
+ img[:-1, :-1].reshape(-1, 3),
62
+ img[+1:, +1:].reshape(-1, 3),
63
+ img[+1:, +1:].reshape(-1, 3)
64
+ ), axis=0)
65
+
66
+ # remove invalid faces
67
+ if valid is not None:
68
+ assert valid.shape == (H, W)
69
+ valid_idxs = valid.ravel()
70
+ valid_faces = valid_idxs[faces].all(axis=-1)
71
+ faces = faces[valid_faces]
72
+ face_colors = face_colors[valid_faces]
73
+
74
+ assert len(faces) == len(face_colors)
75
+ return dict(vertices=vertices, face_colors=face_colors, faces=faces)
76
+
77
+
78
+ def cat_meshes(meshes):
79
+ vertices, faces, colors = zip(*[(m['vertices'], m['faces'], m['face_colors']) for m in meshes])
80
+ n_vertices = np.cumsum([0]+[len(v) for v in vertices])
81
+ for i in range(len(faces)):
82
+ faces[i][:] += n_vertices[i]
83
+
84
+ vertices = np.concatenate(vertices)
85
+ colors = np.concatenate(colors)
86
+ faces = np.concatenate(faces)
87
+ return dict(vertices=vertices, face_colors=colors, faces=faces)
88
+
89
+
90
+ def show_duster_pairs(view1, view2, pred1, pred2):
91
+ import matplotlib.pyplot as pl
92
+ pl.ion()
93
+
94
+ for e in range(len(view1['instance'])):
95
+ i = view1['idx'][e]
96
+ j = view2['idx'][e]
97
+ img1 = rgb(view1['img'][e])
98
+ img2 = rgb(view2['img'][e])
99
+ conf1 = pred1['conf'][e].squeeze()
100
+ conf2 = pred2['conf'][e].squeeze()
101
+ score = conf1.mean()*conf2.mean()
102
+ print(f">> Showing pair #{e} {i}-{j} {score=:g}")
103
+ pl.clf()
104
+ pl.subplot(221).imshow(img1)
105
+ pl.subplot(223).imshow(img2)
106
+ pl.subplot(222).imshow(conf1, vmin=1, vmax=30)
107
+ pl.subplot(224).imshow(conf2, vmin=1, vmax=30)
108
+ pts1 = pred1['pts3d'][e]
109
+ pts2 = pred2['pts3d_in_other_view'][e]
110
+ pl.subplots_adjust(0, 0, 1, 1, 0, 0)
111
+ if input('show pointcloud? (y/n) ') == 'y':
112
+ show_raw_pointcloud(cat(pts1, pts2), cat(img1, img2), point_size=5)
113
+
114
+
115
+ def auto_cam_size(im_poses):
116
+ return 0.1 * get_med_dist_between_poses(im_poses)
117
+
118
+
119
+ class SceneViz:
120
+ def __init__(self):
121
+ self.scene = trimesh.Scene()
122
+
123
+ def add_rgbd(self, image, depth, intrinsics=None, cam2world=None, zfar=np.inf, mask=None):
124
+ image = img_to_arr(image)
125
+
126
+ # make up some intrinsics
127
+ if intrinsics is None:
128
+ H, W, THREE = image.shape
129
+ focal = max(H, W)
130
+ intrinsics = np.float32([[focal, 0, W/2], [0, focal, H/2], [0, 0, 1]])
131
+
132
+ # compute 3d points
133
+ pts3d = depthmap_to_pts3d(depth, intrinsics, cam2world=cam2world)
134
+
135
+ return self.add_pointcloud(pts3d, image, mask=(depth<zfar) if mask is None else mask)
136
+
137
+ def add_pointcloud(self, pts3d, color=(0,0,0), mask=None, denoise=False):
138
+ pts3d = to_numpy(pts3d)
139
+ mask = to_numpy(mask)
140
+ if not isinstance(pts3d, list):
141
+ pts3d = [pts3d.reshape(-1,3)]
142
+ if mask is not None:
143
+ mask = [mask.ravel()]
144
+ if not isinstance(color, (tuple,list)):
145
+ color = [color.reshape(-1,3)]
146
+ if mask is None:
147
+ mask = [slice(None)] * len(pts3d)
148
+
149
+ pts = np.concatenate([p[m] for p,m in zip(pts3d,mask)])
150
+ pct = trimesh.PointCloud(pts)
151
+
152
+ if isinstance(color, (list, np.ndarray, torch.Tensor)):
153
+ color = to_numpy(color)
154
+ col = np.concatenate([p[m] for p,m in zip(color,mask)])
155
+ assert col.shape == pts.shape, bb()
156
+ pct.visual.vertex_colors = uint8(col.reshape(-1,3))
157
+ else:
158
+ assert len(color) == 3
159
+ pct.visual.vertex_colors = np.broadcast_to(uint8(color), pts.shape)
160
+
161
+ if denoise:
162
+ # remove points which are noisy
163
+ centroid = np.median(pct.vertices, axis=0)
164
+ dist_to_centroid = np.linalg.norm( pct.vertices - centroid, axis=-1)
165
+ dist_thr = np.quantile(dist_to_centroid, 0.99)
166
+ valid = (dist_to_centroid < dist_thr)
167
+ # new cleaned pointcloud
168
+ pct = trimesh.PointCloud(pct.vertices[valid], color=pct.visual.vertex_colors[valid])
169
+
170
+ self.scene.add_geometry(pct)
171
+ return self
172
+
173
+ def add_rgbd(self, image, depth, intrinsics=None, cam2world=None, zfar=np.inf, mask=None):
174
+ # make up some intrinsics
175
+ if intrinsics is None:
176
+ H, W, THREE = image.shape
177
+ focal = max(H, W)
178
+ intrinsics = np.float32([[focal, 0, W/2], [0, focal, H/2], [0, 0, 1]])
179
+
180
+ # compute 3d points
181
+ pts3d, mask2 = depthmap_to_absolute_camera_coordinates(depth, intrinsics, cam2world)
182
+ mask2 &= (depth<zfar)
183
+
184
+ # combine with provided mask if any
185
+ if mask is not None:
186
+ mask2 &= mask
187
+
188
+ return self.add_pointcloud(pts3d, image, mask=mask2)
189
+
190
+ def add_camera(self, pose_c2w, focal=None, color=(0, 0, 0), image=None, imsize=None, cam_size=0.03):
191
+ pose_c2w, focal, color, image = to_numpy((pose_c2w, focal, color, image))
192
+ image = img_to_arr(image)
193
+ if isinstance(focal, np.ndarray) and focal.shape == (3,3):
194
+ intrinsics = focal
195
+ focal = (intrinsics[0,0] * intrinsics[1,1]) ** 0.5
196
+ if imsize is None:
197
+ imsize = (2*intrinsics[0,2], 2*intrinsics[1,2])
198
+
199
+ add_scene_cam(self.scene, pose_c2w, color, image, focal, imsize=imsize, screen_width=cam_size, marker=None)
200
+ return self
201
+
202
+ def add_cameras(self, poses, focals=None, images=None, imsizes=None, colors=None, **kw):
203
+ get = lambda arr,idx: None if arr is None else arr[idx]
204
+ for i, pose_c2w in enumerate(poses):
205
+ self.add_camera(pose_c2w, get(focals,i), image=get(images,i), color=get(colors,i), imsize=get(imsizes,i), **kw)
206
+ return self
207
+
208
+ def show(self, point_size=2):
209
+ self.scene.show(line_settings= {'point_size': point_size})
210
+
211
+
212
+ def show_raw_pointcloud_with_cams(imgs, pts3d, mask, focals, cams2world,
213
+ point_size=2, cam_size=0.05, cam_color=None):
214
+ """ Visualization of a pointcloud with cameras
215
+ imgs = (N, H, W, 3) or N-size list of [(H,W,3), ...]
216
+ pts3d = (N, H, W, 3) or N-size list of [(H,W,3), ...]
217
+ focals = (N,) or N-size list of [focal, ...]
218
+ cams2world = (N,4,4) or N-size list of [(4,4), ...]
219
+ """
220
+ assert len(pts3d) == len(mask) <= len(imgs) <= len(cams2world) == len(focals)
221
+ pts3d = to_numpy(pts3d)
222
+ imgs = to_numpy(imgs)
223
+ focals = to_numpy(focals)
224
+ cams2world = to_numpy(cams2world)
225
+
226
+ scene = trimesh.Scene()
227
+
228
+ # full pointcloud
229
+ pts = np.concatenate([p[m] for p, m in zip(pts3d, mask)])
230
+ col = np.concatenate([p[m] for p, m in zip(imgs, mask)])
231
+ pct = trimesh.PointCloud(pts.reshape(-1, 3), colors=col.reshape(-1, 3))
232
+ scene.add_geometry(pct)
233
+
234
+ # add each camera
235
+ for i, pose_c2w in enumerate(cams2world):
236
+ if isinstance(cam_color, list):
237
+ camera_edge_color = cam_color[i]
238
+ else:
239
+ camera_edge_color = cam_color or CAM_COLORS[i % len(CAM_COLORS)]
240
+ add_scene_cam(scene, pose_c2w, camera_edge_color,
241
+ imgs[i] if i < len(imgs) else None, focals[i], screen_width=cam_size)
242
+
243
+ scene.show(line_settings={'point_size': point_size})
244
+
245
+
246
+ def add_scene_cam(scene, pose_c2w, edge_color, image=None, focal=None, imsize=None,
247
+ screen_width=0.03, marker=None):
248
+ if image is not None:
249
+ image = np.asarray(image)
250
+ H, W, THREE = image.shape
251
+ assert THREE == 3
252
+ if image.dtype != np.uint8:
253
+ image = np.uint8(255*image)
254
+ elif imsize is not None:
255
+ W, H = imsize
256
+ elif focal is not None:
257
+ H = W = focal / 1.1
258
+ else:
259
+ H = W = 1
260
+
261
+ if isinstance(focal, np.ndarray):
262
+ focal = focal[0]
263
+ if not focal:
264
+ focal = min(H,W) * 1.1 # default value
265
+
266
+ # create fake camera
267
+ height = max( screen_width/10, focal * screen_width / H )
268
+ width = screen_width * 0.5**0.5
269
+ rot45 = np.eye(4)
270
+ rot45[:3, :3] = Rotation.from_euler('z', np.deg2rad(45)).as_matrix()
271
+ rot45[2, 3] = -height # set the tip of the cone = optical center
272
+ aspect_ratio = np.eye(4)
273
+ aspect_ratio[0, 0] = W/H
274
+ transform = pose_c2w @ OPENGL @ aspect_ratio @ rot45
275
+ cam = trimesh.creation.cone(width, height, sections=4) # , transform=transform)
276
+
277
+ # this is the image
278
+ if image is not None:
279
+ vertices = geotrf(transform, cam.vertices[[4, 5, 1, 3]])
280
+ faces = np.array([[0, 1, 2], [0, 2, 3], [2, 1, 0], [3, 2, 0]])
281
+ img = trimesh.Trimesh(vertices=vertices, faces=faces)
282
+ uv_coords = np.float32([[0, 0], [1, 0], [1, 1], [0, 1]])
283
+ img.visual = trimesh.visual.TextureVisuals(uv_coords, image=PIL.Image.fromarray(image))
284
+ scene.add_geometry(img)
285
+
286
+ # this is the camera mesh
287
+ rot2 = np.eye(4)
288
+ rot2[:3, :3] = Rotation.from_euler('z', np.deg2rad(2)).as_matrix()
289
+ vertices = np.r_[cam.vertices, 0.95*cam.vertices, geotrf(rot2, cam.vertices)]
290
+ vertices = geotrf(transform, vertices)
291
+ faces = []
292
+ for face in cam.faces:
293
+ if 0 in face:
294
+ continue
295
+ a, b, c = face
296
+ a2, b2, c2 = face + len(cam.vertices)
297
+ a3, b3, c3 = face + 2*len(cam.vertices)
298
+
299
+ # add 3 pseudo-edges
300
+ faces.append((a, b, b2))
301
+ faces.append((a, a2, c))
302
+ faces.append((c2, b, c))
303
+
304
+ faces.append((a, b, b3))
305
+ faces.append((a, a3, c))
306
+ faces.append((c3, b, c))
307
+
308
+ # no culling
309
+ faces += [(c, b, a) for a, b, c in faces]
310
+
311
+ cam = trimesh.Trimesh(vertices=vertices, faces=faces)
312
+ cam.visual.face_colors[:, :3] = edge_color
313
+ scene.add_geometry(cam)
314
+
315
+ if marker == 'o':
316
+ marker = trimesh.creation.icosphere(3, radius=screen_width/4)
317
+ marker.vertices += pose_c2w[:3,3]
318
+ marker.visual.face_colors[:,:3] = edge_color
319
+ scene.add_geometry(marker)
320
+
321
+
322
+ def cat(a, b):
323
+ return np.concatenate((a.reshape(-1, 3), b.reshape(-1, 3)))
324
+
325
+
326
+ OPENGL = np.array([[1, 0, 0, 0],
327
+ [0, -1, 0, 0],
328
+ [0, 0, -1, 0],
329
+ [0, 0, 0, 1]])
330
+
331
+
332
+ CAM_COLORS = [(255, 0, 0), (0, 0, 255), (0, 255, 0), (255, 0, 255), (255, 204, 0), (0, 204, 204),
333
+ (128, 255, 255), (255, 128, 255), (255, 255, 128), (0, 0, 0), (128, 128, 128)]
334
+
335
+
336
+ def uint8(colors):
337
+ if not isinstance(colors, np.ndarray):
338
+ colors = np.array(colors)
339
+ if np.issubdtype(colors.dtype, np.floating):
340
+ colors *= 255
341
+ assert 0 <= colors.min() and colors.max() < 256
342
+ return np.uint8(colors)
343
+
344
+
345
+ def segment_sky(image):
346
+ import cv2
347
+ from scipy import ndimage
348
+
349
+ # Convert to HSV
350
+ image = to_numpy(image)
351
+ if np.issubdtype(image.dtype, np.floating):
352
+ image = np.uint8(255*image.clip(min=0, max=1))
353
+ hsv = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)
354
+
355
+ # Define range for blue color and create mask
356
+ lower_blue = np.array([0, 0, 100])
357
+ upper_blue = np.array([30, 255, 255])
358
+ mask = cv2.inRange(hsv, lower_blue, upper_blue).view(bool)
359
+
360
+ # add luminous gray
361
+ mask |= (hsv[:, :, 1] < 10) & (hsv[:, :, 2] > 150)
362
+ mask |= (hsv[:, :, 1] < 30) & (hsv[:, :, 2] > 180)
363
+ mask |= (hsv[:, :, 1] < 50) & (hsv[:, :, 2] > 220)
364
+
365
+ # Morphological operations
366
+ kernel = np.ones((5, 5), np.uint8)
367
+ mask2 = ndimage.binary_opening(mask, structure=kernel)
368
+
369
+ # keep only largest CC
370
+ _, labels, stats, _ = cv2.connectedComponentsWithStats(mask2.view(np.uint8), connectivity=8)
371
+ cc_sizes = stats[1:, cv2.CC_STAT_AREA]
372
+ order = cc_sizes.argsort()[::-1] # bigger first
373
+ i = 0
374
+ selection = []
375
+ while i < len(order) and cc_sizes[order[i]] > cc_sizes[order[0]] / 2:
376
+ selection.append(1 + order[i])
377
+ i += 1
378
+ mask3 = np.in1d(labels, selection).reshape(labels.shape)
379
+
380
+ # Apply mask
381
+ return torch.from_numpy(mask3)