|
|
|
|
|
|
|
|
|
|
|
|
|
import os |
|
import torch |
|
import copy |
|
import numpy as np |
|
import torchvision |
|
import numpy as np |
|
from tqdm import tqdm |
|
from scipy.cluster.hierarchy import DisjointSet |
|
from scipy.spatial.transform import Rotation as R |
|
|
|
from mast3r.utils.misc import hash_md5 |
|
|
|
from mast3r.fast_nn import extract_correspondences_nonsym, bruteforce_reciprocal_nns |
|
|
|
import mast3r.utils.path_to_dust3r |
|
from dust3r.utils.geometry import find_reciprocal_matches, xy_grid |
|
|
|
|
|
def convert_im_matches_pairs(img0, img1, image_to_colmap, im_keypoints, matches_im0, matches_im1, viz): |
|
if viz: |
|
from matplotlib import pyplot as pl |
|
|
|
image_mean = torch.as_tensor( |
|
[0.5, 0.5, 0.5], device='cpu').reshape(1, 3, 1, 1) |
|
image_std = torch.as_tensor( |
|
[0.5, 0.5, 0.5], device='cpu').reshape(1, 3, 1, 1) |
|
rgb0 = img0['img'] * image_std + image_mean |
|
rgb0 = torchvision.transforms.functional.to_pil_image(rgb0[0]) |
|
rgb0 = np.array(rgb0) |
|
|
|
rgb1 = img1['img'] * image_std + image_mean |
|
rgb1 = torchvision.transforms.functional.to_pil_image(rgb1[0]) |
|
rgb1 = np.array(rgb1) |
|
|
|
imgs = [rgb0, rgb1] |
|
|
|
n_viz = 100 |
|
num_matches = matches_im0.shape[0] |
|
match_idx_to_viz = np.round(np.linspace( |
|
0, num_matches - 1, n_viz)).astype(int) |
|
viz_matches_im0, viz_matches_im1 = matches_im0[match_idx_to_viz], matches_im1[match_idx_to_viz] |
|
|
|
H0, W0, H1, W1 = *imgs[0].shape[:2], *imgs[1].shape[:2] |
|
rgb0 = np.pad(imgs[0], ((0, max(H1 - H0, 0)), |
|
(0, 0), (0, 0)), 'constant', constant_values=0) |
|
rgb1 = np.pad(imgs[1], ((0, max(H0 - H1, 0)), |
|
(0, 0), (0, 0)), 'constant', constant_values=0) |
|
img = np.concatenate((rgb0, rgb1), axis=1) |
|
pl.figure() |
|
pl.imshow(img) |
|
cmap = pl.get_cmap('jet') |
|
for ii in range(n_viz): |
|
(x0, y0), (x1, |
|
y1) = viz_matches_im0[ii].T, viz_matches_im1[ii].T |
|
pl.plot([x0, x1 + W0], [y0, y1], '-+', color=cmap(ii / |
|
(n_viz - 1)), scalex=False, scaley=False) |
|
pl.show(block=True) |
|
|
|
matches = [matches_im0.astype(np.float64), matches_im1.astype(np.float64)] |
|
imgs = [img0, img1] |
|
imidx0 = img0['idx'] |
|
imidx1 = img1['idx'] |
|
ravel_matches = [] |
|
for j in range(2): |
|
H, W = imgs[j]['true_shape'][0] |
|
with np.errstate(invalid='ignore'): |
|
qx, qy = matches[j].round().astype(np.int32).T |
|
ravel_matches_j = qx.clip(min=0, max=W - 1, out=qx) + W * qy.clip(min=0, max=H - 1, out=qy) |
|
ravel_matches.append(ravel_matches_j) |
|
imidxj = imgs[j]['idx'] |
|
for m in ravel_matches_j: |
|
if m not in im_keypoints[imidxj]: |
|
im_keypoints[imidxj][m] = 0 |
|
im_keypoints[imidxj][m] += 1 |
|
imid0 = copy.deepcopy(image_to_colmap[imidx0]['colmap_imid']) |
|
imid1 = copy.deepcopy(image_to_colmap[imidx1]['colmap_imid']) |
|
if imid0 > imid1: |
|
colmap_matches = np.stack([ravel_matches[1], ravel_matches[0]], axis=-1) |
|
imid0, imid1 = imid1, imid0 |
|
imidx0, imidx1 = imidx1, imidx0 |
|
else: |
|
colmap_matches = np.stack([ravel_matches[0], ravel_matches[1]], axis=-1) |
|
colmap_matches = np.unique(colmap_matches, axis=0) |
|
return imidx0, imidx1, colmap_matches |
|
|
|
|
|
def get_im_matches(pred1, pred2, pairs, image_to_colmap, im_keypoints, conf_thr, |
|
is_sparse=True, subsample=8, pixel_tol=0, viz=False, device='cuda'): |
|
im_matches = {} |
|
for i in range(len(pred1['pts3d'])): |
|
imidx0 = pairs[i][0]['idx'] |
|
imidx1 = pairs[i][1]['idx'] |
|
if 'desc' in pred1: |
|
descs = [pred1['desc'][i], pred2['desc'][i]] |
|
confidences = [pred1['desc_conf'][i], pred2['desc_conf'][i]] |
|
desc_dim = descs[0].shape[-1] |
|
|
|
if is_sparse: |
|
corres = extract_correspondences_nonsym(descs[0], descs[1], confidences[0], confidences[1], |
|
device=device, subsample=subsample, pixel_tol=pixel_tol) |
|
conf = corres[2] |
|
mask = conf >= conf_thr |
|
matches_im0 = corres[0][mask].cpu().numpy() |
|
matches_im1 = corres[1][mask].cpu().numpy() |
|
else: |
|
confidence_masks = [confidences[0] >= |
|
conf_thr, confidences[1] >= conf_thr] |
|
pts2d_list, desc_list = [], [] |
|
for j in range(2): |
|
conf_j = confidence_masks[j].cpu().numpy().flatten() |
|
true_shape_j = pairs[i][j]['true_shape'][0] |
|
pts2d_j = xy_grid( |
|
true_shape_j[1], true_shape_j[0]).reshape(-1, 2)[conf_j] |
|
desc_j = descs[j].detach().cpu( |
|
).numpy().reshape(-1, desc_dim)[conf_j] |
|
pts2d_list.append(pts2d_j) |
|
desc_list.append(desc_j) |
|
if len(desc_list[0]) == 0 or len(desc_list[1]) == 0: |
|
continue |
|
|
|
nn0, nn1 = bruteforce_reciprocal_nns(desc_list[0], desc_list[1], |
|
device=device, dist='dot', block_size=2**13) |
|
reciprocal_in_P0 = (nn1[nn0] == np.arange(len(nn0))) |
|
|
|
matches_im1 = pts2d_list[1][nn0][reciprocal_in_P0] |
|
matches_im0 = pts2d_list[0][reciprocal_in_P0] |
|
else: |
|
pts3d = [pred1['pts3d'][i], pred2['pts3d_in_other_view'][i]] |
|
confidences = [pred1['conf'][i], pred2['conf'][i]] |
|
|
|
if is_sparse: |
|
corres = extract_correspondences_nonsym(pts3d[0], pts3d[1], confidences[0], confidences[1], |
|
device=device, subsample=subsample, pixel_tol=pixel_tol, |
|
ptmap_key='3d') |
|
conf = corres[2] |
|
mask = conf >= conf_thr |
|
matches_im0 = corres[0][mask].cpu().numpy() |
|
matches_im1 = corres[1][mask].cpu().numpy() |
|
else: |
|
confidence_masks = [confidences[0] >= |
|
conf_thr, confidences[1] >= conf_thr] |
|
|
|
pts2d_list, pts3d_list = [], [] |
|
for j in range(2): |
|
conf_j = confidence_masks[j].cpu().numpy().flatten() |
|
true_shape_j = pairs[i][j]['true_shape'][0] |
|
pts2d_j = xy_grid(true_shape_j[1], true_shape_j[0]).reshape(-1, 2)[conf_j] |
|
pts3d_j = pts3d[j].detach().cpu().numpy().reshape(-1, 3)[conf_j] |
|
pts2d_list.append(pts2d_j) |
|
pts3d_list.append(pts3d_j) |
|
|
|
PQ, PM = pts3d_list[0], pts3d_list[1] |
|
if len(PQ) == 0 or len(PM) == 0: |
|
continue |
|
reciprocal_in_PM, nnM_in_PQ, num_matches = find_reciprocal_matches( |
|
PQ, PM) |
|
|
|
matches_im1 = pts2d_list[1][reciprocal_in_PM] |
|
matches_im0 = pts2d_list[0][nnM_in_PQ][reciprocal_in_PM] |
|
|
|
if len(matches_im0) == 0: |
|
continue |
|
imidx0, imidx1, colmap_matches = convert_im_matches_pairs(pairs[i][0], pairs[i][1], |
|
image_to_colmap, im_keypoints, |
|
matches_im0, matches_im1, viz) |
|
im_matches[(imidx0, imidx1)] = colmap_matches |
|
return im_matches |
|
|
|
|
|
def get_im_matches_from_cache(pairs, cache_path, desc_conf, subsample, |
|
image_to_colmap, im_keypoints, conf_thr, |
|
viz=False, device='cuda'): |
|
im_matches = {} |
|
for i in range(len(pairs)): |
|
imidx0 = pairs[i][0]['idx'] |
|
imidx1 = pairs[i][1]['idx'] |
|
|
|
corres_idx1 = hash_md5(pairs[i][0]['instance']) |
|
corres_idx2 = hash_md5(pairs[i][1]['instance']) |
|
|
|
path_corres = cache_path + f'/corres_conf={desc_conf}_{subsample=}/{corres_idx1}-{corres_idx2}.pth' |
|
if os.path.isfile(path_corres): |
|
score, (xy1, xy2, confs) = torch.load(path_corres, map_location=device) |
|
else: |
|
path_corres = cache_path + f'/corres_conf={desc_conf}_{subsample=}/{corres_idx2}-{corres_idx1}.pth' |
|
score, (xy2, xy1, confs) = torch.load(path_corres, map_location=device) |
|
mask = confs >= conf_thr |
|
matches_im0 = xy1[mask].cpu().numpy() |
|
matches_im1 = xy2[mask].cpu().numpy() |
|
|
|
if len(matches_im0) == 0: |
|
continue |
|
imidx0, imidx1, colmap_matches = convert_im_matches_pairs(pairs[i][0], pairs[i][1], |
|
image_to_colmap, im_keypoints, |
|
matches_im0, matches_im1, viz) |
|
im_matches[(imidx0, imidx1)] = colmap_matches |
|
return im_matches |
|
|
|
|
|
def export_images(db, images, image_paths, focals, ga_world_to_cam, camera_model): |
|
|
|
|
|
image_to_colmap = {} |
|
im_keypoints = {} |
|
for idx in range(len(image_paths)): |
|
im_keypoints[idx] = {} |
|
H, W = images[idx]["orig_shape"] |
|
if focals is None: |
|
focal_x = focal_y = 1.2 * max(W, H) |
|
prior_focal_length = False |
|
cx = W / 2.0 |
|
cy = H / 2.0 |
|
elif isinstance(focals[idx], np.ndarray) and len(focals[idx].shape) == 2: |
|
|
|
focal_x = focals[idx][0, 0] |
|
focal_y = focals[idx][1, 1] |
|
cx = focals[idx][0, 2] * images[idx]["to_orig"][0, 0] |
|
cy = focals[idx][1, 2] * images[idx]["to_orig"][1, 1] |
|
prior_focal_length = True |
|
else: |
|
focal_x = focal_y = float(focals[idx]) |
|
prior_focal_length = True |
|
cx = W / 2.0 |
|
cy = H / 2.0 |
|
focal_x = focal_x * images[idx]["to_orig"][0, 0] |
|
focal_y = focal_y * images[idx]["to_orig"][1, 1] |
|
|
|
if camera_model == "SIMPLE_PINHOLE": |
|
model_id = 0 |
|
focal = (focal_x + focal_y) / 2.0 |
|
params = np.asarray([focal, cx, cy], np.float64) |
|
elif camera_model == "PINHOLE": |
|
model_id = 1 |
|
params = np.asarray([focal_x, focal_y, cx, cy], np.float64) |
|
elif camera_model == "SIMPLE_RADIAL": |
|
model_id = 2 |
|
focal = (focal_x + focal_y) / 2.0 |
|
params = np.asarray([focal, cx, cy, 0.0], np.float64) |
|
elif camera_model == "OPENCV": |
|
model_id = 4 |
|
params = np.asarray([focal_x, focal_y, cx, cy, 0.0, 0.0, 0.0, 0.0], np.float64) |
|
else: |
|
raise ValueError(f"invalid camera model {camera_model}") |
|
|
|
H, W = int(H), int(W) |
|
|
|
camid = db.add_camera( |
|
model_id, W, H, params, prior_focal_length=prior_focal_length) |
|
if ga_world_to_cam is None: |
|
prior_t = np.zeros(3) |
|
prior_q = np.zeros(4) |
|
else: |
|
q = R.from_matrix(ga_world_to_cam[idx][:3, :3]).as_quat() |
|
prior_t = ga_world_to_cam[idx][:3, 3] |
|
prior_q = np.array([q[-1], q[0], q[1], q[2]]) |
|
imid = db.add_image( |
|
image_paths[idx], camid, prior_q=prior_q, prior_t=prior_t) |
|
image_to_colmap[idx] = { |
|
'colmap_imid': imid, |
|
'colmap_camid': camid |
|
} |
|
return image_to_colmap, im_keypoints |
|
|
|
|
|
def export_matches(db, images, image_to_colmap, im_keypoints, im_matches, min_len_track, skip_geometric_verification): |
|
colmap_image_pairs = [] |
|
|
|
|
|
|
|
|
|
print("building tracks") |
|
keypoints_to_track_id = {} |
|
track_id_to_kpt_list = [] |
|
to_merge = [] |
|
for (imidx0, imidx1), colmap_matches in tqdm(im_matches.items()): |
|
if imidx0 not in keypoints_to_track_id: |
|
keypoints_to_track_id[imidx0] = {} |
|
if imidx1 not in keypoints_to_track_id: |
|
keypoints_to_track_id[imidx1] = {} |
|
|
|
for m in colmap_matches: |
|
if m[0] not in keypoints_to_track_id[imidx0] and m[1] not in keypoints_to_track_id[imidx1]: |
|
|
|
track_idx = len(track_id_to_kpt_list) |
|
keypoints_to_track_id[imidx0][m[0]] = track_idx |
|
keypoints_to_track_id[imidx1][m[1]] = track_idx |
|
track_id_to_kpt_list.append( |
|
[(imidx0, m[0]), (imidx1, m[1])]) |
|
elif m[1] not in keypoints_to_track_id[imidx1]: |
|
|
|
track_idx = keypoints_to_track_id[imidx0][m[0]] |
|
keypoints_to_track_id[imidx1][m[1]] = track_idx |
|
track_id_to_kpt_list[track_idx].append((imidx1, m[1])) |
|
elif m[0] not in keypoints_to_track_id[imidx0]: |
|
|
|
track_idx = keypoints_to_track_id[imidx1][m[1]] |
|
keypoints_to_track_id[imidx0][m[0]] = track_idx |
|
track_id_to_kpt_list[track_idx].append((imidx0, m[0])) |
|
else: |
|
|
|
track_idx0 = keypoints_to_track_id[imidx0][m[0]] |
|
track_idx1 = keypoints_to_track_id[imidx1][m[1]] |
|
if track_idx0 != track_idx1: |
|
|
|
to_merge.append((track_idx0, track_idx1)) |
|
|
|
|
|
print("merging tracks") |
|
unique = np.unique(to_merge) |
|
tree = DisjointSet(unique) |
|
for track_idx0, track_idx1 in tqdm(to_merge): |
|
tree.merge(track_idx0, track_idx1) |
|
|
|
subsets = tree.subsets() |
|
print("applying merge") |
|
for setvals in tqdm(subsets): |
|
new_trackid = len(track_id_to_kpt_list) |
|
kpt_list = [] |
|
for track_idx in setvals: |
|
kpt_list.extend(track_id_to_kpt_list[track_idx]) |
|
for imidx, kpid in track_id_to_kpt_list[track_idx]: |
|
keypoints_to_track_id[imidx][kpid] = new_trackid |
|
track_id_to_kpt_list.append(kpt_list) |
|
|
|
|
|
|
|
|
|
|
|
|
|
num_valid_tracks = sum( |
|
[1 for v in track_id_to_kpt_list if len(v) >= min_len_track]) |
|
|
|
keypoints_to_idx = {} |
|
print(f"squashing keypoints - {num_valid_tracks} valid tracks") |
|
for imidx, keypoints_imid in tqdm(im_keypoints.items()): |
|
imid = image_to_colmap[imidx]['colmap_imid'] |
|
keypoints_kept = [] |
|
keypoints_to_idx[imidx] = {} |
|
for kp in keypoints_imid.keys(): |
|
if kp not in keypoints_to_track_id[imidx]: |
|
continue |
|
track_idx = keypoints_to_track_id[imidx][kp] |
|
track_length = len(track_id_to_kpt_list[track_idx]) |
|
if track_length < min_len_track: |
|
continue |
|
keypoints_to_idx[imidx][kp] = len(keypoints_kept) |
|
keypoints_kept.append(kp) |
|
if len(keypoints_kept) == 0: |
|
continue |
|
keypoints_kept = np.array(keypoints_kept) |
|
keypoints_kept = np.unravel_index(keypoints_kept, images[imidx]['true_shape'][0])[ |
|
0].base[:, ::-1].copy().astype(np.float32) |
|
|
|
keypoints_kept[:, 0] += 0.5 |
|
keypoints_kept[:, 1] += 0.5 |
|
keypoints_kept = geotrf(images[imidx]['to_orig'], keypoints_kept, norm=True) |
|
|
|
H, W = images[imidx]['orig_shape'] |
|
keypoints_kept[:, 0] = keypoints_kept[:, 0].clip(min=0, max=W - 0.01) |
|
keypoints_kept[:, 1] = keypoints_kept[:, 1].clip(min=0, max=H - 0.01) |
|
|
|
db.add_keypoints(imid, keypoints_kept) |
|
|
|
print("exporting im_matches") |
|
for (imidx0, imidx1), colmap_matches in im_matches.items(): |
|
imid0, imid1 = image_to_colmap[imidx0]['colmap_imid'], image_to_colmap[imidx1]['colmap_imid'] |
|
assert imid0 < imid1 |
|
final_matches = np.array([[keypoints_to_idx[imidx0][m[0]], keypoints_to_idx[imidx1][m[1]]] |
|
for m in colmap_matches |
|
if m[0] in keypoints_to_idx[imidx0] and m[1] in keypoints_to_idx[imidx1]]) |
|
if len(final_matches) > 0: |
|
colmap_image_pairs.append( |
|
(images[imidx0]['instance'], images[imidx1]['instance'])) |
|
db.add_matches(imid0, imid1, final_matches) |
|
if skip_geometric_verification: |
|
db.add_two_view_geometry(imid0, imid1, final_matches) |
|
return colmap_image_pairs |
|
|