Spaces:
Build error
Build error
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. | |
from collections import defaultdict | |
from collections import deque | |
import torch | |
import time | |
from datetime import datetime | |
from .comm import is_main_process | |
class SmoothedValue(object): | |
"""Track a series of values and provide access to smoothed values over a | |
window or the global series average. | |
""" | |
def __init__(self, window_size=20): | |
self.deque = deque(maxlen=window_size) | |
# self.series = [] | |
self.total = 0.0 | |
self.count = 0 | |
def update(self, value): | |
self.deque.append(value) | |
# self.series.append(value) | |
self.count += 1 | |
if value != value: | |
value = 0 | |
self.total += value | |
def median(self): | |
d = torch.tensor(list(self.deque)) | |
return d.median().item() | |
def avg(self): | |
d = torch.tensor(list(self.deque)) | |
return d.mean().item() | |
def global_avg(self): | |
return self.total / self.count | |
class AverageMeter(object): | |
"""Computes and stores the average and current value""" | |
def __init__(self): | |
self.reset() | |
def reset(self): | |
self.val = 0 | |
self.avg = 0 | |
self.sum = 0 | |
self.count = 0 | |
def update(self, val, n=1): | |
self.val = val | |
self.sum += val * n | |
self.count += n | |
self.avg = self.sum / self.count | |
class MetricLogger(object): | |
def __init__(self, delimiter="\t"): | |
self.meters = defaultdict(SmoothedValue) | |
self.delimiter = delimiter | |
def update(self, **kwargs): | |
for k, v in kwargs.items(): | |
if isinstance(v, torch.Tensor): | |
v = v.item() | |
assert isinstance(v, (float, int)) | |
self.meters[k].update(v) | |
def __getattr__(self, attr): | |
if attr in self.meters: | |
return self.meters[attr] | |
if attr in self.__dict__: | |
return self.__dict__[attr] | |
raise AttributeError("'{}' object has no attribute '{}'".format( | |
type(self).__name__, attr)) | |
def __str__(self): | |
loss_str = [] | |
for name, meter in self.meters.items(): | |
loss_str.append( | |
"{}: {:.4f} ({:.4f})".format(name, meter.median, meter.global_avg) | |
) | |
return self.delimiter.join(loss_str) | |
# haotian added tensorboard support | |
class TensorboardLogger(MetricLogger): | |
def __init__(self, | |
log_dir, | |
start_iter=0, | |
delimiter='\t' | |
): | |
super(TensorboardLogger, self).__init__(delimiter) | |
self.iteration = start_iter | |
self.writer = self._get_tensorboard_writer(log_dir) | |
def _get_tensorboard_writer(log_dir): | |
try: | |
from tensorboardX import SummaryWriter | |
except ImportError: | |
raise ImportError( | |
'To use tensorboard please install tensorboardX ' | |
'[ pip install tensorflow tensorboardX ].' | |
) | |
if is_main_process(): | |
# timestamp = datetime.fromtimestamp(time.time()).strftime('%Y%m%d-%H:%M') | |
tb_logger = SummaryWriter('{}'.format(log_dir)) | |
return tb_logger | |
else: | |
return None | |
def update(self, **kwargs): | |
super(TensorboardLogger, self).update(**kwargs) | |
if self.writer: | |
for k, v in kwargs.items(): | |
if isinstance(v, torch.Tensor): | |
v = v.item() | |
assert isinstance(v, (float, int)) | |
self.writer.add_scalar(k, v, self.iteration) | |
self.iteration += 1 | |