Spaces:
Runtime error
Runtime error
# TRI-VIDAR - Copyright 2022 Toyota Research Institute. All rights reserved. | |
import torch | |
from vidar.metrics.base import BaseEvaluation | |
from vidar.metrics.utils import create_crop_mask, scale_output | |
from vidar.utils.config import cfg_has | |
from vidar.utils.data import dict_remove_nones | |
from vidar.utils.depth import post_process_depth | |
from vidar.utils.distributed import on_rank_0 | |
from vidar.utils.logging import pcolor | |
from vidar.utils.types import is_dict | |
class DepthEvaluation(BaseEvaluation): | |
""" | |
Detph evaluation metrics | |
Parameters | |
---------- | |
cfg : Config | |
Configuration file | |
""" | |
def __init__(self, cfg): | |
super().__init__(cfg, | |
name='depth', task='depth', | |
metrics=('abs_rel', 'sqr_rel', 'rmse', 'rmse_log', 'silog', 'a1', 'a2', 'a3'), | |
) | |
self.min_depth = cfg.min_depth | |
self.max_depth = cfg.max_depth | |
self.crop = cfg_has(cfg, 'crop', '') | |
self.scale_output = cfg_has(cfg, 'scale_output', 'resize') | |
self.post_process = cfg_has(cfg, 'post_process', False) | |
self.median_scaling = cfg_has(cfg, 'median_scaling', False) | |
self.valid_threshold = cfg.has('valid_threshold', None) | |
if self.post_process: | |
self.modes += ['pp'] | |
if self.median_scaling: | |
self.modes += ['gt'] | |
if self.post_process and self.median_scaling: | |
self.modes += ['pp_gt'] | |
def reduce_fn(metrics, seen): | |
"""Reduce function""" | |
valid = seen.view(-1) > 0 | |
return (metrics[valid] / seen.view(-1, 1)[valid]).mean(0) | |
def populate_metrics_dict(self, metrics, metrics_dict, prefix): | |
"""Populate metrics function""" | |
for metric in metrics: | |
if metric.startswith(self.name): | |
name, suffix = metric.split('|') | |
for i, key in enumerate(self.metrics): | |
metrics_dict[f'{prefix}-{name}|{key}_{suffix}'] = \ | |
metrics[metric][i].item() | |
def print(self, reduced_data, prefixes): | |
"""Print function""" | |
print() | |
print(self.horz_line) | |
print(self.metr_line.format(*((self.name.upper(),) + self.metrics))) | |
for n, metrics in enumerate(reduced_data): | |
if sum([self.name in key for key in metrics.keys()]) == 0: | |
continue | |
print(self.horz_line) | |
print(self.wrap(pcolor('*** {:<114}'.format(prefixes[n]), **self.font1))) | |
print(self.horz_line) | |
for key, metric in sorted(metrics.items()): | |
if self.name in key: | |
print(self.wrap(pcolor(self.outp_line.format( | |
*((key.upper(),) + tuple(metric.tolist()))), **self.font2))) | |
print(self.horz_line) | |
print() | |
def compute(self, gt, pred, use_gt_scale=True, mask=None): | |
""" | |
Compute depth metrics | |
Parameters | |
---------- | |
gt : torch.Tensor | |
Ground-truth depth maps [B,1,H,W] | |
pred : torch.Tensor | |
Predicted depth map [B,1,H,W] | |
use_gt_scale : Bool | |
Use median-scaling | |
mask : torch.Tensor or None | |
Mask to remove pixels from evaluation | |
Returns | |
------- | |
metrics : torch.Tensor | |
Depth metrics | |
""" | |
# Match predicted depth map to ground-truth resolution | |
pred = scale_output(pred, gt, self.scale_output) | |
# Create crop mask if requested | |
crop_mask = create_crop_mask(self.crop, gt) | |
# For each batch sample | |
metrics = [] | |
for i, (pred_i, gt_i) in enumerate(zip(pred, gt)): | |
# Squeeze GT and PRED | |
gt_i, pred_i = torch.squeeze(gt_i), torch.squeeze(pred_i) | |
mask_i = torch.squeeze(mask[i]) if mask is not None else None | |
# Keep valid pixels (min/max depth and crop) | |
valid = (gt_i > self.min_depth) & (gt_i < self.max_depth) | |
# Remove invalid predicted pixels as well | |
valid = valid & (pred_i > 0) | |
# Apply crop mask if requested | |
valid = valid & crop_mask.bool() if crop_mask is not None else valid | |
# Apply provided mask if available | |
valid = valid & mask_i.bool() if mask is not None else valid | |
# Invalid evaluation | |
if self.valid_threshold is not None and valid.sum() < self.valid_threshold: | |
return None | |
# Keep only valid pixels | |
gt_i, pred_i = gt_i[valid], pred_i[valid] | |
# GT median scaling if needed | |
if use_gt_scale: | |
pred_i = pred_i * torch.median(gt_i) / torch.median(pred_i) | |
# Clamp PRED depth values to min/max values | |
pred_i = pred_i.clamp(self.min_depth, self.max_depth) | |
# Calculate depth metrics | |
thresh = torch.max((gt_i / pred_i), (pred_i / gt_i)) | |
a1 = (thresh < 1.25).float().mean() | |
a2 = (thresh < 1.25 ** 2).float().mean() | |
a3 = (thresh < 1.25 ** 3).float().mean() | |
diff_i = gt_i - pred_i | |
abs_rel = torch.mean(torch.abs(diff_i) / gt_i) | |
sq_rel = torch.mean(diff_i ** 2 / gt_i) | |
rmse = torch.sqrt(torch.mean(diff_i ** 2)) | |
rmse_log = torch.sqrt(torch.mean((torch.log(gt_i) - torch.log(pred_i)) ** 2)) | |
err = torch.log(pred_i) - torch.log(gt_i) | |
silog = torch.sqrt(torch.mean(err ** 2) - torch.mean(err) ** 2) * 100 | |
metrics.append([abs_rel, sq_rel, rmse, rmse_log, silog, a1, a2, a3]) | |
# Return metrics | |
return torch.tensor(metrics, dtype=gt.dtype) | |
def evaluate(self, batch, output, flipped_output=None): | |
""" | |
Evaluate predictions | |
Parameters | |
---------- | |
batch : Dict | |
Dictionary containing ground-truth information | |
output : Dict | |
Dictionary containing predictions | |
flipped_output : Bool | |
Optional flipped output for post-processing | |
Returns | |
------- | |
metrics : Dict | |
Dictionary with calculated metrics | |
predictions : Dict | |
Dictionary with additional predictions | |
""" | |
metrics, predictions = {}, {} | |
if self.name not in batch: | |
return metrics, predictions | |
# For each output item | |
for key, val in output.items(): | |
# If it corresponds to this task | |
if key.startswith(self.name) and 'debug' not in key: | |
# Loop over every context | |
val = val if is_dict(val) else {0: val} | |
for ctx in val.keys(): | |
# Loop over every scale | |
for i in range(1 if self.only_first else len(val[ctx])): | |
pred = val[ctx][i] | |
gt = batch[self.name][ctx] | |
if self.post_process: | |
pred_flipped = flipped_output[key][ctx][i] | |
pred_pp = post_process_depth(pred, pred_flipped, method='mean') | |
else: | |
pred_pp = None | |
if i > 0: | |
pred = self.interp_nearest(pred, val[ctx][0]) | |
if self.post_process: | |
pred_pp = self.interp_nearest(pred_pp, val[ctx][0]) | |
if pred.dim() == 4: | |
suffix = '(%s)' % str(ctx) + ('_%d' % i if not self.only_first else '') | |
for mode in self.modes: | |
metrics[f'{key}|{mode}{suffix}'] = \ | |
self.compute( | |
gt=gt, | |
pred=pred_pp if 'pp' in mode else pred, | |
use_gt_scale='gt' in mode, | |
mask=None, | |
) | |
elif pred.dim() == 5: | |
for j in range(pred.shape[1]): | |
suffix = '(%s_%d)' % (str(ctx), j) + ('_%d' % i if not self.only_first else '') | |
for mode in self.modes: | |
metrics[f'{key}|{mode}{suffix}'] = self.compute( | |
gt=gt[:, j], | |
pred=pred_pp[:, j] if 'pp' in mode else pred[:, j], | |
use_gt_scale='gt' in mode, | |
mask=None, | |
) | |
return dict_remove_nones(metrics), predictions | |