import time import warnings from importlib.util import find_spec from pathlib import Path from typing import Callable, List import hydra from omegaconf import DictConfig, OmegaConf from pytorch_lightning import Callback from pytorch_lightning.loggers import Logger from pytorch_lightning.utilities import rank_zero_only from . import pylogger, rich_utils log = pylogger.get_pylogger(__name__) def task_wrapper(task_func: Callable) -> Callable: """Optional decorator that wraps the task function in extra utilities. Makes multirun more resistant to failure. Utilities: - Calling the `utils.extras()` before the task is started - Calling the `utils.close_loggers()` after the task is finished - Logging the exception if occurs - Logging the task total execution time - Logging the output dir """ def wrap(cfg: DictConfig): # apply extra utilities extras(cfg) # execute the task try: start_time = time.time() ret = task_func(cfg=cfg) except Exception as ex: log.exception("") # save exception to `.log` file raise ex finally: path = Path(cfg.paths.output_dir, "exec_time.log") content = f"'{cfg.task_name}' execution time: {time.time() - start_time} (s)" save_file(path, content) # save task execution time (even if exception occurs) close_loggers() # close loggers (even if exception occurs so multirun won't fail) log.info(f"Output dir: {cfg.paths.output_dir}") return ret return wrap def extras(cfg: DictConfig) -> None: """Applies optional utilities before the task is started. Utilities: - Ignoring python warnings - Setting tags from command line - Rich config printing """ # return if no `extras` config if not cfg.get("extras"): log.warning("Extras config not found! ") return # disable python warnings if cfg.extras.get("ignore_warnings"): log.info("Disabling python warnings! ") warnings.filterwarnings("ignore") # prompt user to input tags from command line if none are provided in the config if cfg.extras.get("enforce_tags"): log.info("Enforcing tags! ") rich_utils.enforce_tags(cfg, save_to_file=True) # pretty print config tree using Rich library if cfg.extras.get("print_config"): log.info("Printing config tree with Rich! ") rich_utils.print_config_tree(cfg, resolve=True, save_to_file=True) @rank_zero_only def save_file(path: str, content: str) -> None: """Save file in rank zero mode (only on one process in multi-GPU setup).""" with open(path, "w+") as file: file.write(content) def instantiate_callbacks(callbacks_cfg: DictConfig) -> List[Callback]: """Instantiates callbacks from config.""" callbacks: List[Callback] = [] if not callbacks_cfg: log.warning("Callbacks config is empty.") return callbacks if not isinstance(callbacks_cfg, DictConfig): raise TypeError("Callbacks config must be a DictConfig!") for _, cb_conf in callbacks_cfg.items(): if isinstance(cb_conf, DictConfig) and "_target_" in cb_conf: log.info(f"Instantiating callback <{cb_conf._target_}>") callbacks.append(hydra.utils.instantiate(cb_conf)) return callbacks def instantiate_loggers(logger_cfg: DictConfig) -> List[Logger]: """Instantiates loggers from config.""" logger: List[Logger] = [] if not logger_cfg: log.warning("Logger config is empty.") return logger if not isinstance(logger_cfg, DictConfig): raise TypeError("Logger config must be a DictConfig!") for _, lg_conf in logger_cfg.items(): if isinstance(lg_conf, DictConfig) and "_target_" in lg_conf: log.info(f"Instantiating logger <{lg_conf._target_}>") logger.append(hydra.utils.instantiate(lg_conf)) return logger @rank_zero_only def log_hyperparameters(object_dict: dict) -> None: """Controls which config parts are saved by lightning loggers. Additionally saves: - Number of model parameters """ hparams = {} cfg = object_dict["cfg"] model = object_dict["model"] trainer = object_dict["trainer"] if not trainer.logger: log.warning("Logger not found! Skipping hyperparameter logging...") return # save number of model parameters hparams["model/params/total"] = sum(p.numel() for p in model.parameters()) hparams["model/params/trainable"] = sum( p.numel() for p in model.parameters() if p.requires_grad ) hparams["model/params/non_trainable"] = sum( p.numel() for p in model.parameters() if not p.requires_grad ) for k in cfg.keys(): hparams[k] = cfg.get(k) # Resolve all interpolations def _resolve(_cfg): if isinstance(_cfg, DictConfig): _cfg = OmegaConf.to_container(_cfg, resolve=True) return _cfg hparams = {k: _resolve(v) for k, v in hparams.items()} # send hparams to all loggers trainer.logger.log_hyperparams(hparams) def get_metric_value(metric_dict: dict, metric_name: str) -> float: """Safely retrieves value of the metric logged in LightningModule.""" if not metric_name: log.info("Metric name is None! Skipping metric value retrieval...") return None if metric_name not in metric_dict: raise Exception( f"Metric value not found! \n" "Make sure metric name logged in LightningModule is correct!\n" "Make sure `optimized_metric` name in `hparams_search` config is correct!" ) metric_value = metric_dict[metric_name].item() log.info(f"Retrieved metric value! <{metric_name}={metric_value}>") return metric_value def close_loggers() -> None: """Makes sure all loggers closed properly (prevents logging failure during multirun).""" log.info("Closing loggers...") if find_spec("wandb"): # if wandb is installed import wandb if wandb.run: log.info("Closing wandb!") wandb.finish()