kxhit commited on
Commit
a4e1ae5
·
1 Parent(s): 319afdb

space gpu for carvekit

Browse files
app.py CHANGED
@@ -131,6 +131,7 @@ def sam_init():
131
  predictor = SamPredictor(sam)
132
  return predictor
133
 
 
134
  def create_carvekit_interface():
135
  # Check doc strings for more information
136
  interface = HiInterface(object_type="object", # Can be "object" or "hairs-like".
 
131
  predictor = SamPredictor(sam)
132
  return predictor
133
 
134
+ @spaces.GPU
135
  def create_carvekit_interface():
136
  # Check doc strings for more information
137
  interface = HiInterface(object_type="object", # Can be "object" or "hairs-like".
app_bk.py DELETED
@@ -1,786 +0,0 @@
1
- import spaces
2
- import torch
3
- print("cuda is available: ", torch.cuda.is_available())
4
-
5
- import gradio as gr
6
- import os
7
- import shutil
8
- import rembg
9
- import numpy as np
10
- import math
11
- import open3d as o3d
12
- from PIL import Image
13
- import torchvision
14
- import trimesh
15
- from skimage.io import imsave
16
- import imageio
17
- import cv2
18
- import matplotlib.pyplot as pl
19
- pl.ion()
20
-
21
- CaPE_TYPE = "6DoF"
22
- device = 'cuda' #if torch.cuda.is_available() else 'cpu'
23
- weight_dtype = torch.float16
24
- torch.backends.cuda.matmul.allow_tf32 = True # for gpu >= Ampere and pytorch >= 1.12
25
-
26
- # EscherNet
27
- # create angles in archimedean spiral with N steps
28
- def get_archimedean_spiral(sphere_radius, num_steps=250):
29
- # x-z plane, around upper y
30
- '''
31
- https://en.wikipedia.org/wiki/Spiral, section "Spherical spiral". c = a / pi
32
- '''
33
- a = 40
34
- r = sphere_radius
35
-
36
- translations = []
37
- angles = []
38
-
39
- # i = a / 2
40
- i = 0.01
41
- while i < a:
42
- theta = i / a * math.pi
43
- x = r * math.sin(theta) * math.cos(-i)
44
- z = r * math.sin(-theta + math.pi) * math.sin(-i)
45
- y = r * - math.cos(theta)
46
-
47
- # translations.append((x, y, z)) # origin
48
- translations.append((x, z, -y))
49
- angles.append([np.rad2deg(-i), np.rad2deg(theta)])
50
-
51
- # i += a / (2 * num_steps)
52
- i += a / (1 * num_steps)
53
-
54
- return np.array(translations), np.stack(angles)
55
-
56
- def look_at(origin, target, up):
57
- forward = (target - origin)
58
- forward = forward / np.linalg.norm(forward)
59
- right = np.cross(up, forward)
60
- right = right / np.linalg.norm(right)
61
- new_up = np.cross(forward, right)
62
- rotation_matrix = np.column_stack((right, new_up, -forward, target))
63
- matrix = np.row_stack((rotation_matrix, [0, 0, 0, 1]))
64
- return matrix
65
-
66
- import einops
67
- import sys
68
-
69
- sys.path.insert(0, "./6DoF/") # TODO change it when deploying
70
- # use the customized diffusers modules
71
- from diffusers import DDIMScheduler
72
- from dataset import get_pose
73
- from CN_encoder import CN_encoder
74
- from pipeline_zero1to3 import Zero1to3StableDiffusionPipeline
75
- from segment_anything import sam_model_registry, SamPredictor
76
-
77
- # import rembg
78
- from carvekit.api.high import HiInterface
79
-
80
-
81
- pretrained_model_name_or_path = "kxic/EscherNet_demo"
82
- resolution = 256
83
- h,w = resolution,resolution
84
- guidance_scale = 3.0
85
- radius = 2.2
86
- bg_color = [1., 1., 1., 1.]
87
- image_transforms = torchvision.transforms.Compose(
88
- [
89
- torchvision.transforms.Resize((resolution, resolution)), # 256, 256
90
- torchvision.transforms.ToTensor(),
91
- torchvision.transforms.Normalize([0.5], [0.5])
92
- ]
93
- )
94
- xyzs_spiral, angles_spiral = get_archimedean_spiral(1.5, 200)
95
- # only half toop
96
- xyzs_spiral = xyzs_spiral[:100]
97
- angles_spiral = angles_spiral[:100]
98
-
99
- # Init pipeline
100
- scheduler = DDIMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder="scheduler", revision=None)
101
- image_encoder = CN_encoder.from_pretrained(pretrained_model_name_or_path, subfolder="image_encoder", revision=None)
102
- pipeline = Zero1to3StableDiffusionPipeline.from_pretrained(
103
- pretrained_model_name_or_path,
104
- revision=None,
105
- scheduler=scheduler,
106
- image_encoder=None,
107
- safety_checker=None,
108
- feature_extractor=None,
109
- torch_dtype=weight_dtype,
110
- )
111
- pipeline.image_encoder = image_encoder.to(weight_dtype)
112
-
113
- pipeline.set_progress_bar_config(disable=False)
114
-
115
- pipeline = pipeline.to(device)
116
-
117
- # pipeline.enable_xformers_memory_efficient_attention()
118
- # enable vae slicing
119
- pipeline.enable_vae_slicing()
120
- # pipeline.enable_xformers_memory_efficient_attention()
121
-
122
-
123
- #### object segmentation
124
- def sam_init():
125
- sam_checkpoint = os.path.join("./sam_pt/sam_vit_h_4b8939.pth")
126
- if os.path.exists(sam_checkpoint) is False:
127
- os.system("wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth -P ./sam_pt/")
128
- model_type = "vit_h"
129
-
130
- sam = sam_model_registry[model_type](checkpoint=sam_checkpoint).to(device=device)
131
- predictor = SamPredictor(sam)
132
- return predictor
133
-
134
- def create_carvekit_interface():
135
- # Check doc strings for more information
136
- interface = HiInterface(object_type="object", # Can be "object" or "hairs-like".
137
- batch_size_seg=6,
138
- batch_size_matting=1,
139
- device="cpu",
140
- seg_mask_size=640, # Use 640 for Tracer B7 and 320 for U2Net
141
- matting_mask_size=2048,
142
- trimap_prob_threshold=231,
143
- trimap_dilation=30,
144
- trimap_erosion_iters=5,
145
- fp16=True)
146
-
147
- return interface
148
-
149
-
150
- # rembg_session = rembg.new_session()
151
- rembg_session = create_carvekit_interface()
152
- predictor = sam_init()
153
-
154
-
155
-
156
- @spaces.GPU(duration=120)
157
- def run_eschernet(eschernet_input_dict, sample_steps, sample_seed, nvs_num, nvs_mode):
158
- # set the random seed
159
- generator = torch.Generator(device=device).manual_seed(sample_seed)
160
- # generator = None
161
- T_out = nvs_num
162
- T_in = len(eschernet_input_dict['imgs'])
163
- ####### output pose
164
- # TODO choose T_out number of poses sequentially from the spiral
165
- xyzs = xyzs_spiral[::(len(xyzs_spiral) // T_out)]
166
- angles_out = angles_spiral[::(len(xyzs_spiral) // T_out)]
167
-
168
- ####### input's max radius for translation scaling
169
- radii = eschernet_input_dict['radii']
170
- max_t = np.max(radii)
171
- min_t = np.min(radii)
172
-
173
- ####### input pose
174
- pose_in = []
175
- for T_in_index in range(T_in):
176
- pose = get_pose(np.linalg.inv(eschernet_input_dict['poses'][T_in_index]))
177
- pose[1:3, :] *= -1 # coordinate system conversion
178
- pose[3, 3] *= 1. / max_t * radius # scale radius to [1.5, 2.2]
179
- pose_in.append(torch.from_numpy(pose))
180
-
181
- ####### input image
182
- img = eschernet_input_dict['imgs'] / 255.
183
- img[img[:, :, :, -1] == 0.] = bg_color
184
- # TODO batch image_transforms
185
- input_image = [image_transforms(Image.fromarray(np.uint8(im[:, :, :3] * 255.)).convert("RGB")) for im in img]
186
-
187
- ####### nvs pose
188
- pose_out = []
189
- for T_out_index in range(T_out):
190
- azimuth, polar = angles_out[T_out_index]
191
- if CaPE_TYPE == "4DoF":
192
- pose_out.append(torch.tensor([np.deg2rad(polar), np.deg2rad(azimuth), 0., 0.]))
193
- elif CaPE_TYPE == "6DoF":
194
- pose = look_at(origin=np.array([0, 0, 0]), target=xyzs[T_out_index], up=np.array([0, 0, 1]))
195
- pose = np.linalg.inv(pose)
196
- pose[2, :] *= -1
197
- pose_out.append(torch.from_numpy(get_pose(pose)))
198
-
199
-
200
-
201
- # [B, T, C, H, W]
202
- input_image = torch.stack(input_image, dim=0).to(device).to(weight_dtype).unsqueeze(0)
203
- # [B, T, 4]
204
- pose_in = np.stack(pose_in)
205
- pose_out = np.stack(pose_out)
206
-
207
- if CaPE_TYPE == "6DoF":
208
- pose_in_inv = np.linalg.inv(pose_in).transpose([0, 2, 1])
209
- pose_out_inv = np.linalg.inv(pose_out).transpose([0, 2, 1])
210
- pose_in_inv = torch.from_numpy(pose_in_inv).to(device).to(weight_dtype).unsqueeze(0)
211
- pose_out_inv = torch.from_numpy(pose_out_inv).to(device).to(weight_dtype).unsqueeze(0)
212
-
213
- pose_in = torch.from_numpy(pose_in).to(device).to(weight_dtype).unsqueeze(0)
214
- pose_out = torch.from_numpy(pose_out).to(device).to(weight_dtype).unsqueeze(0)
215
-
216
- input_image = einops.rearrange(input_image, "b t c h w -> (b t) c h w")
217
- assert T_in == input_image.shape[0]
218
- assert T_in == pose_in.shape[1]
219
- assert T_out == pose_out.shape[1]
220
-
221
- # run inference
222
- # pipeline.to(device)
223
- pipeline.enable_xformers_memory_efficient_attention()
224
- image = pipeline(input_imgs=input_image, prompt_imgs=input_image,
225
- poses=[[pose_out, pose_out_inv], [pose_in, pose_in_inv]],
226
- height=h, width=w, T_in=T_in, T_out=T_out,
227
- guidance_scale=guidance_scale, num_inference_steps=50, generator=generator,
228
- output_type="numpy").images
229
-
230
- # save output image
231
- output_dir = os.path.join(tmpdirname, "eschernet")
232
- if os.path.exists(output_dir):
233
- shutil.rmtree(output_dir)
234
- os.makedirs(output_dir, exist_ok=True)
235
- # # save to N imgs
236
- # for i in range(T_out):
237
- # imsave(os.path.join(output_dir, f'{i}.png'), (image[i] * 255).astype(np.uint8))
238
- # make a gif
239
- frames = [Image.fromarray((image[i] * 255).astype(np.uint8)) for i in range(T_out)]
240
- # frame_one = frames[0]
241
- # frame_one.save(os.path.join(output_dir, "output.gif"), format="GIF", append_images=frames,
242
- # save_all=True, duration=50, loop=1)
243
-
244
- # get a video
245
- video_path = os.path.join(output_dir, "output.mp4")
246
- imageio.mimwrite(video_path, np.stack(frames), fps=10, codec='h264')
247
-
248
-
249
- return video_path
250
-
251
- # TODO mesh it
252
- @spaces.GPU(duration=120)
253
- def make3d():
254
- pass
255
-
256
-
257
-
258
- ############################ Dust3r as Pose Estimation ############################
259
- from scipy.spatial.transform import Rotation
260
- import copy
261
-
262
- from dust3r.inference import inference
263
- from dust3r.model import AsymmetricCroCo3DStereo
264
- from dust3r.image_pairs import make_pairs
265
- from dust3r.utils.image import load_images, rgb
266
- from dust3r.utils.device import to_numpy
267
- from dust3r.viz import add_scene_cam, CAM_COLORS, OPENGL, pts3d_to_trimesh, cat_meshes
268
- from dust3r.cloud_opt import global_aligner, GlobalAlignerMode
269
- import math
270
-
271
- @spaces.GPU(duration=120)
272
- def _convert_scene_output_to_glb(outdir, imgs, pts3d, mask, focals, cams2world, cam_size=0.05,
273
- cam_color=None, as_pointcloud=False,
274
- transparent_cams=False, silent=False, same_focals=False):
275
- assert len(pts3d) == len(mask) <= len(imgs) <= len(cams2world)
276
- if not same_focals:
277
- assert (len(cams2world) == len(focals))
278
- pts3d = to_numpy(pts3d)
279
- imgs = to_numpy(imgs)
280
- focals = to_numpy(focals)
281
- cams2world = to_numpy(cams2world)
282
-
283
- scene = trimesh.Scene()
284
-
285
- # add axes
286
- scene.add_geometry(trimesh.creation.axis(axis_length=0.5, axis_radius=0.001))
287
-
288
- # full pointcloud
289
- if as_pointcloud:
290
- pts = np.concatenate([p[m] for p, m in zip(pts3d, mask)])
291
- col = np.concatenate([p[m] for p, m in zip(imgs, mask)])
292
- pct = trimesh.PointCloud(pts.reshape(-1, 3), colors=col.reshape(-1, 3))
293
- scene.add_geometry(pct)
294
- else:
295
- meshes = []
296
- for i in range(len(imgs)):
297
- meshes.append(pts3d_to_trimesh(imgs[i], pts3d[i], mask[i]))
298
- mesh = trimesh.Trimesh(**cat_meshes(meshes))
299
- scene.add_geometry(mesh)
300
-
301
- # add each camera
302
- for i, pose_c2w in enumerate(cams2world):
303
- if isinstance(cam_color, list):
304
- camera_edge_color = cam_color[i]
305
- else:
306
- camera_edge_color = cam_color or CAM_COLORS[i % len(CAM_COLORS)]
307
- if same_focals:
308
- focal = focals[0]
309
- else:
310
- focal = focals[i]
311
- add_scene_cam(scene, pose_c2w, camera_edge_color,
312
- None if transparent_cams else imgs[i], focal,
313
- imsize=imgs[i].shape[1::-1], screen_width=cam_size)
314
-
315
- rot = np.eye(4)
316
- rot[:3, :3] = Rotation.from_euler('y', np.deg2rad(180)).as_matrix()
317
- scene.apply_transform(np.linalg.inv(cams2world[0] @ OPENGL @ rot))
318
- outfile = os.path.join(outdir, 'scene.glb')
319
- if not silent:
320
- print('(exporting 3D scene to', outfile, ')')
321
- scene.export(file_obj=outfile)
322
- return outfile
323
-
324
- @spaces.GPU(duration=120)
325
- def get_3D_model_from_scene(outdir, silent, scene, min_conf_thr=3, as_pointcloud=False, mask_sky=False,
326
- clean_depth=False, transparent_cams=False, cam_size=0.05, same_focals=False):
327
- """
328
- extract 3D_model (glb file) from a reconstructed scene
329
- """
330
- if scene is None:
331
- return None
332
- # post processes
333
- if clean_depth:
334
- scene = scene.clean_pointcloud()
335
- if mask_sky:
336
- scene = scene.mask_sky()
337
-
338
- # get optimized values from scene
339
- rgbimg = to_numpy(scene.imgs)
340
- focals = to_numpy(scene.get_focals().cpu())
341
- # cams2world = to_numpy(scene.get_im_poses().cpu())
342
- # TODO use the vis_poses
343
- cams2world = scene.vis_poses
344
-
345
- # 3D pointcloud from depthmap, poses and intrinsics
346
- # pts3d = to_numpy(scene.get_pts3d())
347
- # TODO use the vis_poses
348
- pts3d = scene.vis_pts3d
349
- scene.min_conf_thr = float(scene.conf_trf(torch.tensor(min_conf_thr)))
350
- msk = to_numpy(scene.get_masks())
351
-
352
- return _convert_scene_output_to_glb(outdir, rgbimg, pts3d, msk, focals, cams2world, as_pointcloud=as_pointcloud,
353
- transparent_cams=transparent_cams, cam_size=cam_size, silent=silent,
354
- same_focals=same_focals)
355
-
356
- @spaces.GPU(duration=120)
357
- def get_reconstructed_scene(filelist, schedule, niter, min_conf_thr,
358
- as_pointcloud, mask_sky, clean_depth, transparent_cams, cam_size,
359
- scenegraph_type, winsize, refid, same_focals):
360
- """
361
- from a list of images, run dust3r inference, global aligner.
362
- then run get_3D_model_from_scene
363
- """
364
- silent = False
365
- image_size = 224
366
- # remove the directory if it already exists
367
- outdir = tmpdirname
368
- if os.path.exists(outdir):
369
- shutil.rmtree(outdir)
370
- os.makedirs(outdir, exist_ok=True)
371
- imgs, imgs_rgba = load_images(filelist, size=image_size, verbose=not silent, do_remove_background=True, rembg_session=rembg_session, predictor=predictor)
372
- if len(imgs) == 1:
373
- imgs = [imgs[0], copy.deepcopy(imgs[0])]
374
- imgs[1]['idx'] = 1
375
- if scenegraph_type == "swin":
376
- scenegraph_type = scenegraph_type + "-" + str(winsize)
377
- elif scenegraph_type == "oneref":
378
- scenegraph_type = scenegraph_type + "-" + str(refid)
379
-
380
- pairs = make_pairs(imgs, scene_graph=scenegraph_type, prefilter=None, symmetrize=True)
381
- output = inference(pairs, model, device, batch_size=1, verbose=not silent)
382
-
383
- mode = GlobalAlignerMode.PointCloudOptimizer if len(imgs) > 2 else GlobalAlignerMode.PairViewer
384
- scene = global_aligner(output, device=device, mode=mode, verbose=not silent, same_focals=same_focals)
385
- lr = 0.01
386
-
387
- if mode == GlobalAlignerMode.PointCloudOptimizer:
388
- loss = scene.compute_global_alignment(init='mst', niter=niter, schedule=schedule, lr=lr)
389
-
390
- # outfile = get_3D_model_from_scene(outdir, silent, scene, min_conf_thr, as_pointcloud, mask_sky,
391
- # clean_depth, transparent_cams, cam_size, same_focals=same_focals)
392
-
393
- # also return rgb, depth and confidence imgs
394
- # depth is normalized with the max value for all images
395
- # we apply the jet colormap on the confidence maps
396
- rgbimg = scene.imgs
397
- # depths = to_numpy(scene.get_depthmaps())
398
- # confs = to_numpy([c for c in scene.im_conf])
399
- # cmap = pl.get_cmap('jet')
400
- # depths_max = max([d.max() for d in depths])
401
- # depths = [d / depths_max for d in depths]
402
- # confs_max = max([d.max() for d in confs])
403
- # confs = [cmap(d / confs_max) for d in confs]
404
-
405
- imgs = []
406
- rgbaimg = []
407
- for i in range(len(rgbimg)): # when only 1 image, scene.imgs is two
408
- imgs.append(rgbimg[i])
409
- # imgs.append(rgb(depths[i]))
410
- # imgs.append(rgb(confs[i]))
411
- # imgs.append(imgs_rgba[i])
412
- if len(imgs_rgba) == 1 and i == 1:
413
- imgs.append(imgs_rgba[0])
414
- rgbaimg.append(np.array(imgs_rgba[0]))
415
- else:
416
- imgs.append(imgs_rgba[i])
417
- rgbaimg.append(np.array(imgs_rgba[i]))
418
-
419
- rgbaimg = np.array(rgbaimg)
420
-
421
- # for eschernet
422
- # get optimized values from scene
423
- rgbimg = to_numpy(scene.imgs)
424
- # focals = to_numpy(scene.get_focals().cpu())
425
- cams2world = to_numpy(scene.get_im_poses().cpu())
426
-
427
- # 3D pointcloud from depthmap, poses and intrinsics
428
- pts3d = to_numpy(scene.get_pts3d())
429
- scene.min_conf_thr = float(scene.conf_trf(torch.tensor(min_conf_thr)))
430
- msk = to_numpy(scene.get_masks())
431
- obj_mask = rgbaimg[..., 3] > 0
432
-
433
- # TODO set global coordinate system at the center of the scene, z-axis is up
434
- pts = np.concatenate([p[m] for p, m in zip(pts3d, msk)]).reshape(-1, 3)
435
- pts_obj = np.concatenate([p[m&obj_m] for p, m, obj_m in zip(pts3d, msk, obj_mask)]).reshape(-1, 3)
436
- centroid = np.mean(pts_obj, axis=0) # obj center
437
- obj2world = np.eye(4)
438
- obj2world[:3, 3] = -centroid # T_wc
439
-
440
- # get z_up vector
441
- # TODO fit a plane and get the normal vector
442
- pcd = o3d.geometry.PointCloud()
443
- pcd.points = o3d.utility.Vector3dVector(pts)
444
- plane_model, inliers = pcd.segment_plane(distance_threshold=0.01, ransac_n=3, num_iterations=1000)
445
- # get the normalised normal vector dim = 3
446
- normal = plane_model[:3] / np.linalg.norm(plane_model[:3])
447
- # the normal direction should be pointing up
448
- if normal[1] < 0:
449
- normal = -normal
450
- # print("normal", normal)
451
-
452
- # # TODO z-up 180
453
- # z_up = np.array([[1,0,0,0],
454
- # [0,-1,0,0],
455
- # [0,0,-1,0],
456
- # [0,0,0,1]])
457
- # obj2world = z_up @ obj2world
458
-
459
- # # avg the y
460
- # z_up_avg = cams2world[:,:3,3].sum(0) / np.linalg.norm(cams2world[:,:3,3].sum(0), axis=-1) # average direction in cam coordinate
461
- # # import pdb; pdb.set_trace()
462
- # rot_axis = np.cross(np.array([0, 0, 1]), z_up_avg)
463
- # rot_angle = np.arccos(np.dot(np.array([0, 0, 1]), z_up_avg) / (np.linalg.norm(z_up_avg) + 1e-6))
464
- # rot = Rotation.from_rotvec(rot_angle * rot_axis)
465
- # z_up = np.eye(4)
466
- # z_up[:3, :3] = rot.as_matrix()
467
-
468
- # get the rotation matrix from normal to z-axis
469
- z_axis = np.array([0, 0, 1])
470
- rot_axis = np.cross(normal, z_axis)
471
- rot_angle = np.arccos(np.dot(normal, z_axis) / (np.linalg.norm(normal) + 1e-6))
472
- rot = Rotation.from_rotvec(rot_angle * rot_axis)
473
- z_up = np.eye(4)
474
- z_up[:3, :3] = rot.as_matrix()
475
- obj2world = z_up @ obj2world
476
- # flip 180
477
- flip_rot = np.array([[1, 0, 0, 0],
478
- [0, -1, 0, 0],
479
- [0, 0, -1, 0],
480
- [0, 0, 0, 1]])
481
- obj2world = flip_rot @ obj2world
482
-
483
- # get new cams2obj
484
- cams2obj = []
485
- for i, cam2world in enumerate(cams2world):
486
- cams2obj.append(obj2world @ cam2world)
487
- # TODO transform pts3d to the new coordinate system
488
- for i, pts in enumerate(pts3d):
489
- pts3d[i] = (obj2world @ np.concatenate([pts, np.ones_like(pts)[..., :1]], axis=-1).transpose(2, 0, 1).reshape(4,
490
- -1)) \
491
- .reshape(4, pts.shape[0], pts.shape[1]).transpose(1, 2, 0)[..., :3]
492
- cams2world = np.array(cams2obj)
493
- # TODO rewrite hack
494
- scene.vis_poses = cams2world.copy()
495
- scene.vis_pts3d = pts3d.copy()
496
-
497
- # TODO save cams2world and rgbimg to each file, file name "000.npy", "001.npy", ... and "000.png", "001.png", ...
498
- for i, (img, img_rgba, pose) in enumerate(zip(rgbimg, rgbaimg, cams2world)):
499
- np.save(os.path.join(outdir, f"{i:03d}.npy"), pose)
500
- pl.imsave(os.path.join(outdir, f"{i:03d}.png"), img)
501
- pl.imsave(os.path.join(outdir, f"{i:03d}_rgba.png"), img_rgba)
502
- # np.save(os.path.join(outdir, f"{i:03d}_focal.npy"), to_numpy(focal))
503
- # save the min/max radius of camera
504
- radii = np.linalg.norm(np.linalg.inv(cams2world)[..., :3, 3])
505
- np.save(os.path.join(outdir, "radii.npy"), radii)
506
-
507
- eschernet_input = {"poses": cams2world,
508
- "radii": radii,
509
- "imgs": rgbaimg}
510
- print("got eschernet input")
511
- outfile = get_3D_model_from_scene(outdir, silent, scene, min_conf_thr, as_pointcloud, mask_sky,
512
- clean_depth, transparent_cams, cam_size, same_focals=same_focals)
513
-
514
- return scene, outfile, imgs, eschernet_input
515
-
516
-
517
- def set_scenegraph_options(inputfiles, winsize, refid, scenegraph_type):
518
- num_files = len(inputfiles) if inputfiles is not None else 1
519
- max_winsize = max(1, math.ceil((num_files - 1) / 2))
520
- if scenegraph_type == "swin":
521
- winsize = gr.Slider(label="Scene Graph: Window Size", value=max_winsize,
522
- minimum=1, maximum=max_winsize, step=1, visible=True)
523
- refid = gr.Slider(label="Scene Graph: Id", value=0, minimum=0,
524
- maximum=num_files - 1, step=1, visible=False)
525
- elif scenegraph_type == "oneref":
526
- winsize = gr.Slider(label="Scene Graph: Window Size", value=max_winsize,
527
- minimum=1, maximum=max_winsize, step=1, visible=False)
528
- refid = gr.Slider(label="Scene Graph: Id", value=0, minimum=0,
529
- maximum=num_files - 1, step=1, visible=True)
530
- else:
531
- winsize = gr.Slider(label="Scene Graph: Window Size", value=max_winsize,
532
- minimum=1, maximum=max_winsize, step=1, visible=False)
533
- refid = gr.Slider(label="Scene Graph: Id", value=0, minimum=0,
534
- maximum=num_files - 1, step=1, visible=False)
535
- return winsize, refid
536
-
537
-
538
- def get_examples(path):
539
- objs = []
540
- for obj_name in sorted(os.listdir(path)):
541
- img_files = []
542
- for img_file in sorted(os.listdir(os.path.join(path, obj_name))):
543
- img_files.append(os.path.join(path, obj_name, img_file))
544
- objs.append([img_files])
545
- print("objs = ", objs)
546
- return objs
547
-
548
- def preview_input(inputfiles):
549
- if inputfiles is None:
550
- return None
551
- imgs = []
552
- for img_file in inputfiles:
553
- img = pl.imread(img_file)
554
- imgs.append(img)
555
- return imgs
556
-
557
- # def main():
558
- # dustr init
559
- silent = False
560
- image_size = 224
561
- weights_path = 'checkpoints/DUSt3R_ViTLarge_BaseDecoder_224_linear.pth'
562
- model = AsymmetricCroCo3DStereo.from_pretrained(weights_path).to(device)
563
- # dust3r will write the 3D model inside tmpdirname
564
- # with tempfile.TemporaryDirectory(suffix='dust3r_gradio_demo') as tmpdirname:
565
- tmpdirname = os.path.join('logs/user_object')
566
- # remove the directory if it already exists
567
- if os.path.exists(tmpdirname):
568
- shutil.rmtree(tmpdirname)
569
- os.makedirs(tmpdirname, exist_ok=True)
570
- if not silent:
571
- print('Outputing stuff in', tmpdirname)
572
-
573
- _HEADER_ = '''
574
- <h2><b>[CVPR'24 Oral] EscherNet: A Generative Model for Scalable View Synthesis</b></h2>
575
- <b>EscherNet</b> is a multiview diffusion model for scalable generative any-to-any number/pose novel view synthesis.
576
-
577
- Image views are treated as tokens and the camera pose is encoded by <b>CaPE (Camera Positional Encoding)</b>.
578
-
579
- <a href='https://kxhit.github.io/EscherNet' target='_blank'>Project</a> <b>|</b>
580
- <a href='https://github.com/kxhit/EscherNet' target='_blank'>GitHub</a> <b>|</b>
581
- <a href='https://arxiv.org/abs/2402.03908' target='_blank'>ArXiv</a>
582
-
583
- <h4><b>Tips:</b></h4>
584
-
585
- - Our model can take <b>any number input images</b>. The more images you provide <b>(>=3 for this demo)</b>, the better the results.
586
-
587
- - Our model can generate <b>any number and any pose</b> novel views. You can specify the number of views you want to generate. In this demo, we set novel views on an <b>archemedian spiral</b> for simplicity.
588
-
589
- - The pose estimation is done using <a href='https://github.com/naver/dust3r' target='_blank'>DUSt3R</a>. You can also provide your own poses or get pose via any SLAM system.
590
-
591
- - The current checkpoint supports 6DoF camera pose and is trained on 30k 3D <a href='https://objaverse.allenai.org/' target='_blank'>Objaverse</a> objects for demo. Scaling is on the roadmap!
592
-
593
- '''
594
-
595
- _CITE_ = r"""
596
- 📝 <b>Citation</b>:
597
- ```bibtex
598
- @article{kong2024eschernet,
599
- title={EscherNet: A Generative Model for Scalable View Synthesis},
600
- author={Kong, Xin and Liu, Shikun and Lyu, Xiaoyang and Taher, Marwan and Qi, Xiaojuan and Davison, Andrew J},
601
- journal={arXiv preprint arXiv:2402.03908},
602
- year={2024}
603
- }
604
- ```
605
- """
606
-
607
- with gr.Blocks() as demo:
608
- gr.Markdown(_HEADER_)
609
- # mv_images = gr.State()
610
- scene = gr.State(None)
611
- eschernet_input = gr.State(None)
612
- with gr.Row(variant="panel"):
613
- # left column
614
- with gr.Column():
615
- with gr.Row():
616
- input_image = gr.File(file_count="multiple")
617
- with gr.Row():
618
- run_dust3r = gr.Button("Get Pose!", elem_id="dust3r")
619
- with gr.Row():
620
- processed_image = gr.Gallery(label='Input Views', columns=2, height="100%")
621
- with gr.Row(variant="panel"):
622
- # input examples under "examples" folder
623
- gr.Examples(
624
- examples=get_examples('examples'),
625
- inputs=[input_image],
626
- label="Examples (click one set of images to start!)",
627
- examples_per_page=20
628
- )
629
-
630
-
631
-
632
-
633
-
634
- # right column
635
- with gr.Column():
636
-
637
- with gr.Row():
638
- outmodel = gr.Model3D()
639
-
640
- with gr.Row():
641
- gr.Markdown('''
642
- <h4><b>Check if the pose (blue is axis is estimated z-up direction) and segmentation looks correct. If not, remove the incorrect images and try again.</b></h4>
643
- ''')
644
-
645
- with gr.Row():
646
- with gr.Group():
647
- do_remove_background = gr.Checkbox(
648
- label="Remove Background", value=True
649
- )
650
- sample_seed = gr.Number(value=42, label="Seed Value", precision=0)
651
-
652
- sample_steps = gr.Slider(
653
- label="Sample Steps",
654
- minimum=30,
655
- maximum=75,
656
- value=50,
657
- step=5,
658
- visible=False
659
- )
660
-
661
- nvs_num = gr.Slider(
662
- label="Number of Novel Views",
663
- minimum=5,
664
- maximum=100,
665
- value=30,
666
- step=1
667
- )
668
-
669
- nvs_mode = gr.Dropdown(["archimedes circle"], # "fixed 4 views", "fixed 8 views"
670
- value="archimedes circle", label="Novel Views Pose Chosen", visible=True)
671
-
672
- with gr.Row():
673
- gr.Markdown('''
674
- <h4><b>Choose your desired novel view poses number and generate! The more output images the longer it takes.</b></h4>
675
- ''')
676
-
677
- with gr.Row():
678
- submit = gr.Button("Submit", elem_id="eschernet", variant="primary")
679
-
680
- with gr.Row():
681
- with gr.Column():
682
- output_video = gr.Video(
683
- label="video", format="mp4",
684
- width=379,
685
- autoplay=True,
686
- interactive=False
687
- )
688
-
689
- with gr.Row():
690
- gr.Markdown('''
691
- <h4><b>The novel views are generated on an archimedean spiral (rotating around z-up axis and looking at the object center). You can download the video.</b></h4>
692
- ''')
693
-
694
- gr.Markdown(_CITE_)
695
-
696
- # set dust3r parameter invisible to be clean
697
- with gr.Column():
698
- with gr.Row():
699
- schedule = gr.Dropdown(["linear", "cosine"],
700
- value='linear', label="schedule", info="For global alignment!", visible=False)
701
- niter = gr.Number(value=300, precision=0, minimum=0, maximum=5000,
702
- label="num_iterations", info="For global alignment!", visible=False)
703
- scenegraph_type = gr.Dropdown(["complete", "swin", "oneref"],
704
- value='complete', label="Scenegraph",
705
- info="Define how to make pairs",
706
- interactive=True, visible=False)
707
- same_focals = gr.Checkbox(value=True, label="Focal", info="Use the same focal for all cameras", visible=False)
708
- winsize = gr.Slider(label="Scene Graph: Window Size", value=1,
709
- minimum=1, maximum=1, step=1, visible=False)
710
- refid = gr.Slider(label="Scene Graph: Id", value=0, minimum=0, maximum=0, step=1, visible=False)
711
-
712
- with gr.Row():
713
- # adjust the confidence threshold
714
- min_conf_thr = gr.Slider(label="min_conf_thr", value=3.0, minimum=1.0, maximum=20, step=0.1, visible=False)
715
- # adjust the camera size in the output pointcloud
716
- cam_size = gr.Slider(label="cam_size", value=0.05, minimum=0.01, maximum=0.5, step=0.001, visible=False)
717
- with gr.Row():
718
- as_pointcloud = gr.Checkbox(value=False, label="As pointcloud", visible=False)
719
- # two post process implemented
720
- mask_sky = gr.Checkbox(value=False, label="Mask sky", visible=False)
721
- clean_depth = gr.Checkbox(value=True, label="Clean-up depthmaps", visible=False)
722
- transparent_cams = gr.Checkbox(value=False, label="Transparent cameras", visible=False)
723
-
724
- # events
725
- # scenegraph_type.change(set_scenegraph_options,
726
- # inputs=[input_image, winsize, refid, scenegraph_type],
727
- # outputs=[winsize, refid])
728
- # min_conf_thr.release(fn=model_from_scene_fun,
729
- # inputs=[scene, min_conf_thr, as_pointcloud, mask_sky,
730
- # clean_depth, transparent_cams, cam_size, same_focals],
731
- # outputs=outmodel)
732
- # cam_size.change(fn=model_from_scene_fun,
733
- # inputs=[scene, min_conf_thr, as_pointcloud, mask_sky,
734
- # clean_depth, transparent_cams, cam_size, same_focals],
735
- # outputs=outmodel)
736
- # as_pointcloud.change(fn=model_from_scene_fun,
737
- # inputs=[scene, min_conf_thr, as_pointcloud, mask_sky,
738
- # clean_depth, transparent_cams, cam_size, same_focals],
739
- # outputs=outmodel)
740
- # mask_sky.change(fn=model_from_scene_fun,
741
- # inputs=[scene, min_conf_thr, as_pointcloud, mask_sky,
742
- # clean_depth, transparent_cams, cam_size, same_focals],
743
- # outputs=outmodel)
744
- # clean_depth.change(fn=model_from_scene_fun,
745
- # inputs=[scene, min_conf_thr, as_pointcloud, mask_sky,
746
- # clean_depth, transparent_cams, cam_size, same_focals],
747
- # outputs=outmodel)
748
- # transparent_cams.change(model_from_scene_fun,
749
- # inputs=[scene, min_conf_thr, as_pointcloud, mask_sky,
750
- # clean_depth, transparent_cams, cam_size, same_focals],
751
- # outputs=outmodel)
752
- # run_dust3r.click(fn=recon_fun,
753
- # inputs=[input_image, schedule, niter, min_conf_thr, as_pointcloud,
754
- # mask_sky, clean_depth, transparent_cams, cam_size,
755
- # scenegraph_type, winsize, refid, same_focals],
756
- # outputs=[scene, outmodel, processed_image, eschernet_input])
757
-
758
- # events
759
- input_image.change(set_scenegraph_options,
760
- inputs=[input_image, winsize, refid, scenegraph_type],
761
- outputs=[winsize, refid])
762
- run_dust3r.click(fn=get_reconstructed_scene,
763
- inputs=[input_image, schedule, niter, min_conf_thr, as_pointcloud,
764
- mask_sky, clean_depth, transparent_cams, cam_size,
765
- scenegraph_type, winsize, refid, same_focals],
766
- outputs=[scene, outmodel, processed_image, eschernet_input])
767
-
768
-
769
- # events
770
- input_image.change(fn=preview_input,
771
- inputs=[input_image],
772
- outputs=[processed_image])
773
-
774
- submit.click(fn=run_eschernet,
775
- inputs=[eschernet_input, sample_steps, sample_seed,
776
- nvs_num, nvs_mode],
777
- outputs=[output_video])
778
-
779
-
780
-
781
- # demo.queue(max_size=10)
782
- # demo.launch(share=True, server_name="0.0.0.0", server_port=None)
783
- demo.queue(max_size=10).launch()
784
-
785
- # if __name__ == '__main__':
786
- # main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
app_mini.py DELETED
@@ -1,773 +0,0 @@
1
- import spaces
2
- import torch
3
- print("cuda is available: ", torch.cuda.is_available())
4
-
5
- import gradio as gr
6
- import os
7
- import shutil
8
- import rembg
9
- import numpy as np
10
- import math
11
- import open3d as o3d
12
- from PIL import Image
13
- import torchvision
14
- import trimesh
15
- from skimage.io import imsave
16
- import imageio
17
- import cv2
18
- import matplotlib.pyplot as pl
19
- pl.ion()
20
-
21
- CaPE_TYPE = "6DoF"
22
- device = 'cuda' #if torch.cuda.is_available() else 'cpu'
23
- weight_dtype = torch.float16
24
- torch.backends.cuda.matmul.allow_tf32 = True # for gpu >= Ampere and pytorch >= 1.12
25
-
26
- # EscherNet
27
- # create angles in archimedean spiral with N steps
28
- def get_archimedean_spiral(sphere_radius, num_steps=250):
29
- # x-z plane, around upper y
30
- '''
31
- https://en.wikipedia.org/wiki/Spiral, section "Spherical spiral". c = a / pi
32
- '''
33
- a = 40
34
- r = sphere_radius
35
-
36
- translations = []
37
- angles = []
38
-
39
- # i = a / 2
40
- i = 0.01
41
- while i < a:
42
- theta = i / a * math.pi
43
- x = r * math.sin(theta) * math.cos(-i)
44
- z = r * math.sin(-theta + math.pi) * math.sin(-i)
45
- y = r * - math.cos(theta)
46
-
47
- # translations.append((x, y, z)) # origin
48
- translations.append((x, z, -y))
49
- angles.append([np.rad2deg(-i), np.rad2deg(theta)])
50
-
51
- # i += a / (2 * num_steps)
52
- i += a / (1 * num_steps)
53
-
54
- return np.array(translations), np.stack(angles)
55
-
56
- def look_at(origin, target, up):
57
- forward = (target - origin)
58
- forward = forward / np.linalg.norm(forward)
59
- right = np.cross(up, forward)
60
- right = right / np.linalg.norm(right)
61
- new_up = np.cross(forward, right)
62
- rotation_matrix = np.column_stack((right, new_up, -forward, target))
63
- matrix = np.row_stack((rotation_matrix, [0, 0, 0, 1]))
64
- return matrix
65
-
66
- import einops
67
- import sys
68
-
69
- sys.path.insert(0, "./6DoF/") # TODO change it when deploying
70
- # use the customized diffusers modules
71
- from diffusers import DDIMScheduler
72
- from dataset import get_pose
73
- from CN_encoder import CN_encoder
74
- from pipeline_zero1to3 import Zero1to3StableDiffusionPipeline
75
- from segment_anything import sam_model_registry, SamPredictor
76
-
77
- # import rembg
78
- from carvekit.api.high import HiInterface
79
-
80
-
81
- pretrained_model_name_or_path = "kxic/EscherNet_demo"
82
- resolution = 256
83
- h,w = resolution,resolution
84
- guidance_scale = 3.0
85
- radius = 2.2
86
- bg_color = [1., 1., 1., 1.]
87
- image_transforms = torchvision.transforms.Compose(
88
- [
89
- torchvision.transforms.Resize((resolution, resolution)), # 256, 256
90
- torchvision.transforms.ToTensor(),
91
- torchvision.transforms.Normalize([0.5], [0.5])
92
- ]
93
- )
94
- xyzs_spiral, angles_spiral = get_archimedean_spiral(1.5, 200)
95
- # only half toop
96
- xyzs_spiral = xyzs_spiral[:100]
97
- angles_spiral = angles_spiral[:100]
98
-
99
- # Init pipeline
100
- scheduler = DDIMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder="scheduler", revision=None)
101
- image_encoder = CN_encoder.from_pretrained(pretrained_model_name_or_path, subfolder="image_encoder", revision=None)
102
- pipeline = Zero1to3StableDiffusionPipeline.from_pretrained(
103
- pretrained_model_name_or_path,
104
- revision=None,
105
- scheduler=scheduler,
106
- image_encoder=None,
107
- safety_checker=None,
108
- feature_extractor=None,
109
- torch_dtype=weight_dtype,
110
- )
111
- pipeline.image_encoder = image_encoder.to(weight_dtype)
112
-
113
- pipeline.set_progress_bar_config(disable=False)
114
-
115
- pipeline = pipeline.to(device)
116
-
117
- # pipeline.enable_xformers_memory_efficient_attention()
118
- # enable vae slicing
119
- pipeline.enable_vae_slicing()
120
- # pipeline.enable_xformers_memory_efficient_attention()
121
-
122
-
123
- #### object segmentation
124
- def sam_init():
125
- sam_checkpoint = os.path.join("./sam_pt/sam_vit_h_4b8939.pth")
126
- if os.path.exists(sam_checkpoint) is False:
127
- os.system("wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth -P ./sam_pt/")
128
- model_type = "vit_h"
129
-
130
- sam = sam_model_registry[model_type](checkpoint=sam_checkpoint).to(device=device)
131
- predictor = SamPredictor(sam)
132
- return predictor
133
-
134
- def create_carvekit_interface():
135
- # Check doc strings for more information
136
- interface = HiInterface(object_type="object", # Can be "object" or "hairs-like".
137
- batch_size_seg=6,
138
- batch_size_matting=1,
139
- device="cpu",
140
- seg_mask_size=640, # Use 640 for Tracer B7 and 320 for U2Net
141
- matting_mask_size=2048,
142
- trimap_prob_threshold=231,
143
- trimap_dilation=30,
144
- trimap_erosion_iters=5,
145
- fp16=True)
146
-
147
- return interface
148
-
149
-
150
- # rembg_session = rembg.new_session()
151
- rembg_session = create_carvekit_interface()
152
- predictor = sam_init()
153
-
154
-
155
-
156
- @spaces.GPU(duration=120)
157
- def run_eschernet(eschernet_input_dict, sample_steps, sample_seed, nvs_num, nvs_mode):
158
- # set the random seed
159
- generator = torch.Generator(device=device).manual_seed(sample_seed)
160
- # generator = None
161
- T_out = nvs_num
162
- T_in = len(eschernet_input_dict['imgs'])
163
- ####### output pose
164
- # TODO choose T_out number of poses sequentially from the spiral
165
- xyzs = xyzs_spiral[::(len(xyzs_spiral) // T_out)]
166
- angles_out = angles_spiral[::(len(xyzs_spiral) // T_out)]
167
-
168
- ####### input's max radius for translation scaling
169
- radii = eschernet_input_dict['radii']
170
- max_t = np.max(radii)
171
- min_t = np.min(radii)
172
-
173
- ####### input pose
174
- pose_in = []
175
- for T_in_index in range(T_in):
176
- pose = get_pose(np.linalg.inv(eschernet_input_dict['poses'][T_in_index]))
177
- pose[1:3, :] *= -1 # coordinate system conversion
178
- pose[3, 3] *= 1. / max_t * radius # scale radius to [1.5, 2.2]
179
- pose_in.append(torch.from_numpy(pose))
180
-
181
- ####### input image
182
- img = eschernet_input_dict['imgs'] / 255.
183
- img[img[:, :, :, -1] == 0.] = bg_color
184
- # TODO batch image_transforms
185
- input_image = [image_transforms(Image.fromarray(np.uint8(im[:, :, :3] * 255.)).convert("RGB")) for im in img]
186
-
187
- ####### nvs pose
188
- pose_out = []
189
- for T_out_index in range(T_out):
190
- azimuth, polar = angles_out[T_out_index]
191
- if CaPE_TYPE == "4DoF":
192
- pose_out.append(torch.tensor([np.deg2rad(polar), np.deg2rad(azimuth), 0., 0.]))
193
- elif CaPE_TYPE == "6DoF":
194
- pose = look_at(origin=np.array([0, 0, 0]), target=xyzs[T_out_index], up=np.array([0, 0, 1]))
195
- pose = np.linalg.inv(pose)
196
- pose[2, :] *= -1
197
- pose_out.append(torch.from_numpy(get_pose(pose)))
198
-
199
-
200
-
201
- # [B, T, C, H, W]
202
- input_image = torch.stack(input_image, dim=0).to(device).to(weight_dtype).unsqueeze(0)
203
- # [B, T, 4]
204
- pose_in = np.stack(pose_in)
205
- pose_out = np.stack(pose_out)
206
-
207
- if CaPE_TYPE == "6DoF":
208
- pose_in_inv = np.linalg.inv(pose_in).transpose([0, 2, 1])
209
- pose_out_inv = np.linalg.inv(pose_out).transpose([0, 2, 1])
210
- pose_in_inv = torch.from_numpy(pose_in_inv).to(device).to(weight_dtype).unsqueeze(0)
211
- pose_out_inv = torch.from_numpy(pose_out_inv).to(device).to(weight_dtype).unsqueeze(0)
212
-
213
- pose_in = torch.from_numpy(pose_in).to(device).to(weight_dtype).unsqueeze(0)
214
- pose_out = torch.from_numpy(pose_out).to(device).to(weight_dtype).unsqueeze(0)
215
-
216
- input_image = einops.rearrange(input_image, "b t c h w -> (b t) c h w")
217
- assert T_in == input_image.shape[0]
218
- assert T_in == pose_in.shape[1]
219
- assert T_out == pose_out.shape[1]
220
-
221
- # run inference
222
- # pipeline.to(device)
223
- pipeline.enable_xformers_memory_efficient_attention()
224
- image = pipeline(input_imgs=input_image, prompt_imgs=input_image,
225
- poses=[[pose_out, pose_out_inv], [pose_in, pose_in_inv]],
226
- height=h, width=w, T_in=T_in, T_out=T_out,
227
- guidance_scale=guidance_scale, num_inference_steps=50, generator=generator,
228
- output_type="numpy").images
229
-
230
- # save output image
231
- output_dir = os.path.join(tmpdirname, "eschernet")
232
- if os.path.exists(output_dir):
233
- shutil.rmtree(output_dir)
234
- os.makedirs(output_dir, exist_ok=True)
235
- # # save to N imgs
236
- # for i in range(T_out):
237
- # imsave(os.path.join(output_dir, f'{i}.png'), (image[i] * 255).astype(np.uint8))
238
- # make a gif
239
- frames = [Image.fromarray((image[i] * 255).astype(np.uint8)) for i in range(T_out)]
240
- # frame_one = frames[0]
241
- # frame_one.save(os.path.join(output_dir, "output.gif"), format="GIF", append_images=frames,
242
- # save_all=True, duration=50, loop=1)
243
-
244
- # get a video
245
- video_path = os.path.join(output_dir, "output.mp4")
246
- imageio.mimwrite(video_path, np.stack(frames), fps=10, codec='h264')
247
-
248
-
249
- return video_path
250
-
251
- # TODO mesh it
252
- @spaces.GPU(duration=120)
253
- def make3d():
254
- pass
255
-
256
-
257
-
258
- ############################ Dust3r as Pose Estimation ############################
259
- from scipy.spatial.transform import Rotation
260
- import copy
261
-
262
- from dust3r.inference import inference
263
- from dust3r.model import AsymmetricCroCo3DStereo
264
- from dust3r.image_pairs import make_pairs
265
- from dust3r.utils.image import load_images, rgb
266
- from dust3r.utils.device import to_numpy
267
- from dust3r.viz import add_scene_cam, CAM_COLORS, OPENGL, pts3d_to_trimesh, cat_meshes
268
- from dust3r.cloud_opt import global_aligner, GlobalAlignerMode
269
- import math
270
-
271
- from mini_dust3r.api import OptimizedResult, inferece_dust3r, log_optimized_result
272
- from mini_dust3r.model import AsymmetricCroCo3DStereo
273
-
274
- # @spaces.GPU(duration=120)
275
- def _convert_scene_output_to_glb(outdir, imgs, pts3d, mask, focals, cams2world, cam_size=0.05,
276
- cam_color=None, as_pointcloud=False,
277
- transparent_cams=False, silent=False, same_focals=False):
278
- assert len(pts3d) == len(mask) <= len(imgs) <= len(cams2world)
279
- if not same_focals:
280
- assert (len(cams2world) == len(focals))
281
- pts3d = to_numpy(pts3d)
282
- imgs = to_numpy(imgs)
283
- focals = to_numpy(focals)
284
- cams2world = to_numpy(cams2world)
285
-
286
- scene = trimesh.Scene()
287
-
288
- # add axes
289
- scene.add_geometry(trimesh.creation.axis(axis_length=0.5, axis_radius=0.001))
290
-
291
- # full pointcloud
292
- if as_pointcloud:
293
- pts = np.concatenate([p[m] for p, m in zip(pts3d, mask)])
294
- col = np.concatenate([p[m] for p, m in zip(imgs, mask)])
295
- pct = trimesh.PointCloud(pts.reshape(-1, 3), colors=col.reshape(-1, 3))
296
- scene.add_geometry(pct)
297
- else:
298
- meshes = []
299
- for i in range(len(imgs)):
300
- meshes.append(pts3d_to_trimesh(imgs[i], pts3d[i], mask[i]))
301
- mesh = trimesh.Trimesh(**cat_meshes(meshes))
302
- scene.add_geometry(mesh)
303
-
304
- # add each camera
305
- for i, pose_c2w in enumerate(cams2world):
306
- if isinstance(cam_color, list):
307
- camera_edge_color = cam_color[i]
308
- else:
309
- camera_edge_color = cam_color or CAM_COLORS[i % len(CAM_COLORS)]
310
- if same_focals:
311
- focal = focals[0]
312
- else:
313
- focal = focals[i]
314
- add_scene_cam(scene, pose_c2w, camera_edge_color,
315
- None if transparent_cams else imgs[i], focal,
316
- imsize=imgs[i].shape[1::-1], screen_width=cam_size)
317
-
318
- rot = np.eye(4)
319
- rot[:3, :3] = Rotation.from_euler('y', np.deg2rad(180)).as_matrix()
320
- scene.apply_transform(np.linalg.inv(cams2world[0] @ OPENGL @ rot))
321
- outfile = os.path.join(outdir, 'scene.glb')
322
- if not silent:
323
- print('(exporting 3D scene to', outfile, ')')
324
- scene.export(file_obj=outfile)
325
- return outfile
326
-
327
- # @spaces.GPU(duration=120)
328
- def get_3D_model_from_scene(outdir, silent, scene, min_conf_thr=3, as_pointcloud=False, mask_sky=False,
329
- clean_depth=False, transparent_cams=False, cam_size=0.05, same_focals=False):
330
- """
331
- extract 3D_model (glb file) from a reconstructed scene
332
- """
333
- if scene is None:
334
- return None
335
- # post processes
336
- if clean_depth:
337
- scene = scene.clean_pointcloud()
338
- if mask_sky:
339
- scene = scene.mask_sky()
340
-
341
- # get optimized values from scene
342
- rgbimg = to_numpy(scene.imgs)
343
- focals = to_numpy(scene.get_focals().cpu())
344
- # cams2world = to_numpy(scene.get_im_poses().cpu())
345
- # TODO use the vis_poses
346
- cams2world = scene.vis_poses
347
-
348
- # 3D pointcloud from depthmap, poses and intrinsics
349
- # pts3d = to_numpy(scene.get_pts3d())
350
- # TODO use the vis_poses
351
- pts3d = scene.vis_pts3d
352
- scene.min_conf_thr = float(scene.conf_trf(torch.tensor(min_conf_thr)))
353
- msk = to_numpy(scene.get_masks())
354
-
355
- return _convert_scene_output_to_glb(outdir, rgbimg, pts3d, msk, focals, cams2world, as_pointcloud=as_pointcloud,
356
- transparent_cams=transparent_cams, cam_size=cam_size, silent=silent,
357
- same_focals=same_focals)
358
-
359
- @spaces.GPU(duration=120)
360
- def get_reconstructed_scene(filelist, schedule, niter, min_conf_thr,
361
- as_pointcloud, mask_sky, clean_depth, transparent_cams, cam_size,
362
- scenegraph_type, winsize, refid, same_focals):
363
- """
364
- from a list of images, run dust3r inference, global aligner.
365
- then run get_3D_model_from_scene
366
- """
367
- silent = False
368
- image_size = 224
369
- # remove the directory if it already exists
370
- outdir = tmpdirname
371
- if os.path.exists(outdir):
372
- shutil.rmtree(outdir)
373
- os.makedirs(outdir, exist_ok=True)
374
- # imgs, imgs_rgba = load_images(filelist, size=image_size, verbose=not silent, do_remove_background=True, rembg_session=rembg_session, predictor=predictor)
375
-
376
- optimized_results: OptimizedResult = inferece_dust3r(
377
- image_dir_or_list=filelist,
378
- model=model,
379
- device=device,
380
- batch_size=1,
381
- )
382
- rgbimg = optimized_results.rgb_hw3_list
383
- imgs_rgba = rgbimg
384
- cams2world = optimized_results.world_T_cam_b44
385
- pts3d = optimized_results.point_cloud
386
- pts_obj = pts3d
387
- outfile = os.path.join(outdir, 'scene.glb')
388
- # save point cloud trimesh.PointCloud to .ply
389
- pts3d.export(os.path.join(outdir, 'scene.glb'))
390
-
391
-
392
-
393
- # rgbimg = to_numpy(scene.imgs)
394
-
395
- imgs = []
396
- rgbaimg = []
397
- for i in range(len(rgbimg)): # when only 1 image, scene.imgs is two
398
- imgs.append(rgbimg[i])
399
- # imgs.append(rgb(depths[i]))
400
- # imgs.append(rgb(confs[i]))
401
- # imgs.append(imgs_rgba[i])
402
- if len(imgs_rgba) == 1 and i == 1:
403
- imgs.append(imgs_rgba[0])
404
- rgbaimg.append(np.array(imgs_rgba[0]))
405
- else:
406
- imgs.append(imgs_rgba[i])
407
- rgbaimg.append(np.array(imgs_rgba[i]))
408
-
409
- rgbaimg = np.array(rgbaimg)
410
-
411
- # for eschernet
412
- # cams2world = to_numpy(scene.get_im_poses().cpu())
413
- # pts3d = to_numpy(scene.get_pts3d())
414
- # scene.min_conf_thr = float(scene.conf_trf(torch.tensor(min_conf_thr)))
415
- # msk = to_numpy(scene.get_masks())
416
- # obj_mask = rgbaimg[..., 3] > 0
417
-
418
- # # TODO set global coordinate system at the center of the scene, z-axis is up
419
- # # pts = np.concatenate([p[m] for p, m in zip(pts3d, msk)]).reshape(-1, 3)
420
- # # pts_obj = np.concatenate([p[m&obj_m] for p, m, obj_m in zip(pts3d, msk, obj_mask)]).reshape(-1, 3)
421
- # centroid = np.mean(pts_obj, axis=0) # obj center
422
- # obj2world = np.eye(4)
423
- # obj2world[:3, 3] = -centroid # T_wc
424
- #
425
- # # get z_up vector
426
- # # TODO fit a plane and get the normal vector
427
- # pcd = o3d.geometry.PointCloud()
428
- # pcd.points = o3d.utility.Vector3dVector(pts)
429
- # plane_model, inliers = pcd.segment_plane(distance_threshold=0.01, ransac_n=3, num_iterations=1000)
430
- # # get the normalised normal vector dim = 3
431
- # normal = plane_model[:3] / np.linalg.norm(plane_model[:3])
432
- # # the normal direction should be pointing up
433
- # if normal[1] < 0:
434
- # normal = -normal
435
- # # print("normal", normal)
436
- #
437
- # # # TODO z-up 180
438
- # # z_up = np.array([[1,0,0,0],
439
- # # [0,-1,0,0],
440
- # # [0,0,-1,0],
441
- # # [0,0,0,1]])
442
- # # obj2world = z_up @ obj2world
443
- #
444
- # # # avg the y
445
- # # z_up_avg = cams2world[:,:3,3].sum(0) / np.linalg.norm(cams2world[:,:3,3].sum(0), axis=-1) # average direction in cam coordinate
446
- # # # import pdb; pdb.set_trace()
447
- # # rot_axis = np.cross(np.array([0, 0, 1]), z_up_avg)
448
- # # rot_angle = np.arccos(np.dot(np.array([0, 0, 1]), z_up_avg) / (np.linalg.norm(z_up_avg) + 1e-6))
449
- # # rot = Rotation.from_rotvec(rot_angle * rot_axis)
450
- # # z_up = np.eye(4)
451
- # # z_up[:3, :3] = rot.as_matrix()
452
- #
453
- # # get the rotation matrix from normal to z-axis
454
- # z_axis = np.array([0, 0, 1])
455
- # rot_axis = np.cross(normal, z_axis)
456
- # rot_angle = np.arccos(np.dot(normal, z_axis) / (np.linalg.norm(normal) + 1e-6))
457
- # rot = Rotation.from_rotvec(rot_angle * rot_axis)
458
- # z_up = np.eye(4)
459
- # z_up[:3, :3] = rot.as_matrix()
460
- # obj2world = z_up @ obj2world
461
- # # flip 180
462
- # flip_rot = np.array([[1, 0, 0, 0],
463
- # [0, -1, 0, 0],
464
- # [0, 0, -1, 0],
465
- # [0, 0, 0, 1]])
466
- # obj2world = flip_rot @ obj2world
467
- #
468
- # # get new cams2obj
469
- # cams2obj = []
470
- # for i, cam2world in enumerate(cams2world):
471
- # cams2obj.append(obj2world @ cam2world)
472
- # # TODO transform pts3d to the new coordinate system
473
- # for i, pts in enumerate(pts3d):
474
- # pts3d[i] = (obj2world @ np.concatenate([pts, np.ones_like(pts)[..., :1]], axis=-1).transpose(2, 0, 1).reshape(4,
475
- # -1)) \
476
- # .reshape(4, pts.shape[0], pts.shape[1]).transpose(1, 2, 0)[..., :3]
477
- # cams2world = np.array(cams2obj)
478
- # # TODO rewrite hack
479
- # scene.vis_poses = cams2world.copy()
480
- # scene.vis_pts3d = pts3d.copy()
481
-
482
- # # TODO save cams2world and rgbimg to each file, file name "000.npy", "001.npy", ... and "000.png", "001.png", ...
483
- # for i, (img, img_rgba, pose) in enumerate(zip(rgbimg, rgbaimg, cams2world)):
484
- # np.save(os.path.join(outdir, f"{i:03d}.npy"), pose)
485
- # pl.imsave(os.path.join(outdir, f"{i:03d}.png"), img)
486
- # pl.imsave(os.path.join(outdir, f"{i:03d}_rgba.png"), img_rgba)
487
- # # np.save(os.path.join(outdir, f"{i:03d}_focal.npy"), to_numpy(focal))
488
- # save the min/max radius of camera
489
- radii = np.linalg.norm(np.linalg.inv(cams2world)[..., :3, 3])
490
- # np.save(os.path.join(outdir, "radii.npy"), radii)
491
-
492
- eschernet_input = {"poses": cams2world,
493
- "radii": radii,
494
- "imgs": rgbaimg}
495
- print("got eschernet input")
496
- # outfile = get_3D_model_from_scene(outdir, silent, scene, min_conf_thr, as_pointcloud, mask_sky,
497
- # clean_depth, transparent_cams, cam_size, same_focals=same_focals)
498
-
499
- return scene, outfile, imgs, eschernet_input
500
-
501
-
502
-
503
-
504
- def set_scenegraph_options(inputfiles, winsize, refid, scenegraph_type):
505
- num_files = len(inputfiles) if inputfiles is not None else 1
506
- max_winsize = max(1, math.ceil((num_files - 1) / 2))
507
- if scenegraph_type == "swin":
508
- winsize = gr.Slider(label="Scene Graph: Window Size", value=max_winsize,
509
- minimum=1, maximum=max_winsize, step=1, visible=True)
510
- refid = gr.Slider(label="Scene Graph: Id", value=0, minimum=0,
511
- maximum=num_files - 1, step=1, visible=False)
512
- elif scenegraph_type == "oneref":
513
- winsize = gr.Slider(label="Scene Graph: Window Size", value=max_winsize,
514
- minimum=1, maximum=max_winsize, step=1, visible=False)
515
- refid = gr.Slider(label="Scene Graph: Id", value=0, minimum=0,
516
- maximum=num_files - 1, step=1, visible=True)
517
- else:
518
- winsize = gr.Slider(label="Scene Graph: Window Size", value=max_winsize,
519
- minimum=1, maximum=max_winsize, step=1, visible=False)
520
- refid = gr.Slider(label="Scene Graph: Id", value=0, minimum=0,
521
- maximum=num_files - 1, step=1, visible=False)
522
- return winsize, refid
523
-
524
-
525
- def get_examples(path):
526
- objs = []
527
- for obj_name in sorted(os.listdir(path)):
528
- img_files = []
529
- for img_file in sorted(os.listdir(os.path.join(path, obj_name))):
530
- img_files.append(os.path.join(path, obj_name, img_file))
531
- objs.append([img_files])
532
- print("objs = ", objs)
533
- return objs
534
-
535
- def preview_input(inputfiles):
536
- if inputfiles is None:
537
- return None
538
- imgs = []
539
- for img_file in inputfiles:
540
- img = pl.imread(img_file)
541
- imgs.append(img)
542
- return imgs
543
-
544
- # def main():
545
- # dustr init
546
- silent = False
547
- image_size = 224
548
- weights_path = 'checkpoints/DUSt3R_ViTLarge_BaseDecoder_224_linear.pth'
549
- model = AsymmetricCroCo3DStereo.from_pretrained(weights_path).to(device)
550
- # dust3r will write the 3D model inside tmpdirname
551
- # with tempfile.TemporaryDirectory(suffix='dust3r_gradio_demo') as tmpdirname:
552
- tmpdirname = os.path.join('logs/user_object')
553
- # remove the directory if it already exists
554
- if os.path.exists(tmpdirname):
555
- shutil.rmtree(tmpdirname)
556
- os.makedirs(tmpdirname, exist_ok=True)
557
- if not silent:
558
- print('Outputing stuff in', tmpdirname)
559
-
560
- _HEADER_ = '''
561
- <h2><b>[CVPR'24 Oral] EscherNet: A Generative Model for Scalable View Synthesis</b></h2>
562
- <b>EscherNet</b> is a multiview diffusion model for scalable generative any-to-any number/pose novel view synthesis.
563
-
564
- Image views are treated as tokens and the camera pose is encoded by <b>CaPE (Camera Positional Encoding)</b>.
565
-
566
- <a href='https://kxhit.github.io/EscherNet' target='_blank'>Project</a> <b>|</b>
567
- <a href='https://github.com/kxhit/EscherNet' target='_blank'>GitHub</a> <b>|</b>
568
- <a href='https://arxiv.org/abs/2402.03908' target='_blank'>ArXiv</a>
569
-
570
- <h4><b>Tips:</b></h4>
571
-
572
- - Our model can take <b>any number input images</b>. The more images you provide <b>(>=3 for this demo)</b>, the better the results.
573
-
574
- - Our model can generate <b>any number and any pose</b> novel views. You can specify the number of views you want to generate. In this demo, we set novel views on an <b>archemedian spiral</b> for simplicity.
575
-
576
- - The pose estimation is done using <a href='https://github.com/naver/dust3r' target='_blank'>DUSt3R</a>. You can also provide your own poses or get pose via any SLAM system.
577
-
578
- - The current checkpoint supports 6DoF camera pose and is trained on 30k 3D <a href='https://objaverse.allenai.org/' target='_blank'>Objaverse</a> objects for demo. Scaling is on the roadmap!
579
-
580
- '''
581
-
582
- _CITE_ = r"""
583
- 📝 <b>Citation</b>:
584
- ```bibtex
585
- @article{kong2024eschernet,
586
- title={EscherNet: A Generative Model for Scalable View Synthesis},
587
- author={Kong, Xin and Liu, Shikun and Lyu, Xiaoyang and Taher, Marwan and Qi, Xiaojuan and Davison, Andrew J},
588
- journal={arXiv preprint arXiv:2402.03908},
589
- year={2024}
590
- }
591
- ```
592
- """
593
-
594
- with gr.Blocks() as demo:
595
- gr.Markdown(_HEADER_)
596
- # mv_images = gr.State()
597
- scene = gr.State(None)
598
- eschernet_input = gr.State(None)
599
- with gr.Row(variant="panel"):
600
- # left column
601
- with gr.Column():
602
- with gr.Row():
603
- input_image = gr.File(file_count="multiple")
604
- with gr.Row():
605
- run_dust3r = gr.Button("Get Pose!", elem_id="dust3r")
606
- with gr.Row():
607
- processed_image = gr.Gallery(label='Input Views', columns=2, height="100%")
608
- with gr.Row(variant="panel"):
609
- # input examples under "examples" folder
610
- gr.Examples(
611
- examples=get_examples('examples'),
612
- inputs=[input_image],
613
- label="Examples (click one set of images to start!)",
614
- examples_per_page=20
615
- )
616
-
617
-
618
-
619
-
620
-
621
- # right column
622
- with gr.Column():
623
-
624
- with gr.Row():
625
- outmodel = gr.Model3D()
626
-
627
- with gr.Row():
628
- gr.Markdown('''
629
- <h4><b>Check if the pose (blue is axis is estimated z-up direction) and segmentation looks correct. If not, remove the incorrect images and try again.</b></h4>
630
- ''')
631
-
632
- with gr.Row():
633
- with gr.Group():
634
- do_remove_background = gr.Checkbox(
635
- label="Remove Background", value=True
636
- )
637
- sample_seed = gr.Number(value=42, label="Seed Value", precision=0)
638
-
639
- sample_steps = gr.Slider(
640
- label="Sample Steps",
641
- minimum=30,
642
- maximum=75,
643
- value=50,
644
- step=5,
645
- visible=False
646
- )
647
-
648
- nvs_num = gr.Slider(
649
- label="Number of Novel Views",
650
- minimum=5,
651
- maximum=100,
652
- value=30,
653
- step=1
654
- )
655
-
656
- nvs_mode = gr.Dropdown(["archimedes circle"], # "fixed 4 views", "fixed 8 views"
657
- value="archimedes circle", label="Novel Views Pose Chosen", visible=True)
658
-
659
- with gr.Row():
660
- gr.Markdown('''
661
- <h4><b>Choose your desired novel view poses number and generate! The more output images the longer it takes.</b></h4>
662
- ''')
663
-
664
- with gr.Row():
665
- submit = gr.Button("Submit", elem_id="eschernet", variant="primary")
666
-
667
- with gr.Row():
668
- with gr.Column():
669
- output_video = gr.Video(
670
- label="video", format="mp4",
671
- width=379,
672
- autoplay=True,
673
- interactive=False
674
- )
675
-
676
- with gr.Row():
677
- gr.Markdown('''
678
- <h4><b>The novel views are generated on an archimedean spiral (rotating around z-up axis and looking at the object center). You can download the video.</b></h4>
679
- ''')
680
-
681
- gr.Markdown(_CITE_)
682
-
683
- # set dust3r parameter invisible to be clean
684
- with gr.Column():
685
- with gr.Row():
686
- schedule = gr.Dropdown(["linear", "cosine"],
687
- value='linear', label="schedule", info="For global alignment!", visible=False)
688
- niter = gr.Number(value=300, precision=0, minimum=0, maximum=5000,
689
- label="num_iterations", info="For global alignment!", visible=False)
690
- scenegraph_type = gr.Dropdown(["complete", "swin", "oneref"],
691
- value='complete', label="Scenegraph",
692
- info="Define how to make pairs",
693
- interactive=True, visible=False)
694
- same_focals = gr.Checkbox(value=True, label="Focal", info="Use the same focal for all cameras", visible=False)
695
- winsize = gr.Slider(label="Scene Graph: Window Size", value=1,
696
- minimum=1, maximum=1, step=1, visible=False)
697
- refid = gr.Slider(label="Scene Graph: Id", value=0, minimum=0, maximum=0, step=1, visible=False)
698
-
699
- with gr.Row():
700
- # adjust the confidence threshold
701
- min_conf_thr = gr.Slider(label="min_conf_thr", value=3.0, minimum=1.0, maximum=20, step=0.1, visible=False)
702
- # adjust the camera size in the output pointcloud
703
- cam_size = gr.Slider(label="cam_size", value=0.05, minimum=0.01, maximum=0.5, step=0.001, visible=False)
704
- with gr.Row():
705
- as_pointcloud = gr.Checkbox(value=False, label="As pointcloud", visible=False)
706
- # two post process implemented
707
- mask_sky = gr.Checkbox(value=False, label="Mask sky", visible=False)
708
- clean_depth = gr.Checkbox(value=True, label="Clean-up depthmaps", visible=False)
709
- transparent_cams = gr.Checkbox(value=False, label="Transparent cameras", visible=False)
710
-
711
- # events
712
- # scenegraph_type.change(set_scenegraph_options,
713
- # inputs=[input_image, winsize, refid, scenegraph_type],
714
- # outputs=[winsize, refid])
715
- # min_conf_thr.release(fn=model_from_scene_fun,
716
- # inputs=[scene, min_conf_thr, as_pointcloud, mask_sky,
717
- # clean_depth, transparent_cams, cam_size, same_focals],
718
- # outputs=outmodel)
719
- # cam_size.change(fn=model_from_scene_fun,
720
- # inputs=[scene, min_conf_thr, as_pointcloud, mask_sky,
721
- # clean_depth, transparent_cams, cam_size, same_focals],
722
- # outputs=outmodel)
723
- # as_pointcloud.change(fn=model_from_scene_fun,
724
- # inputs=[scene, min_conf_thr, as_pointcloud, mask_sky,
725
- # clean_depth, transparent_cams, cam_size, same_focals],
726
- # outputs=outmodel)
727
- # mask_sky.change(fn=model_from_scene_fun,
728
- # inputs=[scene, min_conf_thr, as_pointcloud, mask_sky,
729
- # clean_depth, transparent_cams, cam_size, same_focals],
730
- # outputs=outmodel)
731
- # clean_depth.change(fn=model_from_scene_fun,
732
- # inputs=[scene, min_conf_thr, as_pointcloud, mask_sky,
733
- # clean_depth, transparent_cams, cam_size, same_focals],
734
- # outputs=outmodel)
735
- # transparent_cams.change(model_from_scene_fun,
736
- # inputs=[scene, min_conf_thr, as_pointcloud, mask_sky,
737
- # clean_depth, transparent_cams, cam_size, same_focals],
738
- # outputs=outmodel)
739
- # run_dust3r.click(fn=recon_fun,
740
- # inputs=[input_image, schedule, niter, min_conf_thr, as_pointcloud,
741
- # mask_sky, clean_depth, transparent_cams, cam_size,
742
- # scenegraph_type, winsize, refid, same_focals],
743
- # outputs=[scene, outmodel, processed_image, eschernet_input])
744
-
745
- # events
746
- input_image.change(set_scenegraph_options,
747
- inputs=[input_image, winsize, refid, scenegraph_type],
748
- outputs=[winsize, refid])
749
- run_dust3r.click(fn=get_reconstructed_scene,
750
- inputs=[input_image, schedule, niter, min_conf_thr, as_pointcloud,
751
- mask_sky, clean_depth, transparent_cams, cam_size,
752
- scenegraph_type, winsize, refid, same_focals],
753
- outputs=[scene, outmodel, processed_image, eschernet_input])
754
-
755
-
756
- # events
757
- input_image.change(fn=preview_input,
758
- inputs=[input_image],
759
- outputs=[processed_image])
760
-
761
- submit.click(fn=run_eschernet,
762
- inputs=[eschernet_input, sample_steps, sample_seed,
763
- nvs_num, nvs_mode],
764
- outputs=[output_video])
765
-
766
-
767
-
768
- # demo.queue(max_size=10)
769
- # demo.launch(share=True, server_name="0.0.0.0", server_port=None)
770
- demo.queue(max_size=10).launch()
771
-
772
- # if __name__ == '__main__':
773
- # main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mini_dust3r/__init__.py DELETED
File without changes
mini_dust3r/api/__init__.py DELETED
@@ -1,3 +0,0 @@
1
- from .inference import inferece_dust3r, OptimizedResult, log_optimized_result
2
-
3
- __all__ = ["inferece_dust3r", "OptimizedResult", "log_optimized_result"]
 
 
 
 
mini_dust3r/api/inference.py DELETED
@@ -1,225 +0,0 @@
1
- import rerun as rr
2
- from pathlib import Path
3
- from typing import Literal
4
- import copy
5
- import torch
6
- import numpy as np
7
- from jaxtyping import Float32, Bool
8
- import trimesh
9
- from tqdm import tqdm
10
-
11
- from mini_dust3r.utils.image import load_images, ImageDict
12
- from mini_dust3r.inference import inference, Dust3rResult
13
- from mini_dust3r.model import AsymmetricCroCo3DStereo
14
- from mini_dust3r.image_pairs import make_pairs
15
- from mini_dust3r.cloud_opt import global_aligner, GlobalAlignerMode
16
- from mini_dust3r.cloud_opt.base_opt import BasePCOptimizer
17
- from mini_dust3r.viz import pts3d_to_trimesh, cat_meshes
18
- from dataclasses import dataclass
19
-
20
-
21
- @dataclass
22
- class OptimizedResult:
23
- K_b33: Float32[np.ndarray, "b 3 3"]
24
- world_T_cam_b44: Float32[np.ndarray, "b 4 4"]
25
- rgb_hw3_list: list[Float32[np.ndarray, "h w 3"]]
26
- depth_hw_list: list[Float32[np.ndarray, "h w"]]
27
- conf_hw_list: list[Float32[np.ndarray, "h w"]]
28
- masks_list: Bool[np.ndarray, "h w"]
29
- point_cloud: trimesh.PointCloud
30
- mesh: trimesh.Trimesh
31
-
32
-
33
- def log_optimized_result(
34
- optimized_result: OptimizedResult, parent_log_path: Path
35
- ) -> None:
36
- rr.log(f"{parent_log_path}", rr.ViewCoordinates.RDF, timeless=True)
37
- # log pointcloud
38
- rr.log(
39
- f"{parent_log_path}/pointcloud",
40
- rr.Points3D(
41
- positions=optimized_result.point_cloud.vertices,
42
- colors=optimized_result.point_cloud.colors,
43
- ),
44
- timeless=True,
45
- )
46
-
47
- mesh = optimized_result.mesh
48
- rr.log(
49
- f"{parent_log_path}/mesh",
50
- rr.Mesh3D(
51
- vertex_positions=mesh.vertices,
52
- vertex_colors=mesh.visual.vertex_colors,
53
- indices=mesh.faces,
54
- ),
55
- timeless=True,
56
- )
57
- pbar = tqdm(
58
- zip(
59
- optimized_result.rgb_hw3_list,
60
- optimized_result.depth_hw_list,
61
- optimized_result.K_b33,
62
- optimized_result.world_T_cam_b44,
63
- ),
64
- total=len(optimized_result.rgb_hw3_list),
65
- )
66
- for i, (rgb_hw3, depth_hw, k_33, world_T_cam_44) in enumerate(pbar):
67
- camera_log_path = f"{parent_log_path}/camera_{i}"
68
- height, width, _ = rgb_hw3.shape
69
- rr.log(
70
- f"{camera_log_path}",
71
- rr.Transform3D(
72
- translation=world_T_cam_44[:3, 3],
73
- mat3x3=world_T_cam_44[:3, :3],
74
- from_parent=False,
75
- ),
76
- )
77
- rr.log(
78
- f"{camera_log_path}/pinhole",
79
- rr.Pinhole(
80
- image_from_camera=k_33,
81
- height=height,
82
- width=width,
83
- camera_xyz=rr.ViewCoordinates.RDF,
84
- ),
85
- )
86
- rr.log(
87
- f"{camera_log_path}/pinhole/rgb",
88
- rr.Image(rgb_hw3),
89
- )
90
- rr.log(
91
- f"{camera_log_path}/pinhole/depth",
92
- rr.DepthImage(depth_hw),
93
- )
94
-
95
-
96
- def scene_to_results(scene: BasePCOptimizer, min_conf_thr: int) -> OptimizedResult:
97
- ### get camera parameters K and T
98
- K_b33: Float32[np.ndarray, "b 3 3"] = scene.get_intrinsics().numpy(force=True)
99
- world_T_cam_b44: Float32[np.ndarray, "b 4 4"] = scene.get_im_poses().numpy(
100
- force=True
101
- )
102
- ### image, confidence, depths
103
- rgb_hw3_list: list[Float32[np.ndarray, "h w 3"]] = scene.imgs
104
- depth_hw_list: list[Float32[np.ndarray, "h w"]] = [
105
- depth.numpy(force=True) for depth in scene.get_depthmaps()
106
- ]
107
- # normalized depth
108
- # depth_hw_list = [depth_hw / depth_hw.max() for depth_hw in depth_hw_list]
109
-
110
- conf_hw_list: list[Float32[np.ndarray, "h w"]] = [
111
- c.numpy(force=True) for c in scene.im_conf
112
- ]
113
- # normalize confidence
114
- # conf_hw_list = [conf_hw / conf_hw.max() for conf_hw in conf_hw_list]
115
-
116
- # point cloud, mesh
117
- pts3d_list: list[Float32[np.ndarray, "h w 3"]] = [
118
- pt3d.numpy(force=True) for pt3d in scene.get_pts3d()
119
- ]
120
- # get log confidence
121
- log_conf_trf: Float32[torch.Tensor, ""] = scene.conf_trf(torch.tensor(min_conf_thr))
122
- # set the minimum confidence threshold
123
- scene.min_conf_thr = float(log_conf_trf)
124
- masks_list: Bool[np.ndarray, "h w"] = [
125
- mask.numpy(force=True) for mask in scene.get_masks()
126
- ]
127
-
128
- point_cloud: Float32[np.ndarray, "num_points 3"] = np.concatenate(
129
- [p[m] for p, m in zip(pts3d_list, masks_list)]
130
- )
131
- colors: Float32[np.ndarray, "num_points 3"] = np.concatenate(
132
- [p[m] for p, m in zip(rgb_hw3_list, masks_list)]
133
- )
134
- point_cloud = trimesh.PointCloud(
135
- point_cloud.reshape(-1, 3), colors=colors.reshape(-1, 3)
136
- )
137
-
138
- meshes = []
139
- pbar = tqdm(zip(rgb_hw3_list, pts3d_list, masks_list), total=len(rgb_hw3_list))
140
- for rgb_hw3, pts3d, mask in pbar:
141
- meshes.append(pts3d_to_trimesh(rgb_hw3, pts3d, mask))
142
-
143
- mesh = trimesh.Trimesh(**cat_meshes(meshes))
144
- optimised_result = OptimizedResult(
145
- K_b33=K_b33,
146
- world_T_cam_b44=world_T_cam_b44,
147
- rgb_hw3_list=rgb_hw3_list,
148
- depth_hw_list=depth_hw_list,
149
- conf_hw_list=conf_hw_list,
150
- masks_list=masks_list,
151
- point_cloud=point_cloud,
152
- mesh=mesh,
153
- )
154
- return optimised_result
155
-
156
-
157
- def inferece_dust3r(
158
- image_dir_or_list: Path | list[Path],
159
- model: AsymmetricCroCo3DStereo,
160
- device: Literal["cpu", "cuda", "mps"],
161
- batch_size: int = 1,
162
- image_size: Literal[224, 512] = 512,
163
- niter: int = 100,
164
- schedule: Literal["linear", "cosine"] = "linear",
165
- min_conf_thr: float = 10,
166
- ) -> OptimizedResult:
167
- """
168
- Perform inference using the Dust3r algorithm.
169
-
170
- Args:
171
- image_dir_or_list (Union[Path, List[Path]]): Path to the directory containing images or a list of image paths.
172
- model (AsymmetricCroCo3DStereo): The Dust3r model to use for inference.
173
- device (Literal["cpu", "cuda", "mps"]): The device to use for inference ("cpu", "cuda", or "mps").
174
- batch_size (int, optional): The batch size for inference. Defaults to 1.
175
- image_size (Literal[224, 512], optional): The size of the input images. Defaults to 512.
176
- niter (int, optional): The number of iterations for the global alignment optimization. Defaults to 100.
177
- schedule (Literal["linear", "cosine"], optional): The learning rate schedule for the global alignment optimization. Defaults to "linear".
178
- min_conf_thr (float, optional): The minimum confidence threshold for the optimized result. Defaults to 10.
179
-
180
- Returns:
181
- OptimizedResult: The optimized result containing the RGB, depth, and confidence images.
182
-
183
- Raises:
184
- ValueError: If `image_dir_or_list` is neither a list of paths nor a path.
185
- """
186
- if isinstance(image_dir_or_list, list):
187
- imgs: list[ImageDict] = load_images(
188
- folder_or_list=image_dir_or_list, size=image_size, verbose=True
189
- )
190
- elif isinstance(image_dir_or_list, Path):
191
- imgs: list[ImageDict] = load_images(
192
- folder_or_list=str(image_dir_or_list), size=image_size, verbose=True
193
- )
194
- else:
195
- raise ValueError("image_dir_or_list should be a list of paths or a path")
196
-
197
- # if only one image was loaded, duplicate it to feed into stereo network
198
- if len(imgs) == 1:
199
- imgs = [imgs[0], copy.deepcopy(imgs[0])]
200
- imgs[1]["idx"] = 1
201
-
202
- pairs: list[tuple[ImageDict, ImageDict]] = make_pairs(
203
- imgs, scene_graph="complete", prefilter=None, symmetrize=True
204
- )
205
- output: Dust3rResult = inference(pairs, model, device, batch_size=batch_size)
206
-
207
- mode = (
208
- GlobalAlignerMode.PointCloudOptimizer
209
- if len(imgs) > 2
210
- else GlobalAlignerMode.PairViewer
211
- )
212
- scene: BasePCOptimizer = global_aligner(
213
- dust3r_output=output, device=device, mode=mode
214
- )
215
-
216
- lr = 0.01
217
-
218
- if mode == GlobalAlignerMode.PointCloudOptimizer:
219
- loss = scene.compute_global_alignment(
220
- init="mst", niter=niter, schedule=schedule, lr=lr
221
- )
222
-
223
- # get the optimized result from the scene
224
- optimized_result: OptimizedResult = scene_to_results(scene, min_conf_thr)
225
- return optimized_result
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mini_dust3r/cloud_opt/__init__.py DELETED
@@ -1,44 +0,0 @@
1
- # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2
- # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
- #
4
- # --------------------------------------------------------
5
- # global alignment optimization wrapper function
6
- # --------------------------------------------------------
7
- from enum import Enum
8
-
9
- from .optimizer import PointCloudOptimizer
10
- from .modular_optimizer import ModularPointCloudOptimizer
11
- from .pair_viewer import PairViewer
12
- from mini_dust3r.inference import Dust3rResult
13
- from typing import Literal
14
-
15
-
16
- class GlobalAlignerMode(Enum):
17
- PointCloudOptimizer = "PointCloudOptimizer"
18
- ModularPointCloudOptimizer = "ModularPointCloudOptimizer"
19
- PairViewer = "PairViewer"
20
-
21
-
22
- def global_aligner(
23
- dust3r_output: Dust3rResult,
24
- device: Literal["cpu", "cuda", "mps"],
25
- mode: GlobalAlignerMode = GlobalAlignerMode.PointCloudOptimizer,
26
- **optim_kw,
27
- ):
28
- # extract all inputs
29
- view1, view2, pred1, pred2 = [
30
- dust3r_output[k] for k in "view1 view2 pred1 pred2".split()
31
- ]
32
- # build the optimizer
33
- if mode == GlobalAlignerMode.PointCloudOptimizer:
34
- net = PointCloudOptimizer(view1, view2, pred1, pred2, **optim_kw).to(device)
35
- elif mode == GlobalAlignerMode.ModularPointCloudOptimizer:
36
- net = ModularPointCloudOptimizer(view1, view2, pred1, pred2, **optim_kw).to(
37
- device
38
- )
39
- elif mode == GlobalAlignerMode.PairViewer:
40
- net = PairViewer(view1, view2, pred1, pred2, **optim_kw).to(device)
41
- else:
42
- raise NotImplementedError(f"Unknown mode {mode}")
43
-
44
- return net
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mini_dust3r/cloud_opt/base_opt.py DELETED
@@ -1,390 +0,0 @@
1
- # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2
- # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
- #
4
- # --------------------------------------------------------
5
- # Base class for the global alignement procedure
6
- # --------------------------------------------------------
7
- from copy import deepcopy
8
-
9
- import numpy as np
10
- import torch
11
- import torch.nn as nn
12
- import roma
13
- from copy import deepcopy
14
- import tqdm
15
-
16
- from mini_dust3r.utils.geometry import inv, geotrf
17
- from mini_dust3r.utils.device import to_numpy
18
- from mini_dust3r.utils.image import rgb
19
- from mini_dust3r.viz import SceneViz, segment_sky, auto_cam_size
20
- from mini_dust3r.optim_factory import adjust_learning_rate_by_lr
21
-
22
- from mini_dust3r.cloud_opt.commons import (edge_str, ALL_DISTS, NoGradParamDict, get_imshapes, signed_expm1, signed_log1p,
23
- cosine_schedule, linear_schedule, get_conf_trf)
24
- import mini_dust3r.cloud_opt.init_im_poses as init_fun
25
-
26
-
27
- class BasePCOptimizer (nn.Module):
28
- """ Optimize a global scene, given a list of pairwise observations.
29
- Graph node: images
30
- Graph edges: observations = (pred1, pred2)
31
- """
32
-
33
- def __init__(self, *args, **kwargs):
34
- if len(args) == 1 and len(kwargs) == 0:
35
- other = deepcopy(args[0])
36
- attrs = '''edges is_symmetrized dist n_imgs pred_i pred_j imshapes
37
- min_conf_thr conf_thr conf_i conf_j im_conf
38
- base_scale norm_pw_scale POSE_DIM pw_poses
39
- pw_adaptors pw_adaptors has_im_poses rand_pose imgs verbose'''.split()
40
- self.__dict__.update({k: other[k] for k in attrs})
41
- else:
42
- self._init_from_views(*args, **kwargs)
43
-
44
- def _init_from_views(self, view1, view2, pred1, pred2,
45
- dist='l1',
46
- conf='log',
47
- min_conf_thr=3,
48
- base_scale=0.5,
49
- allow_pw_adaptors=False,
50
- pw_break=20,
51
- rand_pose=torch.randn,
52
- iterationsCount=None,
53
- verbose=True):
54
- super().__init__()
55
- if not isinstance(view1['idx'], list):
56
- view1['idx'] = view1['idx'].tolist()
57
- if not isinstance(view2['idx'], list):
58
- view2['idx'] = view2['idx'].tolist()
59
- self.edges = [(int(i), int(j)) for i, j in zip(view1['idx'], view2['idx'])]
60
- self.is_symmetrized = set(self.edges) == {(j, i) for i, j in self.edges}
61
- self.dist = ALL_DISTS[dist]
62
- self.verbose = verbose
63
-
64
- self.n_imgs = self._check_edges()
65
-
66
- # input data
67
- pred1_pts = pred1['pts3d']
68
- pred2_pts = pred2['pts3d_in_other_view']
69
- self.pred_i = NoGradParamDict({ij: pred1_pts[n] for n, ij in enumerate(self.str_edges)})
70
- self.pred_j = NoGradParamDict({ij: pred2_pts[n] for n, ij in enumerate(self.str_edges)})
71
- self.imshapes = get_imshapes(self.edges, pred1_pts, pred2_pts)
72
-
73
- # work in log-scale with conf
74
- pred1_conf = pred1['conf']
75
- pred2_conf = pred2['conf']
76
- self.min_conf_thr = min_conf_thr
77
- self.conf_trf = get_conf_trf(conf)
78
-
79
- self.conf_i = NoGradParamDict({ij: pred1_conf[n] for n, ij in enumerate(self.str_edges)})
80
- self.conf_j = NoGradParamDict({ij: pred2_conf[n] for n, ij in enumerate(self.str_edges)})
81
- self.im_conf = self._compute_img_conf(pred1_conf, pred2_conf)
82
-
83
- # pairwise pose parameters
84
- self.base_scale = base_scale
85
- self.norm_pw_scale = True
86
- self.pw_break = pw_break
87
- self.POSE_DIM = 7
88
- self.pw_poses = nn.Parameter(rand_pose((self.n_edges, 1+self.POSE_DIM))) # pairwise poses
89
- self.pw_adaptors = nn.Parameter(torch.zeros((self.n_edges, 2))) # slight xy/z adaptation
90
- self.pw_adaptors.requires_grad_(allow_pw_adaptors)
91
- self.has_im_poses = False
92
- self.rand_pose = rand_pose
93
-
94
- # possibly store images for show_pointcloud
95
- self.imgs = None
96
- if 'img' in view1 and 'img' in view2:
97
- imgs = [torch.zeros((3,)+hw) for hw in self.imshapes]
98
- for v in range(len(self.edges)):
99
- idx = view1['idx'][v]
100
- imgs[idx] = view1['img'][v]
101
- idx = view2['idx'][v]
102
- imgs[idx] = view2['img'][v]
103
- self.imgs = rgb(imgs)
104
-
105
- @property
106
- def n_edges(self):
107
- return len(self.edges)
108
-
109
- @property
110
- def str_edges(self):
111
- return [edge_str(i, j) for i, j in self.edges]
112
-
113
- @property
114
- def imsizes(self):
115
- return [(w, h) for h, w in self.imshapes]
116
-
117
- @property
118
- def device(self):
119
- return next(iter(self.parameters())).device
120
-
121
- def state_dict(self, trainable=True):
122
- all_params = super().state_dict()
123
- return {k: v for k, v in all_params.items() if k.startswith(('_', 'pred_i.', 'pred_j.', 'conf_i.', 'conf_j.')) != trainable}
124
-
125
- def load_state_dict(self, data):
126
- return super().load_state_dict(self.state_dict(trainable=False) | data)
127
-
128
- def _check_edges(self):
129
- indices = sorted({i for edge in self.edges for i in edge})
130
- assert indices == list(range(len(indices))), 'bad pair indices: missing values '
131
- return len(indices)
132
-
133
- @torch.no_grad()
134
- def _compute_img_conf(self, pred1_conf, pred2_conf):
135
- im_conf = nn.ParameterList([torch.zeros(hw, device=self.device) for hw in self.imshapes])
136
- for e, (i, j) in enumerate(self.edges):
137
- im_conf[i] = torch.maximum(im_conf[i], pred1_conf[e])
138
- im_conf[j] = torch.maximum(im_conf[j], pred2_conf[e])
139
- return im_conf
140
-
141
- def get_adaptors(self):
142
- adapt = self.pw_adaptors
143
- adapt = torch.cat((adapt[:, 0:1], adapt), dim=-1) # (scale_xy, scale_xy, scale_z)
144
- if self.norm_pw_scale: # normalize so that the product == 1
145
- adapt = adapt - adapt.mean(dim=1, keepdim=True)
146
- return (adapt / self.pw_break).exp()
147
-
148
- def _get_poses(self, poses):
149
- # normalize rotation
150
- Q = poses[:, :4]
151
- T = signed_expm1(poses[:, 4:7])
152
- RT = roma.RigidUnitQuat(Q, T).normalize().to_homogeneous()
153
- return RT
154
-
155
- def _set_pose(self, poses, idx, R, T=None, scale=None, force=False):
156
- # all poses == cam-to-world
157
- pose = poses[idx]
158
- if not (pose.requires_grad or force):
159
- return pose
160
-
161
- if R.shape == (4, 4):
162
- assert T is None
163
- T = R[:3, 3]
164
- R = R[:3, :3]
165
-
166
- if R is not None:
167
- pose.data[0:4] = roma.rotmat_to_unitquat(R)
168
- if T is not None:
169
- pose.data[4:7] = signed_log1p(T / (scale or 1)) # translation is function of scale
170
-
171
- if scale is not None:
172
- assert poses.shape[-1] in (8, 13)
173
- pose.data[-1] = np.log(float(scale))
174
- return pose
175
-
176
- def get_pw_norm_scale_factor(self):
177
- if self.norm_pw_scale:
178
- # normalize scales so that things cannot go south
179
- # we want that exp(scale) ~= self.base_scale
180
- return (np.log(self.base_scale) - self.pw_poses[:, -1].mean()).exp()
181
- else:
182
- return 1 # don't norm scale for known poses
183
-
184
- def get_pw_scale(self):
185
- scale = self.pw_poses[:, -1].exp() # (n_edges,)
186
- scale = scale * self.get_pw_norm_scale_factor()
187
- return scale
188
-
189
- def get_pw_poses(self): # cam to world
190
- RT = self._get_poses(self.pw_poses)
191
- scaled_RT = RT.clone()
192
- scaled_RT[:, :3] *= self.get_pw_scale().view(-1, 1, 1) # scale the rotation AND translation
193
- return scaled_RT
194
-
195
- def get_masks(self):
196
- return [(conf > self.min_conf_thr) for conf in self.im_conf]
197
-
198
- def depth_to_pts3d(self):
199
- raise NotImplementedError()
200
-
201
- def get_pts3d(self, raw=False):
202
- res = self.depth_to_pts3d()
203
- if not raw:
204
- res = [dm[:h*w].view(h, w, 3) for dm, (h, w) in zip(res, self.imshapes)]
205
- return res
206
-
207
- def _set_focal(self, idx, focal, force=False):
208
- raise NotImplementedError()
209
-
210
- def get_focals(self):
211
- raise NotImplementedError()
212
-
213
- def get_known_focal_mask(self):
214
- raise NotImplementedError()
215
-
216
- def get_principal_points(self):
217
- raise NotImplementedError()
218
-
219
- def get_conf(self, mode=None):
220
- trf = self.conf_trf if mode is None else get_conf_trf(mode)
221
- return [trf(c) for c in self.im_conf]
222
-
223
- def get_im_poses(self):
224
- raise NotImplementedError()
225
-
226
- def _set_depthmap(self, idx, depth, force=False):
227
- raise NotImplementedError()
228
-
229
- def get_depthmaps(self, raw=False):
230
- raise NotImplementedError()
231
-
232
- @torch.no_grad()
233
- def clean_pointcloud(self, tol=0.001, max_bad_conf=0):
234
- """ Method:
235
- 1) express all 3d points in each camera coordinate frame
236
- 2) if they're in front of a depthmap --> then lower their confidence
237
- """
238
- assert 0 <= tol < 1
239
- cams = inv(self.get_im_poses())
240
- K = self.get_intrinsics()
241
- depthmaps = self.get_depthmaps()
242
- res = deepcopy(self)
243
-
244
- for i, pts3d in enumerate(self.depth_to_pts3d()):
245
- for j in range(self.n_imgs):
246
- if i == j:
247
- continue
248
-
249
- # project 3dpts in other view
250
- Hi, Wi = self.imshapes[i]
251
- Hj, Wj = self.imshapes[j]
252
- proj = geotrf(cams[j], pts3d[:Hi*Wi]).reshape(Hi, Wi, 3)
253
- proj_depth = proj[:, :, 2]
254
- u, v = geotrf(K[j], proj, norm=1, ncol=2).round().long().unbind(-1)
255
-
256
- # check which points are actually in the visible cone
257
- msk_i = (proj_depth > 0) & (0 <= u) & (u < Wj) & (0 <= v) & (v < Hj)
258
- msk_j = v[msk_i], u[msk_i]
259
-
260
- # find bad points = those in front but less confident
261
- bad_points = (proj_depth[msk_i] < (1-tol) * depthmaps[j][msk_j]
262
- ) & (res.im_conf[i][msk_i] < res.im_conf[j][msk_j])
263
-
264
- bad_msk_i = msk_i.clone()
265
- bad_msk_i[msk_i] = bad_points
266
- res.im_conf[i][bad_msk_i] = res.im_conf[i][bad_msk_i].clip_(max=max_bad_conf)
267
-
268
- return res
269
-
270
- def forward(self, ret_details=False):
271
- pw_poses = self.get_pw_poses() # cam-to-world
272
- pw_adapt = self.get_adaptors()
273
- proj_pts3d = self.get_pts3d()
274
- # pre-compute pixel weights
275
- weight_i = {i_j: self.conf_trf(c) for i_j, c in self.conf_i.items()}
276
- weight_j = {i_j: self.conf_trf(c) for i_j, c in self.conf_j.items()}
277
-
278
- loss = 0
279
- if ret_details:
280
- details = -torch.ones((self.n_imgs, self.n_imgs))
281
-
282
- for e, (i, j) in enumerate(self.edges):
283
- i_j = edge_str(i, j)
284
- # distance in image i and j
285
- aligned_pred_i = geotrf(pw_poses[e], pw_adapt[e] * self.pred_i[i_j])
286
- aligned_pred_j = geotrf(pw_poses[e], pw_adapt[e] * self.pred_j[i_j])
287
- li = self.dist(proj_pts3d[i], aligned_pred_i, weight=weight_i[i_j]).mean()
288
- lj = self.dist(proj_pts3d[j], aligned_pred_j, weight=weight_j[i_j]).mean()
289
- loss = loss + li + lj
290
-
291
- if ret_details:
292
- details[i, j] = li + lj
293
- loss /= self.n_edges # average over all pairs
294
-
295
- if ret_details:
296
- return loss, details
297
- return loss
298
-
299
- @torch.cuda.amp.autocast(enabled=False)
300
- def compute_global_alignment(self, init=None, niter_PnP=10, **kw):
301
- if init is None:
302
- pass
303
- elif init == 'msp' or init == 'mst':
304
- init_fun.init_minimum_spanning_tree(self, niter_PnP=niter_PnP)
305
- elif init == 'known_poses':
306
- init_fun.init_from_known_poses(self, min_conf_thr=self.min_conf_thr,
307
- niter_PnP=niter_PnP)
308
- else:
309
- raise ValueError(f'bad value for {init=}')
310
-
311
- return global_alignment_loop(self, **kw)
312
-
313
- @torch.no_grad()
314
- def mask_sky(self):
315
- res = deepcopy(self)
316
- for i in range(self.n_imgs):
317
- sky = segment_sky(self.imgs[i])
318
- res.im_conf[i][sky] = 0
319
- return res
320
-
321
- def show(self, show_pw_cams=False, show_pw_pts3d=False, cam_size=None, **kw):
322
- viz = SceneViz()
323
- if self.imgs is None:
324
- colors = np.random.randint(0, 256, size=(self.n_imgs, 3))
325
- colors = list(map(tuple, colors.tolist()))
326
- for n in range(self.n_imgs):
327
- viz.add_pointcloud(self.get_pts3d()[n], colors[n], self.get_masks()[n])
328
- else:
329
- viz.add_pointcloud(self.get_pts3d(), self.imgs, self.get_masks())
330
- colors = np.random.randint(256, size=(self.n_imgs, 3))
331
-
332
- # camera poses
333
- im_poses = to_numpy(self.get_im_poses())
334
- if cam_size is None:
335
- cam_size = auto_cam_size(im_poses)
336
- viz.add_cameras(im_poses, self.get_focals(), colors=colors,
337
- images=self.imgs, imsizes=self.imsizes, cam_size=cam_size)
338
- if show_pw_cams:
339
- pw_poses = self.get_pw_poses()
340
- viz.add_cameras(pw_poses, color=(192, 0, 192), cam_size=cam_size)
341
-
342
- if show_pw_pts3d:
343
- pts = [geotrf(pw_poses[e], self.pred_i[edge_str(i, j)]) for e, (i, j) in enumerate(self.edges)]
344
- viz.add_pointcloud(pts, (128, 0, 128))
345
-
346
- viz.show(**kw)
347
- return viz
348
-
349
-
350
- def global_alignment_loop(net, lr=0.01, niter=300, schedule='cosine', lr_min=1e-6):
351
- params = [p for p in net.parameters() if p.requires_grad]
352
- if not params:
353
- return net
354
-
355
- verbose = net.verbose
356
- if verbose:
357
- print('Global alignement - optimizing for:')
358
- print([name for name, value in net.named_parameters() if value.requires_grad])
359
-
360
- lr_base = lr
361
- optimizer = torch.optim.Adam(params, lr=lr, betas=(0.9, 0.9))
362
-
363
- loss = float('inf')
364
- if verbose:
365
- with tqdm.tqdm(total=niter) as bar:
366
- while bar.n < bar.total:
367
- loss = global_alignment_iter(net, bar.n, niter, lr_base, lr_min, optimizer, schedule)
368
- bar.set_postfix_str(f'{lr=:g} loss={loss:g}')
369
- bar.update()
370
- else:
371
- for n in range(niter):
372
- loss = global_alignment_iter(net, n, niter, lr_base, lr_min, optimizer, schedule)
373
- return loss
374
-
375
-
376
- def global_alignment_iter(net, cur_iter, niter, lr_base, lr_min, optimizer, schedule):
377
- t = cur_iter / niter
378
- if schedule == 'cosine':
379
- lr = cosine_schedule(t, lr_base, lr_min)
380
- elif schedule == 'linear':
381
- lr = linear_schedule(t, lr_base, lr_min)
382
- else:
383
- raise ValueError(f'bad lr {schedule=}')
384
- adjust_learning_rate_by_lr(optimizer, lr)
385
- optimizer.zero_grad()
386
- loss = net()
387
- loss.backward()
388
- optimizer.step()
389
-
390
- return float(loss)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mini_dust3r/cloud_opt/commons.py DELETED
@@ -1,90 +0,0 @@
1
- # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2
- # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
- #
4
- # --------------------------------------------------------
5
- # utility functions for global alignment
6
- # --------------------------------------------------------
7
- import torch
8
- import torch.nn as nn
9
- import numpy as np
10
-
11
-
12
- def edge_str(i, j):
13
- return f'{i}_{j}'
14
-
15
-
16
- def i_j_ij(ij):
17
- return edge_str(*ij), ij
18
-
19
-
20
- def edge_conf(conf_i, conf_j, edge):
21
- return float(conf_i[edge].mean() * conf_j[edge].mean())
22
-
23
-
24
- def compute_edge_scores(edges, conf_i, conf_j):
25
- return {(i, j): edge_conf(conf_i, conf_j, e) for e, (i, j) in edges}
26
-
27
-
28
- def NoGradParamDict(x):
29
- assert isinstance(x, dict)
30
- return nn.ParameterDict(x).requires_grad_(False)
31
-
32
-
33
- def get_imshapes(edges, pred_i, pred_j):
34
- n_imgs = max(max(e) for e in edges) + 1
35
- imshapes = [None] * n_imgs
36
- for e, (i, j) in enumerate(edges):
37
- shape_i = tuple(pred_i[e].shape[0:2])
38
- shape_j = tuple(pred_j[e].shape[0:2])
39
- if imshapes[i]:
40
- assert imshapes[i] == shape_i, f'incorrect shape for image {i}'
41
- if imshapes[j]:
42
- assert imshapes[j] == shape_j, f'incorrect shape for image {j}'
43
- imshapes[i] = shape_i
44
- imshapes[j] = shape_j
45
- return imshapes
46
-
47
-
48
- def get_conf_trf(mode):
49
- if mode == 'log':
50
- def conf_trf(x): return x.log()
51
- elif mode == 'sqrt':
52
- def conf_trf(x): return x.sqrt()
53
- elif mode == 'm1':
54
- def conf_trf(x): return x-1
55
- elif mode in ('id', 'none'):
56
- def conf_trf(x): return x
57
- else:
58
- raise ValueError(f'bad mode for {mode=}')
59
- return conf_trf
60
-
61
-
62
- def l2_dist(a, b, weight):
63
- return ((a - b).square().sum(dim=-1) * weight)
64
-
65
-
66
- def l1_dist(a, b, weight):
67
- return ((a - b).norm(dim=-1) * weight)
68
-
69
-
70
- ALL_DISTS = dict(l1=l1_dist, l2=l2_dist)
71
-
72
-
73
- def signed_log1p(x):
74
- sign = torch.sign(x)
75
- return sign * torch.log1p(torch.abs(x))
76
-
77
-
78
- def signed_expm1(x):
79
- sign = torch.sign(x)
80
- return sign * torch.expm1(torch.abs(x))
81
-
82
-
83
- def cosine_schedule(t, lr_start, lr_end):
84
- assert 0 <= t <= 1
85
- return lr_end + (lr_start - lr_end) * (1+np.cos(t * np.pi))/2
86
-
87
-
88
- def linear_schedule(t, lr_start, lr_end):
89
- assert 0 <= t <= 1
90
- return lr_start + (lr_end - lr_start) * t
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mini_dust3r/cloud_opt/init_im_poses.py DELETED
@@ -1,316 +0,0 @@
1
- # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2
- # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
- #
4
- # --------------------------------------------------------
5
- # Initialization functions for global alignment
6
- # --------------------------------------------------------
7
- from functools import cache
8
-
9
- import numpy as np
10
- import scipy.sparse as sp
11
- import torch
12
- import cv2
13
- import roma
14
- from tqdm import tqdm
15
-
16
- from mini_dust3r.utils.geometry import geotrf, inv, get_med_dist_between_poses
17
- from mini_dust3r.post_process import estimate_focal_knowing_depth
18
- from mini_dust3r.viz import to_numpy
19
-
20
- from mini_dust3r.cloud_opt.commons import edge_str, i_j_ij, compute_edge_scores
21
-
22
-
23
- @torch.no_grad()
24
- def init_from_known_poses(self, niter_PnP=10, min_conf_thr=3):
25
- device = self.device
26
-
27
- # indices of known poses
28
- nkp, known_poses_msk, known_poses = get_known_poses(self)
29
- assert nkp == self.n_imgs, 'not all poses are known'
30
-
31
- # get all focals
32
- nkf, _, im_focals = get_known_focals(self)
33
- assert nkf == self.n_imgs
34
- im_pp = self.get_principal_points()
35
-
36
- best_depthmaps = {}
37
- # init all pairwise poses
38
- for e, (i, j) in enumerate(tqdm(self.edges, disable=not self.verbose)):
39
- i_j = edge_str(i, j)
40
-
41
- # find relative pose for this pair
42
- P1 = torch.eye(4, device=device)
43
- msk = self.conf_i[i_j] > min(min_conf_thr, self.conf_i[i_j].min() - 0.1)
44
- _, P2 = fast_pnp(self.pred_j[i_j], float(im_focals[i].mean()),
45
- pp=im_pp[i], msk=msk, device=device, niter_PnP=niter_PnP)
46
-
47
- # align the two predicted camera with the two gt cameras
48
- s, R, T = align_multiple_poses(torch.stack((P1, P2)), known_poses[[i, j]])
49
- # normally we have known_poses[i] ~= sRT_to_4x4(s,R,T,device) @ P1
50
- # and geotrf(sRT_to_4x4(1,R,T,device), s*P2[:3,3])
51
- self._set_pose(self.pw_poses, e, R, T, scale=s)
52
-
53
- # remember if this is a good depthmap
54
- score = float(self.conf_i[i_j].mean())
55
- if score > best_depthmaps.get(i, (0,))[0]:
56
- best_depthmaps[i] = score, i_j, s
57
-
58
- # init all image poses
59
- for n in range(self.n_imgs):
60
- assert known_poses_msk[n]
61
- _, i_j, scale = best_depthmaps[n]
62
- depth = self.pred_i[i_j][:, :, 2]
63
- self._set_depthmap(n, depth * scale)
64
-
65
-
66
- @torch.no_grad()
67
- def init_minimum_spanning_tree(self, **kw):
68
- """ Init all camera poses (image-wise and pairwise poses) given
69
- an initial set of pairwise estimations.
70
- """
71
- device = self.device
72
- pts3d, _, im_focals, im_poses = minimum_spanning_tree(self.imshapes, self.edges,
73
- self.pred_i, self.pred_j, self.conf_i, self.conf_j, self.im_conf, self.min_conf_thr,
74
- device, has_im_poses=self.has_im_poses, verbose=self.verbose,
75
- **kw)
76
-
77
- return init_from_pts3d(self, pts3d, im_focals, im_poses)
78
-
79
-
80
- def init_from_pts3d(self, pts3d, im_focals, im_poses):
81
- # init poses
82
- nkp, known_poses_msk, known_poses = get_known_poses(self)
83
- if nkp == 1:
84
- raise NotImplementedError("Would be simpler to just align everything afterwards on the single known pose")
85
- elif nkp > 1:
86
- # global rigid SE3 alignment
87
- s, R, T = align_multiple_poses(im_poses[known_poses_msk], known_poses[known_poses_msk])
88
- trf = sRT_to_4x4(s, R, T, device=known_poses.device)
89
-
90
- # rotate everything
91
- im_poses = trf @ im_poses
92
- im_poses[:, :3, :3] /= s # undo scaling on the rotation part
93
- for img_pts3d in pts3d:
94
- img_pts3d[:] = geotrf(trf, img_pts3d)
95
-
96
- # set all pairwise poses
97
- for e, (i, j) in enumerate(self.edges):
98
- i_j = edge_str(i, j)
99
- # compute transform that goes from cam to world
100
- s, R, T = rigid_points_registration(self.pred_i[i_j], pts3d[i], conf=self.conf_i[i_j])
101
- self._set_pose(self.pw_poses, e, R, T, scale=s)
102
-
103
- # take into account the scale normalization
104
- s_factor = self.get_pw_norm_scale_factor()
105
- im_poses[:, :3, 3] *= s_factor # apply downscaling factor
106
- for img_pts3d in pts3d:
107
- img_pts3d *= s_factor
108
-
109
- # init all image poses
110
- if self.has_im_poses:
111
- for i in range(self.n_imgs):
112
- cam2world = im_poses[i]
113
- depth = geotrf(inv(cam2world), pts3d[i])[..., 2]
114
- self._set_depthmap(i, depth)
115
- self._set_pose(self.im_poses, i, cam2world)
116
- if im_focals[i] is not None:
117
- self._set_focal(i, im_focals[i])
118
-
119
- if self.verbose:
120
- print(' init loss =', float(self()))
121
-
122
-
123
- def minimum_spanning_tree(imshapes, edges, pred_i, pred_j, conf_i, conf_j, im_conf, min_conf_thr,
124
- device, has_im_poses=True, niter_PnP=10, verbose=True):
125
- n_imgs = len(imshapes)
126
- sparse_graph = -dict_to_sparse_graph(compute_edge_scores(map(i_j_ij, edges), conf_i, conf_j))
127
- msp = sp.csgraph.minimum_spanning_tree(sparse_graph).tocoo()
128
-
129
- # temp variable to store 3d points
130
- pts3d = [None] * len(imshapes)
131
-
132
- todo = sorted(zip(-msp.data, msp.row, msp.col)) # sorted edges
133
- im_poses = [None] * n_imgs
134
- im_focals = [None] * n_imgs
135
-
136
- # init with strongest edge
137
- score, i, j = todo.pop()
138
- if verbose:
139
- print(f' init edge ({i}*,{j}*) {score=}')
140
- i_j = edge_str(i, j)
141
- pts3d[i] = pred_i[i_j].clone()
142
- pts3d[j] = pred_j[i_j].clone()
143
- done = {i, j}
144
- if has_im_poses:
145
- im_poses[i] = torch.eye(4, device=device)
146
- im_focals[i] = estimate_focal(pred_i[i_j])
147
-
148
- # set initial pointcloud based on pairwise graph
149
- msp_edges = [(i, j)]
150
- while todo:
151
- # each time, predict the next one
152
- score, i, j = todo.pop()
153
-
154
- if im_focals[i] is None:
155
- im_focals[i] = estimate_focal(pred_i[i_j])
156
-
157
- if i in done:
158
- if verbose:
159
- print(f' init edge ({i},{j}*) {score=}')
160
- assert j not in done
161
- # align pred[i] with pts3d[i], and then set j accordingly
162
- i_j = edge_str(i, j)
163
- s, R, T = rigid_points_registration(pred_i[i_j], pts3d[i], conf=conf_i[i_j])
164
- trf = sRT_to_4x4(s, R, T, device)
165
- pts3d[j] = geotrf(trf, pred_j[i_j])
166
- done.add(j)
167
- msp_edges.append((i, j))
168
-
169
- if has_im_poses and im_poses[i] is None:
170
- im_poses[i] = sRT_to_4x4(1, R, T, device)
171
-
172
- elif j in done:
173
- if verbose:
174
- print(f' init edge ({i}*,{j}) {score=}')
175
- assert i not in done
176
- i_j = edge_str(i, j)
177
- s, R, T = rigid_points_registration(pred_j[i_j], pts3d[j], conf=conf_j[i_j])
178
- trf = sRT_to_4x4(s, R, T, device)
179
- pts3d[i] = geotrf(trf, pred_i[i_j])
180
- done.add(i)
181
- msp_edges.append((i, j))
182
-
183
- if has_im_poses and im_poses[i] is None:
184
- im_poses[i] = sRT_to_4x4(1, R, T, device)
185
- else:
186
- # let's try again later
187
- todo.insert(0, (score, i, j))
188
-
189
- if has_im_poses:
190
- # complete all missing informations
191
- pair_scores = list(sparse_graph.values()) # already negative scores: less is best
192
- edges_from_best_to_worse = np.array(list(sparse_graph.keys()))[np.argsort(pair_scores)]
193
- for i, j in edges_from_best_to_worse.tolist():
194
- if im_focals[i] is None:
195
- im_focals[i] = estimate_focal(pred_i[edge_str(i, j)])
196
-
197
- for i in range(n_imgs):
198
- if im_poses[i] is None:
199
- msk = im_conf[i] > min_conf_thr
200
- res = fast_pnp(pts3d[i], im_focals[i], msk=msk, device=device, niter_PnP=niter_PnP)
201
- if res:
202
- im_focals[i], im_poses[i] = res
203
- if im_poses[i] is None:
204
- im_poses[i] = torch.eye(4, device=device)
205
- im_poses = torch.stack(im_poses)
206
- else:
207
- im_poses = im_focals = None
208
-
209
- return pts3d, msp_edges, im_focals, im_poses
210
-
211
-
212
- def dict_to_sparse_graph(dic):
213
- n_imgs = max(max(e) for e in dic) + 1
214
- res = sp.dok_array((n_imgs, n_imgs))
215
- for edge, value in dic.items():
216
- res[edge] = value
217
- return res
218
-
219
-
220
- def rigid_points_registration(pts1, pts2, conf):
221
- R, T, s = roma.rigid_points_registration(
222
- pts1.reshape(-1, 3), pts2.reshape(-1, 3), weights=conf.ravel(), compute_scaling=True)
223
- return s, R, T # return un-scaled (R, T)
224
-
225
-
226
- def sRT_to_4x4(scale, R, T, device):
227
- trf = torch.eye(4, device=device)
228
- trf[:3, :3] = R * scale
229
- trf[:3, 3] = T.ravel() # doesn't need scaling
230
- return trf
231
-
232
-
233
- def estimate_focal(pts3d_i, pp=None):
234
- if pp is None:
235
- H, W, THREE = pts3d_i.shape
236
- assert THREE == 3
237
- pp = torch.tensor((W/2, H/2), device=pts3d_i.device)
238
- focal = estimate_focal_knowing_depth(pts3d_i.unsqueeze(0), pp.unsqueeze(0), focal_mode='weiszfeld').ravel()
239
- return float(focal)
240
-
241
-
242
- @cache
243
- def pixel_grid(H, W):
244
- return np.mgrid[:W, :H].T.astype(np.float32)
245
-
246
-
247
- def fast_pnp(pts3d, focal, msk, device, pp=None, niter_PnP=10):
248
- # extract camera poses and focals with RANSAC-PnP
249
- if msk.sum() < 4:
250
- return None # we need at least 4 points for PnP
251
- pts3d, msk = map(to_numpy, (pts3d, msk))
252
-
253
- H, W, THREE = pts3d.shape
254
- assert THREE == 3
255
- pixels = pixel_grid(H, W)
256
-
257
- if focal is None:
258
- S = max(W, H)
259
- tentative_focals = np.geomspace(S/2, S*3, 21)
260
- else:
261
- tentative_focals = [focal]
262
-
263
- if pp is None:
264
- pp = (W/2, H/2)
265
- else:
266
- pp = to_numpy(pp)
267
-
268
- best = 0,
269
- for focal in tentative_focals:
270
- K = np.float32([(focal, 0, pp[0]), (0, focal, pp[1]), (0, 0, 1)])
271
-
272
- success, R, T, inliers = cv2.solvePnPRansac(pts3d[msk], pixels[msk], K, None,
273
- iterationsCount=niter_PnP, reprojectionError=5, flags=cv2.SOLVEPNP_SQPNP)
274
- if not success:
275
- continue
276
-
277
- score = len(inliers)
278
- if success and score > best[0]:
279
- best = score, R, T, focal
280
-
281
- if not best[0]:
282
- return None
283
-
284
- _, R, T, best_focal = best
285
- R = cv2.Rodrigues(R)[0] # world to cam
286
- R, T = map(torch.from_numpy, (R, T))
287
- return best_focal, inv(sRT_to_4x4(1, R, T, device)) # cam to world
288
-
289
-
290
- def get_known_poses(self):
291
- if self.has_im_poses:
292
- known_poses_msk = torch.tensor([not (p.requires_grad) for p in self.im_poses])
293
- known_poses = self.get_im_poses()
294
- return known_poses_msk.sum(), known_poses_msk, known_poses
295
- else:
296
- return 0, None, None
297
-
298
-
299
- def get_known_focals(self):
300
- if self.has_im_poses:
301
- known_focal_msk = self.get_known_focal_mask()
302
- known_focals = self.get_focals()
303
- return known_focal_msk.sum(), known_focal_msk, known_focals
304
- else:
305
- return 0, None, None
306
-
307
-
308
- def align_multiple_poses(src_poses, target_poses):
309
- N = len(src_poses)
310
- assert src_poses.shape == target_poses.shape == (N, 4, 4)
311
-
312
- def center_and_z(poses):
313
- eps = get_med_dist_between_poses(poses) / 100
314
- return torch.cat((poses[:, :3, 3], poses[:, :3, 3] + eps*poses[:, :3, 2]))
315
- R, T, s = roma.rigid_points_registration(center_and_z(src_poses), center_and_z(target_poses), compute_scaling=True)
316
- return s, R, T
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mini_dust3r/cloud_opt/modular_optimizer.py DELETED
@@ -1,145 +0,0 @@
1
- # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2
- # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
- #
4
- # --------------------------------------------------------
5
- # Slower implementation of the global alignment that allows to freeze partial poses/intrinsics
6
- # --------------------------------------------------------
7
- import numpy as np
8
- import torch
9
- import torch.nn as nn
10
-
11
- from mini_dust3r.cloud_opt.base_opt import BasePCOptimizer
12
- from mini_dust3r.utils.geometry import geotrf
13
- from mini_dust3r.utils.device import to_cpu, to_numpy
14
- from mini_dust3r.utils.geometry import depthmap_to_pts3d
15
-
16
-
17
- class ModularPointCloudOptimizer (BasePCOptimizer):
18
- """ Optimize a global scene, given a list of pairwise observations.
19
- Unlike PointCloudOptimizer, you can fix parts of the optimization process (partial poses/intrinsics)
20
- Graph node: images
21
- Graph edges: observations = (pred1, pred2)
22
- """
23
-
24
- def __init__(self, *args, optimize_pp=False, fx_and_fy=False, focal_brake=20, **kwargs):
25
- super().__init__(*args, **kwargs)
26
- self.has_im_poses = True # by definition of this class
27
- self.focal_brake = focal_brake
28
-
29
- # adding thing to optimize
30
- self.im_depthmaps = nn.ParameterList(torch.randn(H, W)/10-3 for H, W in self.imshapes) # log(depth)
31
- self.im_poses = nn.ParameterList(self.rand_pose(self.POSE_DIM) for _ in range(self.n_imgs)) # camera poses
32
- default_focals = [self.focal_brake * np.log(max(H, W)) for H, W in self.imshapes]
33
- self.im_focals = nn.ParameterList(torch.FloatTensor([f, f] if fx_and_fy else [
34
- f]) for f in default_focals) # camera intrinsics
35
- self.im_pp = nn.ParameterList(torch.zeros((2,)) for _ in range(self.n_imgs)) # camera intrinsics
36
- self.im_pp.requires_grad_(optimize_pp)
37
-
38
- def preset_pose(self, known_poses, pose_msk=None): # cam-to-world
39
- if isinstance(known_poses, torch.Tensor) and known_poses.ndim == 2:
40
- known_poses = [known_poses]
41
- for idx, pose in zip(self._get_msk_indices(pose_msk), known_poses):
42
- if self.verbose:
43
- print(f' (setting pose #{idx} = {pose[:3,3]})')
44
- self._no_grad(self._set_pose(self.im_poses, idx, torch.tensor(pose), force=True))
45
-
46
- # normalize scale if there's less than 1 known pose
47
- n_known_poses = sum((p.requires_grad is False) for p in self.im_poses)
48
- self.norm_pw_scale = (n_known_poses <= 1)
49
-
50
- def preset_intrinsics(self, known_intrinsics, msk=None):
51
- if isinstance(known_intrinsics, torch.Tensor) and known_intrinsics.ndim == 2:
52
- known_intrinsics = [known_intrinsics]
53
- for K in known_intrinsics:
54
- assert K.shape == (3, 3)
55
- self.preset_focal([K.diagonal()[:2].mean() for K in known_intrinsics], msk)
56
- self.preset_principal_point([K[:2, 2] for K in known_intrinsics], msk)
57
-
58
- def preset_focal(self, known_focals, msk=None):
59
- for idx, focal in zip(self._get_msk_indices(msk), known_focals):
60
- if self.verbose:
61
- print(f' (setting focal #{idx} = {focal})')
62
- self._no_grad(self._set_focal(idx, focal, force=True))
63
-
64
- def preset_principal_point(self, known_pp, msk=None):
65
- for idx, pp in zip(self._get_msk_indices(msk), known_pp):
66
- if self.verbose:
67
- print(f' (setting principal point #{idx} = {pp})')
68
- self._no_grad(self._set_principal_point(idx, pp, force=True))
69
-
70
- def _no_grad(self, tensor):
71
- return tensor.requires_grad_(False)
72
-
73
- def _get_msk_indices(self, msk):
74
- if msk is None:
75
- return range(self.n_imgs)
76
- elif isinstance(msk, int):
77
- return [msk]
78
- elif isinstance(msk, (tuple, list)):
79
- return self._get_msk_indices(np.array(msk))
80
- elif msk.dtype in (bool, torch.bool, np.bool_):
81
- assert len(msk) == self.n_imgs
82
- return np.where(msk)[0]
83
- elif np.issubdtype(msk.dtype, np.integer):
84
- return msk
85
- else:
86
- raise ValueError(f'bad {msk=}')
87
-
88
- def _set_focal(self, idx, focal, force=False):
89
- param = self.im_focals[idx]
90
- if param.requires_grad or force: # can only init a parameter not already initialized
91
- param.data[:] = self.focal_brake * np.log(focal)
92
- return param
93
-
94
- def get_focals(self):
95
- log_focals = torch.stack(list(self.im_focals), dim=0)
96
- return (log_focals / self.focal_brake).exp()
97
-
98
- def _set_principal_point(self, idx, pp, force=False):
99
- param = self.im_pp[idx]
100
- H, W = self.imshapes[idx]
101
- if param.requires_grad or force: # can only init a parameter not already initialized
102
- param.data[:] = to_cpu(to_numpy(pp) - (W/2, H/2)) / 10
103
- return param
104
-
105
- def get_principal_points(self):
106
- return torch.stack([pp.new((W/2, H/2))+10*pp for pp, (H, W) in zip(self.im_pp, self.imshapes)])
107
-
108
- def get_intrinsics(self):
109
- K = torch.zeros((self.n_imgs, 3, 3), device=self.device)
110
- focals = self.get_focals().view(self.n_imgs, -1)
111
- K[:, 0, 0] = focals[:, 0]
112
- K[:, 1, 1] = focals[:, -1]
113
- K[:, :2, 2] = self.get_principal_points()
114
- K[:, 2, 2] = 1
115
- return K
116
-
117
- def get_im_poses(self): # cam to world
118
- cam2world = self._get_poses(torch.stack(list(self.im_poses)))
119
- return cam2world
120
-
121
- def _set_depthmap(self, idx, depth, force=False):
122
- param = self.im_depthmaps[idx]
123
- if param.requires_grad or force: # can only init a parameter not already initialized
124
- param.data[:] = depth.log().nan_to_num(neginf=0)
125
- return param
126
-
127
- def get_depthmaps(self):
128
- return [d.exp() for d in self.im_depthmaps]
129
-
130
- def depth_to_pts3d(self):
131
- # Get depths and projection params if not provided
132
- focals = self.get_focals()
133
- pp = self.get_principal_points()
134
- im_poses = self.get_im_poses()
135
- depth = self.get_depthmaps()
136
-
137
- # convert focal to (1,2,H,W) constant field
138
- def focal_ex(i): return focals[i][..., None, None].expand(1, *focals[i].shape, *self.imshapes[i])
139
- # get pointmaps in camera frame
140
- rel_ptmaps = [depthmap_to_pts3d(depth[i][None], focal_ex(i), pp=pp[i:i+1])[0] for i in range(im_poses.shape[0])]
141
- # project to world frame
142
- return [geotrf(pose, ptmap) for pose, ptmap in zip(im_poses, rel_ptmaps)]
143
-
144
- def get_pts3d(self):
145
- return self.depth_to_pts3d()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mini_dust3r/cloud_opt/optimizer.py DELETED
@@ -1,248 +0,0 @@
1
- # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2
- # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
- #
4
- # --------------------------------------------------------
5
- # Main class for the implementation of the global alignment
6
- # --------------------------------------------------------
7
- import numpy as np
8
- import torch
9
- import torch.nn as nn
10
-
11
- from mini_dust3r.cloud_opt.base_opt import BasePCOptimizer
12
- from mini_dust3r.utils.geometry import xy_grid, geotrf
13
- from mini_dust3r.utils.device import to_cpu, to_numpy
14
-
15
-
16
- class PointCloudOptimizer(BasePCOptimizer):
17
- """ Optimize a global scene, given a list of pairwise observations.
18
- Graph node: images
19
- Graph edges: observations = (pred1, pred2)
20
- """
21
-
22
- def __init__(self, *args, optimize_pp=False, focal_break=20, **kwargs):
23
- super().__init__(*args, **kwargs)
24
-
25
- self.has_im_poses = True # by definition of this class
26
- self.focal_break = focal_break
27
-
28
- # adding thing to optimize
29
- self.im_depthmaps = nn.ParameterList(torch.randn(H, W)/10-3 for H, W in self.imshapes) # log(depth)
30
- self.im_poses = nn.ParameterList(self.rand_pose(self.POSE_DIM) for _ in range(self.n_imgs)) # camera poses
31
- self.im_focals = nn.ParameterList(torch.FloatTensor(
32
- [self.focal_break*np.log(max(H, W))]) for H, W in self.imshapes) # camera intrinsics
33
- self.im_pp = nn.ParameterList(torch.zeros((2,)) for _ in range(self.n_imgs)) # camera intrinsics
34
- self.im_pp.requires_grad_(optimize_pp)
35
-
36
- self.imshape = self.imshapes[0]
37
- im_areas = [h*w for h, w in self.imshapes]
38
- self.max_area = max(im_areas)
39
-
40
- # adding thing to optimize
41
- self.im_depthmaps = ParameterStack(self.im_depthmaps, is_param=True, fill=self.max_area)
42
- self.im_poses = ParameterStack(self.im_poses, is_param=True)
43
- self.im_focals = ParameterStack(self.im_focals, is_param=True)
44
- self.im_pp = ParameterStack(self.im_pp, is_param=True)
45
- self.register_buffer('_pp', torch.tensor([(w/2, h/2) for h, w in self.imshapes]))
46
- self.register_buffer('_grid', ParameterStack(
47
- [xy_grid(W, H, device=self.device) for H, W in self.imshapes], fill=self.max_area))
48
-
49
- # pre-compute pixel weights
50
- self.register_buffer('_weight_i', ParameterStack(
51
- [self.conf_trf(self.conf_i[i_j]) for i_j in self.str_edges], fill=self.max_area))
52
- self.register_buffer('_weight_j', ParameterStack(
53
- [self.conf_trf(self.conf_j[i_j]) for i_j in self.str_edges], fill=self.max_area))
54
-
55
- # precompute aa
56
- self.register_buffer('_stacked_pred_i', ParameterStack(self.pred_i, self.str_edges, fill=self.max_area))
57
- self.register_buffer('_stacked_pred_j', ParameterStack(self.pred_j, self.str_edges, fill=self.max_area))
58
- self.register_buffer('_ei', torch.tensor([i for i, j in self.edges]))
59
- self.register_buffer('_ej', torch.tensor([j for i, j in self.edges]))
60
- self.total_area_i = sum([im_areas[i] for i, j in self.edges])
61
- self.total_area_j = sum([im_areas[j] for i, j in self.edges])
62
-
63
- def _check_all_imgs_are_selected(self, msk):
64
- assert np.all(self._get_msk_indices(msk) == np.arange(self.n_imgs)), 'incomplete mask!'
65
-
66
- def preset_pose(self, known_poses, pose_msk=None): # cam-to-world
67
- self._check_all_imgs_are_selected(pose_msk)
68
-
69
- if isinstance(known_poses, torch.Tensor) and known_poses.ndim == 2:
70
- known_poses = [known_poses]
71
- for idx, pose in zip(self._get_msk_indices(pose_msk), known_poses):
72
- if self.verbose:
73
- print(f' (setting pose #{idx} = {pose[:3,3]})')
74
- self._no_grad(self._set_pose(self.im_poses, idx, torch.tensor(pose)))
75
-
76
- # normalize scale if there's less than 1 known pose
77
- n_known_poses = sum((p.requires_grad is False) for p in self.im_poses)
78
- self.norm_pw_scale = (n_known_poses <= 1)
79
-
80
- self.im_poses.requires_grad_(False)
81
- self.norm_pw_scale = False
82
-
83
- def preset_focal(self, known_focals, msk=None):
84
- self._check_all_imgs_are_selected(msk)
85
-
86
- for idx, focal in zip(self._get_msk_indices(msk), known_focals):
87
- if self.verbose:
88
- print(f' (setting focal #{idx} = {focal})')
89
- self._no_grad(self._set_focal(idx, focal))
90
-
91
- self.im_focals.requires_grad_(False)
92
-
93
- def preset_principal_point(self, known_pp, msk=None):
94
- self._check_all_imgs_are_selected(msk)
95
-
96
- for idx, pp in zip(self._get_msk_indices(msk), known_pp):
97
- if self.verbose:
98
- print(f' (setting principal point #{idx} = {pp})')
99
- self._no_grad(self._set_principal_point(idx, pp))
100
-
101
- self.im_pp.requires_grad_(False)
102
-
103
- def _get_msk_indices(self, msk):
104
- if msk is None:
105
- return range(self.n_imgs)
106
- elif isinstance(msk, int):
107
- return [msk]
108
- elif isinstance(msk, (tuple, list)):
109
- return self._get_msk_indices(np.array(msk))
110
- elif msk.dtype in (bool, torch.bool, np.bool_):
111
- assert len(msk) == self.n_imgs
112
- return np.where(msk)[0]
113
- elif np.issubdtype(msk.dtype, np.integer):
114
- return msk
115
- else:
116
- raise ValueError(f'bad {msk=}')
117
-
118
- def _no_grad(self, tensor):
119
- assert tensor.requires_grad, 'it must be True at this point, otherwise no modification occurs'
120
-
121
- def _set_focal(self, idx, focal, force=False):
122
- param = self.im_focals[idx]
123
- if param.requires_grad or force: # can only init a parameter not already initialized
124
- param.data[:] = self.focal_break * np.log(focal)
125
- return param
126
-
127
- def get_focals(self):
128
- log_focals = torch.stack(list(self.im_focals), dim=0)
129
- return (log_focals / self.focal_break).exp()
130
-
131
- def get_known_focal_mask(self):
132
- return torch.tensor([not (p.requires_grad) for p in self.im_focals])
133
-
134
- def _set_principal_point(self, idx, pp, force=False):
135
- param = self.im_pp[idx]
136
- H, W = self.imshapes[idx]
137
- if param.requires_grad or force: # can only init a parameter not already initialized
138
- param.data[:] = to_cpu(to_numpy(pp) - (W/2, H/2)) / 10
139
- return param
140
-
141
- def get_principal_points(self):
142
- return self._pp + 10 * self.im_pp
143
-
144
- def get_intrinsics(self):
145
- K = torch.zeros((self.n_imgs, 3, 3), device=self.device)
146
- focals = self.get_focals().flatten()
147
- K[:, 0, 0] = K[:, 1, 1] = focals
148
- K[:, :2, 2] = self.get_principal_points()
149
- K[:, 2, 2] = 1
150
- return K
151
-
152
- def get_im_poses(self): # cam to world
153
- cam2world = self._get_poses(self.im_poses)
154
- return cam2world
155
-
156
- def _set_depthmap(self, idx, depth, force=False):
157
- depth = _ravel_hw(depth, self.max_area)
158
-
159
- param = self.im_depthmaps[idx]
160
- if param.requires_grad or force: # can only init a parameter not already initialized
161
- param.data[:] = depth.log().nan_to_num(neginf=0)
162
- return param
163
-
164
- def get_depthmaps(self, raw=False):
165
- res = self.im_depthmaps.exp()
166
- if not raw:
167
- res = [dm[:h*w].view(h, w) for dm, (h, w) in zip(res, self.imshapes)]
168
- return res
169
-
170
- def depth_to_pts3d(self):
171
- # Get depths and projection params if not provided
172
- focals = self.get_focals()
173
- pp = self.get_principal_points()
174
- im_poses = self.get_im_poses()
175
- depth = self.get_depthmaps(raw=True)
176
-
177
- # get pointmaps in camera frame
178
- rel_ptmaps = _fast_depthmap_to_pts3d(depth, self._grid, focals, pp=pp)
179
- # project to world frame
180
- return geotrf(im_poses, rel_ptmaps)
181
-
182
- def get_pts3d(self, raw=False):
183
- res = self.depth_to_pts3d()
184
- if not raw:
185
- res = [dm[:h*w].view(h, w, 3) for dm, (h, w) in zip(res, self.imshapes)]
186
- return res
187
-
188
- def forward(self):
189
- pw_poses = self.get_pw_poses() # cam-to-world
190
- pw_adapt = self.get_adaptors().unsqueeze(1)
191
- proj_pts3d = self.get_pts3d(raw=True)
192
-
193
- # rotate pairwise prediction according to pw_poses
194
- aligned_pred_i = geotrf(pw_poses, pw_adapt * self._stacked_pred_i)
195
- aligned_pred_j = geotrf(pw_poses, pw_adapt * self._stacked_pred_j)
196
-
197
- # compute the less
198
- li = self.dist(proj_pts3d[self._ei], aligned_pred_i, weight=self._weight_i).sum() / self.total_area_i
199
- lj = self.dist(proj_pts3d[self._ej], aligned_pred_j, weight=self._weight_j).sum() / self.total_area_j
200
-
201
- return li + lj
202
-
203
-
204
- def _fast_depthmap_to_pts3d(depth, pixel_grid, focal, pp):
205
- pp = pp.unsqueeze(1)
206
- focal = focal.unsqueeze(1)
207
- assert focal.shape == (len(depth), 1, 1)
208
- assert pp.shape == (len(depth), 1, 2)
209
- assert pixel_grid.shape == depth.shape + (2,)
210
- depth = depth.unsqueeze(-1)
211
- return torch.cat((depth * (pixel_grid - pp) / focal, depth), dim=-1)
212
-
213
-
214
- def ParameterStack(params, keys=None, is_param=None, fill=0):
215
- if keys is not None:
216
- params = [params[k] for k in keys]
217
-
218
- if fill > 0:
219
- params = [_ravel_hw(p, fill) for p in params]
220
-
221
- requires_grad = params[0].requires_grad
222
- assert all(p.requires_grad == requires_grad for p in params)
223
-
224
- params = torch.stack(list(params)).float().detach()
225
- if is_param or requires_grad:
226
- params = nn.Parameter(params)
227
- params.requires_grad_(requires_grad)
228
- return params
229
-
230
-
231
- def _ravel_hw(tensor, fill=0):
232
- # ravel H,W
233
- tensor = tensor.view((tensor.shape[0] * tensor.shape[1],) + tensor.shape[2:])
234
-
235
- if len(tensor) < fill:
236
- tensor = torch.cat((tensor, tensor.new_zeros((fill - len(tensor),)+tensor.shape[1:])))
237
- return tensor
238
-
239
-
240
- def acceptable_focal_range(H, W, minf=0.5, maxf=3.5):
241
- focal_base = max(H, W) / (2 * np.tan(np.deg2rad(60) / 2)) # size / 1.1547005383792515
242
- return minf*focal_base, maxf*focal_base
243
-
244
-
245
- def apply_mask(img, msk):
246
- img = img.copy()
247
- img[msk] = 0
248
- return img
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mini_dust3r/cloud_opt/pair_viewer.py DELETED
@@ -1,127 +0,0 @@
1
- # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2
- # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
- #
4
- # --------------------------------------------------------
5
- # Dummy optimizer for visualizing pairs
6
- # --------------------------------------------------------
7
- import numpy as np
8
- import torch
9
- import torch.nn as nn
10
- import cv2
11
-
12
- from mini_dust3r.cloud_opt.base_opt import BasePCOptimizer
13
- from mini_dust3r.utils.geometry import inv, geotrf, depthmap_to_absolute_camera_coordinates
14
- from mini_dust3r.cloud_opt.commons import edge_str
15
- from mini_dust3r.post_process import estimate_focal_knowing_depth
16
-
17
-
18
- class PairViewer (BasePCOptimizer):
19
- """
20
- This a Dummy Optimizer.
21
- To use only when the goal is to visualize the results for a pair of images (with is_symmetrized)
22
- """
23
-
24
- def __init__(self, *args, **kwargs):
25
- super().__init__(*args, **kwargs)
26
- assert self.is_symmetrized and self.n_edges == 2
27
- self.has_im_poses = True
28
-
29
- # compute all parameters directly from raw input
30
- self.focals = []
31
- self.pp = []
32
- rel_poses = []
33
- confs = []
34
- for i in range(self.n_imgs):
35
- conf = float(self.conf_i[edge_str(i, 1-i)].mean() * self.conf_j[edge_str(i, 1-i)].mean())
36
- if self.verbose:
37
- print(f' - {conf=:.3} for edge {i}-{1-i}')
38
- confs.append(conf)
39
-
40
- H, W = self.imshapes[i]
41
- pts3d = self.pred_i[edge_str(i, 1-i)]
42
- pp = torch.tensor((W/2, H/2))
43
- focal = float(estimate_focal_knowing_depth(pts3d[None], pp, focal_mode='weiszfeld'))
44
- self.focals.append(focal)
45
- self.pp.append(pp)
46
-
47
- # estimate the pose of pts1 in image 2
48
- pixels = np.mgrid[:W, :H].T.astype(np.float32)
49
- pts3d = self.pred_j[edge_str(1-i, i)].numpy()
50
- assert pts3d.shape[:2] == (H, W)
51
- msk = self.get_masks()[i].numpy()
52
- K = np.float32([(focal, 0, pp[0]), (0, focal, pp[1]), (0, 0, 1)])
53
-
54
- try:
55
- res = cv2.solvePnPRansac(pts3d[msk], pixels[msk], K, None,
56
- iterationsCount=100, reprojectionError=5, flags=cv2.SOLVEPNP_SQPNP)
57
- success, R, T, inliers = res
58
- assert success
59
-
60
- R = cv2.Rodrigues(R)[0] # world to cam
61
- pose = inv(np.r_[np.c_[R, T], [(0, 0, 0, 1)]]) # cam to world
62
- except:
63
- pose = np.eye(4)
64
- rel_poses.append(torch.from_numpy(pose.astype(np.float32)))
65
-
66
- # let's use the pair with the most confidence
67
- if confs[0] > confs[1]:
68
- # ptcloud is expressed in camera1
69
- self.im_poses = [torch.eye(4), rel_poses[1]] # I, cam2-to-cam1
70
- self.depth = [self.pred_i['0_1'][..., 2], geotrf(inv(rel_poses[1]), self.pred_j['0_1'])[..., 2]]
71
- else:
72
- # ptcloud is expressed in camera2
73
- self.im_poses = [rel_poses[0], torch.eye(4)] # I, cam1-to-cam2
74
- self.depth = [geotrf(inv(rel_poses[0]), self.pred_j['1_0'])[..., 2], self.pred_i['1_0'][..., 2]]
75
-
76
- self.im_poses = nn.Parameter(torch.stack(self.im_poses, dim=0), requires_grad=False)
77
- self.focals = nn.Parameter(torch.tensor(self.focals), requires_grad=False)
78
- self.pp = nn.Parameter(torch.stack(self.pp, dim=0), requires_grad=False)
79
- self.depth = nn.ParameterList(self.depth)
80
- for p in self.parameters():
81
- p.requires_grad = False
82
-
83
- def _set_depthmap(self, idx, depth, force=False):
84
- if self.verbose:
85
- print('_set_depthmap is ignored in PairViewer')
86
- return
87
-
88
- def get_depthmaps(self, raw=False):
89
- depth = [d.to(self.device) for d in self.depth]
90
- return depth
91
-
92
- def _set_focal(self, idx, focal, force=False):
93
- self.focals[idx] = focal
94
-
95
- def get_focals(self):
96
- return self.focals
97
-
98
- def get_known_focal_mask(self):
99
- return torch.tensor([not (p.requires_grad) for p in self.focals])
100
-
101
- def get_principal_points(self):
102
- return self.pp
103
-
104
- def get_intrinsics(self):
105
- focals = self.get_focals()
106
- pps = self.get_principal_points()
107
- K = torch.zeros((len(focals), 3, 3), device=self.device)
108
- for i in range(len(focals)):
109
- K[i, 0, 0] = K[i, 1, 1] = focals[i]
110
- K[i, :2, 2] = pps[i]
111
- K[i, 2, 2] = 1
112
- return K
113
-
114
- def get_im_poses(self):
115
- return self.im_poses
116
-
117
- def depth_to_pts3d(self):
118
- pts3d = []
119
- for d, intrinsics, im_pose in zip(self.depth, self.get_intrinsics(), self.get_im_poses()):
120
- pts, _ = depthmap_to_absolute_camera_coordinates(d.cpu().numpy(),
121
- intrinsics.cpu().numpy(),
122
- im_pose.cpu().numpy())
123
- pts3d.append(torch.from_numpy(pts).to(device=self.device))
124
- return pts3d
125
-
126
- def forward(self):
127
- return float('nan')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mini_dust3r/croco/blocks.py DELETED
@@ -1,241 +0,0 @@
1
- # Copyright (C) 2022-present Naver Corporation. All rights reserved.
2
- # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
-
4
-
5
- # --------------------------------------------------------
6
- # Main encoder/decoder blocks
7
- # --------------------------------------------------------
8
- # References:
9
- # timm
10
- # https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
11
- # https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/helpers.py
12
- # https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/drop.py
13
- # https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/mlp.py
14
- # https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/patch_embed.py
15
-
16
-
17
- import torch
18
- import torch.nn as nn
19
-
20
- from itertools import repeat
21
- import collections.abc
22
-
23
-
24
- def _ntuple(n):
25
- def parse(x):
26
- if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
27
- return x
28
- return tuple(repeat(x, n))
29
- return parse
30
- to_2tuple = _ntuple(2)
31
-
32
- def drop_path(x, drop_prob: float = 0., training: bool = False, scale_by_keep: bool = True):
33
- """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
34
- """
35
- if drop_prob == 0. or not training:
36
- return x
37
- keep_prob = 1 - drop_prob
38
- shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
39
- random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
40
- if keep_prob > 0.0 and scale_by_keep:
41
- random_tensor.div_(keep_prob)
42
- return x * random_tensor
43
-
44
- class DropPath(nn.Module):
45
- """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
46
- """
47
- def __init__(self, drop_prob: float = 0., scale_by_keep: bool = True):
48
- super(DropPath, self).__init__()
49
- self.drop_prob = drop_prob
50
- self.scale_by_keep = scale_by_keep
51
-
52
- def forward(self, x):
53
- return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)
54
-
55
- def extra_repr(self):
56
- return f'drop_prob={round(self.drop_prob,3):0.3f}'
57
-
58
- class Mlp(nn.Module):
59
- """ MLP as used in Vision Transformer, MLP-Mixer and related networks"""
60
- def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, bias=True, drop=0.):
61
- super().__init__()
62
- out_features = out_features or in_features
63
- hidden_features = hidden_features or in_features
64
- bias = to_2tuple(bias)
65
- drop_probs = to_2tuple(drop)
66
-
67
- self.fc1 = nn.Linear(in_features, hidden_features, bias=bias[0])
68
- self.act = act_layer()
69
- self.drop1 = nn.Dropout(drop_probs[0])
70
- self.fc2 = nn.Linear(hidden_features, out_features, bias=bias[1])
71
- self.drop2 = nn.Dropout(drop_probs[1])
72
-
73
- def forward(self, x):
74
- x = self.fc1(x)
75
- x = self.act(x)
76
- x = self.drop1(x)
77
- x = self.fc2(x)
78
- x = self.drop2(x)
79
- return x
80
-
81
- class Attention(nn.Module):
82
-
83
- def __init__(self, dim, rope=None, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.):
84
- super().__init__()
85
- self.num_heads = num_heads
86
- head_dim = dim // num_heads
87
- self.scale = head_dim ** -0.5
88
- self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
89
- self.attn_drop = nn.Dropout(attn_drop)
90
- self.proj = nn.Linear(dim, dim)
91
- self.proj_drop = nn.Dropout(proj_drop)
92
- self.rope = rope
93
-
94
- def forward(self, x, xpos):
95
- B, N, C = x.shape
96
-
97
- qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).transpose(1,3)
98
- q, k, v = [qkv[:,:,i] for i in range(3)]
99
- # q,k,v = qkv.unbind(2) # make torchscript happy (cannot use tensor as tuple)
100
-
101
- if self.rope is not None:
102
- q = self.rope(q, xpos)
103
- k = self.rope(k, xpos)
104
-
105
- attn = (q @ k.transpose(-2, -1)) * self.scale
106
- attn = attn.softmax(dim=-1)
107
- attn = self.attn_drop(attn)
108
-
109
- x = (attn @ v).transpose(1, 2).reshape(B, N, C)
110
- x = self.proj(x)
111
- x = self.proj_drop(x)
112
- return x
113
-
114
- class Block(nn.Module):
115
-
116
- def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0.,
117
- drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, rope=None):
118
- super().__init__()
119
- self.norm1 = norm_layer(dim)
120
- self.attn = Attention(dim, rope=rope, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
121
- # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
122
- self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
123
- self.norm2 = norm_layer(dim)
124
- mlp_hidden_dim = int(dim * mlp_ratio)
125
- self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
126
-
127
- def forward(self, x, xpos):
128
- x = x + self.drop_path(self.attn(self.norm1(x), xpos))
129
- x = x + self.drop_path(self.mlp(self.norm2(x)))
130
- return x
131
-
132
- class CrossAttention(nn.Module):
133
-
134
- def __init__(self, dim, rope=None, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.):
135
- super().__init__()
136
- self.num_heads = num_heads
137
- head_dim = dim // num_heads
138
- self.scale = head_dim ** -0.5
139
-
140
- self.projq = nn.Linear(dim, dim, bias=qkv_bias)
141
- self.projk = nn.Linear(dim, dim, bias=qkv_bias)
142
- self.projv = nn.Linear(dim, dim, bias=qkv_bias)
143
- self.attn_drop = nn.Dropout(attn_drop)
144
- self.proj = nn.Linear(dim, dim)
145
- self.proj_drop = nn.Dropout(proj_drop)
146
-
147
- self.rope = rope
148
-
149
- def forward(self, query, key, value, qpos, kpos):
150
- B, Nq, C = query.shape
151
- Nk = key.shape[1]
152
- Nv = value.shape[1]
153
-
154
- q = self.projq(query).reshape(B,Nq,self.num_heads, C// self.num_heads).permute(0, 2, 1, 3)
155
- k = self.projk(key).reshape(B,Nk,self.num_heads, C// self.num_heads).permute(0, 2, 1, 3)
156
- v = self.projv(value).reshape(B,Nv,self.num_heads, C// self.num_heads).permute(0, 2, 1, 3)
157
-
158
- if self.rope is not None:
159
- q = self.rope(q, qpos)
160
- k = self.rope(k, kpos)
161
-
162
- attn = (q @ k.transpose(-2, -1)) * self.scale
163
- attn = attn.softmax(dim=-1)
164
- attn = self.attn_drop(attn)
165
-
166
- x = (attn @ v).transpose(1, 2).reshape(B, Nq, C)
167
- x = self.proj(x)
168
- x = self.proj_drop(x)
169
- return x
170
-
171
- class DecoderBlock(nn.Module):
172
-
173
- def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0.,
174
- drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, norm_mem=True, rope=None):
175
- super().__init__()
176
- self.norm1 = norm_layer(dim)
177
- self.attn = Attention(dim, rope=rope, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
178
- self.cross_attn = CrossAttention(dim, rope=rope, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
179
- self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
180
- self.norm2 = norm_layer(dim)
181
- self.norm3 = norm_layer(dim)
182
- mlp_hidden_dim = int(dim * mlp_ratio)
183
- self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
184
- self.norm_y = norm_layer(dim) if norm_mem else nn.Identity()
185
-
186
- def forward(self, x, y, xpos, ypos):
187
- x = x + self.drop_path(self.attn(self.norm1(x), xpos))
188
- y_ = self.norm_y(y)
189
- x = x + self.drop_path(self.cross_attn(self.norm2(x), y_, y_, xpos, ypos))
190
- x = x + self.drop_path(self.mlp(self.norm3(x)))
191
- return x, y
192
-
193
-
194
- # patch embedding
195
- class PositionGetter(object):
196
- """ return positions of patches """
197
-
198
- def __init__(self):
199
- self.cache_positions = {}
200
-
201
- def __call__(self, b, h, w, device):
202
- if not (h,w) in self.cache_positions:
203
- x = torch.arange(w, device=device)
204
- y = torch.arange(h, device=device)
205
- self.cache_positions[h,w] = torch.cartesian_prod(y, x) # (h, w, 2)
206
- pos = self.cache_positions[h,w].view(1, h*w, 2).expand(b, -1, 2).clone()
207
- return pos
208
-
209
- class PatchEmbed(nn.Module):
210
- """ just adding _init_weights + position getter compared to timm.models.layers.patch_embed.PatchEmbed"""
211
-
212
- def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, norm_layer=None, flatten=True):
213
- super().__init__()
214
- img_size = to_2tuple(img_size)
215
- patch_size = to_2tuple(patch_size)
216
- self.img_size = img_size
217
- self.patch_size = patch_size
218
- self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
219
- self.num_patches = self.grid_size[0] * self.grid_size[1]
220
- self.flatten = flatten
221
-
222
- self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
223
- self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
224
-
225
- self.position_getter = PositionGetter()
226
-
227
- def forward(self, x):
228
- B, C, H, W = x.shape
229
- torch._assert(H == self.img_size[0], f"Input image height ({H}) doesn't match model ({self.img_size[0]}).")
230
- torch._assert(W == self.img_size[1], f"Input image width ({W}) doesn't match model ({self.img_size[1]}).")
231
- x = self.proj(x)
232
- pos = self.position_getter(B, x.size(2), x.size(3), x.device)
233
- if self.flatten:
234
- x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
235
- x = self.norm(x)
236
- return x, pos
237
-
238
- def _init_weights(self):
239
- w = self.proj.weight.data
240
- torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
241
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mini_dust3r/croco/croco.py DELETED
@@ -1,249 +0,0 @@
1
- # Copyright (C) 2022-present Naver Corporation. All rights reserved.
2
- # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
-
4
-
5
- # --------------------------------------------------------
6
- # CroCo model during pretraining
7
- # --------------------------------------------------------
8
-
9
-
10
-
11
- import torch
12
- import torch.nn as nn
13
- torch.backends.cuda.matmul.allow_tf32 = True # for gpu >= Ampere and pytorch >= 1.12
14
- from functools import partial
15
-
16
- from mini_dust3r.croco.blocks import Block, DecoderBlock, PatchEmbed
17
- from mini_dust3r.croco.pos_embed import get_2d_sincos_pos_embed, RoPE2D
18
- from mini_dust3r.croco.masking import RandomMask
19
-
20
-
21
- class CroCoNet(nn.Module):
22
-
23
- def __init__(self,
24
- img_size=224, # input image size
25
- patch_size=16, # patch_size
26
- mask_ratio=0.9, # ratios of masked tokens
27
- enc_embed_dim=768, # encoder feature dimension
28
- enc_depth=12, # encoder depth
29
- enc_num_heads=12, # encoder number of heads in the transformer block
30
- dec_embed_dim=512, # decoder feature dimension
31
- dec_depth=8, # decoder depth
32
- dec_num_heads=16, # decoder number of heads in the transformer block
33
- mlp_ratio=4,
34
- norm_layer=partial(nn.LayerNorm, eps=1e-6),
35
- norm_im2_in_dec=True, # whether to apply normalization of the 'memory' = (second image) in the decoder
36
- pos_embed='cosine', # positional embedding (either cosine or RoPE100)
37
- ):
38
-
39
- super(CroCoNet, self).__init__()
40
-
41
- # patch embeddings (with initialization done as in MAE)
42
- self._set_patch_embed(img_size, patch_size, enc_embed_dim)
43
-
44
- # mask generations
45
- self._set_mask_generator(self.patch_embed.num_patches, mask_ratio)
46
-
47
- self.pos_embed = pos_embed
48
- if pos_embed=='cosine':
49
- # positional embedding of the encoder
50
- enc_pos_embed = get_2d_sincos_pos_embed(enc_embed_dim, int(self.patch_embed.num_patches**.5), n_cls_token=0)
51
- self.register_buffer('enc_pos_embed', torch.from_numpy(enc_pos_embed).float())
52
- # positional embedding of the decoder
53
- dec_pos_embed = get_2d_sincos_pos_embed(dec_embed_dim, int(self.patch_embed.num_patches**.5), n_cls_token=0)
54
- self.register_buffer('dec_pos_embed', torch.from_numpy(dec_pos_embed).float())
55
- # pos embedding in each block
56
- self.rope = None # nothing for cosine
57
- elif pos_embed.startswith('RoPE'): # eg RoPE100
58
- self.enc_pos_embed = None # nothing to add in the encoder with RoPE
59
- self.dec_pos_embed = None # nothing to add in the decoder with RoPE
60
- if RoPE2D is None: raise ImportError("Cannot find cuRoPE2D, please install it following the README instructions")
61
- freq = float(pos_embed[len('RoPE'):])
62
- self.rope = RoPE2D(freq=freq)
63
- else:
64
- raise NotImplementedError('Unknown pos_embed '+pos_embed)
65
-
66
- # transformer for the encoder
67
- self.enc_depth = enc_depth
68
- self.enc_embed_dim = enc_embed_dim
69
- self.enc_blocks = nn.ModuleList([
70
- Block(enc_embed_dim, enc_num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer, rope=self.rope)
71
- for i in range(enc_depth)])
72
- self.enc_norm = norm_layer(enc_embed_dim)
73
-
74
- # masked tokens
75
- self._set_mask_token(dec_embed_dim)
76
-
77
- # decoder
78
- self._set_decoder(enc_embed_dim, dec_embed_dim, dec_num_heads, dec_depth, mlp_ratio, norm_layer, norm_im2_in_dec)
79
-
80
- # prediction head
81
- self._set_prediction_head(dec_embed_dim, patch_size)
82
-
83
- # initializer weights
84
- self.initialize_weights()
85
-
86
- def _set_patch_embed(self, img_size=224, patch_size=16, enc_embed_dim=768):
87
- self.patch_embed = PatchEmbed(img_size, patch_size, 3, enc_embed_dim)
88
-
89
- def _set_mask_generator(self, num_patches, mask_ratio):
90
- self.mask_generator = RandomMask(num_patches, mask_ratio)
91
-
92
- def _set_mask_token(self, dec_embed_dim):
93
- self.mask_token = nn.Parameter(torch.zeros(1, 1, dec_embed_dim))
94
-
95
- def _set_decoder(self, enc_embed_dim, dec_embed_dim, dec_num_heads, dec_depth, mlp_ratio, norm_layer, norm_im2_in_dec):
96
- self.dec_depth = dec_depth
97
- self.dec_embed_dim = dec_embed_dim
98
- # transfer from encoder to decoder
99
- self.decoder_embed = nn.Linear(enc_embed_dim, dec_embed_dim, bias=True)
100
- # transformer for the decoder
101
- self.dec_blocks = nn.ModuleList([
102
- DecoderBlock(dec_embed_dim, dec_num_heads, mlp_ratio=mlp_ratio, qkv_bias=True, norm_layer=norm_layer, norm_mem=norm_im2_in_dec, rope=self.rope)
103
- for i in range(dec_depth)])
104
- # final norm layer
105
- self.dec_norm = norm_layer(dec_embed_dim)
106
-
107
- def _set_prediction_head(self, dec_embed_dim, patch_size):
108
- self.prediction_head = nn.Linear(dec_embed_dim, patch_size**2 * 3, bias=True)
109
-
110
-
111
- def initialize_weights(self):
112
- # patch embed
113
- self.patch_embed._init_weights()
114
- # mask tokens
115
- if self.mask_token is not None: torch.nn.init.normal_(self.mask_token, std=.02)
116
- # linears and layer norms
117
- self.apply(self._init_weights)
118
-
119
- def _init_weights(self, m):
120
- if isinstance(m, nn.Linear):
121
- # we use xavier_uniform following official JAX ViT:
122
- torch.nn.init.xavier_uniform_(m.weight)
123
- if isinstance(m, nn.Linear) and m.bias is not None:
124
- nn.init.constant_(m.bias, 0)
125
- elif isinstance(m, nn.LayerNorm):
126
- nn.init.constant_(m.bias, 0)
127
- nn.init.constant_(m.weight, 1.0)
128
-
129
- def _encode_image(self, image, do_mask=False, return_all_blocks=False):
130
- """
131
- image has B x 3 x img_size x img_size
132
- do_mask: whether to perform masking or not
133
- return_all_blocks: if True, return the features at the end of every block
134
- instead of just the features from the last block (eg for some prediction heads)
135
- """
136
- # embed the image into patches (x has size B x Npatches x C)
137
- # and get position if each return patch (pos has size B x Npatches x 2)
138
- x, pos = self.patch_embed(image)
139
- # add positional embedding without cls token
140
- if self.enc_pos_embed is not None:
141
- x = x + self.enc_pos_embed[None,...]
142
- # apply masking
143
- B,N,C = x.size()
144
- if do_mask:
145
- masks = self.mask_generator(x)
146
- x = x[~masks].view(B, -1, C)
147
- posvis = pos[~masks].view(B, -1, 2)
148
- else:
149
- B,N,C = x.size()
150
- masks = torch.zeros((B,N), dtype=bool)
151
- posvis = pos
152
- # now apply the transformer encoder and normalization
153
- if return_all_blocks:
154
- out = []
155
- for blk in self.enc_blocks:
156
- x = blk(x, posvis)
157
- out.append(x)
158
- out[-1] = self.enc_norm(out[-1])
159
- return out, pos, masks
160
- else:
161
- for blk in self.enc_blocks:
162
- x = blk(x, posvis)
163
- x = self.enc_norm(x)
164
- return x, pos, masks
165
-
166
- def _decoder(self, feat1, pos1, masks1, feat2, pos2, return_all_blocks=False):
167
- """
168
- return_all_blocks: if True, return the features at the end of every block
169
- instead of just the features from the last block (eg for some prediction heads)
170
-
171
- masks1 can be None => assume image1 fully visible
172
- """
173
- # encoder to decoder layer
174
- visf1 = self.decoder_embed(feat1)
175
- f2 = self.decoder_embed(feat2)
176
- # append masked tokens to the sequence
177
- B,Nenc,C = visf1.size()
178
- if masks1 is None: # downstreams
179
- f1_ = visf1
180
- else: # pretraining
181
- Ntotal = masks1.size(1)
182
- f1_ = self.mask_token.repeat(B, Ntotal, 1).to(dtype=visf1.dtype)
183
- f1_[~masks1] = visf1.view(B * Nenc, C)
184
- # add positional embedding
185
- if self.dec_pos_embed is not None:
186
- f1_ = f1_ + self.dec_pos_embed
187
- f2 = f2 + self.dec_pos_embed
188
- # apply Transformer blocks
189
- out = f1_
190
- out2 = f2
191
- if return_all_blocks:
192
- _out, out = out, []
193
- for blk in self.dec_blocks:
194
- _out, out2 = blk(_out, out2, pos1, pos2)
195
- out.append(_out)
196
- out[-1] = self.dec_norm(out[-1])
197
- else:
198
- for blk in self.dec_blocks:
199
- out, out2 = blk(out, out2, pos1, pos2)
200
- out = self.dec_norm(out)
201
- return out
202
-
203
- def patchify(self, imgs):
204
- """
205
- imgs: (B, 3, H, W)
206
- x: (B, L, patch_size**2 *3)
207
- """
208
- p = self.patch_embed.patch_size[0]
209
- assert imgs.shape[2] == imgs.shape[3] and imgs.shape[2] % p == 0
210
-
211
- h = w = imgs.shape[2] // p
212
- x = imgs.reshape(shape=(imgs.shape[0], 3, h, p, w, p))
213
- x = torch.einsum('nchpwq->nhwpqc', x)
214
- x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * 3))
215
-
216
- return x
217
-
218
- def unpatchify(self, x, channels=3):
219
- """
220
- x: (N, L, patch_size**2 *channels)
221
- imgs: (N, 3, H, W)
222
- """
223
- patch_size = self.patch_embed.patch_size[0]
224
- h = w = int(x.shape[1]**.5)
225
- assert h * w == x.shape[1]
226
- x = x.reshape(shape=(x.shape[0], h, w, patch_size, patch_size, channels))
227
- x = torch.einsum('nhwpqc->nchpwq', x)
228
- imgs = x.reshape(shape=(x.shape[0], channels, h * patch_size, h * patch_size))
229
- return imgs
230
-
231
- def forward(self, img1, img2):
232
- """
233
- img1: tensor of size B x 3 x img_size x img_size
234
- img2: tensor of size B x 3 x img_size x img_size
235
-
236
- out will be B x N x (3*patch_size*patch_size)
237
- masks are also returned as B x N just in case
238
- """
239
- # encoder of the masked first image
240
- feat1, pos1, mask1 = self._encode_image(img1, do_mask=True)
241
- # encoder of the second image
242
- feat2, pos2, _ = self._encode_image(img2, do_mask=False)
243
- # decoder
244
- decfeat = self._decoder(feat1, pos1, mask1, feat2, pos2)
245
- # prediction head
246
- out = self.prediction_head(decfeat)
247
- # get target
248
- target = self.patchify(img1)
249
- return out, mask1, target
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mini_dust3r/croco/dpt_block.py DELETED
@@ -1,450 +0,0 @@
1
- # Copyright (C) 2022-present Naver Corporation. All rights reserved.
2
- # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
-
4
- # --------------------------------------------------------
5
- # DPT head for ViTs
6
- # --------------------------------------------------------
7
- # References:
8
- # https://github.com/isl-org/DPT
9
- # https://github.com/EPFL-VILAB/MultiMAE/blob/main/multimae/output_adapters.py
10
-
11
- import torch
12
- import torch.nn as nn
13
- import torch.nn.functional as F
14
- from einops import rearrange, repeat
15
- from typing import Union, Tuple, Iterable, List, Optional, Dict
16
-
17
- def pair(t):
18
- return t if isinstance(t, tuple) else (t, t)
19
-
20
- def make_scratch(in_shape, out_shape, groups=1, expand=False):
21
- scratch = nn.Module()
22
-
23
- out_shape1 = out_shape
24
- out_shape2 = out_shape
25
- out_shape3 = out_shape
26
- out_shape4 = out_shape
27
- if expand == True:
28
- out_shape1 = out_shape
29
- out_shape2 = out_shape * 2
30
- out_shape3 = out_shape * 4
31
- out_shape4 = out_shape * 8
32
-
33
- scratch.layer1_rn = nn.Conv2d(
34
- in_shape[0],
35
- out_shape1,
36
- kernel_size=3,
37
- stride=1,
38
- padding=1,
39
- bias=False,
40
- groups=groups,
41
- )
42
- scratch.layer2_rn = nn.Conv2d(
43
- in_shape[1],
44
- out_shape2,
45
- kernel_size=3,
46
- stride=1,
47
- padding=1,
48
- bias=False,
49
- groups=groups,
50
- )
51
- scratch.layer3_rn = nn.Conv2d(
52
- in_shape[2],
53
- out_shape3,
54
- kernel_size=3,
55
- stride=1,
56
- padding=1,
57
- bias=False,
58
- groups=groups,
59
- )
60
- scratch.layer4_rn = nn.Conv2d(
61
- in_shape[3],
62
- out_shape4,
63
- kernel_size=3,
64
- stride=1,
65
- padding=1,
66
- bias=False,
67
- groups=groups,
68
- )
69
-
70
- scratch.layer_rn = nn.ModuleList([
71
- scratch.layer1_rn,
72
- scratch.layer2_rn,
73
- scratch.layer3_rn,
74
- scratch.layer4_rn,
75
- ])
76
-
77
- return scratch
78
-
79
- class ResidualConvUnit_custom(nn.Module):
80
- """Residual convolution module."""
81
-
82
- def __init__(self, features, activation, bn):
83
- """Init.
84
- Args:
85
- features (int): number of features
86
- """
87
- super().__init__()
88
-
89
- self.bn = bn
90
-
91
- self.groups = 1
92
-
93
- self.conv1 = nn.Conv2d(
94
- features,
95
- features,
96
- kernel_size=3,
97
- stride=1,
98
- padding=1,
99
- bias=not self.bn,
100
- groups=self.groups,
101
- )
102
-
103
- self.conv2 = nn.Conv2d(
104
- features,
105
- features,
106
- kernel_size=3,
107
- stride=1,
108
- padding=1,
109
- bias=not self.bn,
110
- groups=self.groups,
111
- )
112
-
113
- if self.bn == True:
114
- self.bn1 = nn.BatchNorm2d(features)
115
- self.bn2 = nn.BatchNorm2d(features)
116
-
117
- self.activation = activation
118
-
119
- self.skip_add = nn.quantized.FloatFunctional()
120
-
121
- def forward(self, x):
122
- """Forward pass.
123
- Args:
124
- x (tensor): input
125
- Returns:
126
- tensor: output
127
- """
128
-
129
- out = self.activation(x)
130
- out = self.conv1(out)
131
- if self.bn == True:
132
- out = self.bn1(out)
133
-
134
- out = self.activation(out)
135
- out = self.conv2(out)
136
- if self.bn == True:
137
- out = self.bn2(out)
138
-
139
- if self.groups > 1:
140
- out = self.conv_merge(out)
141
-
142
- return self.skip_add.add(out, x)
143
-
144
- class FeatureFusionBlock_custom(nn.Module):
145
- """Feature fusion block."""
146
-
147
- def __init__(
148
- self,
149
- features,
150
- activation,
151
- deconv=False,
152
- bn=False,
153
- expand=False,
154
- align_corners=True,
155
- width_ratio=1,
156
- ):
157
- """Init.
158
- Args:
159
- features (int): number of features
160
- """
161
- super(FeatureFusionBlock_custom, self).__init__()
162
- self.width_ratio = width_ratio
163
-
164
- self.deconv = deconv
165
- self.align_corners = align_corners
166
-
167
- self.groups = 1
168
-
169
- self.expand = expand
170
- out_features = features
171
- if self.expand == True:
172
- out_features = features // 2
173
-
174
- self.out_conv = nn.Conv2d(
175
- features,
176
- out_features,
177
- kernel_size=1,
178
- stride=1,
179
- padding=0,
180
- bias=True,
181
- groups=1,
182
- )
183
-
184
- self.resConfUnit1 = ResidualConvUnit_custom(features, activation, bn)
185
- self.resConfUnit2 = ResidualConvUnit_custom(features, activation, bn)
186
-
187
- self.skip_add = nn.quantized.FloatFunctional()
188
-
189
- def forward(self, *xs):
190
- """Forward pass.
191
- Returns:
192
- tensor: output
193
- """
194
- output = xs[0]
195
-
196
- if len(xs) == 2:
197
- res = self.resConfUnit1(xs[1])
198
- if self.width_ratio != 1:
199
- res = F.interpolate(res, size=(output.shape[2], output.shape[3]), mode='bilinear')
200
-
201
- output = self.skip_add.add(output, res)
202
- # output += res
203
-
204
- output = self.resConfUnit2(output)
205
-
206
- if self.width_ratio != 1:
207
- # and output.shape[3] < self.width_ratio * output.shape[2]
208
- #size=(image.shape[])
209
- if (output.shape[3] / output.shape[2]) < (2 / 3) * self.width_ratio:
210
- shape = 3 * output.shape[3]
211
- else:
212
- shape = int(self.width_ratio * 2 * output.shape[2])
213
- output = F.interpolate(output, size=(2* output.shape[2], shape), mode='bilinear')
214
- else:
215
- output = nn.functional.interpolate(output, scale_factor=2,
216
- mode="bilinear", align_corners=self.align_corners)
217
- output = self.out_conv(output)
218
- return output
219
-
220
- def make_fusion_block(features, use_bn, width_ratio=1):
221
- return FeatureFusionBlock_custom(
222
- features,
223
- nn.ReLU(False),
224
- deconv=False,
225
- bn=use_bn,
226
- expand=False,
227
- align_corners=True,
228
- width_ratio=width_ratio,
229
- )
230
-
231
- class Interpolate(nn.Module):
232
- """Interpolation module."""
233
-
234
- def __init__(self, scale_factor, mode, align_corners=False):
235
- """Init.
236
- Args:
237
- scale_factor (float): scaling
238
- mode (str): interpolation mode
239
- """
240
- super(Interpolate, self).__init__()
241
-
242
- self.interp = nn.functional.interpolate
243
- self.scale_factor = scale_factor
244
- self.mode = mode
245
- self.align_corners = align_corners
246
-
247
- def forward(self, x):
248
- """Forward pass.
249
- Args:
250
- x (tensor): input
251
- Returns:
252
- tensor: interpolated data
253
- """
254
-
255
- x = self.interp(
256
- x,
257
- scale_factor=self.scale_factor,
258
- mode=self.mode,
259
- align_corners=self.align_corners,
260
- )
261
-
262
- return x
263
-
264
- class DPTOutputAdapter(nn.Module):
265
- """DPT output adapter.
266
-
267
- :param num_cahnnels: Number of output channels
268
- :param stride_level: tride level compared to the full-sized image.
269
- E.g. 4 for 1/4th the size of the image.
270
- :param patch_size_full: Int or tuple of the patch size over the full image size.
271
- Patch size for smaller inputs will be computed accordingly.
272
- :param hooks: Index of intermediate layers
273
- :param layer_dims: Dimension of intermediate layers
274
- :param feature_dim: Feature dimension
275
- :param last_dim: out_channels/in_channels for the last two Conv2d when head_type == regression
276
- :param use_bn: If set to True, activates batch norm
277
- :param dim_tokens_enc: Dimension of tokens coming from encoder
278
- """
279
-
280
- def __init__(self,
281
- num_channels: int = 1,
282
- stride_level: int = 1,
283
- patch_size: Union[int, Tuple[int, int]] = 16,
284
- main_tasks: Iterable[str] = ('rgb',),
285
- hooks: List[int] = [2, 5, 8, 11],
286
- layer_dims: List[int] = [96, 192, 384, 768],
287
- feature_dim: int = 256,
288
- last_dim: int = 32,
289
- use_bn: bool = False,
290
- dim_tokens_enc: Optional[int] = None,
291
- head_type: str = 'regression',
292
- output_width_ratio=1,
293
- **kwargs):
294
- super().__init__()
295
- self.num_channels = num_channels
296
- self.stride_level = stride_level
297
- self.patch_size = pair(patch_size)
298
- self.main_tasks = main_tasks
299
- self.hooks = hooks
300
- self.layer_dims = layer_dims
301
- self.feature_dim = feature_dim
302
- self.dim_tokens_enc = dim_tokens_enc * len(self.main_tasks) if dim_tokens_enc is not None else None
303
- self.head_type = head_type
304
-
305
- # Actual patch height and width, taking into account stride of input
306
- self.P_H = max(1, self.patch_size[0] // stride_level)
307
- self.P_W = max(1, self.patch_size[1] // stride_level)
308
-
309
- self.scratch = make_scratch(layer_dims, feature_dim, groups=1, expand=False)
310
-
311
- self.scratch.refinenet1 = make_fusion_block(feature_dim, use_bn, output_width_ratio)
312
- self.scratch.refinenet2 = make_fusion_block(feature_dim, use_bn, output_width_ratio)
313
- self.scratch.refinenet3 = make_fusion_block(feature_dim, use_bn, output_width_ratio)
314
- self.scratch.refinenet4 = make_fusion_block(feature_dim, use_bn, output_width_ratio)
315
-
316
- if self.head_type == 'regression':
317
- # The "DPTDepthModel" head
318
- self.head = nn.Sequential(
319
- nn.Conv2d(feature_dim, feature_dim // 2, kernel_size=3, stride=1, padding=1),
320
- Interpolate(scale_factor=2, mode="bilinear", align_corners=True),
321
- nn.Conv2d(feature_dim // 2, last_dim, kernel_size=3, stride=1, padding=1),
322
- nn.ReLU(True),
323
- nn.Conv2d(last_dim, self.num_channels, kernel_size=1, stride=1, padding=0)
324
- )
325
- elif self.head_type == 'semseg':
326
- # The "DPTSegmentationModel" head
327
- self.head = nn.Sequential(
328
- nn.Conv2d(feature_dim, feature_dim, kernel_size=3, padding=1, bias=False),
329
- nn.BatchNorm2d(feature_dim) if use_bn else nn.Identity(),
330
- nn.ReLU(True),
331
- nn.Dropout(0.1, False),
332
- nn.Conv2d(feature_dim, self.num_channels, kernel_size=1),
333
- Interpolate(scale_factor=2, mode="bilinear", align_corners=True),
334
- )
335
- else:
336
- raise ValueError('DPT head_type must be "regression" or "semseg".')
337
-
338
- if self.dim_tokens_enc is not None:
339
- self.init(dim_tokens_enc=dim_tokens_enc)
340
-
341
- def init(self, dim_tokens_enc=768):
342
- """
343
- Initialize parts of decoder that are dependent on dimension of encoder tokens.
344
- Should be called when setting up MultiMAE.
345
-
346
- :param dim_tokens_enc: Dimension of tokens coming from encoder
347
- """
348
- #print(dim_tokens_enc)
349
-
350
- # Set up activation postprocessing layers
351
- if isinstance(dim_tokens_enc, int):
352
- dim_tokens_enc = 4 * [dim_tokens_enc]
353
-
354
- self.dim_tokens_enc = [dt * len(self.main_tasks) for dt in dim_tokens_enc]
355
-
356
- self.act_1_postprocess = nn.Sequential(
357
- nn.Conv2d(
358
- in_channels=self.dim_tokens_enc[0],
359
- out_channels=self.layer_dims[0],
360
- kernel_size=1, stride=1, padding=0,
361
- ),
362
- nn.ConvTranspose2d(
363
- in_channels=self.layer_dims[0],
364
- out_channels=self.layer_dims[0],
365
- kernel_size=4, stride=4, padding=0,
366
- bias=True, dilation=1, groups=1,
367
- )
368
- )
369
-
370
- self.act_2_postprocess = nn.Sequential(
371
- nn.Conv2d(
372
- in_channels=self.dim_tokens_enc[1],
373
- out_channels=self.layer_dims[1],
374
- kernel_size=1, stride=1, padding=0,
375
- ),
376
- nn.ConvTranspose2d(
377
- in_channels=self.layer_dims[1],
378
- out_channels=self.layer_dims[1],
379
- kernel_size=2, stride=2, padding=0,
380
- bias=True, dilation=1, groups=1,
381
- )
382
- )
383
-
384
- self.act_3_postprocess = nn.Sequential(
385
- nn.Conv2d(
386
- in_channels=self.dim_tokens_enc[2],
387
- out_channels=self.layer_dims[2],
388
- kernel_size=1, stride=1, padding=0,
389
- )
390
- )
391
-
392
- self.act_4_postprocess = nn.Sequential(
393
- nn.Conv2d(
394
- in_channels=self.dim_tokens_enc[3],
395
- out_channels=self.layer_dims[3],
396
- kernel_size=1, stride=1, padding=0,
397
- ),
398
- nn.Conv2d(
399
- in_channels=self.layer_dims[3],
400
- out_channels=self.layer_dims[3],
401
- kernel_size=3, stride=2, padding=1,
402
- )
403
- )
404
-
405
- self.act_postprocess = nn.ModuleList([
406
- self.act_1_postprocess,
407
- self.act_2_postprocess,
408
- self.act_3_postprocess,
409
- self.act_4_postprocess
410
- ])
411
-
412
- def adapt_tokens(self, encoder_tokens):
413
- # Adapt tokens
414
- x = []
415
- x.append(encoder_tokens[:, :])
416
- x = torch.cat(x, dim=-1)
417
- return x
418
-
419
- def forward(self, encoder_tokens: List[torch.Tensor], image_size):
420
- #input_info: Dict):
421
- assert self.dim_tokens_enc is not None, 'Need to call init(dim_tokens_enc) function first'
422
- H, W = image_size
423
-
424
- # Number of patches in height and width
425
- N_H = H // (self.stride_level * self.P_H)
426
- N_W = W // (self.stride_level * self.P_W)
427
-
428
- # Hook decoder onto 4 layers from specified ViT layers
429
- layers = [encoder_tokens[hook] for hook in self.hooks]
430
-
431
- # Extract only task-relevant tokens and ignore global tokens.
432
- layers = [self.adapt_tokens(l) for l in layers]
433
-
434
- # Reshape tokens to spatial representation
435
- layers = [rearrange(l, 'b (nh nw) c -> b c nh nw', nh=N_H, nw=N_W) for l in layers]
436
-
437
- layers = [self.act_postprocess[idx](l) for idx, l in enumerate(layers)]
438
- # Project layers to chosen feature dim
439
- layers = [self.scratch.layer_rn[idx](l) for idx, l in enumerate(layers)]
440
-
441
- # Fuse layers using refinement stages
442
- path_4 = self.scratch.refinenet4(layers[3])
443
- path_3 = self.scratch.refinenet3(path_4, layers[2])
444
- path_2 = self.scratch.refinenet2(path_3, layers[1])
445
- path_1 = self.scratch.refinenet1(path_2, layers[0])
446
-
447
- # Output head
448
- out = self.head(path_1)
449
-
450
- return out
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mini_dust3r/croco/masking.py DELETED
@@ -1,25 +0,0 @@
1
- # Copyright (C) 2022-present Naver Corporation. All rights reserved.
2
- # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
-
4
-
5
- # --------------------------------------------------------
6
- # Masking utils
7
- # --------------------------------------------------------
8
-
9
- import torch
10
- import torch.nn as nn
11
-
12
- class RandomMask(nn.Module):
13
- """
14
- random masking
15
- """
16
-
17
- def __init__(self, num_patches, mask_ratio):
18
- super().__init__()
19
- self.num_patches = num_patches
20
- self.num_mask = int(mask_ratio * self.num_patches)
21
-
22
- def __call__(self, x):
23
- noise = torch.rand(x.size(0), self.num_patches, device=x.device)
24
- argsort = torch.argsort(noise, dim=1)
25
- return argsort < self.num_mask
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mini_dust3r/croco/pos_embed.py DELETED
@@ -1,159 +0,0 @@
1
- # Copyright (C) 2022-present Naver Corporation. All rights reserved.
2
- # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
-
4
-
5
- # --------------------------------------------------------
6
- # Position embedding utils
7
- # --------------------------------------------------------
8
-
9
-
10
-
11
- import numpy as np
12
-
13
- import torch
14
-
15
- # --------------------------------------------------------
16
- # 2D sine-cosine position embedding
17
- # References:
18
- # MAE: https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py
19
- # Transformer: https://github.com/tensorflow/models/blob/master/official/nlp/transformer/model_utils.py
20
- # MoCo v3: https://github.com/facebookresearch/moco-v3
21
- # --------------------------------------------------------
22
- def get_2d_sincos_pos_embed(embed_dim, grid_size, n_cls_token=0):
23
- """
24
- grid_size: int of the grid height and width
25
- return:
26
- pos_embed: [grid_size*grid_size, embed_dim] or [n_cls_token+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
27
- """
28
- grid_h = np.arange(grid_size, dtype=np.float32)
29
- grid_w = np.arange(grid_size, dtype=np.float32)
30
- grid = np.meshgrid(grid_w, grid_h) # here w goes first
31
- grid = np.stack(grid, axis=0)
32
-
33
- grid = grid.reshape([2, 1, grid_size, grid_size])
34
- pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
35
- if n_cls_token>0:
36
- pos_embed = np.concatenate([np.zeros([n_cls_token, embed_dim]), pos_embed], axis=0)
37
- return pos_embed
38
-
39
-
40
- def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
41
- assert embed_dim % 2 == 0
42
-
43
- # use half of dimensions to encode grid_h
44
- emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
45
- emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
46
-
47
- emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
48
- return emb
49
-
50
-
51
- def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
52
- """
53
- embed_dim: output dimension for each position
54
- pos: a list of positions to be encoded: size (M,)
55
- out: (M, D)
56
- """
57
- assert embed_dim % 2 == 0
58
- omega = np.arange(embed_dim // 2, dtype=float)
59
- omega /= embed_dim / 2.
60
- omega = 1. / 10000**omega # (D/2,)
61
-
62
- pos = pos.reshape(-1) # (M,)
63
- out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
64
-
65
- emb_sin = np.sin(out) # (M, D/2)
66
- emb_cos = np.cos(out) # (M, D/2)
67
-
68
- emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
69
- return emb
70
-
71
-
72
- # --------------------------------------------------------
73
- # Interpolate position embeddings for high-resolution
74
- # References:
75
- # MAE: https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py
76
- # DeiT: https://github.com/facebookresearch/deit
77
- # --------------------------------------------------------
78
- def interpolate_pos_embed(model, checkpoint_model):
79
- if 'pos_embed' in checkpoint_model:
80
- pos_embed_checkpoint = checkpoint_model['pos_embed']
81
- embedding_size = pos_embed_checkpoint.shape[-1]
82
- num_patches = model.patch_embed.num_patches
83
- num_extra_tokens = model.pos_embed.shape[-2] - num_patches
84
- # height (== width) for the checkpoint position embedding
85
- orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
86
- # height (== width) for the new position embedding
87
- new_size = int(num_patches ** 0.5)
88
- # class_token and dist_token are kept unchanged
89
- if orig_size != new_size:
90
- print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size))
91
- extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
92
- # only the position tokens are interpolated
93
- pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
94
- pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
95
- pos_tokens = torch.nn.functional.interpolate(
96
- pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
97
- pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
98
- new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
99
- checkpoint_model['pos_embed'] = new_pos_embed
100
-
101
-
102
- #----------------------------------------------------------
103
- # RoPE2D: RoPE implementation in 2D
104
- #----------------------------------------------------------
105
-
106
- try:
107
- from mini_dust3r.croco.curope import cuRoPE2D
108
- RoPE2D = cuRoPE2D
109
- except ImportError:
110
- print('Warning, cannot find cuda-compiled version of RoPE2D, using a slow pytorch version instead')
111
-
112
- class RoPE2D(torch.nn.Module):
113
-
114
- def __init__(self, freq=100.0, F0=1.0):
115
- super().__init__()
116
- self.base = freq
117
- self.F0 = F0
118
- self.cache = {}
119
-
120
- def get_cos_sin(self, D, seq_len, device, dtype):
121
- if (D,seq_len,device,dtype) not in self.cache:
122
- inv_freq = 1.0 / (self.base ** (torch.arange(0, D, 2).float().to(device) / D))
123
- t = torch.arange(seq_len, device=device, dtype=inv_freq.dtype)
124
- freqs = torch.einsum("i,j->ij", t, inv_freq).to(dtype)
125
- freqs = torch.cat((freqs, freqs), dim=-1)
126
- cos = freqs.cos() # (Seq, Dim)
127
- sin = freqs.sin()
128
- self.cache[D,seq_len,device,dtype] = (cos,sin)
129
- return self.cache[D,seq_len,device,dtype]
130
-
131
- @staticmethod
132
- def rotate_half(x):
133
- x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :]
134
- return torch.cat((-x2, x1), dim=-1)
135
-
136
- def apply_rope1d(self, tokens, pos1d, cos, sin):
137
- assert pos1d.ndim==2
138
- cos = torch.nn.functional.embedding(pos1d, cos)[:, None, :, :]
139
- sin = torch.nn.functional.embedding(pos1d, sin)[:, None, :, :]
140
- return (tokens * cos) + (self.rotate_half(tokens) * sin)
141
-
142
- def forward(self, tokens, positions):
143
- """
144
- input:
145
- * tokens: batch_size x nheads x ntokens x dim
146
- * positions: batch_size x ntokens x 2 (y and x position of each token)
147
- output:
148
- * tokens after appplying RoPE2D (batch_size x nheads x ntokens x dim)
149
- """
150
- assert tokens.size(3)%2==0, "number of dimensions should be a multiple of two"
151
- D = tokens.size(3) // 2
152
- assert positions.ndim==3 and positions.shape[-1] == 2 # Batch, Seq, 2
153
- cos, sin = self.get_cos_sin(D, int(positions.max())+1, tokens.device, tokens.dtype)
154
- # split features into two along the feature dimension, and apply rope1d on each half
155
- y, x = tokens.chunk(2, dim=-1)
156
- y = self.apply_rope1d(y, positions[:,:,0], cos, sin)
157
- x = self.apply_rope1d(x, positions[:,:,1], cos, sin)
158
- tokens = torch.cat((y, x), dim=-1)
159
- return tokens
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mini_dust3r/heads/__init__.py DELETED
@@ -1,19 +0,0 @@
1
- # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2
- # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
- #
4
- # --------------------------------------------------------
5
- # head factory
6
- # --------------------------------------------------------
7
- from .linear_head import LinearPts3d
8
- from .dpt_head import create_dpt_head
9
-
10
-
11
- def head_factory(head_type, output_mode, net, has_conf=False):
12
- """" build a prediction head for the decoder
13
- """
14
- if head_type == 'linear' and output_mode == 'pts3d':
15
- return LinearPts3d(net, has_conf)
16
- elif head_type == 'dpt' and output_mode == 'pts3d':
17
- return create_dpt_head(net, has_conf=has_conf)
18
- else:
19
- raise NotImplementedError(f"unexpected {head_type=} and {output_mode=}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mini_dust3r/heads/dpt_head.py DELETED
@@ -1,114 +0,0 @@
1
- # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2
- # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
- #
4
- # --------------------------------------------------------
5
- # dpt head implementation for DUST3R
6
- # Downstream heads assume inputs of size B x N x C (where N is the number of tokens) ;
7
- # or if it takes as input the output at every layer, the attribute return_all_layers should be set to True
8
- # the forward function also takes as input a dictionnary img_info with key "height" and "width"
9
- # for PixelwiseTask, the output will be of dimension B x num_channels x H x W
10
- # --------------------------------------------------------
11
- from einops import rearrange
12
- from typing import List
13
- import torch
14
- import torch.nn as nn
15
- from mini_dust3r.heads.postprocess import postprocess
16
- from mini_dust3r.croco.dpt_block import DPTOutputAdapter
17
-
18
-
19
- class DPTOutputAdapter_fix(DPTOutputAdapter):
20
- """
21
- Adapt croco's DPTOutputAdapter implementation for dust3r:
22
- remove duplicated weigths, and fix forward for dust3r
23
- """
24
-
25
- def init(self, dim_tokens_enc=768):
26
- super().init(dim_tokens_enc)
27
- # these are duplicated weights
28
- del self.act_1_postprocess
29
- del self.act_2_postprocess
30
- del self.act_3_postprocess
31
- del self.act_4_postprocess
32
-
33
- def forward(self, encoder_tokens: List[torch.Tensor], image_size=None):
34
- assert self.dim_tokens_enc is not None, 'Need to call init(dim_tokens_enc) function first'
35
- # H, W = input_info['image_size']
36
- image_size = self.image_size if image_size is None else image_size
37
- H, W = image_size
38
- # Number of patches in height and width
39
- N_H = H // (self.stride_level * self.P_H)
40
- N_W = W // (self.stride_level * self.P_W)
41
-
42
- # Hook decoder onto 4 layers from specified ViT layers
43
- layers = [encoder_tokens[hook] for hook in self.hooks]
44
-
45
- # Extract only task-relevant tokens and ignore global tokens.
46
- layers = [self.adapt_tokens(l) for l in layers]
47
-
48
- # Reshape tokens to spatial representation
49
- layers = [rearrange(l, 'b (nh nw) c -> b c nh nw', nh=N_H, nw=N_W) for l in layers]
50
-
51
- layers = [self.act_postprocess[idx](l) for idx, l in enumerate(layers)]
52
- # Project layers to chosen feature dim
53
- layers = [self.scratch.layer_rn[idx](l) for idx, l in enumerate(layers)]
54
-
55
- # Fuse layers using refinement stages
56
- path_4 = self.scratch.refinenet4(layers[3])[:, :, :layers[2].shape[2], :layers[2].shape[3]]
57
- path_3 = self.scratch.refinenet3(path_4, layers[2])
58
- path_2 = self.scratch.refinenet2(path_3, layers[1])
59
- path_1 = self.scratch.refinenet1(path_2, layers[0])
60
-
61
- # Output head
62
- out = self.head(path_1)
63
-
64
- return out
65
-
66
-
67
- class PixelwiseTaskWithDPT(nn.Module):
68
- """ DPT module for dust3r, can return 3D points + confidence for all pixels"""
69
-
70
- def __init__(self, *, n_cls_token=0, hooks_idx=None, dim_tokens=None,
71
- output_width_ratio=1, num_channels=1, postprocess=None, depth_mode=None, conf_mode=None, **kwargs):
72
- super(PixelwiseTaskWithDPT, self).__init__()
73
- self.return_all_layers = True # backbone needs to return all layers
74
- self.postprocess = postprocess
75
- self.depth_mode = depth_mode
76
- self.conf_mode = conf_mode
77
-
78
- assert n_cls_token == 0, "Not implemented"
79
- dpt_args = dict(output_width_ratio=output_width_ratio,
80
- num_channels=num_channels,
81
- **kwargs)
82
- if hooks_idx is not None:
83
- dpt_args.update(hooks=hooks_idx)
84
- self.dpt = DPTOutputAdapter_fix(**dpt_args)
85
- dpt_init_args = {} if dim_tokens is None else {'dim_tokens_enc': dim_tokens}
86
- self.dpt.init(**dpt_init_args)
87
-
88
- def forward(self, x, img_info):
89
- out = self.dpt(x, image_size=(img_info[0], img_info[1]))
90
- if self.postprocess:
91
- out = self.postprocess(out, self.depth_mode, self.conf_mode)
92
- return out
93
-
94
-
95
- def create_dpt_head(net, has_conf=False):
96
- """
97
- return PixelwiseTaskWithDPT for given net params
98
- """
99
- assert net.dec_depth > 9
100
- l2 = net.dec_depth
101
- feature_dim = 256
102
- last_dim = feature_dim//2
103
- out_nchan = 3
104
- ed = net.enc_embed_dim
105
- dd = net.dec_embed_dim
106
- return PixelwiseTaskWithDPT(num_channels=out_nchan + has_conf,
107
- feature_dim=feature_dim,
108
- last_dim=last_dim,
109
- hooks_idx=[0, l2*2//4, l2*3//4, l2],
110
- dim_tokens=[ed, dd, dd, dd],
111
- postprocess=postprocess,
112
- depth_mode=net.depth_mode,
113
- conf_mode=net.conf_mode,
114
- head_type='regression')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mini_dust3r/heads/linear_head.py DELETED
@@ -1,41 +0,0 @@
1
- # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2
- # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
- #
4
- # --------------------------------------------------------
5
- # linear head implementation for DUST3R
6
- # --------------------------------------------------------
7
- import torch.nn as nn
8
- import torch.nn.functional as F
9
- from mini_dust3r.heads.postprocess import postprocess
10
-
11
-
12
- class LinearPts3d (nn.Module):
13
- """
14
- Linear head for dust3r
15
- Each token outputs: - 16x16 3D points (+ confidence)
16
- """
17
-
18
- def __init__(self, net, has_conf=False):
19
- super().__init__()
20
- self.patch_size = net.patch_embed.patch_size[0]
21
- self.depth_mode = net.depth_mode
22
- self.conf_mode = net.conf_mode
23
- self.has_conf = has_conf
24
-
25
- self.proj = nn.Linear(net.dec_embed_dim, (3 + has_conf)*self.patch_size**2)
26
-
27
- def setup(self, croconet):
28
- pass
29
-
30
- def forward(self, decout, img_shape):
31
- H, W = img_shape
32
- tokens = decout[-1]
33
- B, S, D = tokens.shape
34
-
35
- # extract 3D points
36
- feat = self.proj(tokens) # B,S,D
37
- feat = feat.transpose(-1, -2).view(B, -1, H//self.patch_size, W//self.patch_size)
38
- feat = F.pixel_shuffle(feat, self.patch_size) # B,3,H,W
39
-
40
- # permute + norm depth
41
- return postprocess(feat, self.depth_mode, self.conf_mode)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mini_dust3r/heads/postprocess.py DELETED
@@ -1,58 +0,0 @@
1
- # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2
- # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
- #
4
- # --------------------------------------------------------
5
- # post process function for all heads: extract 3D points/confidence from output
6
- # --------------------------------------------------------
7
- import torch
8
-
9
-
10
- def postprocess(out, depth_mode, conf_mode):
11
- """
12
- extract 3D points/confidence from prediction head output
13
- """
14
- fmap = out.permute(0, 2, 3, 1) # B,H,W,3
15
- res = dict(pts3d=reg_dense_depth(fmap[:, :, :, 0:3], mode=depth_mode))
16
-
17
- if conf_mode is not None:
18
- res['conf'] = reg_dense_conf(fmap[:, :, :, 3], mode=conf_mode)
19
- return res
20
-
21
-
22
- def reg_dense_depth(xyz, mode):
23
- """
24
- extract 3D points from prediction head output
25
- """
26
- mode, vmin, vmax = mode
27
-
28
- no_bounds = (vmin == -float('inf')) and (vmax == float('inf'))
29
- assert no_bounds
30
-
31
- if mode == 'linear':
32
- if no_bounds:
33
- return xyz # [-inf, +inf]
34
- return xyz.clip(min=vmin, max=vmax)
35
-
36
- # distance to origin
37
- d = xyz.norm(dim=-1, keepdim=True)
38
- xyz = xyz / d.clip(min=1e-8)
39
-
40
- if mode == 'square':
41
- return xyz * d.square()
42
-
43
- if mode == 'exp':
44
- return xyz * torch.expm1(d)
45
-
46
- raise ValueError(f'bad {mode=}')
47
-
48
-
49
- def reg_dense_conf(x, mode):
50
- """
51
- extract confidence from prediction head output
52
- """
53
- mode, vmin, vmax = mode
54
- if mode == 'exp':
55
- return vmin + x.exp().clip(max=vmax-vmin)
56
- if mode == 'sigmoid':
57
- return (vmax - vmin) * torch.sigmoid(x) + vmin
58
- raise ValueError(f'bad {mode=}')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mini_dust3r/image_pairs.py DELETED
@@ -1,85 +0,0 @@
1
- # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2
- # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
- #
4
- # --------------------------------------------------------
5
- # utilities needed to load image pairs
6
- # --------------------------------------------------------
7
- import numpy as np
8
- import torch
9
- from mini_dust3r.utils.image import ImageDict
10
-
11
-
12
- def make_pairs(
13
- imgs: list[ImageDict],
14
- scene_graph: str = "complete",
15
- prefilter=None,
16
- symmetrize=True,
17
- ) -> list[tuple[ImageDict, ImageDict]]:
18
- pairs = []
19
- if scene_graph == "complete": # complete graph
20
- for i in range(len(imgs)):
21
- for j in range(i):
22
- pairs.append((imgs[i], imgs[j]))
23
- elif scene_graph.startswith("swin"):
24
- winsize = int(scene_graph.split("-")[1]) if "-" in scene_graph else 3
25
- pairsid = set()
26
- for i in range(len(imgs)):
27
- for j in range(1, winsize + 1):
28
- idx = (i + j) % len(imgs) # explicit loop closure
29
- pairsid.add((i, idx) if i < idx else (idx, i))
30
- for i, j in pairsid:
31
- pairs.append((imgs[i], imgs[j]))
32
- elif scene_graph.startswith("oneref"):
33
- refid = int(scene_graph.split("-")[1]) if "-" in scene_graph else 0
34
- for j in range(len(imgs)):
35
- if j != refid:
36
- pairs.append((imgs[refid], imgs[j]))
37
- if symmetrize:
38
- pairs += [(img2, img1) for img1, img2 in pairs]
39
-
40
- # now, remove edges
41
- if isinstance(prefilter, str) and prefilter.startswith("seq"):
42
- pairs = filter_pairs_seq(pairs, int(prefilter[3:]))
43
-
44
- if isinstance(prefilter, str) and prefilter.startswith("cyc"):
45
- pairs = filter_pairs_seq(pairs, int(prefilter[3:]), cyclic=True)
46
-
47
- return pairs
48
-
49
-
50
- def sel(x, kept):
51
- if isinstance(x, dict):
52
- return {k: sel(v, kept) for k, v in x.items()}
53
- if isinstance(x, (torch.Tensor, np.ndarray)):
54
- return x[kept]
55
- if isinstance(x, (tuple, list)):
56
- return type(x)([x[k] for k in kept])
57
-
58
-
59
- def _filter_edges_seq(edges, seq_dis_thr, cyclic=False):
60
- # number of images
61
- n = max(max(e) for e in edges) + 1
62
-
63
- kept = []
64
- for e, (i, j) in enumerate(edges):
65
- dis = abs(i - j)
66
- if cyclic:
67
- dis = min(dis, abs(i + n - j), abs(i - n - j))
68
- if dis <= seq_dis_thr:
69
- kept.append(e)
70
- return kept
71
-
72
-
73
- def filter_pairs_seq(pairs, seq_dis_thr, cyclic=False):
74
- edges = [(img1["idx"], img2["idx"]) for img1, img2 in pairs]
75
- kept = _filter_edges_seq(edges, seq_dis_thr, cyclic=cyclic)
76
- return [pairs[i] for i in kept]
77
-
78
-
79
- def filter_edges_seq(view1, view2, pred1, pred2, seq_dis_thr, cyclic=False):
80
- edges = [(int(i), int(j)) for i, j in zip(view1["idx"], view2["idx"])]
81
- kept = _filter_edges_seq(edges, seq_dis_thr, cyclic=cyclic)
82
- print(
83
- f">> Filtering edges more than {seq_dis_thr} frames apart: kept {len(kept)}/{len(edges)} edges"
84
- )
85
- return sel(view1, kept), sel(view2, kept), sel(pred1, kept), sel(pred2, kept)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mini_dust3r/inference.py DELETED
@@ -1,204 +0,0 @@
1
- # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2
- # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
- #
4
- # --------------------------------------------------------
5
- # utilities needed for the inference
6
- # --------------------------------------------------------
7
- import tqdm
8
- import torch
9
- from mini_dust3r.utils.device import to_cpu, collate_with_cat
10
- from mini_dust3r.utils.misc import invalid_to_nans
11
- from mini_dust3r.utils.geometry import depthmap_to_pts3d, geotrf
12
- from mini_dust3r.utils.image import ImageDict
13
- from mini_dust3r.model import AsymmetricCroCo3DStereo
14
-
15
- from typing import Literal, TypedDict, Optional
16
- from jaxtyping import Float32
17
-
18
-
19
- class Dust3rPred1(TypedDict):
20
- pts3d: Float32[torch.Tensor, "b h w c"]
21
- conf: Float32[torch.Tensor, "b h w"]
22
-
23
-
24
- class Dust3rPred2(TypedDict):
25
- pts3d_in_other_view: Float32[torch.Tensor, "b h w c"]
26
- conf: Float32[torch.Tensor, "b h w"]
27
-
28
-
29
- class Dust3rResult(TypedDict):
30
- view1: ImageDict
31
- view2: ImageDict
32
- pred1: Dust3rPred1
33
- pred2: Dust3rPred2
34
- loss: Optional[int]
35
-
36
-
37
- def _interleave_imgs(img1, img2):
38
- res = {}
39
- for key, value1 in img1.items():
40
- value2 = img2[key]
41
- if isinstance(value1, torch.Tensor):
42
- value = torch.stack((value1, value2), dim=1).flatten(0, 1)
43
- else:
44
- value = [x for pair in zip(value1, value2) for x in pair]
45
- res[key] = value
46
- return res
47
-
48
-
49
- def make_batch_symmetric(batch):
50
- view1, view2 = batch
51
- view1, view2 = (_interleave_imgs(view1, view2), _interleave_imgs(view2, view1))
52
- return view1, view2
53
-
54
-
55
- def loss_of_one_batch(
56
- batch, model, criterion, device, symmetrize_batch=False, use_amp=False, ret=None
57
- ):
58
- view1, view2 = batch
59
- for view in batch:
60
- for name in (
61
- "img pts3d valid_mask camera_pose camera_intrinsics F_matrix corres".split()
62
- ): # pseudo_focal
63
- if name not in view:
64
- continue
65
- view[name] = view[name].to(device, non_blocking=True)
66
-
67
- if symmetrize_batch:
68
- view1, view2 = make_batch_symmetric(batch)
69
-
70
- with torch.cuda.amp.autocast(enabled=bool(use_amp)):
71
- pred1, pred2 = model(view1, view2)
72
-
73
- # loss is supposed to be symmetric
74
- with torch.cuda.amp.autocast(enabled=False):
75
- loss = (
76
- criterion(view1, view2, pred1, pred2) if criterion is not None else None
77
- )
78
-
79
- result = dict(view1=view1, view2=view2, pred1=pred1, pred2=pred2, loss=loss)
80
- return result[ret] if ret else result
81
-
82
-
83
- @torch.no_grad()
84
- def inference(
85
- pairs: list[tuple[ImageDict, ImageDict]],
86
- model: AsymmetricCroCo3DStereo,
87
- device: Literal["cpu", "cuda", "mps"],
88
- batch_size: int = 8,
89
- verbose: bool = True,
90
- ) -> Dust3rResult:
91
- if verbose:
92
- print(f">> Inference with model on {len(pairs)} image pairs")
93
- result = []
94
-
95
- # first, check if all images have the same size
96
- multiple_shapes = not (check_if_same_size(pairs))
97
- if multiple_shapes: # force bs=1
98
- batch_size = 1
99
-
100
- for i in tqdm.trange(0, len(pairs), batch_size, disable=not verbose):
101
- res: Dust3rResult = loss_of_one_batch(
102
- collate_with_cat(pairs[i : i + batch_size]), model, None, device
103
- )
104
- result.append(to_cpu(res))
105
-
106
- result = collate_with_cat(result, lists=multiple_shapes)
107
-
108
- return result
109
-
110
-
111
- def check_if_same_size(pairs):
112
- shapes1 = [img1["img"].shape[-2:] for img1, img2 in pairs]
113
- shapes2 = [img2["img"].shape[-2:] for img1, img2 in pairs]
114
- return all(shapes1[0] == s for s in shapes1) and all(
115
- shapes2[0] == s for s in shapes2
116
- )
117
-
118
-
119
- def get_pred_pts3d(gt, pred, use_pose=False):
120
- if "depth" in pred and "pseudo_focal" in pred:
121
- try:
122
- pp = gt["camera_intrinsics"][..., :2, 2]
123
- except KeyError:
124
- pp = None
125
- pts3d = depthmap_to_pts3d(**pred, pp=pp)
126
-
127
- elif "pts3d" in pred:
128
- # pts3d from my camera
129
- pts3d = pred["pts3d"]
130
-
131
- elif "pts3d_in_other_view" in pred:
132
- # pts3d from the other camera, already transformed
133
- assert use_pose is True
134
- return pred["pts3d_in_other_view"] # return!
135
-
136
- if use_pose:
137
- camera_pose = pred.get("camera_pose")
138
- assert camera_pose is not None
139
- pts3d = geotrf(camera_pose, pts3d)
140
-
141
- return pts3d
142
-
143
-
144
- def find_opt_scaling(
145
- gt_pts1,
146
- gt_pts2,
147
- pr_pts1,
148
- pr_pts2=None,
149
- fit_mode="weiszfeld_stop_grad",
150
- valid1=None,
151
- valid2=None,
152
- ):
153
- assert gt_pts1.ndim == pr_pts1.ndim == 4
154
- assert gt_pts1.shape == pr_pts1.shape
155
- if gt_pts2 is not None:
156
- assert gt_pts2.ndim == pr_pts2.ndim == 4
157
- assert gt_pts2.shape == pr_pts2.shape
158
-
159
- # concat the pointcloud
160
- nan_gt_pts1 = invalid_to_nans(gt_pts1, valid1).flatten(1, 2)
161
- nan_gt_pts2 = (
162
- invalid_to_nans(gt_pts2, valid2).flatten(1, 2) if gt_pts2 is not None else None
163
- )
164
-
165
- pr_pts1 = invalid_to_nans(pr_pts1, valid1).flatten(1, 2)
166
- pr_pts2 = (
167
- invalid_to_nans(pr_pts2, valid2).flatten(1, 2) if pr_pts2 is not None else None
168
- )
169
-
170
- all_gt = (
171
- torch.cat((nan_gt_pts1, nan_gt_pts2), dim=1)
172
- if gt_pts2 is not None
173
- else nan_gt_pts1
174
- )
175
- all_pr = torch.cat((pr_pts1, pr_pts2), dim=1) if pr_pts2 is not None else pr_pts1
176
-
177
- dot_gt_pr = (all_pr * all_gt).sum(dim=-1)
178
- dot_gt_gt = all_gt.square().sum(dim=-1)
179
-
180
- if fit_mode.startswith("avg"):
181
- # scaling = (all_pr / all_gt).view(B, -1).mean(dim=1)
182
- scaling = dot_gt_pr.nanmean(dim=1) / dot_gt_gt.nanmean(dim=1)
183
- elif fit_mode.startswith("median"):
184
- scaling = (dot_gt_pr / dot_gt_gt).nanmedian(dim=1).values
185
- elif fit_mode.startswith("weiszfeld"):
186
- # init scaling with l2 closed form
187
- scaling = dot_gt_pr.nanmean(dim=1) / dot_gt_gt.nanmean(dim=1)
188
- # iterative re-weighted least-squares
189
- for iter in range(10):
190
- # re-weighting by inverse of distance
191
- dis = (all_pr - scaling.view(-1, 1, 1) * all_gt).norm(dim=-1)
192
- # print(dis.nanmean(-1))
193
- w = dis.clip_(min=1e-8).reciprocal()
194
- # update the scaling with the new weights
195
- scaling = (w * dot_gt_pr).nanmean(dim=1) / (w * dot_gt_gt).nanmean(dim=1)
196
- else:
197
- raise ValueError(f"bad {fit_mode=}")
198
-
199
- if fit_mode.endswith("stop_grad"):
200
- scaling = scaling.detach()
201
-
202
- scaling = scaling.clip(min=1e-3)
203
- # assert scaling.isfinite().all(), bb()
204
- return scaling
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mini_dust3r/model.py DELETED
@@ -1,259 +0,0 @@
1
- # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2
- # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
- #
4
- # --------------------------------------------------------
5
- # DUSt3R model class
6
- # --------------------------------------------------------
7
- from copy import deepcopy
8
- import torch
9
- import os
10
- from packaging import version
11
- import huggingface_hub
12
-
13
- from .utils.misc import (
14
- fill_default_args,
15
- freeze_all_params,
16
- is_symmetrized,
17
- interleave,
18
- transpose_to_landscape,
19
- )
20
- from .heads import head_factory
21
- from mini_dust3r.patch_embed import get_patch_embed
22
-
23
- from mini_dust3r.croco.croco import CroCoNet
24
-
25
- inf = float("inf")
26
-
27
- hf_version_number = huggingface_hub.__version__
28
- assert version.parse(hf_version_number) >= version.parse(
29
- "0.22.0"
30
- ), "Outdated huggingface_hub version, please reinstall requirements.txt"
31
-
32
-
33
- def load_model(model_path, device, verbose=True):
34
- if verbose:
35
- print("... loading model from", model_path)
36
- ckpt = torch.load(model_path, map_location="cpu")
37
- args = ckpt["args"].model.replace("ManyAR_PatchEmbed", "PatchEmbedDust3R")
38
- if "landscape_only" not in args:
39
- args = args[:-1] + ", landscape_only=False)"
40
- else:
41
- args = args.replace(" ", "").replace(
42
- "landscape_only=True", "landscape_only=False"
43
- )
44
- assert "landscape_only=False" in args
45
- if verbose:
46
- print(f"instantiating : {args}")
47
- net = eval(args)
48
- s = net.load_state_dict(ckpt["model"], strict=False)
49
- if verbose:
50
- print(s)
51
- return net.to(device)
52
-
53
-
54
- class AsymmetricCroCo3DStereo(
55
- CroCoNet,
56
- huggingface_hub.PyTorchModelHubMixin,
57
- library_name="dust3r",
58
- repo_url="https://github.com/naver/dust3r",
59
- tags=["image-to-3d"],
60
- ):
61
- """Two siamese encoders, followed by two decoders.
62
- The goal is to output 3d points directly, both images in view1's frame
63
- (hence the asymmetry).
64
- """
65
-
66
- def __init__(
67
- self,
68
- output_mode="pts3d",
69
- head_type="linear",
70
- depth_mode=("exp", -inf, inf),
71
- conf_mode=("exp", 1, inf),
72
- freeze="none",
73
- landscape_only=True,
74
- patch_embed_cls="PatchEmbedDust3R", # PatchEmbedDust3R or ManyAR_PatchEmbed
75
- **croco_kwargs,
76
- ):
77
- self.patch_embed_cls = patch_embed_cls
78
- self.croco_args = fill_default_args(croco_kwargs, super().__init__)
79
- super().__init__(**croco_kwargs)
80
-
81
- # dust3r specific initialization
82
- self.dec_blocks2 = deepcopy(self.dec_blocks)
83
- self.set_downstream_head(
84
- output_mode,
85
- head_type,
86
- landscape_only,
87
- depth_mode,
88
- conf_mode,
89
- **croco_kwargs,
90
- )
91
- self.set_freeze(freeze)
92
-
93
- @classmethod
94
- def from_pretrained(cls, pretrained_model_name_or_path, **kw):
95
- if os.path.isfile(pretrained_model_name_or_path):
96
- return load_model(pretrained_model_name_or_path, device="cpu")
97
- else:
98
- return super(AsymmetricCroCo3DStereo, cls).from_pretrained(
99
- pretrained_model_name_or_path, **kw
100
- )
101
-
102
- def _set_patch_embed(self, img_size=224, patch_size=16, enc_embed_dim=768):
103
- self.patch_embed = get_patch_embed(
104
- self.patch_embed_cls, img_size, patch_size, enc_embed_dim
105
- )
106
-
107
- def load_state_dict(self, ckpt, **kw):
108
- # duplicate all weights for the second decoder if not present
109
- new_ckpt = dict(ckpt)
110
- if not any(k.startswith("dec_blocks2") for k in ckpt):
111
- for key, value in ckpt.items():
112
- if key.startswith("dec_blocks"):
113
- new_ckpt[key.replace("dec_blocks", "dec_blocks2")] = value
114
- return super().load_state_dict(new_ckpt, **kw)
115
-
116
- def set_freeze(self, freeze): # this is for use by downstream models
117
- self.freeze = freeze
118
- to_be_frozen = {
119
- "none": [],
120
- "mask": [self.mask_token],
121
- "encoder": [self.mask_token, self.patch_embed, self.enc_blocks],
122
- }
123
- freeze_all_params(to_be_frozen[freeze])
124
-
125
- def _set_prediction_head(self, *args, **kwargs):
126
- """No prediction head"""
127
- return
128
-
129
- def set_downstream_head(
130
- self,
131
- output_mode,
132
- head_type,
133
- landscape_only,
134
- depth_mode,
135
- conf_mode,
136
- patch_size,
137
- img_size,
138
- **kw,
139
- ):
140
- assert (
141
- img_size[0] % patch_size == 0 and img_size[1] % patch_size == 0
142
- ), f"{img_size=} must be multiple of {patch_size=}"
143
- self.output_mode = output_mode
144
- self.head_type = head_type
145
- self.depth_mode = depth_mode
146
- self.conf_mode = conf_mode
147
- # allocate heads
148
- self.downstream_head1 = head_factory(
149
- head_type, output_mode, self, has_conf=bool(conf_mode)
150
- )
151
- self.downstream_head2 = head_factory(
152
- head_type, output_mode, self, has_conf=bool(conf_mode)
153
- )
154
- # magic wrapper
155
- self.head1 = transpose_to_landscape(
156
- self.downstream_head1, activate=landscape_only
157
- )
158
- self.head2 = transpose_to_landscape(
159
- self.downstream_head2, activate=landscape_only
160
- )
161
-
162
- def _encode_image(self, image, true_shape):
163
- # embed the image into patches (x has size B x Npatches x C)
164
- x, pos = self.patch_embed(image, true_shape=true_shape)
165
-
166
- # add positional embedding without cls token
167
- assert self.enc_pos_embed is None
168
-
169
- # now apply the transformer encoder and normalization
170
- for blk in self.enc_blocks:
171
- x = blk(x, pos)
172
-
173
- x = self.enc_norm(x)
174
- return x, pos, None
175
-
176
- def _encode_image_pairs(self, img1, img2, true_shape1, true_shape2):
177
- if img1.shape[-2:] == img2.shape[-2:]:
178
- out, pos, _ = self._encode_image(
179
- torch.cat((img1, img2), dim=0),
180
- torch.cat((true_shape1, true_shape2), dim=0),
181
- )
182
- out, out2 = out.chunk(2, dim=0)
183
- pos, pos2 = pos.chunk(2, dim=0)
184
- else:
185
- out, pos, _ = self._encode_image(img1, true_shape1)
186
- out2, pos2, _ = self._encode_image(img2, true_shape2)
187
- return out, out2, pos, pos2
188
-
189
- def _encode_symmetrized(self, view1, view2):
190
- img1 = view1["img"]
191
- img2 = view2["img"]
192
- B = img1.shape[0]
193
- # Recover true_shape when available, otherwise assume that the img shape is the true one
194
- shape1 = view1.get(
195
- "true_shape", torch.tensor(img1.shape[-2:])[None].repeat(B, 1)
196
- )
197
- shape2 = view2.get(
198
- "true_shape", torch.tensor(img2.shape[-2:])[None].repeat(B, 1)
199
- )
200
- # warning! maybe the images have different portrait/landscape orientations
201
-
202
- if is_symmetrized(view1, view2):
203
- # computing half of forward pass!'
204
- feat1, feat2, pos1, pos2 = self._encode_image_pairs(
205
- img1[::2], img2[::2], shape1[::2], shape2[::2]
206
- )
207
- feat1, feat2 = interleave(feat1, feat2)
208
- pos1, pos2 = interleave(pos1, pos2)
209
- else:
210
- feat1, feat2, pos1, pos2 = self._encode_image_pairs(
211
- img1, img2, shape1, shape2
212
- )
213
-
214
- return (shape1, shape2), (feat1, feat2), (pos1, pos2)
215
-
216
- def _decoder(self, f1, pos1, f2, pos2):
217
- final_output = [(f1, f2)] # before projection
218
-
219
- # project to decoder dim
220
- f1 = self.decoder_embed(f1)
221
- f2 = self.decoder_embed(f2)
222
-
223
- final_output.append((f1, f2))
224
- for blk1, blk2 in zip(self.dec_blocks, self.dec_blocks2):
225
- # img1 side
226
- f1, _ = blk1(*final_output[-1][::+1], pos1, pos2)
227
- # img2 side
228
- f2, _ = blk2(*final_output[-1][::-1], pos2, pos1)
229
- # store the result
230
- final_output.append((f1, f2))
231
-
232
- # normalize last output
233
- del final_output[1] # duplicate with final_output[0]
234
- final_output[-1] = tuple(map(self.dec_norm, final_output[-1]))
235
- return zip(*final_output)
236
-
237
- def _downstream_head(self, head_num, decout, img_shape):
238
- B, S, D = decout[-1].shape
239
- # img_shape = tuple(map(int, img_shape))
240
- head = getattr(self, f"head{head_num}")
241
- return head(decout, img_shape)
242
-
243
- def forward(self, view1, view2):
244
- # encode the two images --> B,S,D
245
- (shape1, shape2), (feat1, feat2), (pos1, pos2) = self._encode_symmetrized(
246
- view1, view2
247
- )
248
-
249
- # combine all ref images into object-centric representation
250
- dec1, dec2 = self._decoder(feat1, pos1, feat2, pos2)
251
-
252
- with torch.cuda.amp.autocast(enabled=False):
253
- res1 = self._downstream_head(1, [tok.float() for tok in dec1], shape1)
254
- res2 = self._downstream_head(2, [tok.float() for tok in dec2], shape2)
255
-
256
- res2["pts3d_in_other_view"] = res2.pop(
257
- "pts3d"
258
- ) # predict view2's pts3d in view1's frame
259
- return res1, res2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mini_dust3r/optim_factory.py DELETED
@@ -1,14 +0,0 @@
1
- # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2
- # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
- #
4
- # --------------------------------------------------------
5
- # optimization functions
6
- # --------------------------------------------------------
7
-
8
-
9
- def adjust_learning_rate_by_lr(optimizer, lr):
10
- for param_group in optimizer.param_groups:
11
- if "lr_scale" in param_group:
12
- param_group["lr"] = lr * param_group["lr_scale"]
13
- else:
14
- param_group["lr"] = lr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mini_dust3r/patch_embed.py DELETED
@@ -1,69 +0,0 @@
1
- # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2
- # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
- #
4
- # --------------------------------------------------------
5
- # PatchEmbed implementation for DUST3R,
6
- # in particular ManyAR_PatchEmbed that Handle images with non-square aspect ratio
7
- # --------------------------------------------------------
8
- import torch
9
- from mini_dust3r.croco.blocks import PatchEmbed
10
-
11
-
12
- def get_patch_embed(patch_embed_cls, img_size, patch_size, enc_embed_dim):
13
- assert patch_embed_cls in ['PatchEmbedDust3R', 'ManyAR_PatchEmbed']
14
- patch_embed = eval(patch_embed_cls)(img_size, patch_size, 3, enc_embed_dim)
15
- return patch_embed
16
-
17
-
18
- class PatchEmbedDust3R(PatchEmbed):
19
- def forward(self, x, **kw):
20
- B, C, H, W = x.shape
21
- assert H % self.patch_size[0] == 0, f"Input image height ({H}) is not a multiple of patch size ({self.patch_size[0]})."
22
- assert W % self.patch_size[1] == 0, f"Input image width ({W}) is not a multiple of patch size ({self.patch_size[1]})."
23
- x = self.proj(x)
24
- pos = self.position_getter(B, x.size(2), x.size(3), x.device)
25
- if self.flatten:
26
- x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
27
- x = self.norm(x)
28
- return x, pos
29
-
30
-
31
- class ManyAR_PatchEmbed (PatchEmbed):
32
- """ Handle images with non-square aspect ratio.
33
- All images in the same batch have the same aspect ratio.
34
- true_shape = [(height, width) ...] indicates the actual shape of each image.
35
- """
36
-
37
- def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, norm_layer=None, flatten=True):
38
- self.embed_dim = embed_dim
39
- super().__init__(img_size, patch_size, in_chans, embed_dim, norm_layer, flatten)
40
-
41
- def forward(self, img, true_shape):
42
- B, C, H, W = img.shape
43
- assert W >= H, f'img should be in landscape mode, but got {W=} {H=}'
44
- assert H % self.patch_size[0] == 0, f"Input image height ({H}) is not a multiple of patch size ({self.patch_size[0]})."
45
- assert W % self.patch_size[1] == 0, f"Input image width ({W}) is not a multiple of patch size ({self.patch_size[1]})."
46
- assert true_shape.shape == (B, 2), f"true_shape has the wrong shape={true_shape.shape}"
47
-
48
- # size expressed in tokens
49
- W //= self.patch_size[0]
50
- H //= self.patch_size[1]
51
- n_tokens = H * W
52
-
53
- height, width = true_shape.T
54
- is_landscape = (width >= height)
55
- is_portrait = ~is_landscape
56
-
57
- # allocate result
58
- x = img.new_zeros((B, n_tokens, self.embed_dim))
59
- pos = img.new_zeros((B, n_tokens, 2), dtype=torch.int64)
60
-
61
- # linear projection, transposed if necessary
62
- x[is_landscape] = self.proj(img[is_landscape]).permute(0, 2, 3, 1).flatten(1, 2).float()
63
- x[is_portrait] = self.proj(img[is_portrait].swapaxes(-1, -2)).permute(0, 2, 3, 1).flatten(1, 2).float()
64
-
65
- pos[is_landscape] = self.position_getter(1, H, W, pos.device)
66
- pos[is_portrait] = self.position_getter(1, W, H, pos.device)
67
-
68
- x = self.norm(x)
69
- return x, pos
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mini_dust3r/post_process.py DELETED
@@ -1,60 +0,0 @@
1
- # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2
- # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
- #
4
- # --------------------------------------------------------
5
- # utilities for interpreting the DUST3R output
6
- # --------------------------------------------------------
7
- import numpy as np
8
- import torch
9
- from mini_dust3r.utils.geometry import xy_grid
10
-
11
-
12
- def estimate_focal_knowing_depth(pts3d, pp, focal_mode='median', min_focal=0., max_focal=np.inf):
13
- """ Reprojection method, for when the absolute depth is known:
14
- 1) estimate the camera focal using a robust estimator
15
- 2) reproject points onto true rays, minimizing a certain error
16
- """
17
- B, H, W, THREE = pts3d.shape
18
- assert THREE == 3
19
-
20
- # centered pixel grid
21
- pixels = xy_grid(W, H, device=pts3d.device).view(1, -1, 2) - pp.view(-1, 1, 2) # B,HW,2
22
- pts3d = pts3d.flatten(1, 2) # (B, HW, 3)
23
-
24
- if focal_mode == 'median':
25
- with torch.no_grad():
26
- # direct estimation of focal
27
- u, v = pixels.unbind(dim=-1)
28
- x, y, z = pts3d.unbind(dim=-1)
29
- fx_votes = (u * z) / x
30
- fy_votes = (v * z) / y
31
-
32
- # assume square pixels, hence same focal for X and Y
33
- f_votes = torch.cat((fx_votes.view(B, -1), fy_votes.view(B, -1)), dim=-1)
34
- focal = torch.nanmedian(f_votes, dim=-1).values
35
-
36
- elif focal_mode == 'weiszfeld':
37
- # init focal with l2 closed form
38
- # we try to find focal = argmin Sum | pixel - focal * (x,y)/z|
39
- xy_over_z = (pts3d[..., :2] / pts3d[..., 2:3]).nan_to_num(posinf=0, neginf=0) # homogeneous (x,y,1)
40
-
41
- dot_xy_px = (xy_over_z * pixels).sum(dim=-1)
42
- dot_xy_xy = xy_over_z.square().sum(dim=-1)
43
-
44
- focal = dot_xy_px.mean(dim=1) / dot_xy_xy.mean(dim=1)
45
-
46
- # iterative re-weighted least-squares
47
- for iter in range(10):
48
- # re-weighting by inverse of distance
49
- dis = (pixels - focal.view(-1, 1, 1) * xy_over_z).norm(dim=-1)
50
- # print(dis.nanmean(-1))
51
- w = dis.clip(min=1e-8).reciprocal()
52
- # update the scaling with the new weights
53
- focal = (w * dot_xy_px).mean(dim=1) / (w * dot_xy_xy).mean(dim=1)
54
- else:
55
- raise ValueError(f'bad {focal_mode=}')
56
-
57
- focal_base = max(H, W) / (2 * np.tan(np.deg2rad(60) / 2)) # size / 1.1547005383792515
58
- focal = focal.clip(min=min_focal*focal_base, max=max_focal*focal_base)
59
- # print(focal)
60
- return focal
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mini_dust3r/utils/device.py DELETED
@@ -1,76 +0,0 @@
1
- # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2
- # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
- #
4
- # --------------------------------------------------------
5
- # utilitary functions for DUSt3R
6
- # --------------------------------------------------------
7
- import numpy as np
8
- import torch
9
-
10
-
11
- def todevice(batch, device, callback=None, non_blocking=False):
12
- ''' Transfer some variables to another device (i.e. GPU, CPU:torch, CPU:numpy).
13
-
14
- batch: list, tuple, dict of tensors or other things
15
- device: pytorch device or 'numpy'
16
- callback: function that would be called on every sub-elements.
17
- '''
18
- if callback:
19
- batch = callback(batch)
20
-
21
- if isinstance(batch, dict):
22
- return {k: todevice(v, device) for k, v in batch.items()}
23
-
24
- if isinstance(batch, (tuple, list)):
25
- return type(batch)(todevice(x, device) for x in batch)
26
-
27
- x = batch
28
- if device == 'numpy':
29
- if isinstance(x, torch.Tensor):
30
- x = x.detach().cpu().numpy()
31
- elif x is not None:
32
- if isinstance(x, np.ndarray):
33
- x = torch.from_numpy(x)
34
- if torch.is_tensor(x):
35
- x = x.to(device, non_blocking=non_blocking)
36
- return x
37
-
38
-
39
- to_device = todevice # alias
40
-
41
-
42
- def to_numpy(x): return todevice(x, 'numpy')
43
- def to_cpu(x): return todevice(x, 'cpu')
44
- def to_cuda(x): return todevice(x, 'cuda')
45
-
46
-
47
- def collate_with_cat(whatever, lists=False):
48
- if isinstance(whatever, dict):
49
- return {k: collate_with_cat(vals, lists=lists) for k, vals in whatever.items()}
50
-
51
- elif isinstance(whatever, (tuple, list)):
52
- if len(whatever) == 0:
53
- return whatever
54
- elem = whatever[0]
55
- T = type(whatever)
56
-
57
- if elem is None:
58
- return None
59
- if isinstance(elem, (bool, float, int, str)):
60
- return whatever
61
- if isinstance(elem, tuple):
62
- return T(collate_with_cat(x, lists=lists) for x in zip(*whatever))
63
- if isinstance(elem, dict):
64
- return {k: collate_with_cat([e[k] for e in whatever], lists=lists) for k in elem}
65
-
66
- if isinstance(elem, torch.Tensor):
67
- return listify(whatever) if lists else torch.cat(whatever)
68
- if isinstance(elem, np.ndarray):
69
- return listify(whatever) if lists else torch.cat([torch.from_numpy(x) for x in whatever])
70
-
71
- # otherwise, we just chain lists
72
- return sum(whatever, T())
73
-
74
-
75
- def listify(elems):
76
- return [x for e in elems for x in e]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mini_dust3r/utils/geometry.py DELETED
@@ -1,361 +0,0 @@
1
- # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2
- # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
- #
4
- # --------------------------------------------------------
5
- # geometry utilitary functions
6
- # --------------------------------------------------------
7
- import torch
8
- import numpy as np
9
- from scipy.spatial import cKDTree as KDTree
10
-
11
- from mini_dust3r.utils.misc import invalid_to_zeros, invalid_to_nans
12
- from mini_dust3r.utils.device import to_numpy
13
-
14
-
15
- def xy_grid(W, H, device=None, origin=(0, 0), unsqueeze=None, cat_dim=-1, homogeneous=False, **arange_kw):
16
- """ Output a (H,W,2) array of int32
17
- with output[j,i,0] = i + origin[0]
18
- output[j,i,1] = j + origin[1]
19
- """
20
- if device is None:
21
- # numpy
22
- arange, meshgrid, stack, ones = np.arange, np.meshgrid, np.stack, np.ones
23
- else:
24
- # torch
25
- arange = lambda *a, **kw: torch.arange(*a, device=device, **kw)
26
- meshgrid, stack = torch.meshgrid, torch.stack
27
- ones = lambda *a: torch.ones(*a, device=device)
28
-
29
- tw, th = [arange(o, o+s, **arange_kw) for s, o in zip((W, H), origin)]
30
- grid = meshgrid(tw, th, indexing='xy')
31
- if homogeneous:
32
- grid = grid + (ones((H, W)),)
33
- if unsqueeze is not None:
34
- grid = (grid[0].unsqueeze(unsqueeze), grid[1].unsqueeze(unsqueeze))
35
- if cat_dim is not None:
36
- grid = stack(grid, cat_dim)
37
- return grid
38
-
39
-
40
- def geotrf(Trf, pts, ncol=None, norm=False):
41
- """ Apply a geometric transformation to a list of 3-D points.
42
-
43
- H: 3x3 or 4x4 projection matrix (typically a Homography)
44
- p: numpy/torch/tuple of coordinates. Shape must be (...,2) or (...,3)
45
-
46
- ncol: int. number of columns of the result (2 or 3)
47
- norm: float. if != 0, the resut is projected on the z=norm plane.
48
-
49
- Returns an array of projected 2d points.
50
- """
51
- assert Trf.ndim >= 2
52
- if isinstance(Trf, np.ndarray):
53
- pts = np.asarray(pts)
54
- elif isinstance(Trf, torch.Tensor):
55
- pts = torch.as_tensor(pts, dtype=Trf.dtype)
56
-
57
- # adapt shape if necessary
58
- output_reshape = pts.shape[:-1]
59
- ncol = ncol or pts.shape[-1]
60
-
61
- # optimized code
62
- if (isinstance(Trf, torch.Tensor) and isinstance(pts, torch.Tensor) and
63
- Trf.ndim == 3 and pts.ndim == 4):
64
- d = pts.shape[3]
65
- if Trf.shape[-1] == d:
66
- pts = torch.einsum("bij, bhwj -> bhwi", Trf, pts)
67
- elif Trf.shape[-1] == d+1:
68
- pts = torch.einsum("bij, bhwj -> bhwi", Trf[:, :d, :d], pts) + Trf[:, None, None, :d, d]
69
- else:
70
- raise ValueError(f'bad shape, not ending with 3 or 4, for {pts.shape=}')
71
- else:
72
- if Trf.ndim >= 3:
73
- n = Trf.ndim-2
74
- assert Trf.shape[:n] == pts.shape[:n], 'batch size does not match'
75
- Trf = Trf.reshape(-1, Trf.shape[-2], Trf.shape[-1])
76
-
77
- if pts.ndim > Trf.ndim:
78
- # Trf == (B,d,d) & pts == (B,H,W,d) --> (B, H*W, d)
79
- pts = pts.reshape(Trf.shape[0], -1, pts.shape[-1])
80
- elif pts.ndim == 2:
81
- # Trf == (B,d,d) & pts == (B,d) --> (B, 1, d)
82
- pts = pts[:, None, :]
83
-
84
- if pts.shape[-1]+1 == Trf.shape[-1]:
85
- Trf = Trf.swapaxes(-1, -2) # transpose Trf
86
- pts = pts @ Trf[..., :-1, :] + Trf[..., -1:, :]
87
- elif pts.shape[-1] == Trf.shape[-1]:
88
- Trf = Trf.swapaxes(-1, -2) # transpose Trf
89
- pts = pts @ Trf
90
- else:
91
- pts = Trf @ pts.T
92
- if pts.ndim >= 2:
93
- pts = pts.swapaxes(-1, -2)
94
-
95
- if norm:
96
- pts = pts / pts[..., -1:] # DONT DO /= BECAUSE OF WEIRD PYTORCH BUG
97
- if norm != 1:
98
- pts *= norm
99
-
100
- res = pts[..., :ncol].reshape(*output_reshape, ncol)
101
- return res
102
-
103
-
104
- def inv(mat):
105
- """ Invert a torch or numpy matrix
106
- """
107
- if isinstance(mat, torch.Tensor):
108
- return torch.linalg.inv(mat)
109
- if isinstance(mat, np.ndarray):
110
- return np.linalg.inv(mat)
111
- raise ValueError(f'bad matrix type = {type(mat)}')
112
-
113
-
114
- def depthmap_to_pts3d(depth, pseudo_focal, pp=None, **_):
115
- """
116
- Args:
117
- - depthmap (BxHxW array):
118
- - pseudo_focal: [B,H,W] ; [B,2,H,W] or [B,1,H,W]
119
- Returns:
120
- pointmap of absolute coordinates (BxHxWx3 array)
121
- """
122
-
123
- if len(depth.shape) == 4:
124
- B, H, W, n = depth.shape
125
- else:
126
- B, H, W = depth.shape
127
- n = None
128
-
129
- if len(pseudo_focal.shape) == 3: # [B,H,W]
130
- pseudo_focalx = pseudo_focaly = pseudo_focal
131
- elif len(pseudo_focal.shape) == 4: # [B,2,H,W] or [B,1,H,W]
132
- pseudo_focalx = pseudo_focal[:, 0]
133
- if pseudo_focal.shape[1] == 2:
134
- pseudo_focaly = pseudo_focal[:, 1]
135
- else:
136
- pseudo_focaly = pseudo_focalx
137
- else:
138
- raise NotImplementedError("Error, unknown input focal shape format.")
139
-
140
- assert pseudo_focalx.shape == depth.shape[:3]
141
- assert pseudo_focaly.shape == depth.shape[:3]
142
- grid_x, grid_y = xy_grid(W, H, cat_dim=0, device=depth.device)[:, None]
143
-
144
- # set principal point
145
- if pp is None:
146
- grid_x = grid_x - (W-1)/2
147
- grid_y = grid_y - (H-1)/2
148
- else:
149
- grid_x = grid_x.expand(B, -1, -1) - pp[:, 0, None, None]
150
- grid_y = grid_y.expand(B, -1, -1) - pp[:, 1, None, None]
151
-
152
- if n is None:
153
- pts3d = torch.empty((B, H, W, 3), device=depth.device)
154
- pts3d[..., 0] = depth * grid_x / pseudo_focalx
155
- pts3d[..., 1] = depth * grid_y / pseudo_focaly
156
- pts3d[..., 2] = depth
157
- else:
158
- pts3d = torch.empty((B, H, W, 3, n), device=depth.device)
159
- pts3d[..., 0, :] = depth * (grid_x / pseudo_focalx)[..., None]
160
- pts3d[..., 1, :] = depth * (grid_y / pseudo_focaly)[..., None]
161
- pts3d[..., 2, :] = depth
162
- return pts3d
163
-
164
-
165
- def depthmap_to_camera_coordinates(depthmap, camera_intrinsics, pseudo_focal=None):
166
- """
167
- Args:
168
- - depthmap (HxW array):
169
- - camera_intrinsics: a 3x3 matrix
170
- Returns:
171
- pointmap of absolute coordinates (HxWx3 array), and a mask specifying valid pixels.
172
- """
173
- camera_intrinsics = np.float32(camera_intrinsics)
174
- H, W = depthmap.shape
175
-
176
- # Compute 3D ray associated with each pixel
177
- # Strong assumption: there are no skew terms
178
- assert camera_intrinsics[0, 1] == 0.0
179
- assert camera_intrinsics[1, 0] == 0.0
180
- if pseudo_focal is None:
181
- fu = camera_intrinsics[0, 0]
182
- fv = camera_intrinsics[1, 1]
183
- else:
184
- assert pseudo_focal.shape == (H, W)
185
- fu = fv = pseudo_focal
186
- cu = camera_intrinsics[0, 2]
187
- cv = camera_intrinsics[1, 2]
188
-
189
- u, v = np.meshgrid(np.arange(W), np.arange(H))
190
- z_cam = depthmap
191
- x_cam = (u - cu) * z_cam / fu
192
- y_cam = (v - cv) * z_cam / fv
193
- X_cam = np.stack((x_cam, y_cam, z_cam), axis=-1).astype(np.float32)
194
-
195
- # Mask for valid coordinates
196
- valid_mask = (depthmap > 0.0)
197
- return X_cam, valid_mask
198
-
199
-
200
- def depthmap_to_absolute_camera_coordinates(depthmap, camera_intrinsics, camera_pose, **kw):
201
- """
202
- Args:
203
- - depthmap (HxW array):
204
- - camera_intrinsics: a 3x3 matrix
205
- - camera_pose: a 4x3 or 4x4 cam2world matrix
206
- Returns:
207
- pointmap of absolute coordinates (HxWx3 array), and a mask specifying valid pixels."""
208
- X_cam, valid_mask = depthmap_to_camera_coordinates(depthmap, camera_intrinsics)
209
-
210
- # R_cam2world = np.float32(camera_params["R_cam2world"])
211
- # t_cam2world = np.float32(camera_params["t_cam2world"]).squeeze()
212
- R_cam2world = camera_pose[:3, :3]
213
- t_cam2world = camera_pose[:3, 3]
214
-
215
- # Express in absolute coordinates (invalid depth values)
216
- X_world = np.einsum("ik, vuk -> vui", R_cam2world, X_cam) + t_cam2world[None, None, :]
217
- return X_world, valid_mask
218
-
219
-
220
- def colmap_to_opencv_intrinsics(K):
221
- """
222
- Modify camera intrinsics to follow a different convention.
223
- Coordinates of the center of the top-left pixels are by default:
224
- - (0.5, 0.5) in Colmap
225
- - (0,0) in OpenCV
226
- """
227
- K = K.copy()
228
- K[0, 2] -= 0.5
229
- K[1, 2] -= 0.5
230
- return K
231
-
232
-
233
- def opencv_to_colmap_intrinsics(K):
234
- """
235
- Modify camera intrinsics to follow a different convention.
236
- Coordinates of the center of the top-left pixels are by default:
237
- - (0.5, 0.5) in Colmap
238
- - (0,0) in OpenCV
239
- """
240
- K = K.copy()
241
- K[0, 2] += 0.5
242
- K[1, 2] += 0.5
243
- return K
244
-
245
-
246
- def normalize_pointcloud(pts1, pts2, norm_mode='avg_dis', valid1=None, valid2=None):
247
- """ renorm pointmaps pts1, pts2 with norm_mode
248
- """
249
- assert pts1.ndim >= 3 and pts1.shape[-1] == 3
250
- assert pts2 is None or (pts2.ndim >= 3 and pts2.shape[-1] == 3)
251
- norm_mode, dis_mode = norm_mode.split('_')
252
-
253
- if norm_mode == 'avg':
254
- # gather all points together (joint normalization)
255
- nan_pts1, nnz1 = invalid_to_zeros(pts1, valid1, ndim=3)
256
- nan_pts2, nnz2 = invalid_to_zeros(pts2, valid2, ndim=3) if pts2 is not None else (None, 0)
257
- all_pts = torch.cat((nan_pts1, nan_pts2), dim=1) if pts2 is not None else nan_pts1
258
-
259
- # compute distance to origin
260
- all_dis = all_pts.norm(dim=-1)
261
- if dis_mode == 'dis':
262
- pass # do nothing
263
- elif dis_mode == 'log1p':
264
- all_dis = torch.log1p(all_dis)
265
- elif dis_mode == 'warp-log1p':
266
- # actually warp input points before normalizing them
267
- log_dis = torch.log1p(all_dis)
268
- warp_factor = log_dis / all_dis.clip(min=1e-8)
269
- H1, W1 = pts1.shape[1:-1]
270
- pts1 = pts1 * warp_factor[:, :W1*H1].view(-1, H1, W1, 1)
271
- if pts2 is not None:
272
- H2, W2 = pts2.shape[1:-1]
273
- pts2 = pts2 * warp_factor[:, W1*H1:].view(-1, H2, W2, 1)
274
- all_dis = log_dis # this is their true distance afterwards
275
- else:
276
- raise ValueError(f'bad {dis_mode=}')
277
-
278
- norm_factor = all_dis.sum(dim=1) / (nnz1 + nnz2 + 1e-8)
279
- else:
280
- # gather all points together (joint normalization)
281
- nan_pts1 = invalid_to_nans(pts1, valid1, ndim=3)
282
- nan_pts2 = invalid_to_nans(pts2, valid2, ndim=3) if pts2 is not None else None
283
- all_pts = torch.cat((nan_pts1, nan_pts2), dim=1) if pts2 is not None else nan_pts1
284
-
285
- # compute distance to origin
286
- all_dis = all_pts.norm(dim=-1)
287
-
288
- if norm_mode == 'avg':
289
- norm_factor = all_dis.nanmean(dim=1)
290
- elif norm_mode == 'median':
291
- norm_factor = all_dis.nanmedian(dim=1).values.detach()
292
- elif norm_mode == 'sqrt':
293
- norm_factor = all_dis.sqrt().nanmean(dim=1)**2
294
- else:
295
- raise ValueError(f'bad {norm_mode=}')
296
-
297
- norm_factor = norm_factor.clip(min=1e-8)
298
- while norm_factor.ndim < pts1.ndim:
299
- norm_factor.unsqueeze_(-1)
300
-
301
- res = pts1 / norm_factor
302
- if pts2 is not None:
303
- res = (res, pts2 / norm_factor)
304
- return res
305
-
306
-
307
- @torch.no_grad()
308
- def get_joint_pointcloud_depth(z1, z2, valid_mask1, valid_mask2=None, quantile=0.5):
309
- # set invalid points to NaN
310
- _z1 = invalid_to_nans(z1, valid_mask1).reshape(len(z1), -1)
311
- _z2 = invalid_to_nans(z2, valid_mask2).reshape(len(z2), -1) if z2 is not None else None
312
- _z = torch.cat((_z1, _z2), dim=-1) if z2 is not None else _z1
313
-
314
- # compute median depth overall (ignoring nans)
315
- if quantile == 0.5:
316
- shift_z = torch.nanmedian(_z, dim=-1).values
317
- else:
318
- shift_z = torch.nanquantile(_z, quantile, dim=-1)
319
- return shift_z # (B,)
320
-
321
-
322
- @torch.no_grad()
323
- def get_joint_pointcloud_center_scale(pts1, pts2, valid_mask1=None, valid_mask2=None, z_only=False, center=True):
324
- # set invalid points to NaN
325
- _pts1 = invalid_to_nans(pts1, valid_mask1).reshape(len(pts1), -1, 3)
326
- _pts2 = invalid_to_nans(pts2, valid_mask2).reshape(len(pts2), -1, 3) if pts2 is not None else None
327
- _pts = torch.cat((_pts1, _pts2), dim=1) if pts2 is not None else _pts1
328
-
329
- # compute median center
330
- _center = torch.nanmedian(_pts, dim=1, keepdim=True).values # (B,1,3)
331
- if z_only:
332
- _center[..., :2] = 0 # do not center X and Y
333
-
334
- # compute median norm
335
- _norm = ((_pts - _center) if center else _pts).norm(dim=-1)
336
- scale = torch.nanmedian(_norm, dim=1).values
337
- return _center[:, None, :, :], scale[:, None, None, None]
338
-
339
-
340
- def find_reciprocal_matches(P1, P2):
341
- """
342
- returns 3 values:
343
- 1 - reciprocal_in_P2: a boolean array of size P2.shape[0], a "True" value indicates a match
344
- 2 - nn2_in_P1: a int array of size P2.shape[0], it contains the indexes of the closest points in P1
345
- 3 - reciprocal_in_P2.sum(): the number of matches
346
- """
347
- tree1 = KDTree(P1)
348
- tree2 = KDTree(P2)
349
-
350
- _, nn1_in_P2 = tree2.query(P1, workers=8)
351
- _, nn2_in_P1 = tree1.query(P2, workers=8)
352
-
353
- reciprocal_in_P1 = (nn2_in_P1[nn1_in_P2] == np.arange(len(nn1_in_P2)))
354
- reciprocal_in_P2 = (nn1_in_P2[nn2_in_P1] == np.arange(len(nn2_in_P1)))
355
- assert reciprocal_in_P1.sum() == reciprocal_in_P2.sum()
356
- return reciprocal_in_P2, nn2_in_P1, reciprocal_in_P2.sum()
357
-
358
-
359
- def get_med_dist_between_poses(poses):
360
- from scipy.spatial.distance import pdist
361
- return np.median(pdist([to_numpy(p[:3, 3]) for p in poses]))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mini_dust3r/utils/image.py DELETED
@@ -1,141 +0,0 @@
1
- # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2
- # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
- #
4
- # --------------------------------------------------------
5
- # utilitary functions about images (loading/converting...)
6
- # --------------------------------------------------------
7
- import os
8
- import torch
9
- import numpy as np
10
- import PIL.Image
11
- from PIL.ImageOps import exif_transpose
12
- import torchvision.transforms as tvf
13
-
14
- os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1"
15
- import cv2 # noqa
16
- from typing import Literal, TypedDict
17
- from jaxtyping import Float32, Int32
18
-
19
- try:
20
- from pillow_heif import register_heif_opener # noqa
21
-
22
- register_heif_opener()
23
- heif_support_enabled = True
24
- except ImportError:
25
- heif_support_enabled = False
26
-
27
- ImgNorm = tvf.Compose([tvf.ToTensor(), tvf.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
28
-
29
-
30
- class ImageDict(TypedDict):
31
- img: Float32[torch.Tensor, "b c h w"]
32
- true_shape: tuple[int, int] | Int32[torch.Tensor, "b 2"]
33
- idx: int | list[int]
34
- instance: str | list[str]
35
-
36
-
37
- def imread_cv2(path, options=cv2.IMREAD_COLOR):
38
- """Open an image or a depthmap with opencv-python."""
39
- if path.endswith((".exr", "EXR")):
40
- options = cv2.IMREAD_ANYDEPTH
41
- img = cv2.imread(path, options)
42
- if img is None:
43
- raise IOError(f"Could not load image={path} with {options=}")
44
- if img.ndim == 3:
45
- img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
46
- return img
47
-
48
-
49
- def rgb(ftensor, true_shape=None):
50
- if isinstance(ftensor, list):
51
- return [rgb(x, true_shape=true_shape) for x in ftensor]
52
- if isinstance(ftensor, torch.Tensor):
53
- ftensor = ftensor.detach().cpu().numpy() # H,W,3
54
- if ftensor.ndim == 3 and ftensor.shape[0] == 3:
55
- ftensor = ftensor.transpose(1, 2, 0)
56
- elif ftensor.ndim == 4 and ftensor.shape[1] == 3:
57
- ftensor = ftensor.transpose(0, 2, 3, 1)
58
- if true_shape is not None:
59
- H, W = true_shape
60
- ftensor = ftensor[:H, :W]
61
- if ftensor.dtype == np.uint8:
62
- img = np.float32(ftensor) / 255
63
- else:
64
- img = (ftensor * 0.5) + 0.5
65
- return img.clip(min=0, max=1)
66
-
67
-
68
- def _resize_pil_image(img, long_edge_size):
69
- S = max(img.size)
70
- if S > long_edge_size:
71
- interp = PIL.Image.LANCZOS
72
- elif S <= long_edge_size:
73
- interp = PIL.Image.BICUBIC
74
- new_size = tuple(int(round(x * long_edge_size / S)) for x in img.size)
75
- return img.resize(new_size, interp)
76
-
77
-
78
- def load_images(
79
- folder_or_list: str | list,
80
- size: Literal[224, 512],
81
- square_ok: bool = False,
82
- verbose: bool = True,
83
- ) -> list[ImageDict]:
84
- """open and convert all images in a list or folder to proper input format for DUSt3R"""
85
- if isinstance(folder_or_list, str):
86
- if verbose:
87
- print(f">> Loading images from {folder_or_list}")
88
- root, folder_content = folder_or_list, sorted(os.listdir(folder_or_list))
89
-
90
- elif isinstance(folder_or_list, list):
91
- if verbose:
92
- print(f">> Loading a list of {len(folder_or_list)} images")
93
- root, folder_content = "", folder_or_list
94
-
95
- else:
96
- raise ValueError(f"bad {folder_or_list=} ({type(folder_or_list)})")
97
-
98
- supported_images_extensions = [".jpg", ".jpeg", ".png"]
99
- if heif_support_enabled:
100
- supported_images_extensions += [".heic", ".heif"]
101
- supported_images_extensions = tuple(supported_images_extensions)
102
-
103
- imgs = []
104
- for path in folder_content:
105
- if not path.lower().endswith(supported_images_extensions):
106
- continue
107
- img = exif_transpose(PIL.Image.open(os.path.join(root, path))).convert("RGB")
108
- W1, H1 = img.size
109
- if size == 224:
110
- # resize short side to 224 (then crop)
111
- img = _resize_pil_image(img, round(size * max(W1 / H1, H1 / W1)))
112
- else:
113
- # resize long side to 512
114
- img = _resize_pil_image(img, size)
115
- W, H = img.size
116
- cx, cy = W // 2, H // 2
117
- if size == 224:
118
- half = min(cx, cy)
119
- img = img.crop((cx - half, cy - half, cx + half, cy + half))
120
- else:
121
- halfw, halfh = ((2 * cx) // 16) * 8, ((2 * cy) // 16) * 8
122
- if not (square_ok) and W == H:
123
- halfh = 3 * halfw / 4
124
- img = img.crop((cx - halfw, cy - halfh, cx + halfw, cy + halfh))
125
-
126
- W2, H2 = img.size
127
- if verbose:
128
- print(f" - adding {path} with resolution {W1}x{H1} --> {W2}x{H2}")
129
- imgs.append(
130
- dict(
131
- img=ImgNorm(img)[None],
132
- true_shape=np.int32([img.size[::-1]]),
133
- idx=len(imgs),
134
- instance=str(len(imgs)),
135
- )
136
- )
137
-
138
- assert imgs, "no images foud at " + root
139
- if verbose:
140
- print(f" (Found {len(imgs)} images)")
141
- return imgs
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mini_dust3r/utils/misc.py DELETED
@@ -1,121 +0,0 @@
1
- # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2
- # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
- #
4
- # --------------------------------------------------------
5
- # utilitary functions for DUSt3R
6
- # --------------------------------------------------------
7
- import torch
8
-
9
-
10
- def fill_default_args(kwargs, func):
11
- import inspect # a bit hacky but it works reliably
12
- signature = inspect.signature(func)
13
-
14
- for k, v in signature.parameters.items():
15
- if v.default is inspect.Parameter.empty:
16
- continue
17
- kwargs.setdefault(k, v.default)
18
-
19
- return kwargs
20
-
21
-
22
- def freeze_all_params(modules):
23
- for module in modules:
24
- try:
25
- for n, param in module.named_parameters():
26
- param.requires_grad = False
27
- except AttributeError:
28
- # module is directly a parameter
29
- module.requires_grad = False
30
-
31
-
32
- def is_symmetrized(gt1, gt2):
33
- x = gt1['instance']
34
- y = gt2['instance']
35
- if len(x) == len(y) and len(x) == 1:
36
- return False # special case of batchsize 1
37
- ok = True
38
- for i in range(0, len(x), 2):
39
- ok = ok and (x[i] == y[i+1]) and (x[i+1] == y[i])
40
- return ok
41
-
42
-
43
- def flip(tensor):
44
- """ flip so that tensor[0::2] <=> tensor[1::2] """
45
- return torch.stack((tensor[1::2], tensor[0::2]), dim=1).flatten(0, 1)
46
-
47
-
48
- def interleave(tensor1, tensor2):
49
- res1 = torch.stack((tensor1, tensor2), dim=1).flatten(0, 1)
50
- res2 = torch.stack((tensor2, tensor1), dim=1).flatten(0, 1)
51
- return res1, res2
52
-
53
-
54
- def transpose_to_landscape(head, activate=True):
55
- """ Predict in the correct aspect-ratio,
56
- then transpose the result in landscape
57
- and stack everything back together.
58
- """
59
- def wrapper_no(decout, true_shape):
60
- B = len(true_shape)
61
- assert true_shape[0:1].allclose(true_shape), 'true_shape must be all identical'
62
- H, W = true_shape[0].cpu().tolist()
63
- res = head(decout, (H, W))
64
- return res
65
-
66
- def wrapper_yes(decout, true_shape):
67
- B = len(true_shape)
68
- # by definition, the batch is in landscape mode so W >= H
69
- H, W = int(true_shape.min()), int(true_shape.max())
70
-
71
- height, width = true_shape.T
72
- is_landscape = (width >= height)
73
- is_portrait = ~is_landscape
74
-
75
- # true_shape = true_shape.cpu()
76
- if is_landscape.all():
77
- return head(decout, (H, W))
78
- if is_portrait.all():
79
- return transposed(head(decout, (W, H)))
80
-
81
- # batch is a mix of both portraint & landscape
82
- def selout(ar): return [d[ar] for d in decout]
83
- l_result = head(selout(is_landscape), (H, W))
84
- p_result = transposed(head(selout(is_portrait), (W, H)))
85
-
86
- # allocate full result
87
- result = {}
88
- for k in l_result | p_result:
89
- x = l_result[k].new(B, *l_result[k].shape[1:])
90
- x[is_landscape] = l_result[k]
91
- x[is_portrait] = p_result[k]
92
- result[k] = x
93
-
94
- return result
95
-
96
- return wrapper_yes if activate else wrapper_no
97
-
98
-
99
- def transposed(dic):
100
- return {k: v.swapaxes(1, 2) for k, v in dic.items()}
101
-
102
-
103
- def invalid_to_nans(arr, valid_mask, ndim=999):
104
- if valid_mask is not None:
105
- arr = arr.clone()
106
- arr[~valid_mask] = float('nan')
107
- if arr.ndim > ndim:
108
- arr = arr.flatten(-2 - (arr.ndim - ndim), -2)
109
- return arr
110
-
111
-
112
- def invalid_to_zeros(arr, valid_mask, ndim=999):
113
- if valid_mask is not None:
114
- arr = arr.clone()
115
- arr[~valid_mask] = 0
116
- nnz = valid_mask.view(len(valid_mask), -1).sum(1)
117
- else:
118
- nnz = arr.numel() // len(arr) if len(arr) else 0 # number of point per image
119
- if arr.ndim > ndim:
120
- arr = arr.flatten(-2 - (arr.ndim - ndim), -2)
121
- return arr, nnz
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mini_dust3r/viz.py DELETED
@@ -1,320 +0,0 @@
1
- # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2
- # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
- #
4
- # --------------------------------------------------------
5
- # Visualization utilities using trimesh
6
- # --------------------------------------------------------
7
- import PIL.Image
8
- import numpy as np
9
- from scipy.spatial.transform import Rotation
10
- import torch
11
-
12
- from mini_dust3r.utils.geometry import geotrf, get_med_dist_between_poses
13
- from mini_dust3r.utils.device import to_numpy
14
- from mini_dust3r.utils.image import rgb
15
-
16
- try:
17
- import trimesh
18
- except ImportError:
19
- print('/!\\ module trimesh is not installed, cannot visualize results /!\\')
20
-
21
-
22
- def cat_3d(vecs):
23
- if isinstance(vecs, (np.ndarray, torch.Tensor)):
24
- vecs = [vecs]
25
- return np.concatenate([p.reshape(-1, 3) for p in to_numpy(vecs)])
26
-
27
-
28
- def show_raw_pointcloud(pts3d, colors, point_size=2):
29
- scene = trimesh.Scene()
30
-
31
- pct = trimesh.PointCloud(cat_3d(pts3d), colors=cat_3d(colors))
32
- scene.add_geometry(pct)
33
-
34
- scene.show(line_settings={'point_size': point_size})
35
-
36
-
37
- def pts3d_to_trimesh(img, pts3d, valid=None):
38
- H, W, THREE = img.shape
39
- assert THREE == 3
40
- assert img.shape == pts3d.shape
41
-
42
- vertices = pts3d.reshape(-1, 3)
43
-
44
- # make squares: each pixel == 2 triangles
45
- idx = np.arange(len(vertices)).reshape(H, W)
46
- idx1 = idx[:-1, :-1].ravel() # top-left corner
47
- idx2 = idx[:-1, +1:].ravel() # right-left corner
48
- idx3 = idx[+1:, :-1].ravel() # bottom-left corner
49
- idx4 = idx[+1:, +1:].ravel() # bottom-right corner
50
- faces = np.concatenate((
51
- np.c_[idx1, idx2, idx3],
52
- np.c_[idx3, idx2, idx1], # same triangle, but backward (cheap solution to cancel face culling)
53
- np.c_[idx2, idx3, idx4],
54
- np.c_[idx4, idx3, idx2], # same triangle, but backward (cheap solution to cancel face culling)
55
- ), axis=0)
56
-
57
- # prepare triangle colors
58
- face_colors = np.concatenate((
59
- img[:-1, :-1].reshape(-1, 3),
60
- img[:-1, :-1].reshape(-1, 3),
61
- img[+1:, +1:].reshape(-1, 3),
62
- img[+1:, +1:].reshape(-1, 3)
63
- ), axis=0)
64
-
65
- # remove invalid faces
66
- if valid is not None:
67
- assert valid.shape == (H, W)
68
- valid_idxs = valid.ravel()
69
- valid_faces = valid_idxs[faces].all(axis=-1)
70
- faces = faces[valid_faces]
71
- face_colors = face_colors[valid_faces]
72
-
73
- assert len(faces) == len(face_colors)
74
- return dict(vertices=vertices, face_colors=face_colors, faces=faces)
75
-
76
-
77
- def cat_meshes(meshes):
78
- vertices, faces, colors = zip(*[(m['vertices'], m['faces'], m['face_colors']) for m in meshes])
79
- n_vertices = np.cumsum([0]+[len(v) for v in vertices])
80
- for i in range(len(faces)):
81
- faces[i][:] += n_vertices[i]
82
-
83
- vertices = np.concatenate(vertices)
84
- colors = np.concatenate(colors)
85
- faces = np.concatenate(faces)
86
- return dict(vertices=vertices, face_colors=colors, faces=faces)
87
-
88
-
89
- def show_duster_pairs(view1, view2, pred1, pred2):
90
- import matplotlib.pyplot as pl
91
- pl.ion()
92
-
93
- for e in range(len(view1['instance'])):
94
- i = view1['idx'][e]
95
- j = view2['idx'][e]
96
- img1 = rgb(view1['img'][e])
97
- img2 = rgb(view2['img'][e])
98
- conf1 = pred1['conf'][e].squeeze()
99
- conf2 = pred2['conf'][e].squeeze()
100
- score = conf1.mean()*conf2.mean()
101
- print(f">> Showing pair #{e} {i}-{j} {score=:g}")
102
- pl.clf()
103
- pl.subplot(221).imshow(img1)
104
- pl.subplot(223).imshow(img2)
105
- pl.subplot(222).imshow(conf1, vmin=1, vmax=30)
106
- pl.subplot(224).imshow(conf2, vmin=1, vmax=30)
107
- pts1 = pred1['pts3d'][e]
108
- pts2 = pred2['pts3d_in_other_view'][e]
109
- pl.subplots_adjust(0, 0, 1, 1, 0, 0)
110
- if input('show pointcloud? (y/n) ') == 'y':
111
- show_raw_pointcloud(cat(pts1, pts2), cat(img1, img2), point_size=5)
112
-
113
-
114
- def auto_cam_size(im_poses):
115
- return 0.1 * get_med_dist_between_poses(im_poses)
116
-
117
-
118
- class SceneViz:
119
- def __init__(self):
120
- self.scene = trimesh.Scene()
121
-
122
- def add_pointcloud(self, pts3d, color, mask=None):
123
- pts3d = to_numpy(pts3d)
124
- mask = to_numpy(mask)
125
- if mask is None:
126
- mask = [slice(None)] * len(pts3d)
127
- pts = np.concatenate([p[m] for p, m in zip(pts3d, mask)])
128
- pct = trimesh.PointCloud(pts.reshape(-1, 3))
129
-
130
- if isinstance(color, (list, np.ndarray, torch.Tensor)):
131
- color = to_numpy(color)
132
- col = np.concatenate([p[m] for p, m in zip(color, mask)])
133
- assert col.shape == pts.shape
134
- pct.visual.vertex_colors = uint8(col.reshape(-1, 3))
135
- else:
136
- assert len(color) == 3
137
- pct.visual.vertex_colors = np.broadcast_to(uint8(color), pts.shape)
138
-
139
- self.scene.add_geometry(pct)
140
- return self
141
-
142
- def add_camera(self, pose_c2w, focal=None, color=(0, 0, 0), image=None, imsize=None, cam_size=0.03):
143
- pose_c2w, focal, color, image = to_numpy((pose_c2w, focal, color, image))
144
- add_scene_cam(self.scene, pose_c2w, color, image, focal, screen_width=cam_size)
145
- return self
146
-
147
- def add_cameras(self, poses, focals=None, images=None, imsizes=None, colors=None, **kw):
148
- def get(arr, idx): return None if arr is None else arr[idx]
149
- for i, pose_c2w in enumerate(poses):
150
- self.add_camera(pose_c2w, get(focals, i), image=get(images, i),
151
- color=get(colors, i), imsize=get(imsizes, i), **kw)
152
- return self
153
-
154
- def show(self, point_size=2):
155
- self.scene.show(line_settings={'point_size': point_size})
156
-
157
-
158
- def show_raw_pointcloud_with_cams(imgs, pts3d, mask, focals, cams2world,
159
- point_size=2, cam_size=0.05, cam_color=None):
160
- """ Visualization of a pointcloud with cameras
161
- imgs = (N, H, W, 3) or N-size list of [(H,W,3), ...]
162
- pts3d = (N, H, W, 3) or N-size list of [(H,W,3), ...]
163
- focals = (N,) or N-size list of [focal, ...]
164
- cams2world = (N,4,4) or N-size list of [(4,4), ...]
165
- """
166
- assert len(pts3d) == len(mask) <= len(imgs) <= len(cams2world) == len(focals)
167
- pts3d = to_numpy(pts3d)
168
- imgs = to_numpy(imgs)
169
- focals = to_numpy(focals)
170
- cams2world = to_numpy(cams2world)
171
-
172
- scene = trimesh.Scene()
173
-
174
- # full pointcloud
175
- pts = np.concatenate([p[m] for p, m in zip(pts3d, mask)])
176
- col = np.concatenate([p[m] for p, m in zip(imgs, mask)])
177
- pct = trimesh.PointCloud(pts.reshape(-1, 3), colors=col.reshape(-1, 3))
178
- scene.add_geometry(pct)
179
-
180
- # add each camera
181
- for i, pose_c2w in enumerate(cams2world):
182
- if isinstance(cam_color, list):
183
- camera_edge_color = cam_color[i]
184
- else:
185
- camera_edge_color = cam_color or CAM_COLORS[i % len(CAM_COLORS)]
186
- add_scene_cam(scene, pose_c2w, camera_edge_color,
187
- imgs[i] if i < len(imgs) else None, focals[i], screen_width=cam_size)
188
-
189
- scene.show(line_settings={'point_size': point_size})
190
-
191
-
192
- def add_scene_cam(scene, pose_c2w, edge_color, image=None, focal=None, imsize=None, screen_width=0.03):
193
-
194
- if image is not None:
195
- H, W, THREE = image.shape
196
- assert THREE == 3
197
- if image.dtype != np.uint8:
198
- image = np.uint8(255*image)
199
- elif imsize is not None:
200
- W, H = imsize
201
- elif focal is not None:
202
- H = W = focal / 1.1
203
- else:
204
- H = W = 1
205
-
206
- if focal is None:
207
- focal = min(H, W) * 1.1 # default value
208
- elif isinstance(focal, np.ndarray):
209
- focal = focal[0]
210
-
211
- # create fake camera
212
- height = focal * screen_width / H
213
- width = screen_width * 0.5**0.5
214
- rot45 = np.eye(4)
215
- rot45[:3, :3] = Rotation.from_euler('z', np.deg2rad(45)).as_matrix()
216
- rot45[2, 3] = -height # set the tip of the cone = optical center
217
- aspect_ratio = np.eye(4)
218
- aspect_ratio[0, 0] = W/H
219
- transform = pose_c2w @ OPENGL @ aspect_ratio @ rot45
220
- cam = trimesh.creation.cone(width, height, sections=4) # , transform=transform)
221
-
222
- # this is the image
223
- if image is not None:
224
- vertices = geotrf(transform, cam.vertices[[4, 5, 1, 3]])
225
- faces = np.array([[0, 1, 2], [0, 2, 3], [2, 1, 0], [3, 2, 0]])
226
- img = trimesh.Trimesh(vertices=vertices, faces=faces)
227
- uv_coords = np.float32([[0, 0], [1, 0], [1, 1], [0, 1]])
228
- img.visual = trimesh.visual.TextureVisuals(uv_coords, image=PIL.Image.fromarray(image))
229
- scene.add_geometry(img)
230
-
231
- # this is the camera mesh
232
- rot2 = np.eye(4)
233
- rot2[:3, :3] = Rotation.from_euler('z', np.deg2rad(2)).as_matrix()
234
- vertices = np.r_[cam.vertices, 0.95*cam.vertices, geotrf(rot2, cam.vertices)]
235
- vertices = geotrf(transform, vertices)
236
- faces = []
237
- for face in cam.faces:
238
- if 0 in face:
239
- continue
240
- a, b, c = face
241
- a2, b2, c2 = face + len(cam.vertices)
242
- a3, b3, c3 = face + 2*len(cam.vertices)
243
-
244
- # add 3 pseudo-edges
245
- faces.append((a, b, b2))
246
- faces.append((a, a2, c))
247
- faces.append((c2, b, c))
248
-
249
- faces.append((a, b, b3))
250
- faces.append((a, a3, c))
251
- faces.append((c3, b, c))
252
-
253
- # no culling
254
- faces += [(c, b, a) for a, b, c in faces]
255
-
256
- cam = trimesh.Trimesh(vertices=vertices, faces=faces)
257
- cam.visual.face_colors[:, :3] = edge_color
258
- scene.add_geometry(cam)
259
-
260
-
261
- def cat(a, b):
262
- return np.concatenate((a.reshape(-1, 3), b.reshape(-1, 3)))
263
-
264
-
265
- OPENGL = np.array([[1, 0, 0, 0],
266
- [0, -1, 0, 0],
267
- [0, 0, -1, 0],
268
- [0, 0, 0, 1]])
269
-
270
-
271
- CAM_COLORS = [(255, 0, 0), (0, 0, 255), (0, 255, 0), (255, 0, 255), (255, 204, 0), (0, 204, 204),
272
- (128, 255, 255), (255, 128, 255), (255, 255, 128), (0, 0, 0), (128, 128, 128)]
273
-
274
-
275
- def uint8(colors):
276
- if not isinstance(colors, np.ndarray):
277
- colors = np.array(colors)
278
- if np.issubdtype(colors.dtype, np.floating):
279
- colors *= 255
280
- assert 0 <= colors.min() and colors.max() < 256
281
- return np.uint8(colors)
282
-
283
-
284
- def segment_sky(image):
285
- import cv2
286
- from scipy import ndimage
287
-
288
- # Convert to HSV
289
- image = to_numpy(image)
290
- if np.issubdtype(image.dtype, np.floating):
291
- image = np.uint8(255*image.clip(min=0, max=1))
292
- hsv = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)
293
-
294
- # Define range for blue color and create mask
295
- lower_blue = np.array([0, 0, 100])
296
- upper_blue = np.array([30, 255, 255])
297
- mask = cv2.inRange(hsv, lower_blue, upper_blue).view(bool)
298
-
299
- # add luminous gray
300
- mask |= (hsv[:, :, 1] < 10) & (hsv[:, :, 2] > 150)
301
- mask |= (hsv[:, :, 1] < 30) & (hsv[:, :, 2] > 180)
302
- mask |= (hsv[:, :, 1] < 50) & (hsv[:, :, 2] > 220)
303
-
304
- # Morphological operations
305
- kernel = np.ones((5, 5), np.uint8)
306
- mask2 = ndimage.binary_opening(mask, structure=kernel)
307
-
308
- # keep only largest CC
309
- _, labels, stats, _ = cv2.connectedComponentsWithStats(mask2.view(np.uint8), connectivity=8)
310
- cc_sizes = stats[1:, cv2.CC_STAT_AREA]
311
- order = cc_sizes.argsort()[::-1] # bigger first
312
- i = 0
313
- selection = []
314
- while i < len(order) and cc_sizes[order[i]] > cc_sizes[order[0]] / 2:
315
- selection.append(1 + order[i])
316
- i += 1
317
- mask3 = np.in1d(labels, selection).reshape(labels.shape)
318
-
319
- # Apply mask
320
- return torch.from_numpy(mask3)