|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import os |
|
import numpy as np |
|
import random |
|
import torch |
|
import torchvision.transforms as tvf |
|
import argparse |
|
from tqdm import tqdm |
|
from PIL import Image |
|
import math |
|
|
|
from mast3r.model import AsymmetricMASt3R |
|
from mast3r.fast_nn import fast_reciprocal_NNs |
|
from mast3r.utils.coarse_to_fine import select_pairs_of_crops, crop_slice |
|
from mast3r.utils.collate import cat_collate, cat_collate_fn_map |
|
from mast3r.utils.misc import mkdir_for |
|
from mast3r.datasets.utils.cropping import crop_to_homography |
|
|
|
import mast3r.utils.path_to_dust3r |
|
from dust3r.inference import inference, loss_of_one_batch |
|
from dust3r.utils.geometry import geotrf, colmap_to_opencv_intrinsics, opencv_to_colmap_intrinsics |
|
from dust3r.datasets.utils.transforms import ImgNorm |
|
from dust3r_visloc.datasets import * |
|
from dust3r_visloc.localization import run_pnp |
|
from dust3r_visloc.evaluation import get_pose_error, aggregate_stats, export_results |
|
from dust3r_visloc.datasets.utils import get_HW_resolution, rescale_points3d |
|
|
|
|
|
def get_args_parser(): |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument("--dataset", type=str, required=True, help="visloc dataset to eval") |
|
parser_weights = parser.add_mutually_exclusive_group(required=True) |
|
parser_weights.add_argument("--weights", type=str, help="path to the model weights", default=None) |
|
parser_weights.add_argument("--model_name", type=str, help="name of the model weights", |
|
choices=["MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric"]) |
|
|
|
parser.add_argument("--confidence_threshold", type=float, default=1.001, |
|
help="confidence values higher than threshold are invalid") |
|
parser.add_argument('--pixel_tol', default=5, type=int) |
|
|
|
parser.add_argument("--coarse_to_fine", action='store_true', default=False, |
|
help="do the matching from coarse to fine") |
|
parser.add_argument("--max_image_size", type=int, default=None, |
|
help="max image size for the fine resolution") |
|
parser.add_argument("--c2f_crop_with_homography", action='store_true', default=False, |
|
help="when using coarse to fine, crop with homographies to keep cx, cy centered") |
|
|
|
parser.add_argument("--device", type=str, default='cuda', help="pytorch device") |
|
parser.add_argument("--pnp_mode", type=str, default="cv2", choices=['cv2', 'poselib', 'pycolmap'], |
|
help="pnp lib to use") |
|
parser_reproj = parser.add_mutually_exclusive_group() |
|
parser_reproj.add_argument("--reprojection_error", type=float, default=5.0, help="pnp reprojection error") |
|
parser_reproj.add_argument("--reprojection_error_diag_ratio", type=float, default=None, |
|
help="pnp reprojection error as a ratio of the diagonal of the image") |
|
|
|
parser.add_argument("--max_batch_size", type=int, default=48, |
|
help="max batch size for inference on crops when using coarse to fine") |
|
parser.add_argument("--pnp_max_points", type=int, default=100_000, help="pnp maximum number of points kept") |
|
parser.add_argument("--viz_matches", type=int, default=0, help="debug matches") |
|
|
|
parser.add_argument("--output_dir", type=str, default=None, help="output path") |
|
parser.add_argument("--output_label", type=str, default='', help="prefix for results files") |
|
return parser |
|
|
|
|
|
@torch.no_grad() |
|
def coarse_matching(query_view, map_view, model, device, pixel_tol, fast_nn_params): |
|
|
|
imgs = [] |
|
for idx, img in enumerate([query_view['rgb_rescaled'], map_view['rgb_rescaled']]): |
|
imgs.append(dict(img=img.unsqueeze(0), true_shape=np.int32([img.shape[1:]]), |
|
idx=idx, instance=str(idx))) |
|
output = inference([tuple(imgs)], model, device, batch_size=1, verbose=False) |
|
pred1, pred2 = output['pred1'], output['pred2'] |
|
conf_list = [pred1['desc_conf'].squeeze(0).cpu().numpy(), pred2['desc_conf'].squeeze(0).cpu().numpy()] |
|
desc_list = [pred1['desc'].squeeze(0).detach(), pred2['desc'].squeeze(0).detach()] |
|
|
|
|
|
PQ, PM = desc_list[0], desc_list[1] |
|
if len(PQ) == 0 or len(PM) == 0: |
|
return [], [], [], [] |
|
|
|
if pixel_tol == 0: |
|
matches_im_map, matches_im_query = fast_reciprocal_NNs(PM, PQ, subsample_or_initxy1=8, **fast_nn_params) |
|
HM, WM = map_view['rgb_rescaled'].shape[1:] |
|
HQ, WQ = query_view['rgb_rescaled'].shape[1:] |
|
|
|
valid_matches_map = (matches_im_map[:, 0] >= 3) & (matches_im_map[:, 0] < WM - 3) & ( |
|
matches_im_map[:, 1] >= 3) & (matches_im_map[:, 1] < HM - 3) |
|
valid_matches_query = (matches_im_query[:, 0] >= 3) & (matches_im_query[:, 0] < WQ - 3) & ( |
|
matches_im_query[:, 1] >= 3) & (matches_im_query[:, 1] < HQ - 3) |
|
valid_matches = valid_matches_map & valid_matches_query |
|
matches_im_map = matches_im_map[valid_matches] |
|
matches_im_query = matches_im_query[valid_matches] |
|
valid_pts3d = [] |
|
matches_confs = [] |
|
else: |
|
yM, xM = torch.where(map_view['valid_rescaled']) |
|
matches_im_map, matches_im_query = fast_reciprocal_NNs(PM, PQ, (xM, yM), pixel_tol=pixel_tol, **fast_nn_params) |
|
valid_pts3d = map_view['pts3d_rescaled'].cpu().numpy()[matches_im_map[:, 1], matches_im_map[:, 0]] |
|
matches_confs = np.minimum( |
|
conf_list[1][matches_im_map[:, 1], matches_im_map[:, 0]], |
|
conf_list[0][matches_im_query[:, 1], matches_im_query[:, 0]] |
|
) |
|
|
|
matches_im_query = matches_im_query.astype(np.float64) |
|
matches_im_map = matches_im_map.astype(np.float64) |
|
matches_im_query[:, 0] += 0.5 |
|
matches_im_query[:, 1] += 0.5 |
|
matches_im_map[:, 0] += 0.5 |
|
matches_im_map[:, 1] += 0.5 |
|
|
|
matches_im_query = geotrf(query_view['to_orig'], matches_im_query, norm=True) |
|
matches_im_map = geotrf(map_view['to_orig'], matches_im_map, norm=True) |
|
|
|
matches_im_query[:, 0] -= 0.5 |
|
matches_im_query[:, 1] -= 0.5 |
|
matches_im_map[:, 0] -= 0.5 |
|
matches_im_map[:, 1] -= 0.5 |
|
return valid_pts3d, matches_im_query, matches_im_map, matches_confs |
|
|
|
|
|
@torch.no_grad() |
|
def crops_inference(pairs, model, device, batch_size=48, verbose=True): |
|
assert len(pairs) == 2, "Error, data should be a tuple of dicts containing the batch of image pairs" |
|
|
|
B = pairs[0]['img'].shape[0] |
|
if B < batch_size: |
|
return loss_of_one_batch(pairs, model, None, device=device, symmetrize_batch=False) |
|
preds = [] |
|
for ii in range(0, B, batch_size): |
|
sel = slice(ii, ii + min(B - ii, batch_size)) |
|
temp_data = [{}, {}] |
|
for di in [0, 1]: |
|
temp_data[di] = {kk: pairs[di][kk][sel] |
|
for kk in pairs[di].keys() if pairs[di][kk] is not None} |
|
preds.append(loss_of_one_batch(temp_data, model, |
|
None, device=device, symmetrize_batch=False)) |
|
|
|
return cat_collate(preds, collate_fn_map=cat_collate_fn_map) |
|
|
|
|
|
@torch.no_grad() |
|
def fine_matching(query_views, map_views, model, device, max_batch_size, pixel_tol, fast_nn_params): |
|
assert pixel_tol > 0 |
|
output = crops_inference([query_views, map_views], |
|
model, device, batch_size=max_batch_size, verbose=False) |
|
pred1, pred2 = output['pred1'], output['pred2'] |
|
descs1 = pred1['desc'].clone() |
|
descs2 = pred2['desc'].clone() |
|
confs1 = pred1['desc_conf'].clone() |
|
confs2 = pred2['desc_conf'].clone() |
|
|
|
|
|
valid_pts3d, matches_im_map, matches_im_query, matches_confs = [], [], [], [] |
|
for ppi, (pp1, pp2, cc11, cc21) in enumerate(zip(descs1, descs2, confs1, confs2)): |
|
valid_ppi = map_views['valid'][ppi] |
|
pts3d_ppi = map_views['pts3d'][ppi].cpu().numpy() |
|
conf_list_ppi = [cc11.cpu().numpy(), cc21.cpu().numpy()] |
|
|
|
y_ppi, x_ppi = torch.where(valid_ppi) |
|
matches_im_map_ppi, matches_im_query_ppi = fast_reciprocal_NNs(pp2, pp1, (x_ppi, y_ppi), |
|
pixel_tol=pixel_tol, **fast_nn_params) |
|
|
|
valid_pts3d_ppi = pts3d_ppi[matches_im_map_ppi[:, 1], matches_im_map_ppi[:, 0]] |
|
matches_confs_ppi = np.minimum( |
|
conf_list_ppi[1][matches_im_map_ppi[:, 1], matches_im_map_ppi[:, 0]], |
|
conf_list_ppi[0][matches_im_query_ppi[:, 1], matches_im_query_ppi[:, 0]] |
|
) |
|
|
|
matches_im_map_ppi = geotrf(map_views['to_orig'][ppi].cpu().numpy(), matches_im_map_ppi.copy(), norm=True) |
|
matches_im_query_ppi = geotrf(query_views['to_orig'][ppi].cpu().numpy(), matches_im_query_ppi.copy(), norm=True) |
|
|
|
matches_im_map.append(matches_im_map_ppi) |
|
matches_im_query.append(matches_im_query_ppi) |
|
valid_pts3d.append(valid_pts3d_ppi) |
|
matches_confs.append(matches_confs_ppi) |
|
|
|
if len(valid_pts3d) == 0: |
|
return [], [], [], [] |
|
|
|
matches_im_map = np.concatenate(matches_im_map, axis=0) |
|
matches_im_query = np.concatenate(matches_im_query, axis=0) |
|
valid_pts3d = np.concatenate(valid_pts3d, axis=0) |
|
matches_confs = np.concatenate(matches_confs, axis=0) |
|
return valid_pts3d, matches_im_query, matches_im_map, matches_confs |
|
|
|
|
|
def crop(img, mask, pts3d, crop, intrinsics=None): |
|
out_cropped_img = img.clone() |
|
if mask is not None: |
|
out_cropped_mask = mask.clone() |
|
else: |
|
out_cropped_mask = None |
|
if pts3d is not None: |
|
out_cropped_pts3d = pts3d.clone() |
|
else: |
|
out_cropped_pts3d = None |
|
to_orig = torch.eye(3, device=img.device) |
|
|
|
|
|
if intrinsics is not None: |
|
K_old = intrinsics |
|
imsize, K_new, R, H = crop_to_homography(K_old, crop) |
|
|
|
H /= H[2, 2] |
|
homo8 = H.ravel().tolist()[:8] |
|
|
|
pilim = Image.fromarray((255 * (img + 1.) / 2).to(torch.uint8).numpy()) |
|
pilout_cropped_img = pilim.transform(imsize, Image.Transform.PERSPECTIVE, |
|
homo8, resample=Image.Resampling.BICUBIC) |
|
|
|
|
|
out_cropped_img = 2. * torch.tensor(np.array(pilout_cropped_img)).to(img) / 255. - 1. |
|
if out_cropped_mask is not None: |
|
pilmask = Image.fromarray((255 * out_cropped_mask).to(torch.uint8).numpy()) |
|
pilout_cropped_mask = pilmask.transform( |
|
imsize, Image.Transform.PERSPECTIVE, homo8, resample=Image.Resampling.NEAREST) |
|
out_cropped_mask = torch.from_numpy(np.array(pilout_cropped_mask) > 0).to(out_cropped_mask.dtype) |
|
if out_cropped_pts3d is not None: |
|
out_cropped_pts3d = out_cropped_pts3d.numpy() |
|
out_cropped_X = np.array(Image.fromarray(out_cropped_pts3d[:, :, 0]).transform(imsize, |
|
Image.Transform.PERSPECTIVE, |
|
homo8, |
|
resample=Image.Resampling.NEAREST)) |
|
out_cropped_Y = np.array(Image.fromarray(out_cropped_pts3d[:, :, 1]).transform(imsize, |
|
Image.Transform.PERSPECTIVE, |
|
homo8, |
|
resample=Image.Resampling.NEAREST)) |
|
out_cropped_Z = np.array(Image.fromarray(out_cropped_pts3d[:, :, 2]).transform(imsize, |
|
Image.Transform.PERSPECTIVE, |
|
homo8, |
|
resample=Image.Resampling.NEAREST)) |
|
|
|
out_cropped_pts3d = torch.from_numpy(np.stack([out_cropped_X, out_cropped_Y, out_cropped_Z], axis=-1)) |
|
|
|
to_orig = torch.tensor(H, device=img.device) |
|
else: |
|
out_cropped_img = img[crop_slice(crop)] |
|
if out_cropped_mask is not None: |
|
out_cropped_mask = out_cropped_mask[crop_slice(crop)] |
|
if out_cropped_pts3d is not None: |
|
out_cropped_pts3d = out_cropped_pts3d[crop_slice(crop)] |
|
to_orig[:2, -1] = torch.tensor(crop[:2]) |
|
|
|
return out_cropped_img, out_cropped_mask, out_cropped_pts3d, to_orig |
|
|
|
|
|
def resize_image_to_max(max_image_size, rgb, K): |
|
W, H = rgb.size |
|
if max_image_size and max(W, H) > max_image_size: |
|
islandscape = (W >= H) |
|
if islandscape: |
|
WMax = max_image_size |
|
HMax = int(H * (WMax / W)) |
|
else: |
|
HMax = max_image_size |
|
WMax = int(W * (HMax / H)) |
|
resize_op = tvf.Compose([ImgNorm, tvf.Resize(size=[HMax, WMax])]) |
|
rgb_tensor = resize_op(rgb).permute(1, 2, 0) |
|
to_orig_max = np.array([[W / WMax, 0, 0], |
|
[0, H / HMax, 0], |
|
[0, 0, 1]]) |
|
to_resize_max = np.array([[WMax / W, 0, 0], |
|
[0, HMax / H, 0], |
|
[0, 0, 1]]) |
|
|
|
|
|
new_K = opencv_to_colmap_intrinsics(K) |
|
new_K[0, :] *= WMax / W |
|
new_K[1, :] *= HMax / H |
|
new_K = colmap_to_opencv_intrinsics(new_K) |
|
else: |
|
rgb_tensor = ImgNorm(rgb).permute(1, 2, 0) |
|
to_orig_max = np.eye(3) |
|
to_resize_max = np.eye(3) |
|
HMax, WMax = H, W |
|
new_K = K |
|
return rgb_tensor, new_K, to_orig_max, to_resize_max, (HMax, WMax) |
|
|
|
|
|
if __name__ == '__main__': |
|
parser = get_args_parser() |
|
args = parser.parse_args() |
|
conf_thr = args.confidence_threshold |
|
device = args.device |
|
pnp_mode = args.pnp_mode |
|
assert args.pixel_tol > 0 |
|
reprojection_error = args.reprojection_error |
|
reprojection_error_diag_ratio = args.reprojection_error_diag_ratio |
|
pnp_max_points = args.pnp_max_points |
|
viz_matches = args.viz_matches |
|
|
|
if args.weights is not None: |
|
weights_path = args.weights |
|
else: |
|
weights_path = "naver/" + args.model_name |
|
model = AsymmetricMASt3R.from_pretrained(weights_path).to(args.device) |
|
fast_nn_params = dict(device=device, dist='dot', block_size=2**13) |
|
dataset = eval(args.dataset) |
|
dataset.set_resolution(model) |
|
|
|
query_names = [] |
|
poses_pred = [] |
|
pose_errors = [] |
|
angular_errors = [] |
|
params_str = f'tol_{args.pixel_tol}' + ("_c2f" if args.coarse_to_fine else '') |
|
if args.max_image_size is not None: |
|
params_str = params_str + f'_{args.max_image_size}' |
|
if args.coarse_to_fine and args.c2f_crop_with_homography: |
|
params_str = params_str + '_with_homography' |
|
for idx in tqdm(range(len(dataset))): |
|
views = dataset[(idx)] |
|
query_view = views[0] |
|
map_views = views[1:] |
|
query_names.append(query_view['image_name']) |
|
|
|
query_pts2d = [] |
|
query_pts3d = [] |
|
maxdim = max(model.patch_embed.img_size) |
|
query_rgb_tensor, query_K, query_to_orig_max, query_to_resize_max, (HQ, WQ) = resize_image_to_max( |
|
args.max_image_size, query_view['rgb'], query_view['intrinsics']) |
|
|
|
|
|
query_resolution = get_HW_resolution(HQ, WQ, maxdim=maxdim, patchsize=model.patch_embed.patch_size) |
|
for map_view in map_views: |
|
if args.output_dir is not None: |
|
cache_file = os.path.join(args.output_dir, 'matches', params_str, |
|
query_view['image_name'], map_view['image_name'] + '.npz') |
|
else: |
|
cache_file = None |
|
|
|
if cache_file is not None and os.path.isfile(cache_file): |
|
matches = np.load(cache_file) |
|
valid_pts3d = matches['valid_pts3d'] |
|
matches_im_query = matches['matches_im_query'] |
|
matches_im_map = matches['matches_im_map'] |
|
matches_conf = matches['matches_conf'] |
|
else: |
|
|
|
if args.coarse_to_fine and (maxdim < max(WQ, HQ)): |
|
|
|
_, coarse_matches_im0, coarse_matches_im1, _ = coarse_matching(query_view, map_view, model, device, |
|
0, fast_nn_params) |
|
|
|
|
|
if viz_matches > 0: |
|
num_matches = coarse_matches_im1.shape[0] |
|
print(f'found {num_matches} matches') |
|
|
|
viz_imgs = [np.array(query_view['rgb']), np.array(map_view['rgb'])] |
|
from matplotlib import pyplot as pl |
|
n_viz = viz_matches |
|
match_idx_to_viz = np.round(np.linspace(0, num_matches - 1, n_viz)).astype(int) |
|
viz_matches_im_query = coarse_matches_im0[match_idx_to_viz] |
|
viz_matches_im_map = coarse_matches_im1[match_idx_to_viz] |
|
|
|
H0, W0, H1, W1 = *viz_imgs[0].shape[:2], *viz_imgs[1].shape[:2] |
|
img0 = np.pad(viz_imgs[0], ((0, max(H1 - H0, 0)), (0, 0), (0, 0)), |
|
'constant', constant_values=0) |
|
img1 = np.pad(viz_imgs[1], ((0, max(H0 - H1, 0)), (0, 0), (0, 0)), |
|
'constant', constant_values=0) |
|
img = np.concatenate((img0, img1), axis=1) |
|
pl.figure() |
|
pl.imshow(img) |
|
cmap = pl.get_cmap('jet') |
|
for i in range(n_viz): |
|
(x0, y0), (x1, y1) = viz_matches_im_query[i].T, viz_matches_im_map[i].T |
|
pl.plot([x0, x1 + W0], [y0, y1], '-+', |
|
color=cmap(i / (n_viz - 1)), scalex=False, scaley=False) |
|
pl.show(block=True) |
|
|
|
valid_all = map_view['valid'] |
|
pts3d = map_view['pts3d'] |
|
|
|
WM_full, HM_full = map_view['rgb'].size |
|
map_rgb_tensor, map_K, map_to_orig_max, map_to_resize_max, (HM, WM) = resize_image_to_max( |
|
args.max_image_size, map_view['rgb'], map_view['intrinsics']) |
|
if WM_full != WM or HM_full != HM: |
|
y_full, x_full = torch.where(valid_all) |
|
pos2d_cv2 = torch.stack([x_full, y_full], dim=-1).cpu().numpy().astype(np.float64) |
|
sparse_pts3d = pts3d[y_full, x_full].cpu().numpy() |
|
_, _, pts3d_max, valid_max = rescale_points3d( |
|
pos2d_cv2, sparse_pts3d, map_to_resize_max, HM, WM) |
|
pts3d = torch.from_numpy(pts3d_max) |
|
valid_all = torch.from_numpy(valid_max) |
|
|
|
coarse_matches_im0 = geotrf(query_to_resize_max, coarse_matches_im0, norm=True) |
|
coarse_matches_im1 = geotrf(map_to_resize_max, coarse_matches_im1, norm=True) |
|
|
|
crops1, crops2 = [], [] |
|
crops_v1, crops_p1 = [], [] |
|
to_orig1, to_orig2 = [], [] |
|
map_resolution = get_HW_resolution(HM, WM, maxdim=maxdim, patchsize=model.patch_embed.patch_size) |
|
|
|
for crop_q, crop_b, pair_tag in select_pairs_of_crops(map_rgb_tensor, |
|
query_rgb_tensor, |
|
coarse_matches_im1, |
|
coarse_matches_im0, |
|
maxdim=maxdim, |
|
overlap=.5, |
|
forced_resolution=[map_resolution, |
|
query_resolution]): |
|
|
|
if not args.c2f_crop_with_homography: |
|
map_K = None |
|
query_K = None |
|
|
|
c1, v1, p1, trf1 = crop(map_rgb_tensor, valid_all, pts3d, crop_q, map_K) |
|
c2, _, _, trf2 = crop(query_rgb_tensor, None, None, crop_b, query_K) |
|
crops1.append(c1) |
|
crops2.append(c2) |
|
crops_v1.append(v1) |
|
crops_p1.append(p1) |
|
to_orig1.append(trf1) |
|
to_orig2.append(trf2) |
|
|
|
if len(crops1) == 0 or len(crops2) == 0: |
|
valid_pts3d, matches_im_query, matches_im_map, matches_conf = [], [], [], [] |
|
else: |
|
crops1, crops2 = torch.stack(crops1), torch.stack(crops2) |
|
if len(crops1.shape) == 3: |
|
crops1, crops2 = crops1[None], crops2[None] |
|
crops_v1 = torch.stack(crops_v1) |
|
crops_p1 = torch.stack(crops_p1) |
|
to_orig1, to_orig2 = torch.stack(to_orig1), torch.stack(to_orig2) |
|
map_crop_view = dict(img=crops1.permute(0, 3, 1, 2), |
|
instance=['1' for _ in range(crops1.shape[0])], |
|
valid=crops_v1, pts3d=crops_p1, |
|
to_orig=to_orig1) |
|
query_crop_view = dict(img=crops2.permute(0, 3, 1, 2), |
|
instance=['2' for _ in range(crops2.shape[0])], |
|
to_orig=to_orig2) |
|
|
|
|
|
valid_pts3d, matches_im_query, matches_im_map, matches_conf = fine_matching(query_crop_view, |
|
map_crop_view, |
|
model, device, |
|
args.max_batch_size, |
|
args.pixel_tol, |
|
fast_nn_params) |
|
matches_im_query = geotrf(query_to_orig_max, matches_im_query, norm=True) |
|
matches_im_map = geotrf(map_to_orig_max, matches_im_map, norm=True) |
|
else: |
|
|
|
valid_pts3d, matches_im_query, matches_im_map, matches_conf = coarse_matching(query_view, map_view, |
|
model, device, |
|
args.pixel_tol, |
|
fast_nn_params) |
|
if cache_file is not None: |
|
mkdir_for(cache_file) |
|
np.savez(cache_file, valid_pts3d=valid_pts3d, matches_im_query=matches_im_query, |
|
matches_im_map=matches_im_map, matches_conf=matches_conf) |
|
|
|
|
|
if len(matches_conf) > 0: |
|
mask = matches_conf >= conf_thr |
|
valid_pts3d = valid_pts3d[mask] |
|
matches_im_query = matches_im_query[mask] |
|
matches_im_map = matches_im_map[mask] |
|
matches_conf = matches_conf[mask] |
|
|
|
|
|
if viz_matches > 0: |
|
num_matches = matches_im_map.shape[0] |
|
print(f'found {num_matches} matches') |
|
|
|
viz_imgs = [np.array(query_view['rgb']), np.array(map_view['rgb'])] |
|
from matplotlib import pyplot as pl |
|
n_viz = viz_matches |
|
match_idx_to_viz = np.round(np.linspace(0, num_matches - 1, n_viz)).astype(int) |
|
viz_matches_im_query = matches_im_query[match_idx_to_viz] |
|
viz_matches_im_map = matches_im_map[match_idx_to_viz] |
|
|
|
H0, W0, H1, W1 = *viz_imgs[0].shape[:2], *viz_imgs[1].shape[:2] |
|
img0 = np.pad(viz_imgs[0], ((0, max(H1 - H0, 0)), (0, 0), (0, 0)), 'constant', constant_values=0) |
|
img1 = np.pad(viz_imgs[1], ((0, max(H0 - H1, 0)), (0, 0), (0, 0)), 'constant', constant_values=0) |
|
img = np.concatenate((img0, img1), axis=1) |
|
pl.figure() |
|
pl.imshow(img) |
|
cmap = pl.get_cmap('jet') |
|
for i in range(n_viz): |
|
(x0, y0), (x1, y1) = viz_matches_im_query[i].T, viz_matches_im_map[i].T |
|
pl.plot([x0, x1 + W0], [y0, y1], '-+', color=cmap(i / (n_viz - 1)), scalex=False, scaley=False) |
|
pl.show(block=True) |
|
|
|
if len(valid_pts3d) == 0: |
|
pass |
|
else: |
|
query_pts3d.append(valid_pts3d) |
|
query_pts2d.append(matches_im_query) |
|
|
|
if len(query_pts2d) == 0: |
|
success = False |
|
pr_querycam_to_world = None |
|
else: |
|
query_pts2d = np.concatenate(query_pts2d, axis=0).astype(np.float32) |
|
query_pts3d = np.concatenate(query_pts3d, axis=0) |
|
if len(query_pts2d) > pnp_max_points: |
|
idxs = random.sample(range(len(query_pts2d)), pnp_max_points) |
|
query_pts3d = query_pts3d[idxs] |
|
query_pts2d = query_pts2d[idxs] |
|
|
|
W, H = query_view['rgb'].size |
|
if reprojection_error_diag_ratio is not None: |
|
reprojection_error_img = reprojection_error_diag_ratio * math.sqrt(W**2 + H**2) |
|
else: |
|
reprojection_error_img = reprojection_error |
|
success, pr_querycam_to_world = run_pnp(query_pts2d, query_pts3d, |
|
query_view['intrinsics'], query_view['distortion'], |
|
pnp_mode, reprojection_error_img, img_size=[W, H]) |
|
|
|
if not success: |
|
abs_transl_error = float('inf') |
|
abs_angular_error = float('inf') |
|
else: |
|
abs_transl_error, abs_angular_error = get_pose_error(pr_querycam_to_world, query_view['cam_to_world']) |
|
|
|
pose_errors.append(abs_transl_error) |
|
angular_errors.append(abs_angular_error) |
|
poses_pred.append(pr_querycam_to_world) |
|
|
|
xp_label = params_str + f'_conf_{conf_thr}' |
|
if args.output_label: |
|
xp_label = args.output_label + "_" + xp_label |
|
if reprojection_error_diag_ratio is not None: |
|
xp_label = xp_label + f'_reproj_diag_{reprojection_error_diag_ratio}' |
|
else: |
|
xp_label = xp_label + f'_reproj_err_{reprojection_error}' |
|
export_results(args.output_dir, xp_label, query_names, poses_pred) |
|
out_string = aggregate_stats(f'{args.dataset}', pose_errors, angular_errors) |
|
print(out_string) |
|
|