# -*- coding: utf-8 -*- r""" Lightning Trainer Setup ============== Setup logic for the lightning trainer. """ import os from argparse import Namespace from datetime import datetime from typing import Union import click import pandas as pd import pytorch_lightning as pl from polos.models.utils import apply_to_sample from pytorch_lightning.callbacks import ( Callback, EarlyStopping, ModelCheckpoint, ) from pytorch_lightning.loggers import LightningLoggerBase, WandbLogger, TensorBoardLogger from pytorch_lightning.utilities import rank_zero_only class TrainerConfig: """ The TrainerConfig class is used to define default hyper-parameters that are used to initialize our Lightning Trainer. These parameters are then overwritted with the values defined in the YAML file. -------------------- General Parameters ------------------------- :param seed: Training seed. :param deterministic: If true enables cudnn.deterministic. Might make your system slower, but ensures reproducibility. :param model: Model class we want to train. :param verbode: verbosity mode. :param overfit_batches: Uses this much data of the training set. If nonzero, will use the same training set for validation and testing. If the training dataloaders have shuffle=True, Lightning will automatically disable it. :param lr_finder: Runs a small portion of the training where the learning rate is increased after each processed batch and the corresponding loss is logged. The result of this is a lr vs. loss plot that can be used as guidance for choosing a optimal initial lr. -------------------- Model Checkpoint & Early Stopping ------------------------- :param early_stopping: If true enables EarlyStopping. :param save_top_k: If save_top_k == k, the best k models according to the metric monitored will be saved. :param monitor: Metric to be monitored. :param save_weights_only: Saves only the weights of the model. :param period: Interval (number of epochs) between checkpoints. :param metric_mode: One of {min, max}. In min mode, training will stop when the metric monitored has stopped decreasing; in max mode it will stop when the metric monitored has stopped increasing. :param min_delta: Minimum change in the monitored metric to qualify as an improvement. :param patience: Number of epochs with no improvement after which training will be stopped. """ seed: int = 3 deterministic: bool = True model: str = None verbose: bool = False overfit_batches: Union[int, float] = 0.0 # Model Checkpoint & Early Stopping early_stopping: bool = True save_top_k: int = 1 monitor: str = "kendall" save_weights_only: bool = False metric_mode: str = "max" min_delta: float = 0.0 patience: int = 1 accumulate_grad_batches: int = 1 lr_finder: bool = False def __init__(self, initial_data: dict) -> None: trainer_attr = pl.Trainer.default_attributes() for key in trainer_attr: setattr(self, key, trainer_attr[key]) for key in initial_data: if hasattr(self, key): setattr(self, key, initial_data[key]) def namespace(self) -> Namespace: return Namespace( **{ name: getattr(self, name) for name in dir(self) if not callable(getattr(self, name)) and not name.startswith("__") } ) class TrainReport(Callback): """ Logger Callback that echos results during training. """ _stack: list = [] # stack to keep metrics from all epochs @rank_zero_only def on_validation_end( self, trainer: pl.Trainer, pl_module: pl.LightningModule ) -> None: metrics = trainer.callback_metrics metrics = LightningLoggerBase._flatten_dict(metrics, "_") metrics = apply_to_sample(lambda x: x.item(), metrics) self._stack.append(metrics) # pl_module.print() # Print newline @rank_zero_only def on_fit_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None: click.secho("\nTraining Report Experiment:", fg="yellow") index_column = ["Epoch " + str(i) for i in range(len(self._stack) - 1)] df = pd.DataFrame(self._stack[1:], index=index_column) # Clean dataframe columns del df["train_loss_step"] del df["gpu_id: 0/memory.used (MB)"] del df["train_loss_epoch"] del df["train_avg_loss"] click.secho("{}".format(df), fg="yellow") def build_trainer(hparams: Namespace, resume_from_checkpoint) -> pl.Trainer: """ :param hparams: Namespace :returns: Lightning Trainer (obj) """ # Early Stopping Callback early_stop_callback = EarlyStopping( monitor=hparams.monitor, min_delta=hparams.min_delta, patience=hparams.patience, verbose=hparams.verbose, mode=hparams.metric_mode, ) # TestTube Logger Callback wandb_logger = WandbLogger(name="polos", project="polos_cvpr", save_dir="experiments/", version="version_" + datetime.now().strftime("%d-%m-%Y--%H-%M-%S")) tb_logger = TensorBoardLogger( save_dir="experiments/", version="version_" + datetime.now().strftime("%d-%m-%Y--%H-%M-%S"), name="lightning", ) # Model Checkpoint Callback ckpt_path = os.path.join("experiments/lightning/", wandb_logger.version) checkpoint_callback = ModelCheckpoint( dirpath=ckpt_path, save_top_k=hparams.save_top_k, verbose=hparams.verbose, monitor=hparams.monitor, save_weights_only=hparams.save_weights_only, period=1, mode=hparams.metric_mode, ) other_callbacks = [early_stop_callback, checkpoint_callback, TrainReport()] trainer = pl.Trainer( logger=[wandb_logger,tb_logger], callbacks=other_callbacks, gradient_clip_val=hparams.gradient_clip_val, gpus=hparams.gpus, log_gpu_memory="all", deterministic=hparams.deterministic, overfit_batches=hparams.overfit_batches, check_val_every_n_epoch=1, fast_dev_run=False, accumulate_grad_batches=hparams.accumulate_grad_batches, max_epochs=hparams.max_epochs, min_epochs=hparams.min_epochs, limit_train_batches=hparams.limit_train_batches, limit_val_batches=hparams.limit_val_batches, val_check_interval=hparams.val_check_interval, distributed_backend=hparams.distributed_backend, precision=hparams.precision, weights_summary="top", profiler=hparams.profiler, resume_from_checkpoint=resume_from_checkpoint, ) return trainer