import csv import datetime from collections import defaultdict import numpy as np import torch import torchvision import wandb from termcolor import colored from torch.utils.tensorboard import SummaryWriter COMMON_TRAIN_FORMAT = [('frame', 'F', 'int'), ('step', 'S', 'int'), ('episode', 'E', 'int'), ('episode_length', 'L', 'int'), ('episode_reward', 'R', 'float'), ('fps', 'FPS', 'float'), ('total_time', 'T', 'time')] COMMON_EVAL_FORMAT = [('frame', 'F', 'int'), ('step', 'S', 'int'), ('episode', 'E', 'int'), ('episode_length', 'L', 'int'), ('episode_reward', 'R', 'float'), ('total_time', 'T', 'time')] class AverageMeter(object): def __init__(self): self._sum = 0 self._count = 0 def update(self, value, n=1): self._sum += value self._count += n def value(self): return self._sum / max(1, self._count) class MetersGroup(object): def __init__(self, csv_file_name, formating, use_wandb): self._csv_file_name = csv_file_name self._formating = formating self._meters = defaultdict(AverageMeter) self._csv_file = None self._csv_writer = None self.use_wandb = use_wandb def log(self, key, value, n=1): self._meters[key].update(value, n) def _prime_meters(self): data = dict() for key, meter in self._meters.items(): if key.startswith('train'): key = key[len('train') + 1:] else: key = key[len('eval') + 1:] key = key.replace('/', '_') data[key] = meter.value() return data def _remove_old_entries(self, data): rows = [] with self._csv_file_name.open('r') as f: reader = csv.DictReader(f) for row in reader: if 'episode' in row: # BUGFIX: covers weird cases where CSV are badly written if row['episode'] == '': rows.append(row) continue if type(row['episode']) == type(None): continue if float(row['episode']) >= data['episode']: break rows.append(row) with self._csv_file_name.open('w') as f: # To handle CSV that have more keys than new data keys = set(data.keys()) if len(rows) > 0: keys = keys | set(row.keys()) keys = sorted(list(keys)) # writer = csv.DictWriter(f, fieldnames=keys, restval=0.0) writer.writeheader() for row in rows: writer.writerow(row) def _dump_to_csv(self, data): if self._csv_writer is None: should_write_header = True if self._csv_file_name.exists(): self._remove_old_entries(data) should_write_header = False self._csv_file = self._csv_file_name.open('a') self._csv_writer = csv.DictWriter(self._csv_file, fieldnames=sorted(data.keys()), restval=0.0) if should_write_header: self._csv_writer.writeheader() # To handle components that start training later # (restval covers only when data has less keys than the CSV) if self._csv_writer.fieldnames != sorted(data.keys()) and \ len(self._csv_writer.fieldnames) < len(data.keys()): self._csv_file.close() self._csv_file = self._csv_file_name.open('r') dict_reader = csv.DictReader(self._csv_file) rows = [row for row in dict_reader] self._csv_file.close() self._csv_file = self._csv_file_name.open('w') self._csv_writer = csv.DictWriter(self._csv_file, fieldnames=sorted(data.keys()), restval=0.0) self._csv_writer.writeheader() for row in rows: self._csv_writer.writerow(row) self._csv_writer.writerow(data) self._csv_file.flush() def _format(self, key, value, ty): if ty == 'int': value = int(value) return f'{key}: {value}' elif ty == 'float': return f'{key}: {value:.04f}' elif ty == 'time': value = str(datetime.timedelta(seconds=int(value))) return f'{key}: {value}' else: raise f'invalid format type: {ty}' def _dump_to_console(self, data, prefix): prefix = colored(prefix, 'yellow' if prefix == 'train' else 'green') pieces = [f'| {prefix: <14}'] for key, disp_key, ty in self._formating: value = data.get(key, 0) pieces.append(self._format(disp_key, value, ty)) print(' | '.join(pieces)) def _dump_to_wandb(self, data): wandb.log(data) def dump(self, step, prefix): if len(self._meters) == 0: return data = self._prime_meters() data['frame'] = step if self.use_wandb: wandb_data = {prefix + '/' + key: val for key, val in data.items()} self._dump_to_wandb(data=wandb_data) # self._dump_to_csv(data) self._dump_to_console(data, prefix) self._meters.clear() class Logger(object): def __init__(self, log_dir, use_tb, use_wandb): self._log_dir = log_dir self._train_mg = MetersGroup(log_dir / 'train.csv', formating=COMMON_TRAIN_FORMAT, use_wandb=use_wandb) self._eval_mg = MetersGroup(log_dir / 'eval.csv', formating=COMMON_EVAL_FORMAT, use_wandb=use_wandb) if use_tb: self._sw = SummaryWriter(str(log_dir / 'tb')) else: self._sw = None self.use_wandb = use_wandb def _try_sw_log(self, key, value, step): if self._sw is not None: self._sw.add_scalar(key, value, step) def log(self, key, value, step): assert key.startswith('train') or key.startswith('eval') if type(value) == torch.Tensor: value = value.item() self._try_sw_log(key, value, step) mg = self._train_mg if key.startswith('train') else self._eval_mg mg.log(key, value) def log_metrics(self, metrics, step, ty): for key, value in metrics.items(): self.log(f'{ty}/{key}', value, step) def dump(self, step, ty=None): if ty is None or ty == 'eval': self._eval_mg.dump(step, 'eval') if ty is None or ty == 'train': self._train_mg.dump(step, 'train') def log_and_dump_ctx(self, step, ty): return LogAndDumpCtx(self, step, ty) def log_visual(self, data, step): if self._sw is not None: for k, v in data.items(): if len(v.shape) == 3: self._sw.add_image(k, v) else: if len(v.shape) == 4: v = np.expand_dims(v, axis=0) self._sw.add_video(k, v, global_step=step, fps=15) if self.use_wandb: for k, v in data.items(): if type(v) is not np.ndarray: v = v.cpu() if v.dtype not in [np.uint8]: v = v*255 v = np.uint8(v) if len(v.shape) == 3: if v.shape[0] == 3: v = v.transpose(1,2,0) # Note: defaulting to save only one image/video to save storage on wandb wandb.log({k: wandb.Image(v)},) else: # Note: defaulting to save only one image/video to save storage on wandb wandb.log({k: wandb.Video(v, fps=15, format="gif")},) class LogAndDumpCtx: def __init__(self, logger, step, ty): self._logger = logger self._step = step self._ty = ty def __enter__(self): return self def __call__(self, key, value): self._logger.log(f'{self._ty}/{key}', value, self._step) def __exit__(self, *args): self._logger.dump(self._step, self._ty)