# Copyright (C) 2024-present Naver Corporation. All rights reserved. # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). # # -------------------------------------------------------- # dataset utilities # -------------------------------------------------------- import numpy as np import quaternion import torchvision.transforms as tvf from dust3r.utils.geometry import geotrf def cam_to_world_from_kapture(kdata, timestamp, camera_id): camera_to_world = kdata.trajectories[timestamp, camera_id].inverse() camera_pose = np.eye(4, dtype=np.float32) camera_pose[:3, :3] = quaternion.as_rotation_matrix(camera_to_world.r) camera_pose[:3, 3] = camera_to_world.t_raw return camera_pose ratios_resolutions = { 224: {1.0: [224, 224]}, 512: {4 / 3: [512, 384], 32 / 21: [512, 336], 16 / 9: [512, 288], 2 / 1: [512, 256], 16 / 5: [512, 160]} } def get_HW_resolution(H, W, maxdim, patchsize=16): assert maxdim in ratios_resolutions, "Error, maxdim can only be 224 or 512 for now. Other maxdims not implemented yet." ratios_resolutions_maxdim = ratios_resolutions[maxdim] mindims = set([min(res) for res in ratios_resolutions_maxdim.values()]) ratio = W / H ref_ratios = np.array([*(ratios_resolutions_maxdim.keys())]) islandscape = (W >= H) if islandscape: diff = np.abs(ratio - ref_ratios) else: diff = np.abs(ratio - (1 / ref_ratios)) selkey = ref_ratios[np.argmin(diff)] res = ratios_resolutions_maxdim[selkey] # check patchsize and make sure output resolution is a multiple of patchsize if isinstance(patchsize, tuple): assert len(patchsize) == 2 and isinstance(patchsize[0], int) and isinstance( patchsize[1], int), "What is your patchsize format? Expected a single int or a tuple of two ints." assert patchsize[0] == patchsize[1], "Error, non square patches not managed" patchsize = patchsize[0] assert max(res) == maxdim assert min(res) in mindims return res[::-1] if islandscape else res # return HW def get_resize_function(maxdim, patch_size, H, W, is_mask=False): if [max(H, W), min(H, W)] in ratios_resolutions[maxdim].values(): return lambda x: x, np.eye(3), np.eye(3) else: target_HW = get_HW_resolution(H, W, maxdim=maxdim, patchsize=patch_size) ratio = W / H target_ratio = target_HW[1] / target_HW[0] to_orig_crop = np.eye(3) to_rescaled_crop = np.eye(3) if abs(ratio - target_ratio) < np.finfo(np.float32).eps: crop_W = W crop_H = H elif ratio - target_ratio < 0: crop_W = W crop_H = int(W / target_ratio) to_orig_crop[1, 2] = (H - crop_H) / 2.0 to_rescaled_crop[1, 2] = -(H - crop_H) / 2.0 else: crop_W = int(H * target_ratio) crop_H = H to_orig_crop[0, 2] = (W - crop_W) / 2.0 to_rescaled_crop[0, 2] = - (W - crop_W) / 2.0 crop_op = tvf.CenterCrop([crop_H, crop_W]) if is_mask: resize_op = tvf.Resize(size=target_HW, interpolation=tvf.InterpolationMode.NEAREST_EXACT) else: resize_op = tvf.Resize(size=target_HW) to_orig_resize = np.array([[crop_W / target_HW[1], 0, 0], [0, crop_H / target_HW[0], 0], [0, 0, 1]]) to_rescaled_resize = np.array([[target_HW[1] / crop_W, 0, 0], [0, target_HW[0] / crop_H, 0], [0, 0, 1]]) op = tvf.Compose([crop_op, resize_op]) return op, to_rescaled_resize @ to_rescaled_crop, to_orig_crop @ to_orig_resize def rescale_points3d(pts2d, pts3d, to_resize, HR, WR): # rescale pts2d as floats # to colmap, so that the image is in [0, D] -> [0, NewD] pts2d = pts2d.copy() pts2d[:, 0] += 0.5 pts2d[:, 1] += 0.5 pts2d_rescaled = geotrf(to_resize, pts2d, norm=True) pts2d_rescaled_int = pts2d_rescaled.copy() # convert back to cv2 before round [-0.5, 0.5] -> pixel 0 pts2d_rescaled_int[:, 0] -= 0.5 pts2d_rescaled_int[:, 1] -= 0.5 pts2d_rescaled_int = pts2d_rescaled_int.round().astype(np.int64) # update valid (remove cropped regions) valid_rescaled = (pts2d_rescaled_int[:, 0] >= 0) & (pts2d_rescaled_int[:, 0] < WR) & ( pts2d_rescaled_int[:, 1] >= 0) & (pts2d_rescaled_int[:, 1] < HR) pts2d_rescaled_int = pts2d_rescaled_int[valid_rescaled] # rebuild pts3d from rescaled ps2d poses pts3d_rescaled = np.full((HR, WR, 3), np.nan, dtype=np.float32) # pts3d in 512 x something pts3d_rescaled[pts2d_rescaled_int[:, 1], pts2d_rescaled_int[:, 0]] = pts3d[valid_rescaled] return pts2d_rescaled, pts2d_rescaled_int, pts3d_rescaled, np.isfinite(pts3d_rescaled.sum(axis=-1))