# TRI-VIDAR - Copyright 2022 Toyota Research Institute. All rights reserved. import os from vidar.utils.config import cfg_has from vidar.utils.data import make_list from vidar.utils.types import is_dict, is_list from vidar.utils.viz import viz_depth, viz_optical_flow from vidar.utils.write import write_depth, write_image, write_pickle, write_npz class Saver: """ Wandb logger class to monitor training Parameters ---------- cfg : Config Configuration with parameters ckpt : String Name of the model checkpoint (used to create the save folder) """ def __init__(self, cfg, ckpt=None): self.folder = cfg_has(cfg, 'folder', None) self.rgb = make_list(cfg.rgb) if cfg_has(cfg, 'rgb') else [] self.depth = make_list(cfg.depth) if cfg_has(cfg, 'depth') else [] self.pose = make_list(cfg.pose) if cfg_has(cfg, 'pose') else [] self.optical_flow = make_list(cfg.optical_flow) if cfg_has(cfg, 'optical_flow') else [] self.store_data = cfg_has(cfg, 'store_data', False) self.separate = cfg.has('separate', False) self.ckpt = None if ckpt is None else \ os.path.splitext(os.path.basename(ckpt))[0] self.naming = cfg_has(cfg, 'naming', 'filename') assert self.naming in ['filename', 'splitname'], \ 'Invalid naming for saver: {}'.format(self.naming) def get_filename(self, path, batch, idx, i): """Get filename based on input information""" if self.naming == 'filename': filename = os.path.join(path, batch['filename'][0][i]).replace('{}', 'rgb') os.makedirs(os.path.dirname(filename), exist_ok=True) return filename elif self.naming == 'splitname': if self.separate: return os.path.join(path, '%010d' % idx, '%010d' % idx) else: return os.path.join(path, '%010d' % idx) else: raise NotImplementedError('Invalid naming for saver: {}'.format(self.naming)) def save_data(self, batch, output, prefix): """ Prepare for data saving Parameters ---------- batch : Dict Dictionary with batch information output : Dict Dictionary with output information prefix : String Prefix string for the log name """ if self.folder is None: return idx = batch['idx'] predictions = output['predictions'] path = os.path.join(self.folder, prefix) if self.ckpt is not None: path = os.path.join(path, self.ckpt) os.makedirs(path, exist_ok=True) self.save(batch, predictions, path, idx, 0) def save(self, batch, predictions, path, idx, i): """ Save batch and prediction information Parameters ---------- batch : Dict Dictionary with batch information predictions : Dict Dictionary with output predictions path : String Path where data will be saved idx : Int Batch index in the split i : Int Index within batch Returns ------- data : Dict Dictionary with output data that was saved """ filename = self.get_filename(path, batch, idx, i) raw_intrinsics = batch['raw_intrinsics'][0][i].cpu() if 'raw_intrinsics' in batch else \ batch['intrinsics'][0][i].cpu() if 'intrinsics' in batch else None intrinsics = batch['intrinsics'][0][i].cpu() if 'intrinsics' in batch else None data = { 'raw_intrinsics': raw_intrinsics, 'intrinsics': intrinsics, } for key in batch.keys(): if key.startswith('rgb'): data[key + '_gt'] = {k: v[i].cpu() for k, v in batch[key].items()} for ctx in batch[key].keys(): rgb = batch[key][ctx][i].cpu() if 'gt' in self.rgb: if rgb.dim() == 5: for j in range(rgb.shape[1]): write_image('%s_%s(%d_%d)_gt.png' % (filename, key, j, ctx), rgb[:, j]) else: write_image('%s_%s(%d)_gt.png' % (filename, key, ctx), rgb) if key.startswith('depth'): data[key + '_gt'] = {k: v[i].cpu() for k, v in batch[key].items()} for ctx in batch[key].keys(): depth = batch[key][ctx][i].cpu() if 'gt_png' in self.depth: write_depth('%s_%s(%d)_gt.png' % (filename, key, ctx), depth) if 'gt_npz' in self.depth: write_depth('%s_%s(%d)_gt.npz' % (filename, key, ctx), depth, intrinsics=raw_intrinsics) if 'gt_viz' in self.depth: write_image('%s_%s(%d)_gt_viz.png' % (filename, key, ctx), viz_depth(depth, filter_zeros=True)) if key.startswith('pose'): pose = {k: v[i].cpu() for k, v in batch[key].items()} data[key + '_gt'] = pose if 'gt' in self.pose: write_pickle('%s_%s_gt' % (filename, key), pose) for key in predictions.keys(): if key.startswith('rgb'): data[key + '_pred'] = {k: v[i].cpu() for k, v in predictions[key].items()} for ctx in predictions[key].keys(): rgb = predictions[key][ctx][i].cpu() if 'pred' in self.rgb: if rgb.dim() == 5: for j in range(rgb.shape[1]): write_image('%s_%s(%d_%d)_pred.png' % (filename, key, j, ctx), rgb[:, j]) else: write_image('%s_%s(%d)_pred.png' % (filename, key, ctx), rgb) if key.startswith('depth'): data[key + '_pred'] = {k: v[i].cpu() for k, v in predictions[key].items()} for ctx in predictions[key].keys(): depth = predictions[key][ctx][0][i].cpu() if 'png' in self.depth: write_depth('%s_%s(%d)_pred.png' % (filename, key, ctx), depth) if 'npz' in self.depth: write_depth('%s_%s(%d)_pred.npz' % (filename, key, ctx), depth, intrinsics=intrinsics) if 'viz' in self.depth: write_image('%s_%s(%d)_pred_viz.png' % (filename, key, ctx), viz_depth(depth)) if key.startswith('pose'): pose = {key: val[i].cpu() for key, val in predictions[key].items()} data[key + '_pred'] = pose if 'pred' in self.pose: write_pickle('%s_%s_pred' % (filename, key), pose) if key.startswith('fwd_optical_flow'): optical_flow = {key: val[i].cpu() for key, val in predictions[key].items()} data[key + '_pred'] = optical_flow if 'npz' in self.optical_flow: write_npz('%s_%s_pred' % (filename, key), {'fwd_optical_flow': optical_flow}) if 'viz' in self.optical_flow: for ctx in optical_flow.keys(): write_image('%s_%s(%d)_pred_viz.png' % (filename, key, ctx), viz_optical_flow(optical_flow[ctx])) if key.startswith('mask'): if is_dict(predictions[key]): data[key] = {k: v[i].cpu() for k, v in predictions[key].items()} for ctx in data[key].keys(): write_image('%s_%s(%d)_pred_viz.png' % (filename, key, ctx), predictions[key][ctx][0]) elif is_list(predictions[key]): data[key] = [v[i].cpu() for k, v in predictions[key]] for ctx in data[key]: write_image('%s_%s(%d)_pred_viz.png' % (filename, key, ctx), predictions[key][ctx][0]) else: data[key] = predictions[key][i].cpu() write_image('%s_%s_pred_viz.png' % (filename, key), predictions[key][0]) if self.store_data: write_pickle('%s' % filename, data) return data