Spaces:
Running
on
Zero
Running
on
Zero
import csv | |
import datetime | |
from collections import defaultdict | |
import numpy as np | |
import torch | |
import torchvision | |
import wandb | |
from termcolor import colored | |
from torch.utils.tensorboard import SummaryWriter | |
COMMON_TRAIN_FORMAT = [('frame', 'F', 'int'), ('step', 'S', 'int'), | |
('episode', 'E', 'int'), ('episode_length', 'L', 'int'), | |
('episode_reward', 'R', 'float'), | |
('fps', 'FPS', 'float'), ('total_time', 'T', 'time')] | |
COMMON_EVAL_FORMAT = [('frame', 'F', 'int'), ('step', 'S', 'int'), | |
('episode', 'E', 'int'), ('episode_length', 'L', 'int'), | |
('episode_reward', 'R', 'float'), | |
('total_time', 'T', 'time')] | |
class AverageMeter(object): | |
def __init__(self): | |
self._sum = 0 | |
self._count = 0 | |
def update(self, value, n=1): | |
self._sum += value | |
self._count += n | |
def value(self): | |
return self._sum / max(1, self._count) | |
class MetersGroup(object): | |
def __init__(self, csv_file_name, formating, use_wandb): | |
self._csv_file_name = csv_file_name | |
self._formating = formating | |
self._meters = defaultdict(AverageMeter) | |
self._csv_file = None | |
self._csv_writer = None | |
self.use_wandb = use_wandb | |
def log(self, key, value, n=1): | |
self._meters[key].update(value, n) | |
def _prime_meters(self): | |
data = dict() | |
for key, meter in self._meters.items(): | |
if key.startswith('train'): | |
key = key[len('train') + 1:] | |
else: | |
key = key[len('eval') + 1:] | |
key = key.replace('/', '_') | |
data[key] = meter.value() | |
return data | |
def _remove_old_entries(self, data): | |
rows = [] | |
with self._csv_file_name.open('r') as f: | |
reader = csv.DictReader(f) | |
for row in reader: | |
if 'episode' in row: | |
# BUGFIX: covers weird cases where CSV are badly written | |
if row['episode'] == '': | |
rows.append(row) | |
continue | |
if type(row['episode']) == type(None): | |
continue | |
if float(row['episode']) >= data['episode']: | |
break | |
rows.append(row) | |
with self._csv_file_name.open('w') as f: | |
# To handle CSV that have more keys than new data | |
keys = set(data.keys()) | |
if len(rows) > 0: keys = keys | set(row.keys()) | |
keys = sorted(list(keys)) | |
# | |
writer = csv.DictWriter(f, | |
fieldnames=keys, | |
restval=0.0) | |
writer.writeheader() | |
for row in rows: | |
writer.writerow(row) | |
def _dump_to_csv(self, data): | |
if self._csv_writer is None: | |
should_write_header = True | |
if self._csv_file_name.exists(): | |
self._remove_old_entries(data) | |
should_write_header = False | |
self._csv_file = self._csv_file_name.open('a') | |
self._csv_writer = csv.DictWriter(self._csv_file, | |
fieldnames=sorted(data.keys()), | |
restval=0.0) | |
if should_write_header: | |
self._csv_writer.writeheader() | |
# To handle components that start training later | |
# (restval covers only when data has less keys than the CSV) | |
if self._csv_writer.fieldnames != sorted(data.keys()) and \ | |
len(self._csv_writer.fieldnames) < len(data.keys()): | |
self._csv_file.close() | |
self._csv_file = self._csv_file_name.open('r') | |
dict_reader = csv.DictReader(self._csv_file) | |
rows = [row for row in dict_reader] | |
self._csv_file.close() | |
self._csv_file = self._csv_file_name.open('w') | |
self._csv_writer = csv.DictWriter(self._csv_file, | |
fieldnames=sorted(data.keys()), | |
restval=0.0) | |
self._csv_writer.writeheader() | |
for row in rows: | |
self._csv_writer.writerow(row) | |
self._csv_writer.writerow(data) | |
self._csv_file.flush() | |
def _format(self, key, value, ty): | |
if ty == 'int': | |
value = int(value) | |
return f'{key}: {value}' | |
elif ty == 'float': | |
return f'{key}: {value:.04f}' | |
elif ty == 'time': | |
value = str(datetime.timedelta(seconds=int(value))) | |
return f'{key}: {value}' | |
else: | |
raise f'invalid format type: {ty}' | |
def _dump_to_console(self, data, prefix): | |
prefix = colored(prefix, 'yellow' if prefix == 'train' else 'green') | |
pieces = [f'| {prefix: <14}'] | |
for key, disp_key, ty in self._formating: | |
value = data.get(key, 0) | |
pieces.append(self._format(disp_key, value, ty)) | |
print(' | '.join(pieces)) | |
def _dump_to_wandb(self, data): | |
wandb.log(data) | |
def dump(self, step, prefix): | |
if len(self._meters) == 0: | |
return | |
data = self._prime_meters() | |
data['frame'] = step | |
if self.use_wandb: | |
wandb_data = {prefix + '/' + key: val for key, val in data.items()} | |
self._dump_to_wandb(data=wandb_data) | |
# self._dump_to_csv(data) | |
self._dump_to_console(data, prefix) | |
self._meters.clear() | |
class Logger(object): | |
def __init__(self, log_dir, use_tb, use_wandb): | |
self._log_dir = log_dir | |
self._train_mg = MetersGroup(log_dir / 'train.csv', | |
formating=COMMON_TRAIN_FORMAT, | |
use_wandb=use_wandb) | |
self._eval_mg = MetersGroup(log_dir / 'eval.csv', | |
formating=COMMON_EVAL_FORMAT, | |
use_wandb=use_wandb) | |
if use_tb: | |
self._sw = SummaryWriter(str(log_dir / 'tb')) | |
else: | |
self._sw = None | |
self.use_wandb = use_wandb | |
def _try_sw_log(self, key, value, step): | |
if self._sw is not None: | |
self._sw.add_scalar(key, value, step) | |
def log(self, key, value, step): | |
assert key.startswith('train') or key.startswith('eval') | |
if type(value) == torch.Tensor: | |
value = value.item() | |
self._try_sw_log(key, value, step) | |
mg = self._train_mg if key.startswith('train') else self._eval_mg | |
mg.log(key, value) | |
def log_metrics(self, metrics, step, ty): | |
for key, value in metrics.items(): | |
self.log(f'{ty}/{key}', value, step) | |
def dump(self, step, ty=None): | |
if ty is None or ty == 'eval': | |
self._eval_mg.dump(step, 'eval') | |
if ty is None or ty == 'train': | |
self._train_mg.dump(step, 'train') | |
def log_and_dump_ctx(self, step, ty): | |
return LogAndDumpCtx(self, step, ty) | |
def log_visual(self, data, step): | |
if self._sw is not None: | |
for k, v in data.items(): | |
if len(v.shape) == 3: | |
self._sw.add_image(k, v) | |
else: | |
if len(v.shape) == 4: | |
v = np.expand_dims(v, axis=0) | |
self._sw.add_video(k, v, global_step=step, fps=15) | |
if self.use_wandb: | |
for k, v in data.items(): | |
if type(v) is not np.ndarray: | |
v = v.cpu() | |
if v.dtype not in [np.uint8]: | |
v = v*255 | |
v = np.uint8(v) | |
if len(v.shape) == 3: | |
if v.shape[0] == 3: | |
v = v.transpose(1,2,0) | |
# Note: defaulting to save only one image/video to save storage on wandb | |
wandb.log({k: wandb.Image(v)},) | |
else: | |
# Note: defaulting to save only one image/video to save storage on wandb | |
wandb.log({k: wandb.Video(v, fps=15, format="gif")},) | |
class LogAndDumpCtx: | |
def __init__(self, logger, step, ty): | |
self._logger = logger | |
self._step = step | |
self._ty = ty | |
def __enter__(self): | |
return self | |
def __call__(self, key, value): | |
self._logger.log(f'{self._ty}/{key}', value, self._step) | |
def __exit__(self, *args): | |
self._logger.dump(self._step, self._ty) | |