File size: 8,126 Bytes
83ae704 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 |
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
#
# --------------------------------------------------------
# cropping/match extraction
# --------------------------------------------------------
import numpy as np
import mast3r.utils.path_to_dust3r # noqa
from dust3r.utils.device import to_numpy
from dust3r.utils.geometry import inv, geotrf
def reciprocal_1d(corres_1_to_2, corres_2_to_1, ret_recip=False):
is_reciprocal1 = (corres_2_to_1[corres_1_to_2] == np.arange(len(corres_1_to_2)))
pos1 = is_reciprocal1.nonzero()[0]
pos2 = corres_1_to_2[pos1]
if ret_recip:
return is_reciprocal1, pos1, pos2
return pos1, pos2
def extract_correspondences_from_pts3d(view1, view2, target_n_corres, rng=np.random, ret_xy=True, nneg=0):
view1, view2 = to_numpy((view1, view2))
# project pixels from image1 --> 3d points --> image2 pixels
shape1, corres1_to_2 = reproject_view(view1['pts3d'], view2)
shape2, corres2_to_1 = reproject_view(view2['pts3d'], view1)
# compute reciprocal correspondences:
# pos1 == valid pixels (correspondences) in image1
is_reciprocal1, pos1, pos2 = reciprocal_1d(corres1_to_2, corres2_to_1, ret_recip=True)
is_reciprocal2 = (corres1_to_2[corres2_to_1] == np.arange(len(corres2_to_1)))
if target_n_corres is None:
if ret_xy:
pos1 = unravel_xy(pos1, shape1)
pos2 = unravel_xy(pos2, shape2)
return pos1, pos2
available_negatives = min((~is_reciprocal1).sum(), (~is_reciprocal2).sum())
target_n_positives = int(target_n_corres * (1 - nneg))
n_positives = min(len(pos1), target_n_positives)
n_negatives = min(target_n_corres - n_positives, available_negatives)
if n_negatives + n_positives != target_n_corres:
# should be really rare => when there are not enough negatives
# in that case, break nneg and add a few more positives ?
n_positives = target_n_corres - n_negatives
assert n_positives <= len(pos1)
assert n_positives <= len(pos1)
assert n_positives <= len(pos2)
assert n_negatives <= (~is_reciprocal1).sum()
assert n_negatives <= (~is_reciprocal2).sum()
assert n_positives + n_negatives == target_n_corres
valid = np.ones(n_positives, dtype=bool)
if n_positives < len(pos1):
# random sub-sampling of valid correspondences
perm = rng.permutation(len(pos1))[:n_positives]
pos1 = pos1[perm]
pos2 = pos2[perm]
if n_negatives > 0:
# add false correspondences if not enough
def norm(p): return p / p.sum()
pos1 = np.r_[pos1, rng.choice(shape1[0] * shape1[1], size=n_negatives, replace=False, p=norm(~is_reciprocal1))]
pos2 = np.r_[pos2, rng.choice(shape2[0] * shape2[1], size=n_negatives, replace=False, p=norm(~is_reciprocal2))]
valid = np.r_[valid, np.zeros(n_negatives, dtype=bool)]
# convert (x+W*y) back to 2d (x,y) coordinates
if ret_xy:
pos1 = unravel_xy(pos1, shape1)
pos2 = unravel_xy(pos2, shape2)
return pos1, pos2, valid
def reproject_view(pts3d, view2):
shape = view2['pts3d'].shape[:2]
return reproject(pts3d, view2['camera_intrinsics'], inv(view2['camera_pose']), shape)
def reproject(pts3d, K, world2cam, shape):
H, W, THREE = pts3d.shape
assert THREE == 3
# reproject in camera2 space
with np.errstate(divide='ignore', invalid='ignore'):
pos = geotrf(K @ world2cam[:3], pts3d, norm=1, ncol=2)
# quantize to pixel positions
return (H, W), ravel_xy(pos, shape)
def ravel_xy(pos, shape):
H, W = shape
with np.errstate(invalid='ignore'):
qx, qy = pos.reshape(-1, 2).round().astype(np.int32).T
quantized_pos = qx.clip(min=0, max=W - 1, out=qx) + W * qy.clip(min=0, max=H - 1, out=qy)
return quantized_pos
def unravel_xy(pos, shape):
# convert (x+W*y) back to 2d (x,y) coordinates
return np.unravel_index(pos, shape)[0].base[:, ::-1].copy()
def _rotation_origin_to_pt(target):
""" Align the origin (0,0,1) with the target point (x,y,1) in projective space.
Method: rotate z to put target on (x'+,0,1), then rotate on Y to get (0,0,1) and un-rotate z.
"""
from scipy.spatial.transform import Rotation
x, y = target
rot_z = np.arctan2(y, x)
rot_y = np.arctan(np.linalg.norm(target))
R = Rotation.from_euler('ZYZ', [rot_z, rot_y, -rot_z]).as_matrix()
return R
def _dotmv(Trf, pts, ncol=None, norm=False):
assert Trf.ndim >= 2
ncol = ncol or pts.shape[-1]
# adapt shape if necessary
output_reshape = pts.shape[:-1]
if Trf.ndim >= 3:
n = Trf.ndim - 2
assert Trf.shape[:n] == pts.shape[:n], 'batch size does not match'
Trf = Trf.reshape(-1, Trf.shape[-2], Trf.shape[-1])
if pts.ndim > Trf.ndim:
# Trf == (B,d,d) & pts == (B,H,W,d) --> (B, H*W, d)
pts = pts.reshape(Trf.shape[0], -1, pts.shape[-1])
elif pts.ndim == 2:
# Trf == (B,d,d) & pts == (B,d) --> (B, 1, d)
pts = pts[:, None, :]
if pts.shape[-1] + 1 == Trf.shape[-1]:
Trf = Trf.swapaxes(-1, -2) # transpose Trf
pts = pts @ Trf[..., :-1, :] + Trf[..., -1:, :]
elif pts.shape[-1] == Trf.shape[-1]:
Trf = Trf.swapaxes(-1, -2) # transpose Trf
pts = pts @ Trf
else:
pts = Trf @ pts.T
if pts.ndim >= 2:
pts = pts.swapaxes(-1, -2)
if norm:
pts = pts / pts[..., -1:] # DONT DO /= BECAUSE OF WEIRD PYTORCH BUG
if norm != 1:
pts *= norm
res = pts[..., :ncol].reshape(*output_reshape, ncol)
return res
def crop_to_homography(K, crop, target_size=None):
""" Given an image and its intrinsics,
we want to replicate a rectangular crop with an homography,
so that the principal point of the new 'crop' is centered.
"""
# build intrinsics for the crop
crop = np.round(crop)
crop_size = crop[2:] - crop[:2]
K2 = K.copy() # same focal
K2[:2, 2] = crop_size / 2 # new principal point is perfectly centered
# find which corner is the most far-away from current principal point
# so that the final homography does not go over the image borders
corners = crop.reshape(-1, 2)
corner_idx = np.abs(corners - K[:2, 2]).argmax(0)
corner = corners[corner_idx, [0, 1]]
# align with the corresponding corner from the target view
corner2 = np.c_[[0, 0], crop_size][[0, 1], corner_idx]
old_pt = _dotmv(np.linalg.inv(K), corner, norm=1)
new_pt = _dotmv(np.linalg.inv(K2), corner2, norm=1)
R = _rotation_origin_to_pt(old_pt) @ np.linalg.inv(_rotation_origin_to_pt(new_pt))
if target_size is not None:
imsize = target_size
target_size = np.asarray(target_size)
scaling = min(target_size / crop_size)
K2[:2] *= scaling
K2[:2, 2] = target_size / 2
else:
imsize = tuple(np.int32(crop_size).tolist())
return imsize, K2, R, K @ R @ np.linalg.inv(K2)
def gen_random_crops(imsize, n_crops, resolution, aug_crop, rng=np.random):
""" Generate random crops of size=resolution,
for an input image upscaled to (imsize + randint(0 , aug_crop))
"""
resolution_crop = np.array(resolution) * min(np.array(imsize) / resolution)
# (virtually) upscale the input image
# scaling = rng.uniform(1, 1+(aug_crop+1)/min(imsize))
scaling = np.exp(rng.uniform(0, np.log(1 + aug_crop / min(imsize))))
imsize2 = np.int32(np.array(imsize) * scaling)
# generate some random crops
topleft = rng.random((n_crops, 2)) * (imsize2 - resolution_crop)
crops = np.c_[topleft, topleft + resolution_crop]
# print(f"{scaling=}, {topleft=}")
# reduce the resolution to come back to original size
crops /= scaling
return crops
def in2d_rect(corres, crops):
# corres = (N,2)
# crops = (M,4)
# output = (N, M)
is_sup = (corres[:, None] >= crops[None, :, 0:2])
is_inf = (corres[:, None] < crops[None, :, 2:4])
return (is_sup & is_inf).all(axis=-1)
|