|
import numpy as np
|
|
import argparse
|
|
from tqdm import tqdm
|
|
import torch
|
|
import pandas as pd
|
|
|
|
|
|
from configs.default import get_cfg_defaults
|
|
from datasets import dataset_dict
|
|
from baselines.pose import PoseRecover
|
|
from utils.metrics import relative_pose_error, rotation_angular_error, error_auc, add, adi, compute_continuous_auc
|
|
|
|
|
|
def main(args):
|
|
config = get_cfg_defaults()
|
|
config.merge_from_file(args.config)
|
|
|
|
task = config.DATASET.TASK
|
|
dataset = config.DATASET.DATA_SOURCE
|
|
|
|
|
|
|
|
|
|
|
|
|
|
build_fn = dataset_dict[task][dataset]
|
|
testset = build_fn('test', config)
|
|
testloader = torch.utils.data.DataLoader(testset, batch_size=1)
|
|
|
|
device = args.device
|
|
img_resize = args.resize
|
|
poseRec = PoseRecover(matcher=args.matcher, solver=args.solver, img_resize=img_resize, device=device)
|
|
|
|
preprocess_times, extract_times, match_times, recover_times = [], [], [], []
|
|
R_errs, t_errs = [], []
|
|
ts_errs = []
|
|
adds, adis = [], []
|
|
for i, data in enumerate(tqdm(testloader)):
|
|
if dataset == 'ho3d' and args.obj_name is not None and data['objName'][0] != args.obj_name:
|
|
continue
|
|
|
|
image0, image1 = data['images'][0].to(device)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
bbox0, bbox1 = None, None
|
|
if task == 'object':
|
|
bbox0, bbox1 = data['bboxes'][0]
|
|
x1, y1, x2, y2 = bbox0
|
|
u1, v1, u2, v2 = bbox1
|
|
image0 = image0[:, y1:y2, x1:x2]
|
|
image1 = image1[:, v1:v2, u1:u2]
|
|
|
|
mask0, mask1 = None, None
|
|
if args.mask:
|
|
mask0, mask1 = data['masks'][0].to(device)
|
|
|
|
depth0, depth1 = None, None
|
|
if args.depth:
|
|
depth0, depth1 = data['depths'][0]
|
|
|
|
K0, K1 = data['intrinsics'][0]
|
|
T = torch.eye(4)
|
|
T[:3, :3] = data['rotation'][0]
|
|
T[:3, 3] = data['translation'][0]
|
|
T = T.numpy()
|
|
|
|
R, t, points0, points1, preprocess_time, extract_time, match_time, recover_time = poseRec.recover(image0, image1, K0, K1, bbox0, bbox1, mask0, mask1, depth0, depth1)
|
|
preprocess_times.append(preprocess_time)
|
|
extract_times.append(extract_time)
|
|
match_times.append(match_time)
|
|
recover_times.append(recover_time)
|
|
|
|
if np.isnan(R).any():
|
|
R_err = 180
|
|
R = np.identity(3)
|
|
t_err = 180
|
|
t = np.array([0., 0., 0.])
|
|
else:
|
|
t_err, R_err = relative_pose_error(T, R, t, ignore_gt_t_thr=0.0)
|
|
|
|
R_errs.append(R_err)
|
|
t_errs.append(t_err)
|
|
|
|
if args.depth:
|
|
t = np.nan_to_num(t)
|
|
ts_errs.append(torch.tensor(T[:3, 3] - t).norm(2))
|
|
|
|
if task == 'object':
|
|
if np.isnan(R).any():
|
|
adds.append(1.)
|
|
adis.append(1.)
|
|
else:
|
|
adds.append(add(R, t, T[:3, :3], T[:3, 3], data['point_cloud'][0].numpy()))
|
|
adis.append(adi(R, t, T[:3, :3], T[:3, 3], data['point_cloud'][0].numpy()))
|
|
|
|
metrics = []
|
|
values = []
|
|
|
|
preprocess_times = np.array(preprocess_time) * 1000
|
|
extract_times = np.array(extract_time) * 1000
|
|
match_times = np.array(match_times) * 1000
|
|
recover_times = np.array(recover_time) * 1000
|
|
|
|
metrics.append('Extracting Time (ms)')
|
|
values.append(f'{np.mean(extract_times):.1f}')
|
|
|
|
metrics.append('Matching Time (ms)')
|
|
values.append(f'{np.mean(match_times):.1f}')
|
|
|
|
metrics.append('Recovering Time (ms)')
|
|
values.append(f'{np.mean(recover_times):.1f}')
|
|
|
|
metrics.append('Total Time (ms)')
|
|
values.append(f'{np.mean(extract_times) + np.mean(match_times) + np.mean(recover_times):.1f}')
|
|
|
|
|
|
angular_thresholds = [5, 10, 20]
|
|
pose_errors = np.max(np.stack([R_errs, t_errs]), axis=0)
|
|
aucs = error_auc(pose_errors, angular_thresholds, mode='Pose estimation')
|
|
for k in aucs:
|
|
metrics.append(k)
|
|
values.append(f'{aucs[k] * 100:.2f}')
|
|
|
|
R_errs = torch.tensor(R_errs)
|
|
t_errs = torch.tensor(t_errs)
|
|
|
|
metrics.append('Rotation Avg. Error (°)')
|
|
values.append(f'{R_errs.mean():.2f}')
|
|
|
|
metrics.append('Rotation Med. Error (°)')
|
|
values.append(f'{R_errs.median():.2f}')
|
|
|
|
metrics.append('Rotation @30° ACC')
|
|
values.append(f'{(R_errs < 30).float().mean() * 100:.1f}')
|
|
|
|
metrics.append('Rotation @15° ACC')
|
|
values.append(f'{(R_errs < 15).float().mean() * 100:.1f}')
|
|
|
|
if args.depth:
|
|
ts_errs = torch.tensor(ts_errs)
|
|
|
|
metrics.append('Translation Avg. Error (m)')
|
|
values.append(f'{ts_errs.mean():.4f}')
|
|
|
|
metrics.append('Translation Med. Error (m)')
|
|
values.append(f'{ts_errs.median():.4f}')
|
|
|
|
metrics.append('Translation @1m ACC')
|
|
values.append(f'{(ts_errs < 1.0).float().mean() * 100:.1f}')
|
|
|
|
metrics.append('Translation @10cm ACC')
|
|
values.append(f'{(ts_errs < 0.1).float().mean() * 100:.1f}')
|
|
|
|
if task == 'object':
|
|
metrics.append('Object ADD')
|
|
values.append(f'{compute_continuous_auc(adds, np.linspace(0.0, 0.1, 1000)) * 100:.1f}')
|
|
|
|
metrics.append('Object ADD-S')
|
|
values.append(f'{compute_continuous_auc(adis, np.linspace(0.0, 0.1, 1000)) * 100:.1f}')
|
|
|
|
res = pd.DataFrame({'Metrics': metrics, 'Values': values})
|
|
print(res)
|
|
|
|
|
|
def get_parser():
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument('config', type=str, help='.yaml configure file path')
|
|
parser.add_argument('matcher', type=str)
|
|
parser.add_argument('--solver', type=str, default='procrustes')
|
|
|
|
parser.add_argument('--resize', type=int, default=None)
|
|
parser.add_argument('--depth', action='store_true')
|
|
|
|
parser.add_argument('--mask', action='store_true')
|
|
parser.add_argument('--obj_name', type=str, default=None)
|
|
|
|
parser.add_argument('--device', type=str, default='cuda:0')
|
|
|
|
return parser
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = get_parser()
|
|
args = parser.parse_args()
|
|
main(args)
|
|
|