File size: 6,903 Bytes
03f6091
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
# -*- coding: utf-8 -*-
r"""
Lightning Trainer Setup
==============
   Setup logic for the lightning trainer.
"""
import os
from argparse import Namespace
from datetime import datetime
from typing import Union

import click
import pandas as pd

import pytorch_lightning as pl
from polos.models.utils import apply_to_sample
from pytorch_lightning.callbacks import (
    Callback,
    EarlyStopping,
    ModelCheckpoint,
)
from pytorch_lightning.loggers import LightningLoggerBase, WandbLogger, TensorBoardLogger
from pytorch_lightning.utilities import rank_zero_only


class TrainerConfig:
    """
    The TrainerConfig class is used to define default hyper-parameters that
    are used to initialize our Lightning Trainer. These parameters are then overwritted
    with the values defined in the YAML file.

    -------------------- General Parameters -------------------------

    :param seed: Training seed.

    :param deterministic: If true enables cudnn.deterministic. Might make your system
        slower, but ensures reproducibility.

    :param model: Model class we want to train.

    :param verbode: verbosity mode.

    :param overfit_batches: Uses this much data of the training set. If nonzero, will use
        the same training set for validation and testing. If the training dataloaders
        have shuffle=True, Lightning will automatically disable it.

    :param lr_finder: Runs a small portion of the training where the learning rate is increased
        after each processed batch and the corresponding loss is logged. The result of this is
        a lr vs. loss plot that can be used as guidance for choosing a optimal initial lr.

    -------------------- Model Checkpoint & Early Stopping -------------------------

    :param early_stopping: If true enables EarlyStopping.

    :param save_top_k: If save_top_k == k, the best k models according to the metric
        monitored will be saved.

    :param monitor: Metric to be monitored.

    :param save_weights_only: Saves only the weights of the model.

    :param period: Interval (number of epochs) between checkpoints.

    :param metric_mode: One of {min, max}. In min mode, training will stop when the
        metric monitored has stopped decreasing; in max mode it will stop when the
        metric monitored has stopped increasing.

    :param min_delta: Minimum change in the monitored metric to qualify as an improvement.

    :param patience: Number of epochs with no improvement after which training will be stopped.
    """

    seed: int = 3
    deterministic: bool = True
    model: str = None
    verbose: bool = False
    overfit_batches: Union[int, float] = 0.0

    # Model Checkpoint & Early Stopping
    early_stopping: bool = True
    save_top_k: int = 1
    monitor: str = "kendall"
    save_weights_only: bool = False
    metric_mode: str = "max"
    min_delta: float = 0.0
    patience: int = 1
    accumulate_grad_batches: int = 1
    lr_finder: bool = False

    def __init__(self, initial_data: dict) -> None:
        trainer_attr = pl.Trainer.default_attributes()
        for key in trainer_attr:
            setattr(self, key, trainer_attr[key])

        for key in initial_data:
            if hasattr(self, key):
                setattr(self, key, initial_data[key])

    def namespace(self) -> Namespace:
        return Namespace(
            **{
                name: getattr(self, name)
                for name in dir(self)
                if not callable(getattr(self, name)) and not name.startswith("__")
            }
        )


class TrainReport(Callback):
    """ Logger Callback that echos results during training. """

    _stack: list = []  # stack to keep metrics from all epochs

    @rank_zero_only
    def on_validation_end(
        self, trainer: pl.Trainer, pl_module: pl.LightningModule
    ) -> None:
        metrics = trainer.callback_metrics
        metrics = LightningLoggerBase._flatten_dict(metrics, "_")
        metrics = apply_to_sample(lambda x: x.item(), metrics)
        self._stack.append(metrics)
        # pl_module.print() # Print newline

    @rank_zero_only
    def on_fit_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None:
        click.secho("\nTraining Report Experiment:", fg="yellow")
        index_column = ["Epoch " + str(i) for i in range(len(self._stack) - 1)]
        df = pd.DataFrame(self._stack[1:], index=index_column)
        # Clean dataframe columns
        del df["train_loss_step"]
        del df["gpu_id: 0/memory.used (MB)"]
        del df["train_loss_epoch"]
        del df["train_avg_loss"]
        click.secho("{}".format(df), fg="yellow")


def build_trainer(hparams: Namespace, resume_from_checkpoint) -> pl.Trainer:
    """
    :param hparams: Namespace

    :returns: Lightning Trainer (obj)
    """
    # Early Stopping Callback
    early_stop_callback = EarlyStopping(
        monitor=hparams.monitor,
        min_delta=hparams.min_delta,
        patience=hparams.patience,
        verbose=hparams.verbose,
        mode=hparams.metric_mode,
    )

    # TestTube Logger Callback
    wandb_logger = WandbLogger(name="polos",
                               project="polos_cvpr",
                               save_dir="experiments/",
                               version="version_" + datetime.now().strftime("%d-%m-%Y--%H-%M-%S"))

    tb_logger = TensorBoardLogger(
        save_dir="experiments/",
        version="version_" + datetime.now().strftime("%d-%m-%Y--%H-%M-%S"),
        name="lightning",
    )

    # Model Checkpoint Callback
    ckpt_path = os.path.join("experiments/lightning/", wandb_logger.version)

    checkpoint_callback = ModelCheckpoint(
        dirpath=ckpt_path,
        save_top_k=hparams.save_top_k,
        verbose=hparams.verbose,
        monitor=hparams.monitor,
        save_weights_only=hparams.save_weights_only,
        period=1,
        mode=hparams.metric_mode,
    )
    other_callbacks = [early_stop_callback, checkpoint_callback, TrainReport()]

    trainer = pl.Trainer(
        logger=[wandb_logger,tb_logger],
        callbacks=other_callbacks,
        gradient_clip_val=hparams.gradient_clip_val,
        gpus=hparams.gpus,
        log_gpu_memory="all",
        deterministic=hparams.deterministic,
        overfit_batches=hparams.overfit_batches,
        check_val_every_n_epoch=1,
        fast_dev_run=False,
        accumulate_grad_batches=hparams.accumulate_grad_batches,
        max_epochs=hparams.max_epochs,
        min_epochs=hparams.min_epochs,
        limit_train_batches=hparams.limit_train_batches,
        limit_val_batches=hparams.limit_val_batches,
        val_check_interval=hparams.val_check_interval,
        distributed_backend=hparams.distributed_backend,
        precision=hparams.precision,
        weights_summary="top",
        profiler=hparams.profiler,
        resume_from_checkpoint=resume_from_checkpoint,
    )
    return trainer