Spaces:
Running
Running
import argparse | |
import pickle | |
from pathlib import Path | |
import cv2 | |
import h5py | |
import numpy as np | |
import pycolmap | |
import torch | |
from scipy.io import loadmat | |
from tqdm import tqdm | |
from . import logger | |
from .utils.parsers import names_to_pair, parse_retrieval | |
def interpolate_scan(scan, kp): | |
h, w, c = scan.shape | |
kp = kp / np.array([[w - 1, h - 1]]) * 2 - 1 | |
assert np.all(kp > -1) and np.all(kp < 1) | |
scan = torch.from_numpy(scan).permute(2, 0, 1)[None] | |
kp = torch.from_numpy(kp)[None, None] | |
grid_sample = torch.nn.functional.grid_sample | |
# To maximize the number of points that have depth: | |
# do bilinear interpolation first and then nearest for the remaining points | |
interp_lin = grid_sample(scan, kp, align_corners=True, mode="bilinear")[ | |
0, :, 0 | |
] | |
interp_nn = torch.nn.functional.grid_sample( | |
scan, kp, align_corners=True, mode="nearest" | |
)[0, :, 0] | |
interp = torch.where(torch.isnan(interp_lin), interp_nn, interp_lin) | |
valid = ~torch.any(torch.isnan(interp), 0) | |
kp3d = interp.T.numpy() | |
valid = valid.numpy() | |
return kp3d, valid | |
def get_scan_pose(dataset_dir, rpath): | |
split_image_rpath = rpath.split("/") | |
floor_name = split_image_rpath[-3] | |
scan_id = split_image_rpath[-2] | |
image_name = split_image_rpath[-1] | |
building_name = image_name[:3] | |
path = Path( | |
dataset_dir, | |
"database/alignments", | |
floor_name, | |
f"transformations/{building_name}_trans_{scan_id}.txt", | |
) | |
with open(path) as f: | |
raw_lines = f.readlines() | |
P_after_GICP = np.array( | |
[ | |
np.fromstring(raw_lines[7], sep=" "), | |
np.fromstring(raw_lines[8], sep=" "), | |
np.fromstring(raw_lines[9], sep=" "), | |
np.fromstring(raw_lines[10], sep=" "), | |
] | |
) | |
return P_after_GICP | |
def pose_from_cluster( | |
dataset_dir, q, retrieved, feature_file, match_file, skip=None | |
): | |
height, width = cv2.imread(str(dataset_dir / q)).shape[:2] | |
cx = 0.5 * width | |
cy = 0.5 * height | |
focal_length = 4032.0 * 28.0 / 36.0 | |
all_mkpq = [] | |
all_mkpr = [] | |
all_mkp3d = [] | |
all_indices = [] | |
kpq = feature_file[q]["keypoints"].__array__() | |
num_matches = 0 | |
for i, r in enumerate(retrieved): | |
kpr = feature_file[r]["keypoints"].__array__() | |
pair = names_to_pair(q, r) | |
m = match_file[pair]["matches0"].__array__() | |
v = m > -1 | |
if skip and (np.count_nonzero(v) < skip): | |
continue | |
mkpq, mkpr = kpq[v], kpr[m[v]] | |
num_matches += len(mkpq) | |
scan_r = loadmat(Path(dataset_dir, r + ".mat"))["XYZcut"] | |
mkp3d, valid = interpolate_scan(scan_r, mkpr) | |
Tr = get_scan_pose(dataset_dir, r) | |
mkp3d = (Tr[:3, :3] @ mkp3d.T + Tr[:3, -1:]).T | |
all_mkpq.append(mkpq[valid]) | |
all_mkpr.append(mkpr[valid]) | |
all_mkp3d.append(mkp3d[valid]) | |
all_indices.append(np.full(np.count_nonzero(valid), i)) | |
all_mkpq = np.concatenate(all_mkpq, 0) | |
all_mkpr = np.concatenate(all_mkpr, 0) | |
all_mkp3d = np.concatenate(all_mkp3d, 0) | |
all_indices = np.concatenate(all_indices, 0) | |
cfg = { | |
"model": "SIMPLE_PINHOLE", | |
"width": width, | |
"height": height, | |
"params": [focal_length, cx, cy], | |
} | |
ret = pycolmap.absolute_pose_estimation(all_mkpq, all_mkp3d, cfg, 48.00) | |
ret["cfg"] = cfg | |
return ret, all_mkpq, all_mkpr, all_mkp3d, all_indices, num_matches | |
def main(dataset_dir, retrieval, features, matches, results, skip_matches=None): | |
assert retrieval.exists(), retrieval | |
assert features.exists(), features | |
assert matches.exists(), matches | |
retrieval_dict = parse_retrieval(retrieval) | |
queries = list(retrieval_dict.keys()) | |
feature_file = h5py.File(features, "r", libver="latest") | |
match_file = h5py.File(matches, "r", libver="latest") | |
poses = {} | |
logs = { | |
"features": features, | |
"matches": matches, | |
"retrieval": retrieval, | |
"loc": {}, | |
} | |
logger.info("Starting localization...") | |
for q in tqdm(queries): | |
db = retrieval_dict[q] | |
ret, mkpq, mkpr, mkp3d, indices, num_matches = pose_from_cluster( | |
dataset_dir, q, db, feature_file, match_file, skip_matches | |
) | |
poses[q] = (ret["qvec"], ret["tvec"]) | |
logs["loc"][q] = { | |
"db": db, | |
"PnP_ret": ret, | |
"keypoints_query": mkpq, | |
"keypoints_db": mkpr, | |
"3d_points": mkp3d, | |
"indices_db": indices, | |
"num_matches": num_matches, | |
} | |
logger.info(f"Writing poses to {results}...") | |
with open(results, "w") as f: | |
for q in queries: | |
qvec, tvec = poses[q] | |
qvec = " ".join(map(str, qvec)) | |
tvec = " ".join(map(str, tvec)) | |
name = q.split("/")[-1] | |
f.write(f"{name} {qvec} {tvec}\n") | |
logs_path = f"{results}_logs.pkl" | |
logger.info(f"Writing logs to {logs_path}...") | |
with open(logs_path, "wb") as f: | |
pickle.dump(logs, f) | |
logger.info("Done!") | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--dataset_dir", type=Path, required=True) | |
parser.add_argument("--retrieval", type=Path, required=True) | |
parser.add_argument("--features", type=Path, required=True) | |
parser.add_argument("--matches", type=Path, required=True) | |
parser.add_argument("--results", type=Path, required=True) | |
parser.add_argument("--skip_matches", type=int) | |
args = parser.parse_args() | |
main(**args.__dict__) | |