|
r""" Logging """
|
|
|
|
import datetime
|
|
import logging
|
|
import os
|
|
|
|
from tensorboardX import SummaryWriter
|
|
import torch
|
|
|
|
|
|
class Logger:
|
|
r""" Writes results of training/testing """
|
|
@classmethod
|
|
def initialize(cls, args, training):
|
|
logtime = datetime.datetime.now().__format__('_%m%d_%H%M%S')
|
|
logpath = args.logpath if training else '_TEST_' + args.load.split('/')[-1].split('.')[0] + logtime
|
|
if logpath == '': logpath = logtime
|
|
|
|
cls.logpath = os.path.join('logs', logpath + '.log')
|
|
cls.benchmark = args.benchmark
|
|
os.makedirs(cls.logpath)
|
|
|
|
logging.basicConfig(filemode='w',
|
|
filename=os.path.join(cls.logpath, 'log.txt'),
|
|
level=logging.INFO,
|
|
format='%(message)s',
|
|
datefmt='%m-%d %H:%M:%S')
|
|
|
|
|
|
console = logging.StreamHandler()
|
|
console.setLevel(logging.INFO)
|
|
formatter = logging.Formatter('%(message)s')
|
|
console.setFormatter(formatter)
|
|
logging.getLogger('').addHandler(console)
|
|
|
|
|
|
cls.tbd_writer = SummaryWriter(os.path.join(cls.logpath, 'tbd/runs'))
|
|
|
|
|
|
if training:
|
|
logging.info(':======== Convolutional Hough Matching Networks =========')
|
|
for arg_key in args.__dict__:
|
|
logging.info('| %20s: %-24s' % (arg_key, str(args.__dict__[arg_key])))
|
|
logging.info(':========================================================\n')
|
|
|
|
@classmethod
|
|
def info(cls, msg):
|
|
r""" Writes message to .txt """
|
|
logging.info(msg)
|
|
|
|
@classmethod
|
|
def save_model(cls, model, epoch, val_pck):
|
|
torch.save(model.state_dict(), os.path.join(cls.logpath, 'pck_best_model.pt'))
|
|
cls.info('Model saved @%d w/ val. PCK: %5.2f.\n' % (epoch, val_pck))
|
|
|
|
|
|
class AverageMeter:
|
|
r""" Stores loss, evaluation results, selected layers """
|
|
def __init__(self, benchamrk):
|
|
r""" Constructor of AverageMeter """
|
|
self.buffer_keys = ['pck']
|
|
self.buffer = {}
|
|
for key in self.buffer_keys:
|
|
self.buffer[key] = []
|
|
|
|
self.loss_buffer = []
|
|
|
|
def update(self, eval_result, loss=None):
|
|
for key in self.buffer_keys:
|
|
self.buffer[key] += eval_result[key]
|
|
|
|
if loss is not None:
|
|
self.loss_buffer.append(loss)
|
|
|
|
def write_result(self, split, epoch):
|
|
msg = '\n*** %s ' % split
|
|
msg += '[@Epoch %02d] ' % epoch
|
|
|
|
if len(self.loss_buffer) > 0:
|
|
msg += 'Loss: %5.2f ' % (sum(self.loss_buffer) / len(self.loss_buffer))
|
|
|
|
for key in self.buffer_keys:
|
|
msg += '%s: %6.2f ' % (key.upper(), sum(self.buffer[key]) / len(self.buffer[key]))
|
|
msg += '***\n'
|
|
Logger.info(msg)
|
|
|
|
def write_process(self, batch_idx, datalen, epoch):
|
|
msg = '[Epoch: %02d] ' % epoch
|
|
msg += '[Batch: %04d/%04d] ' % (batch_idx+1, datalen)
|
|
if len(self.loss_buffer) > 0:
|
|
msg += 'Loss: %5.2f ' % self.loss_buffer[-1]
|
|
msg += 'Avg Loss: %5.5f ' % (sum(self.loss_buffer) / len(self.loss_buffer))
|
|
|
|
for key in self.buffer_keys:
|
|
msg += 'Avg %s: %5.2f ' % (key.upper(), sum(self.buffer[key]) / len(self.buffer[key]) * 100)
|
|
Logger.info(msg)
|
|
|
|
def write_test_process(self, batch_idx, datalen):
|
|
msg = '[Batch: %04d/%04d] ' % (batch_idx+1, datalen)
|
|
|
|
for key in self.buffer_keys:
|
|
if key == 'pck':
|
|
pcks = torch.stack(self.buffer[key]).mean(dim=0) * 100
|
|
val = ''
|
|
for p in pcks:
|
|
val += '%5.2f ' % p.item()
|
|
msg += 'Avg %s: %s ' % (key.upper(), val)
|
|
else:
|
|
msg += 'Avg %s: %5.2f ' % (key.upper(), sum(self.buffer[key]) / len(self.buffer[key]))
|
|
Logger.info(msg)
|
|
|
|
def get_test_result(self):
|
|
result = {}
|
|
for key in self.buffer_keys:
|
|
result[key] = torch.stack(self.buffer[key]).mean(dim=0) * 100
|
|
|
|
return result
|
|
|