AK391
files
d380b77
raw
history blame
5.28 kB
import bisect
import functools
import logging
import numbers
import os
import signal
import sys
import traceback
import warnings
import torch
from pytorch_lightning import seed_everything
LOGGER = logging.getLogger(__name__)
def check_and_warn_input_range(tensor, min_value, max_value, name):
actual_min = tensor.min()
actual_max = tensor.max()
if actual_min < min_value or actual_max > max_value:
warnings.warn(f"{name} must be in {min_value}..{max_value} range, but it ranges {actual_min}..{actual_max}")
def sum_dict_with_prefix(target, cur_dict, prefix, default=0):
for k, v in cur_dict.items():
target_key = prefix + k
target[target_key] = target.get(target_key, default) + v
def average_dicts(dict_list):
result = {}
norm = 1e-3
for dct in dict_list:
sum_dict_with_prefix(result, dct, '')
norm += 1
for k in list(result):
result[k] /= norm
return result
def add_prefix_to_keys(dct, prefix):
return {prefix + k: v for k, v in dct.items()}
def set_requires_grad(module, value):
for param in module.parameters():
param.requires_grad = value
def flatten_dict(dct):
result = {}
for k, v in dct.items():
if isinstance(k, tuple):
k = '_'.join(k)
if isinstance(v, dict):
for sub_k, sub_v in flatten_dict(v).items():
result[f'{k}_{sub_k}'] = sub_v
else:
result[k] = v
return result
class LinearRamp:
def __init__(self, start_value=0, end_value=1, start_iter=-1, end_iter=0):
self.start_value = start_value
self.end_value = end_value
self.start_iter = start_iter
self.end_iter = end_iter
def __call__(self, i):
if i < self.start_iter:
return self.start_value
if i >= self.end_iter:
return self.end_value
part = (i - self.start_iter) / (self.end_iter - self.start_iter)
return self.start_value * (1 - part) + self.end_value * part
class LadderRamp:
def __init__(self, start_iters, values):
self.start_iters = start_iters
self.values = values
assert len(values) == len(start_iters) + 1, (len(values), len(start_iters))
def __call__(self, i):
segment_i = bisect.bisect_right(self.start_iters, i)
return self.values[segment_i]
def get_ramp(kind='ladder', **kwargs):
if kind == 'linear':
return LinearRamp(**kwargs)
if kind == 'ladder':
return LadderRamp(**kwargs)
raise ValueError(f'Unexpected ramp kind: {kind}')
def print_traceback_handler(sig, frame):
LOGGER.warning(f'Received signal {sig}')
bt = ''.join(traceback.format_stack())
LOGGER.warning(f'Requested stack trace:\n{bt}')
def register_debug_signal_handlers(sig=signal.SIGUSR1, handler=print_traceback_handler):
LOGGER.warning(f'Setting signal {sig} handler {handler}')
signal.signal(sig, handler)
def handle_deterministic_config(config):
seed = dict(config).get('seed', None)
if seed is None:
return False
seed_everything(seed)
return True
def get_shape(t):
if torch.is_tensor(t):
return tuple(t.shape)
elif isinstance(t, dict):
return {n: get_shape(q) for n, q in t.items()}
elif isinstance(t, (list, tuple)):
return [get_shape(q) for q in t]
elif isinstance(t, numbers.Number):
return type(t)
else:
raise ValueError('unexpected type {}'.format(type(t)))
def get_has_ddp_rank():
master_port = os.environ.get('MASTER_PORT', None)
node_rank = os.environ.get('NODE_RANK', None)
local_rank = os.environ.get('LOCAL_RANK', None)
world_size = os.environ.get('WORLD_SIZE', None)
has_rank = master_port is not None or node_rank is not None or local_rank is not None or world_size is not None
return has_rank
def handle_ddp_subprocess():
def main_decorator(main_func):
@functools.wraps(main_func)
def new_main(*args, **kwargs):
# Trainer sets MASTER_PORT, NODE_RANK, LOCAL_RANK, WORLD_SIZE
parent_cwd = os.environ.get('TRAINING_PARENT_WORK_DIR', None)
has_parent = parent_cwd is not None
has_rank = get_has_ddp_rank()
assert has_parent == has_rank, f'Inconsistent state: has_parent={has_parent}, has_rank={has_rank}'
if has_parent:
# we are in the worker
sys.argv.extend([
f'hydra.run.dir={parent_cwd}',
# 'hydra/hydra_logging=disabled',
# 'hydra/job_logging=disabled'
])
# do nothing if this is a top-level process
# TRAINING_PARENT_WORK_DIR is set in handle_ddp_parent_process after hydra initialization
main_func(*args, **kwargs)
return new_main
return main_decorator
def handle_ddp_parent_process():
parent_cwd = os.environ.get('TRAINING_PARENT_WORK_DIR', None)
has_parent = parent_cwd is not None
has_rank = get_has_ddp_rank()
assert has_parent == has_rank, f'Inconsistent state: has_parent={has_parent}, has_rank={has_rank}'
if parent_cwd is None:
os.environ['TRAINING_PARENT_WORK_DIR'] = os.getcwd()
return has_parent