Jiading Fang
add define
fc16538
# TRI-VIDAR - Copyright 2022 Toyota Research Institute. All rights reserved.
from collections import OrderedDict
from copy import deepcopy
import torch
import torch.nn as nn
import wandb
from vidar.utils.config import cfg_has
from vidar.utils.distributed import world_size
from vidar.utils.logging import pcolor
from vidar.utils.types import is_dict, is_tensor, is_seq, is_namespace
from vidar.utils.viz import viz_depth, viz_inv_depth, viz_normals, viz_optical_flow, viz_camera
class WandbLogger:
"""
Wandb logger class to monitor training
Parameters
----------
cfg : Config
Configuration with parameters
verbose : Bool
Print information on screen if enabled
"""
def __init__(self, cfg, verbose=False):
super().__init__()
self.num_logs = {
'train': cfg_has(cfg, 'num_train_logs', 0),
'val': cfg_has(cfg, 'num_validation_logs', 0),
'test': cfg_has(cfg, 'num_test_logs', 0),
}
self._name = cfg.name if cfg_has(cfg, 'name') else None
self._dir = cfg.folder
self._entity = cfg.entity
self._project = cfg.project
self._tags = cfg_has(cfg, 'tags', '')
self._notes = cfg_has(cfg, 'notes', '')
self._id = None
self._anonymous = None
self._log_model = True
self._experiment = self._create_experiment()
self._metrics = OrderedDict()
self.only_first = cfg_has(cfg, 'only_first', False)
cfg.name = self.run_name
cfg.url = self.run_url
if verbose:
self.print()
@staticmethod
def finish():
"""Finish wandb session"""
wandb.finish()
def print(self):
"""Print information on screen"""
font_base = {'color': 'red', 'attrs': ('bold', 'dark')}
font_name = {'color': 'red', 'attrs': ('bold',)}
font_underline = {'color': 'red', 'attrs': ('underline',)}
print(pcolor('#' * 60, **font_base))
print(pcolor('### WandB: ', **font_base) + \
pcolor('{}'.format(self.run_name), **font_name))
print(pcolor('### ', **font_base) + \
pcolor('{}'.format(self.run_url), **font_underline))
print(pcolor('#' * 60, **font_base))
def __getstate__(self):
"""Get the current logger state"""
state = self.__dict__.copy()
state['_id'] = self._experiment.id if self._experiment is not None else None
state['_experiment'] = None
return state
def _create_experiment(self):
"""Creates and returns a new experiment"""
experiment = wandb.init(
name=self._name, dir=self._dir, project=self._project,
anonymous=self._anonymous, reinit=True, id=self._id, notes=self._notes,
resume='allow', tags=self._tags, entity=self._entity
)
wandb.run.save()
return experiment
def watch(self, model: nn.Module, log='gradients', log_freq=100):
"""Watch training parameters"""
self.experiment.watch(model, log=log, log_freq=log_freq)
@property
def experiment(self):
"""Returns the experiment (creates a new if it doesn't exist)"""
if self._experiment is None:
self._experiment = self._create_experiment()
return self._experiment
@property
def run_name(self):
"""Returns run name"""
return wandb.run.name if self._experiment else None
@property
def run_url(self):
"""Returns run URL"""
return f'https://app.wandb.ai/' \
f'{wandb.run.entity}/' \
f'{wandb.run.project}/runs/' \
f'{wandb.run.id}' if self._experiment else None
def log_config(self, cfg):
"""Log model configuration"""
cfg = recursive_convert_config(deepcopy(cfg))
self.experiment.config.update(cfg, allow_val_change=True)
def log_metrics(self, metrics):
"""Log training metrics"""
self._metrics.update(metrics)
if 'epochs' in metrics or 'samples' in metrics:
self.experiment.log(self._metrics)
self._metrics.clear()
def log_images(self, batch, output, prefix, ontology=None):
"""
Log images depending on its nature
Parameters
----------
batch : Dict
Dictionary containing batch information
output : Dict
Dictionary containing output information
prefix : String
Prefix string for the log name
ontology : Dict
Dictionary with ontology information
"""
for data, suffix in zip([batch, output['predictions']], ['-gt', '-pred']):
for key in data.keys():
if key.startswith('rgb'):
self._metrics.update(log_rgb(
key, prefix + suffix, data, only_first=self.only_first))
elif key.startswith('depth'):
self._metrics.update(log_depth(
key, prefix + suffix, data, only_first=self.only_first))
elif key.startswith('inv_depth'):
self._metrics.update(log_inv_depth(
key, prefix + suffix, data, only_first=self.only_first))
elif 'normals' in key:
self._metrics.update(log_normals(
key, prefix + suffix, data, only_first=self.only_first))
elif key.startswith('stddev'):
self._metrics.update(log_stddev(
key, prefix + suffix, data, only_first=self.only_first))
elif key.startswith('logvar'):
self._metrics.update(log_logvar(
key, prefix + suffix, data, only_first=self.only_first))
elif 'optical_flow' in key:
self._metrics.update(log_optical_flow(
key, prefix + suffix, data, only_first=self.only_first))
elif 'mask' in key or 'valid' in key:
self._metrics.update(log_rgb(
key, prefix, data, only_first=self.only_first))
# elif 'camera' in key:
# self._metrics.update(log_camera(
# key, prefix + suffix, data, only_first=self.only_first))
# elif 'uncertainty' in key:
# self._metrics.update(log_uncertainty(key, prefix, data))
# elif 'semantic' in key and ontology is not None:
# self._metrics.update(log_semantic(key, prefix, data, ontology=ontology))
# if 'scene_flow' in key:
# self._metrics.update(log_scene_flow(key, prefix_idx, data))
# elif 'score' in key:
# # Log score as image heatmap
# self._metrics.update(log_keypoint_score(key, prefix, data))
def log_data(self, mode, batch, output, dataset, prefix, ontology=None):
"""Helper function used to log images"""
idx = batch['idx'][0]
num_logs = self.num_logs[mode]
if num_logs > 0:
interval = (len(dataset) // world_size() // num_logs) * world_size()
if interval == 0 or (idx % interval == 0 and idx < interval * num_logs):
prefix = '{}-{}-{}'.format(mode, prefix, batch['idx'][0].item())
# batch, output = prepare_logging(batch, output)
self.log_images(batch, output, prefix, ontology=ontology)
def recursive_convert_config(cfg):
"""Convert configuration to dictionary recursively"""
cfg = cfg.__dict__
for key, val in cfg.items():
if is_namespace(val):
cfg[key] = recursive_convert_config(val)
return cfg
def prep_image(key, prefix, image):
"""Prepare image for logging"""
if is_tensor(image):
if image.dim() == 2:
image = image.unsqueeze(0)
if image.dim() == 4:
image = image[0]
image = image.detach().permute(1, 2, 0).cpu().numpy()
prefix_key = '{}-{}'.format(prefix, key)
return {prefix_key: wandb.Image(image, caption=key)}
def log_sequence(key, prefix, data, i, only_first, fn):
"""Logs a sequence of images (list, tuple or dict)"""
log = {}
if is_dict(data):
for ctx, dict_val in data.items():
if is_seq(dict_val):
if only_first:
dict_val = dict_val[:1]
for idx, list_val in enumerate(dict_val):
if list_val.dim() == 5:
for j in range(list_val.shape[1]):
log.update(fn('%s(%s_%d)_%d' % (key, str(ctx), j, idx), prefix, list_val[:, j], i))
else:
log.update(fn('%s(%s)_%d' % (key, str(ctx), idx), prefix, list_val, i))
else:
if dict_val.dim() == 5:
for j in range(dict_val.shape[1]):
log.update(fn('%s(%s_%d)' % (key, str(ctx), j), prefix, dict_val[:, j], i))
else:
log.update(fn('%s(%s)' % (key, str(ctx)), prefix, dict_val, i))
elif is_seq(data):
if only_first:
data = data[:1]
for idx, list_val in enumerate(data):
log.update(fn('%s_%d' % (key, idx), prefix, list_val, i))
else:
log.update(fn('%s' % key, prefix, data, i))
return log
def log_rgb(key, prefix, batch, i=0, only_first=None):
"""Log RGB image"""
rgb = batch[key] if is_dict(batch) else batch
if is_seq(rgb) or is_dict(rgb):
return log_sequence(key, prefix, rgb, i, only_first, log_rgb)
return prep_image(key, prefix, rgb[i].clamp(min=0.0, max=1.0))
def log_depth(key, prefix, batch, i=0, only_first=None):
"""Log depth map"""
depth = batch[key] if is_dict(batch) else batch
if is_seq(depth) or is_dict(depth):
return log_sequence(key, prefix, depth, i, only_first, log_depth)
return prep_image(key, prefix, viz_depth(depth[i], filter_zeros=True))
def log_inv_depth(key, prefix, batch, i=0, only_first=None):
"""Log inverse depth map"""
inv_depth = batch[key] if is_dict(batch) else batch
if is_seq(inv_depth) or is_dict(inv_depth):
return log_sequence(key, prefix, inv_depth, i, only_first, log_inv_depth)
return prep_image(key, prefix, viz_inv_depth(inv_depth[i]))
def log_normals(key, prefix, batch, i=0, only_first=None):
"""Log normals"""
normals = batch[key] if is_dict(batch) else batch
if is_seq(normals) or is_dict(normals):
return log_sequence(key, prefix, normals, i, only_first, log_normals)
return prep_image(key, prefix, viz_normals(normals[i]))
def log_optical_flow(key, prefix, batch, i=0, only_first=None):
"""Log optical flow"""
optical_flow = batch[key] if is_dict(batch) else batch
if is_seq(optical_flow) or is_dict(optical_flow):
return log_sequence(key, prefix, optical_flow, i, only_first, log_optical_flow)
return prep_image(key, prefix, viz_optical_flow(optical_flow[i]))
def log_stddev(key, prefix, batch, i=0, only_first=None):
"""Log standard deviation"""
stddev = batch[key] if is_dict(batch) else batch
if is_seq(stddev) or is_dict(stddev):
return log_sequence(key, prefix, stddev, i, only_first, log_stddev)
return prep_image(key, prefix, viz_inv_depth(stddev[i], colormap='jet'))
def log_logvar(key, prefix, batch, i=0, only_first=None):
"""Log standard deviation"""
logvar = batch[key] if is_dict(batch) else batch
if is_seq(logvar) or is_dict(logvar):
return log_sequence(key, prefix, logvar, i, only_first, log_logvar)
return prep_image(key, prefix, viz_inv_depth(torch.exp(logvar[i]), colormap='jet'))
def log_camera(key, prefix, batch, i=0, only_first=None):
"""Log camera"""
camera = batch[key] if is_dict(batch) else batch
if is_seq(camera) or is_dict(camera):
return log_sequence(key, prefix, camera, i, only_first, log_camera)
return prep_image(key, prefix, viz_camera(camera[i]))