import logging from argparse import ArgumentParser from datetime import datetime from logging import Logger from pathlib import Path from typing import Any, Mapping, Optional, Union import ignite.distributed as idist import torch import yaml from ignite.contrib.engines import common from ignite.engine import Engine from ignite.engine.events import Events from ignite.handlers import Checkpoint, DiskSaver, global_step_from_engine from ignite.handlers.early_stopping import EarlyStopping from ignite.handlers.terminate_on_nan import TerminateOnNan from ignite.handlers.time_limit import TimeLimit from ignite.utils import setup_logger def setup_parser(config_path="base_config.yaml"): with open(config_path, "r") as f: config = yaml.safe_load(f.read()) parser = ArgumentParser() parser.add_argument("--config", default=None, type=str) parser.add_argument("--backend", default=None, type=str) for k, v in config.items(): if isinstance(v, bool): parser.add_argument(f"--{k}", action="store_true") else: parser.add_argument(f"--{k}", default=v, type=type(v)) return parser def log_metrics(engine: Engine, tag: str) -> None: """Log `engine.state.metrics` with given `engine` and `tag`. Parameters ---------- engine instance of `Engine` which metrics to log. tag a string to add at the start of output. """ metrics_format = "{0} [{1}/{2}]: {3}".format( tag, engine.state.epoch, engine.state.iteration, engine.state.metrics ) epoch_size = engine.state.epoch_length local_iteration = engine.state.iteration - epoch_size * (engine.state.epoch - 1) metrics_format = f"{tag} Epoch {engine.state.epoch} - [{local_iteration} / {epoch_size}] : {engine.state.metrics}" engine.logger.info(metrics_format) def resume_from( to_load: Mapping, checkpoint_fp: Union[str, Path], logger: Logger, strict: bool = True, model_dir: Optional[str] = None, ) -> None: """Loads state dict from a checkpoint file to resume the training. Parameters ---------- to_load a dictionary with objects, e.g. {“model”: model, “optimizer”: optimizer, ...} checkpoint_fp path to the checkpoint file logger to log info about resuming from a checkpoint strict whether to strictly enforce that the keys in `state_dict` match the keys returned by this module’s `state_dict()` function. Default: True model_dir directory in which to save the object """ if isinstance(checkpoint_fp, str) and checkpoint_fp.startswith("https://"): checkpoint = torch.hub.load_state_dict_from_url( checkpoint_fp, model_dir=model_dir, map_location="cpu", check_hash=True, ) else: if isinstance(checkpoint_fp, str): checkpoint_fp = Path(checkpoint_fp) if not checkpoint_fp.exists(): raise FileNotFoundError(f"Given {str(checkpoint_fp)} does not exist.") checkpoint = torch.load(checkpoint_fp, map_location="cpu") Checkpoint.load_objects(to_load=to_load, checkpoint=checkpoint, strict=strict) logger.info("Successfully resumed from a checkpoint: %s", checkpoint_fp) def setup_output_dir(config: Any, rank: int) -> Path: """Create output folder.""" if rank == 0: now = datetime.now().strftime("%Y%m%d-%H%M%S") name = f"{now}-backend-{config.backend}-lr-{config.lr}" path = Path(config.output_dir, name) path.mkdir(parents=True, exist_ok=True) config.output_dir = path.as_posix() return Path(idist.broadcast(config.output_dir, src=0)) def setup_logging(config: Any) -> Logger: """Setup logger with `ignite.utils.setup_logger()`. Parameters ---------- config config object. config has to contain `verbose` and `output_dir` attribute. Returns ------- logger an instance of `Logger` """ green = "\033[32m" reset = "\033[0m" logger = setup_logger( name=f"{green}[ignite]{reset}", level=logging.DEBUG if config.debug else logging.INFO, format="%(name)s: %(message)s", filepath=config.output_dir / "training-info.log", ) return logger