SRPose / eval_baselines.py
FrickYinn's picture
Upload 53 files
e170a8e verified
import numpy as np
import argparse
from tqdm import tqdm
import torch
import pandas as pd
# from lightglue.utils import load_image
from configs.default import get_cfg_defaults
from datasets import dataset_dict
from baselines.pose import PoseRecover
from utils.metrics import relative_pose_error, rotation_angular_error, error_auc, add, adi, compute_continuous_auc
def main(args):
config = get_cfg_defaults()
config.merge_from_file(args.config)
task = config.DATASET.TASK
dataset = config.DATASET.DATA_SOURCE
# try:
# data_root = config.DATASET.TEST.DATA_ROOT
# except:
# data_root = config.DATASET.DATA_ROOT
build_fn = dataset_dict[task][dataset]
testset = build_fn('test', config)
testloader = torch.utils.data.DataLoader(testset, batch_size=1)
device = args.device
img_resize = args.resize
poseRec = PoseRecover(matcher=args.matcher, solver=args.solver, img_resize=img_resize, device=device)
preprocess_times, extract_times, match_times, recover_times = [], [], [], []
R_errs, t_errs = [], []
ts_errs = []
adds, adis = [], []
for i, data in enumerate(tqdm(testloader)):
if dataset == 'ho3d' and args.obj_name is not None and data['objName'][0] != args.obj_name:
continue
image0, image1 = data['images'][0].to(device)
# if dataset == 'megadepth':
# image0 = load_image(os.path.join(data_root, data['pair_names'][0][0])).to(device)
# image1 = load_image(os.path.join(data_root, data['pair_names'][1][0])).to(device)
# else:
# image0, image1 = data['images'][0].to(device)
bbox0, bbox1 = None, None
if task == 'object':
bbox0, bbox1 = data['bboxes'][0]
x1, y1, x2, y2 = bbox0
u1, v1, u2, v2 = bbox1
image0 = image0[:, y1:y2, x1:x2]
image1 = image1[:, v1:v2, u1:u2]
mask0, mask1 = None, None
if args.mask:
mask0, mask1 = data['masks'][0].to(device)
depth0, depth1 = None, None
if args.depth:
depth0, depth1 = data['depths'][0]
K0, K1 = data['intrinsics'][0]
T = torch.eye(4)
T[:3, :3] = data['rotation'][0]
T[:3, 3] = data['translation'][0]
T = T.numpy()
R, t, points0, points1, preprocess_time, extract_time, match_time, recover_time = poseRec.recover(image0, image1, K0, K1, bbox0, bbox1, mask0, mask1, depth0, depth1)
preprocess_times.append(preprocess_time)
extract_times.append(extract_time)
match_times.append(match_time)
recover_times.append(recover_time)
if np.isnan(R).any():
R_err = 180
R = np.identity(3)
t_err = 180
t = np.array([0., 0., 0.])
else:
t_err, R_err = relative_pose_error(T, R, t, ignore_gt_t_thr=0.0)
R_errs.append(R_err)
t_errs.append(t_err)
if args.depth:
t = np.nan_to_num(t)
ts_errs.append(torch.tensor(T[:3, 3] - t).norm(2))
if task == 'object':
if np.isnan(R).any():
adds.append(1.)
adis.append(1.)
else:
adds.append(add(R, t, T[:3, :3], T[:3, 3], data['point_cloud'][0].numpy()))
adis.append(adi(R, t, T[:3, :3], T[:3, 3], data['point_cloud'][0].numpy()))
metrics = []
values = []
preprocess_times = np.array(preprocess_time) * 1000
extract_times = np.array(extract_time) * 1000
match_times = np.array(match_times) * 1000
recover_times = np.array(recover_time) * 1000
metrics.append('Extracting Time (ms)')
values.append(f'{np.mean(extract_times):.1f}')
metrics.append('Matching Time (ms)')
values.append(f'{np.mean(match_times):.1f}')
metrics.append('Recovering Time (ms)')
values.append(f'{np.mean(recover_times):.1f}')
metrics.append('Total Time (ms)')
values.append(f'{np.mean(extract_times) + np.mean(match_times) + np.mean(recover_times):.1f}')
# pose auc
angular_thresholds = [5, 10, 20]
pose_errors = np.max(np.stack([R_errs, t_errs]), axis=0)
aucs = error_auc(pose_errors, angular_thresholds, mode='Pose estimation') # (auc@5, auc@10, auc@20)
for k in aucs:
metrics.append(k)
values.append(f'{aucs[k] * 100:.2f}')
R_errs = torch.tensor(R_errs)
t_errs = torch.tensor(t_errs)
metrics.append('Rotation Avg. Error (°)')
values.append(f'{R_errs.mean():.2f}')
metrics.append('Rotation Med. Error (°)')
values.append(f'{R_errs.median():.2f}')
metrics.append('Rotation @30° ACC')
values.append(f'{(R_errs < 30).float().mean() * 100:.1f}')
metrics.append('Rotation @15° ACC')
values.append(f'{(R_errs < 15).float().mean() * 100:.1f}')
if args.depth:
ts_errs = torch.tensor(ts_errs)
metrics.append('Translation Avg. Error (m)')
values.append(f'{ts_errs.mean():.4f}')
metrics.append('Translation Med. Error (m)')
values.append(f'{ts_errs.median():.4f}')
metrics.append('Translation @1m ACC')
values.append(f'{(ts_errs < 1.0).float().mean() * 100:.1f}')
metrics.append('Translation @10cm ACC')
values.append(f'{(ts_errs < 0.1).float().mean() * 100:.1f}')
if task == 'object':
metrics.append('Object ADD')
values.append(f'{compute_continuous_auc(adds, np.linspace(0.0, 0.1, 1000)) * 100:.1f}')
metrics.append('Object ADD-S')
values.append(f'{compute_continuous_auc(adis, np.linspace(0.0, 0.1, 1000)) * 100:.1f}')
res = pd.DataFrame({'Metrics': metrics, 'Values': values})
print(res)
def get_parser():
parser = argparse.ArgumentParser()
parser.add_argument('config', type=str, help='.yaml configure file path')
parser.add_argument('matcher', type=str)
parser.add_argument('--solver', type=str, default='procrustes')
parser.add_argument('--resize', type=int, default=None)
parser.add_argument('--depth', action='store_true')
parser.add_argument('--mask', action='store_true')
parser.add_argument('--obj_name', type=str, default=None)
parser.add_argument('--device', type=str, default='cuda:0')
return parser
if __name__ == "__main__":
parser = get_parser()
args = parser.parse_args()
main(args)