Spaces:
Runtime error
Runtime error
import os, re | |
from omegaconf import OmegaConf | |
import logging | |
mainlogger = logging.getLogger('mainlogger') | |
import torch | |
from collections import OrderedDict | |
def init_workspace(name, logdir, model_config, lightning_config, rank=0): | |
workdir = os.path.join(logdir, name) | |
ckptdir = os.path.join(workdir, "checkpoints") | |
cfgdir = os.path.join(workdir, "configs") | |
loginfo = os.path.join(workdir, "loginfo") | |
# Create logdirs and save configs (all ranks will do to avoid missing directory error if rank:0 is slower) | |
os.makedirs(workdir, exist_ok=True) | |
os.makedirs(ckptdir, exist_ok=True) | |
os.makedirs(cfgdir, exist_ok=True) | |
os.makedirs(loginfo, exist_ok=True) | |
if rank == 0: | |
if "callbacks" in lightning_config and 'metrics_over_trainsteps_checkpoint' in lightning_config.callbacks: | |
os.makedirs(os.path.join(ckptdir, 'trainstep_checkpoints'), exist_ok=True) | |
OmegaConf.save(model_config, os.path.join(cfgdir, "model.yaml")) | |
OmegaConf.save(OmegaConf.create({"lightning": lightning_config}), os.path.join(cfgdir, "lightning.yaml")) | |
return workdir, ckptdir, cfgdir, loginfo | |
def check_config_attribute(config, name): | |
if name in config: | |
value = getattr(config, name) | |
return value | |
else: | |
return None | |
def get_trainer_callbacks(lightning_config, config, logdir, ckptdir, logger): | |
default_callbacks_cfg = { | |
"model_checkpoint": { | |
"target": "pytorch_lightning.callbacks.ModelCheckpoint", | |
"params": { | |
"dirpath": ckptdir, | |
"filename": "{epoch}", | |
"verbose": True, | |
"save_last": False, | |
} | |
}, | |
"batch_logger": { | |
"target": "callbacks.ImageLogger", | |
"params": { | |
"save_dir": logdir, | |
"batch_frequency": 1000, | |
"max_images": 4, | |
"clamp": True, | |
} | |
}, | |
"learning_rate_logger": { | |
"target": "pytorch_lightning.callbacks.LearningRateMonitor", | |
"params": { | |
"logging_interval": "step", | |
"log_momentum": False | |
} | |
}, | |
"cuda_callback": { | |
"target": "callbacks.CUDACallback" | |
}, | |
} | |
## optional setting for saving checkpoints | |
monitor_metric = check_config_attribute(config.model.params, "monitor") | |
if monitor_metric is not None: | |
mainlogger.info(f"Monitoring {monitor_metric} as checkpoint metric.") | |
default_callbacks_cfg["model_checkpoint"]["params"]["monitor"] = monitor_metric | |
default_callbacks_cfg["model_checkpoint"]["params"]["save_top_k"] = 3 | |
default_callbacks_cfg["model_checkpoint"]["params"]["mode"] = "min" | |
if 'metrics_over_trainsteps_checkpoint' in lightning_config.callbacks: | |
mainlogger.info('Caution: Saving checkpoints every n train steps without deleting. This might require some free space.') | |
default_metrics_over_trainsteps_ckpt_dict = { | |
'metrics_over_trainsteps_checkpoint': {"target": 'pytorch_lightning.callbacks.ModelCheckpoint', | |
'params': { | |
"dirpath": os.path.join(ckptdir, 'trainstep_checkpoints'), | |
"filename": "{epoch}-{step}", | |
"verbose": True, | |
'save_top_k': -1, | |
'every_n_train_steps': 10000, | |
'save_weights_only': True | |
} | |
} | |
} | |
default_callbacks_cfg.update(default_metrics_over_trainsteps_ckpt_dict) | |
if "callbacks" in lightning_config: | |
callbacks_cfg = lightning_config.callbacks | |
else: | |
callbacks_cfg = OmegaConf.create() | |
callbacks_cfg = OmegaConf.merge(default_callbacks_cfg, callbacks_cfg) | |
return callbacks_cfg | |
def get_trainer_logger(lightning_config, logdir, on_debug): | |
default_logger_cfgs = { | |
"tensorboard": { | |
"target": "pytorch_lightning.loggers.TensorBoardLogger", | |
"params": { | |
"save_dir": logdir, | |
"name": "tensorboard", | |
} | |
}, | |
"testtube": { | |
"target": "pytorch_lightning.loggers.CSVLogger", | |
"params": { | |
"name": "testtube", | |
"save_dir": logdir, | |
} | |
}, | |
} | |
os.makedirs(os.path.join(logdir, "tensorboard"), exist_ok=True) | |
default_logger_cfg = default_logger_cfgs["tensorboard"] | |
if "logger" in lightning_config: | |
logger_cfg = lightning_config.logger | |
else: | |
logger_cfg = OmegaConf.create() | |
logger_cfg = OmegaConf.merge(default_logger_cfg, logger_cfg) | |
return logger_cfg | |
def get_trainer_strategy(lightning_config): | |
default_strategy_dict = { | |
"target": "pytorch_lightning.strategies.DDPShardedStrategy" | |
} | |
if "strategy" in lightning_config: | |
strategy_cfg = lightning_config.strategy | |
return strategy_cfg | |
else: | |
strategy_cfg = OmegaConf.create() | |
strategy_cfg = OmegaConf.merge(default_strategy_dict, strategy_cfg) | |
return strategy_cfg | |
def load_checkpoints(model, model_cfg): | |
if check_config_attribute(model_cfg, "pretrained_checkpoint"): | |
pretrained_ckpt = model_cfg.pretrained_checkpoint | |
assert os.path.exists(pretrained_ckpt), "Error: Pre-trained checkpoint NOT found at:%s"%pretrained_ckpt | |
mainlogger.info(">>> Load weights from pretrained checkpoint") | |
pl_sd = torch.load(pretrained_ckpt, map_location="cpu") | |
try: | |
if 'state_dict' in pl_sd.keys(): | |
model.load_state_dict(pl_sd["state_dict"], strict=True) | |
mainlogger.info(">>> Loaded weights from pretrained checkpoint: %s"%pretrained_ckpt) | |
else: | |
# deepspeed | |
new_pl_sd = OrderedDict() | |
for key in pl_sd['module'].keys(): | |
new_pl_sd[key[16:]]=pl_sd['module'][key] | |
model.load_state_dict(new_pl_sd, strict=True) | |
except: | |
model.load_state_dict(pl_sd) | |
else: | |
mainlogger.info(">>> Start training from scratch") | |
return model | |
def set_logger(logfile, name='mainlogger'): | |
logger = logging.getLogger(name) | |
logger.setLevel(logging.INFO) | |
fh = logging.FileHandler(logfile, mode='w') | |
fh.setLevel(logging.INFO) | |
ch = logging.StreamHandler() | |
ch.setLevel(logging.DEBUG) | |
fh.setFormatter(logging.Formatter("%(asctime)s-%(levelname)s: %(message)s")) | |
ch.setFormatter(logging.Formatter("%(message)s")) | |
logger.addHandler(fh) | |
logger.addHandler(ch) | |
return logger |