Jiading Fang
add define
fc16538
raw
history blame
4.26 kB
# TRI-VIDAR - Copyright 2022 Toyota Research Institute. All rights reserved.
import argparse
import os
from functools import partial
from termcolor import colored
from vidar.utils.distributed import on_rank_0
def pcolor(string, color, on_color=None, attrs=None):
"""
Produces a colored string for printing
Parameters
----------
string : String
String that will be colored
color : String
Color to use
on_color : String
Background color to use
attrs : list[String]
Different attributes for the string
Returns
-------
string: String
Colored string
"""
return colored(string, color, on_color, attrs)
@on_rank_0
def print_config(config):
"""
Prints header for model configuration
Parameters
----------
config : Config
Model configuration
"""
header_colors = {
0: ('red', ('bold', 'dark')),
1: ('cyan', ('bold','dark')),
2: ('green', ('bold', 'dark')),
3: ('green', ('bold', 'dark')),
}
line_colors = ('blue', ())
# Recursive print function
def print_recursive(rec_args, pad=3, level=0):
# if level == 0:
# print(pcolor('config:',
# color=header_colors[level][0],
# attrs=header_colors[level][1]))
for key, val in rec_args.__dict__.items():
if isinstance(val, argparse.Namespace):
print(pcolor('{} {}:'.format('-' * pad, key),
color=header_colors[level][0],
attrs=header_colors[level][1]))
print_recursive(val, pad + 2, level + 1)
else:
print('{}: {}'.format(pcolor('{} {}'.format('-' * pad, key),
color=line_colors[0],
attrs=line_colors[1]), val))
# Color partial functions
pcolor1 = partial(pcolor, color='blue', attrs=('bold', 'dark'))
pcolor2 = partial(pcolor, color='blue', attrs=('bold',))
# Config and name
line = pcolor1('#' * 120)
# if 'default' in config.__dict__.keys():
# path = pcolor1('### Config: ') + \
# pcolor2('{}'.format(config.default.replace('/', '.'))) + \
# pcolor1(' -> ') + \
# pcolor2('{}'.format(config.config.replace('/', '.')))
# if 'name' in config.__dict__.keys():
# name = pcolor1('### Name: ') + \
# pcolor2('{}'.format(config.name))
# # Add wandb link if available
# if not config.wandb.dry_run:
# name += pcolor1(' -> ') + \
# pcolor2('{}'.format(config.wandb.url))
# # Add s3 link if available
# if config.checkpoint.s3_path is not '':
# name += pcolor1('\n### s3:') + \
# pcolor2(' {}'.format(config.checkpoint.s3_url))
# # # Create header string
# # header = '%s\n%s\n%s\n%s' % (line, path, name, line)
# Print header, config and header again
print()
# print(header)
print_recursive(config)
# print(header)
print()
def set_debug(debug):
"""
Enable or disable debug terminal logging
Parameters
----------
debug : Bool
Debugging flag (True to enable)
"""
# Disable logging if requested
if not debug:
os.environ['NCCL_DEBUG'] = ''
os.environ['WANDB_SILENT'] = 'true'
# warnings.filterwarnings("ignore")
# logging.disable(logging.CRITICAL)
class AvgMeter:
"""Average meter for logging"""
def __init__(self, n_max=100):
self.n_max = n_max
self.values = []
def __call__(self, value):
"""Append new value and returns average"""
self.values.append(value)
if len(self.values) > self.n_max:
self.values.pop(0)
return self.get()
def get(self):
"""Get current average"""
return sum(self.values) / len(self.values)
def reset(self):
"""Reset meter"""
self.values.clear()
def get_and_reset(self):
"""Get current average and reset"""
average = self.get()
self.reset()
return average