Spaces:
Runtime error
Runtime error
# TRI-VIDAR - Copyright 2022 Toyota Research Institute. All rights reserved. | |
from collections import OrderedDict | |
from copy import deepcopy | |
import torch | |
import torch.nn as nn | |
import wandb | |
from vidar.utils.config import cfg_has | |
from vidar.utils.distributed import world_size | |
from vidar.utils.logging import pcolor | |
from vidar.utils.types import is_dict, is_tensor, is_seq, is_namespace | |
from vidar.utils.viz import viz_depth, viz_inv_depth, viz_normals, viz_optical_flow, viz_camera | |
class WandbLogger: | |
""" | |
Wandb logger class to monitor training | |
Parameters | |
---------- | |
cfg : Config | |
Configuration with parameters | |
verbose : Bool | |
Print information on screen if enabled | |
""" | |
def __init__(self, cfg, verbose=False): | |
super().__init__() | |
self.num_logs = { | |
'train': cfg_has(cfg, 'num_train_logs', 0), | |
'val': cfg_has(cfg, 'num_validation_logs', 0), | |
'test': cfg_has(cfg, 'num_test_logs', 0), | |
} | |
self._name = cfg.name if cfg_has(cfg, 'name') else None | |
self._dir = cfg.folder | |
self._entity = cfg.entity | |
self._project = cfg.project | |
self._tags = cfg_has(cfg, 'tags', '') | |
self._notes = cfg_has(cfg, 'notes', '') | |
self._id = None | |
self._anonymous = None | |
self._log_model = True | |
self._experiment = self._create_experiment() | |
self._metrics = OrderedDict() | |
self.only_first = cfg_has(cfg, 'only_first', False) | |
cfg.name = self.run_name | |
cfg.url = self.run_url | |
if verbose: | |
self.print() | |
def finish(): | |
"""Finish wandb session""" | |
wandb.finish() | |
def print(self): | |
"""Print information on screen""" | |
font_base = {'color': 'red', 'attrs': ('bold', 'dark')} | |
font_name = {'color': 'red', 'attrs': ('bold',)} | |
font_underline = {'color': 'red', 'attrs': ('underline',)} | |
print(pcolor('#' * 60, **font_base)) | |
print(pcolor('### WandB: ', **font_base) + \ | |
pcolor('{}'.format(self.run_name), **font_name)) | |
print(pcolor('### ', **font_base) + \ | |
pcolor('{}'.format(self.run_url), **font_underline)) | |
print(pcolor('#' * 60, **font_base)) | |
def __getstate__(self): | |
"""Get the current logger state""" | |
state = self.__dict__.copy() | |
state['_id'] = self._experiment.id if self._experiment is not None else None | |
state['_experiment'] = None | |
return state | |
def _create_experiment(self): | |
"""Creates and returns a new experiment""" | |
experiment = wandb.init( | |
name=self._name, dir=self._dir, project=self._project, | |
anonymous=self._anonymous, reinit=True, id=self._id, notes=self._notes, | |
resume='allow', tags=self._tags, entity=self._entity | |
) | |
wandb.run.save() | |
return experiment | |
def watch(self, model: nn.Module, log='gradients', log_freq=100): | |
"""Watch training parameters""" | |
self.experiment.watch(model, log=log, log_freq=log_freq) | |
def experiment(self): | |
"""Returns the experiment (creates a new if it doesn't exist)""" | |
if self._experiment is None: | |
self._experiment = self._create_experiment() | |
return self._experiment | |
def run_name(self): | |
"""Returns run name""" | |
return wandb.run.name if self._experiment else None | |
def run_url(self): | |
"""Returns run URL""" | |
return f'https://app.wandb.ai/' \ | |
f'{wandb.run.entity}/' \ | |
f'{wandb.run.project}/runs/' \ | |
f'{wandb.run.id}' if self._experiment else None | |
def log_config(self, cfg): | |
"""Log model configuration""" | |
cfg = recursive_convert_config(deepcopy(cfg)) | |
self.experiment.config.update(cfg, allow_val_change=True) | |
def log_metrics(self, metrics): | |
"""Log training metrics""" | |
self._metrics.update(metrics) | |
if 'epochs' in metrics or 'samples' in metrics: | |
self.experiment.log(self._metrics) | |
self._metrics.clear() | |
def log_images(self, batch, output, prefix, ontology=None): | |
""" | |
Log images depending on its nature | |
Parameters | |
---------- | |
batch : Dict | |
Dictionary containing batch information | |
output : Dict | |
Dictionary containing output information | |
prefix : String | |
Prefix string for the log name | |
ontology : Dict | |
Dictionary with ontology information | |
""" | |
for data, suffix in zip([batch, output['predictions']], ['-gt', '-pred']): | |
for key in data.keys(): | |
if key.startswith('rgb'): | |
self._metrics.update(log_rgb( | |
key, prefix + suffix, data, only_first=self.only_first)) | |
elif key.startswith('depth'): | |
self._metrics.update(log_depth( | |
key, prefix + suffix, data, only_first=self.only_first)) | |
elif key.startswith('inv_depth'): | |
self._metrics.update(log_inv_depth( | |
key, prefix + suffix, data, only_first=self.only_first)) | |
elif 'normals' in key: | |
self._metrics.update(log_normals( | |
key, prefix + suffix, data, only_first=self.only_first)) | |
elif key.startswith('stddev'): | |
self._metrics.update(log_stddev( | |
key, prefix + suffix, data, only_first=self.only_first)) | |
elif key.startswith('logvar'): | |
self._metrics.update(log_logvar( | |
key, prefix + suffix, data, only_first=self.only_first)) | |
elif 'optical_flow' in key: | |
self._metrics.update(log_optical_flow( | |
key, prefix + suffix, data, only_first=self.only_first)) | |
elif 'mask' in key or 'valid' in key: | |
self._metrics.update(log_rgb( | |
key, prefix, data, only_first=self.only_first)) | |
# elif 'camera' in key: | |
# self._metrics.update(log_camera( | |
# key, prefix + suffix, data, only_first=self.only_first)) | |
# elif 'uncertainty' in key: | |
# self._metrics.update(log_uncertainty(key, prefix, data)) | |
# elif 'semantic' in key and ontology is not None: | |
# self._metrics.update(log_semantic(key, prefix, data, ontology=ontology)) | |
# if 'scene_flow' in key: | |
# self._metrics.update(log_scene_flow(key, prefix_idx, data)) | |
# elif 'score' in key: | |
# # Log score as image heatmap | |
# self._metrics.update(log_keypoint_score(key, prefix, data)) | |
def log_data(self, mode, batch, output, dataset, prefix, ontology=None): | |
"""Helper function used to log images""" | |
idx = batch['idx'][0] | |
num_logs = self.num_logs[mode] | |
if num_logs > 0: | |
interval = (len(dataset) // world_size() // num_logs) * world_size() | |
if interval == 0 or (idx % interval == 0 and idx < interval * num_logs): | |
prefix = '{}-{}-{}'.format(mode, prefix, batch['idx'][0].item()) | |
# batch, output = prepare_logging(batch, output) | |
self.log_images(batch, output, prefix, ontology=ontology) | |
def recursive_convert_config(cfg): | |
"""Convert configuration to dictionary recursively""" | |
cfg = cfg.__dict__ | |
for key, val in cfg.items(): | |
if is_namespace(val): | |
cfg[key] = recursive_convert_config(val) | |
return cfg | |
def prep_image(key, prefix, image): | |
"""Prepare image for logging""" | |
if is_tensor(image): | |
if image.dim() == 2: | |
image = image.unsqueeze(0) | |
if image.dim() == 4: | |
image = image[0] | |
image = image.detach().permute(1, 2, 0).cpu().numpy() | |
prefix_key = '{}-{}'.format(prefix, key) | |
return {prefix_key: wandb.Image(image, caption=key)} | |
def log_sequence(key, prefix, data, i, only_first, fn): | |
"""Logs a sequence of images (list, tuple or dict)""" | |
log = {} | |
if is_dict(data): | |
for ctx, dict_val in data.items(): | |
if is_seq(dict_val): | |
if only_first: | |
dict_val = dict_val[:1] | |
for idx, list_val in enumerate(dict_val): | |
if list_val.dim() == 5: | |
for j in range(list_val.shape[1]): | |
log.update(fn('%s(%s_%d)_%d' % (key, str(ctx), j, idx), prefix, list_val[:, j], i)) | |
else: | |
log.update(fn('%s(%s)_%d' % (key, str(ctx), idx), prefix, list_val, i)) | |
else: | |
if dict_val.dim() == 5: | |
for j in range(dict_val.shape[1]): | |
log.update(fn('%s(%s_%d)' % (key, str(ctx), j), prefix, dict_val[:, j], i)) | |
else: | |
log.update(fn('%s(%s)' % (key, str(ctx)), prefix, dict_val, i)) | |
elif is_seq(data): | |
if only_first: | |
data = data[:1] | |
for idx, list_val in enumerate(data): | |
log.update(fn('%s_%d' % (key, idx), prefix, list_val, i)) | |
else: | |
log.update(fn('%s' % key, prefix, data, i)) | |
return log | |
def log_rgb(key, prefix, batch, i=0, only_first=None): | |
"""Log RGB image""" | |
rgb = batch[key] if is_dict(batch) else batch | |
if is_seq(rgb) or is_dict(rgb): | |
return log_sequence(key, prefix, rgb, i, only_first, log_rgb) | |
return prep_image(key, prefix, rgb[i].clamp(min=0.0, max=1.0)) | |
def log_depth(key, prefix, batch, i=0, only_first=None): | |
"""Log depth map""" | |
depth = batch[key] if is_dict(batch) else batch | |
if is_seq(depth) or is_dict(depth): | |
return log_sequence(key, prefix, depth, i, only_first, log_depth) | |
return prep_image(key, prefix, viz_depth(depth[i], filter_zeros=True)) | |
def log_inv_depth(key, prefix, batch, i=0, only_first=None): | |
"""Log inverse depth map""" | |
inv_depth = batch[key] if is_dict(batch) else batch | |
if is_seq(inv_depth) or is_dict(inv_depth): | |
return log_sequence(key, prefix, inv_depth, i, only_first, log_inv_depth) | |
return prep_image(key, prefix, viz_inv_depth(inv_depth[i])) | |
def log_normals(key, prefix, batch, i=0, only_first=None): | |
"""Log normals""" | |
normals = batch[key] if is_dict(batch) else batch | |
if is_seq(normals) or is_dict(normals): | |
return log_sequence(key, prefix, normals, i, only_first, log_normals) | |
return prep_image(key, prefix, viz_normals(normals[i])) | |
def log_optical_flow(key, prefix, batch, i=0, only_first=None): | |
"""Log optical flow""" | |
optical_flow = batch[key] if is_dict(batch) else batch | |
if is_seq(optical_flow) or is_dict(optical_flow): | |
return log_sequence(key, prefix, optical_flow, i, only_first, log_optical_flow) | |
return prep_image(key, prefix, viz_optical_flow(optical_flow[i])) | |
def log_stddev(key, prefix, batch, i=0, only_first=None): | |
"""Log standard deviation""" | |
stddev = batch[key] if is_dict(batch) else batch | |
if is_seq(stddev) or is_dict(stddev): | |
return log_sequence(key, prefix, stddev, i, only_first, log_stddev) | |
return prep_image(key, prefix, viz_inv_depth(stddev[i], colormap='jet')) | |
def log_logvar(key, prefix, batch, i=0, only_first=None): | |
"""Log standard deviation""" | |
logvar = batch[key] if is_dict(batch) else batch | |
if is_seq(logvar) or is_dict(logvar): | |
return log_sequence(key, prefix, logvar, i, only_first, log_logvar) | |
return prep_image(key, prefix, viz_inv_depth(torch.exp(logvar[i]), colormap='jet')) | |
def log_camera(key, prefix, batch, i=0, only_first=None): | |
"""Log camera""" | |
camera = batch[key] if is_dict(batch) else batch | |
if is_seq(camera) or is_dict(camera): | |
return log_sequence(key, prefix, camera, i, only_first, log_camera) | |
return prep_image(key, prefix, viz_camera(camera[i])) | |