Jiading Fang
add define
fc16538
# TRI-VIDAR - Copyright 2022 Toyota Research Institute. All rights reserved.
import torch
from vidar.metrics.base import BaseEvaluation
from vidar.metrics.utils import create_crop_mask, scale_output
from vidar.utils.config import cfg_has
from vidar.utils.data import dict_remove_nones
from vidar.utils.depth import post_process_depth
from vidar.utils.distributed import on_rank_0
from vidar.utils.logging import pcolor
from vidar.utils.types import is_dict
class DepthEvaluation(BaseEvaluation):
"""
Detph evaluation metrics
Parameters
----------
cfg : Config
Configuration file
"""
def __init__(self, cfg):
super().__init__(cfg,
name='depth', task='depth',
metrics=('abs_rel', 'sqr_rel', 'rmse', 'rmse_log', 'silog', 'a1', 'a2', 'a3'),
)
self.min_depth = cfg.min_depth
self.max_depth = cfg.max_depth
self.crop = cfg_has(cfg, 'crop', '')
self.scale_output = cfg_has(cfg, 'scale_output', 'resize')
self.post_process = cfg_has(cfg, 'post_process', False)
self.median_scaling = cfg_has(cfg, 'median_scaling', False)
self.valid_threshold = cfg.has('valid_threshold', None)
if self.post_process:
self.modes += ['pp']
if self.median_scaling:
self.modes += ['gt']
if self.post_process and self.median_scaling:
self.modes += ['pp_gt']
@staticmethod
def reduce_fn(metrics, seen):
"""Reduce function"""
valid = seen.view(-1) > 0
return (metrics[valid] / seen.view(-1, 1)[valid]).mean(0)
def populate_metrics_dict(self, metrics, metrics_dict, prefix):
"""Populate metrics function"""
for metric in metrics:
if metric.startswith(self.name):
name, suffix = metric.split('|')
for i, key in enumerate(self.metrics):
metrics_dict[f'{prefix}-{name}|{key}_{suffix}'] = \
metrics[metric][i].item()
@on_rank_0
def print(self, reduced_data, prefixes):
"""Print function"""
print()
print(self.horz_line)
print(self.metr_line.format(*((self.name.upper(),) + self.metrics)))
for n, metrics in enumerate(reduced_data):
if sum([self.name in key for key in metrics.keys()]) == 0:
continue
print(self.horz_line)
print(self.wrap(pcolor('*** {:<114}'.format(prefixes[n]), **self.font1)))
print(self.horz_line)
for key, metric in sorted(metrics.items()):
if self.name in key:
print(self.wrap(pcolor(self.outp_line.format(
*((key.upper(),) + tuple(metric.tolist()))), **self.font2)))
print(self.horz_line)
print()
def compute(self, gt, pred, use_gt_scale=True, mask=None):
"""
Compute depth metrics
Parameters
----------
gt : torch.Tensor
Ground-truth depth maps [B,1,H,W]
pred : torch.Tensor
Predicted depth map [B,1,H,W]
use_gt_scale : Bool
Use median-scaling
mask : torch.Tensor or None
Mask to remove pixels from evaluation
Returns
-------
metrics : torch.Tensor
Depth metrics
"""
# Match predicted depth map to ground-truth resolution
pred = scale_output(pred, gt, self.scale_output)
# Create crop mask if requested
crop_mask = create_crop_mask(self.crop, gt)
# For each batch sample
metrics = []
for i, (pred_i, gt_i) in enumerate(zip(pred, gt)):
# Squeeze GT and PRED
gt_i, pred_i = torch.squeeze(gt_i), torch.squeeze(pred_i)
mask_i = torch.squeeze(mask[i]) if mask is not None else None
# Keep valid pixels (min/max depth and crop)
valid = (gt_i > self.min_depth) & (gt_i < self.max_depth)
# Remove invalid predicted pixels as well
valid = valid & (pred_i > 0)
# Apply crop mask if requested
valid = valid & crop_mask.bool() if crop_mask is not None else valid
# Apply provided mask if available
valid = valid & mask_i.bool() if mask is not None else valid
# Invalid evaluation
if self.valid_threshold is not None and valid.sum() < self.valid_threshold:
return None
# Keep only valid pixels
gt_i, pred_i = gt_i[valid], pred_i[valid]
# GT median scaling if needed
if use_gt_scale:
pred_i = pred_i * torch.median(gt_i) / torch.median(pred_i)
# Clamp PRED depth values to min/max values
pred_i = pred_i.clamp(self.min_depth, self.max_depth)
# Calculate depth metrics
thresh = torch.max((gt_i / pred_i), (pred_i / gt_i))
a1 = (thresh < 1.25).float().mean()
a2 = (thresh < 1.25 ** 2).float().mean()
a3 = (thresh < 1.25 ** 3).float().mean()
diff_i = gt_i - pred_i
abs_rel = torch.mean(torch.abs(diff_i) / gt_i)
sq_rel = torch.mean(diff_i ** 2 / gt_i)
rmse = torch.sqrt(torch.mean(diff_i ** 2))
rmse_log = torch.sqrt(torch.mean((torch.log(gt_i) - torch.log(pred_i)) ** 2))
err = torch.log(pred_i) - torch.log(gt_i)
silog = torch.sqrt(torch.mean(err ** 2) - torch.mean(err) ** 2) * 100
metrics.append([abs_rel, sq_rel, rmse, rmse_log, silog, a1, a2, a3])
# Return metrics
return torch.tensor(metrics, dtype=gt.dtype)
def evaluate(self, batch, output, flipped_output=None):
"""
Evaluate predictions
Parameters
----------
batch : Dict
Dictionary containing ground-truth information
output : Dict
Dictionary containing predictions
flipped_output : Bool
Optional flipped output for post-processing
Returns
-------
metrics : Dict
Dictionary with calculated metrics
predictions : Dict
Dictionary with additional predictions
"""
metrics, predictions = {}, {}
if self.name not in batch:
return metrics, predictions
# For each output item
for key, val in output.items():
# If it corresponds to this task
if key.startswith(self.name) and 'debug' not in key:
# Loop over every context
val = val if is_dict(val) else {0: val}
for ctx in val.keys():
# Loop over every scale
for i in range(1 if self.only_first else len(val[ctx])):
pred = val[ctx][i]
gt = batch[self.name][ctx]
if self.post_process:
pred_flipped = flipped_output[key][ctx][i]
pred_pp = post_process_depth(pred, pred_flipped, method='mean')
else:
pred_pp = None
if i > 0:
pred = self.interp_nearest(pred, val[ctx][0])
if self.post_process:
pred_pp = self.interp_nearest(pred_pp, val[ctx][0])
if pred.dim() == 4:
suffix = '(%s)' % str(ctx) + ('_%d' % i if not self.only_first else '')
for mode in self.modes:
metrics[f'{key}|{mode}{suffix}'] = \
self.compute(
gt=gt,
pred=pred_pp if 'pp' in mode else pred,
use_gt_scale='gt' in mode,
mask=None,
)
elif pred.dim() == 5:
for j in range(pred.shape[1]):
suffix = '(%s_%d)' % (str(ctx), j) + ('_%d' % i if not self.only_first else '')
for mode in self.modes:
metrics[f'{key}|{mode}{suffix}'] = self.compute(
gt=gt[:, j],
pred=pred_pp[:, j] if 'pp' in mode else pred[:, j],
use_gt_scale='gt' in mode,
mask=None,
)
return dict_remove_nones(metrics), predictions