File size: 4,841 Bytes
c4c7cee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
import os
import hydra
import wandb
from os.path import isfile, join
from shutil import copyfile

import torch

from omegaconf import OmegaConf
from hydra.core.hydra_config import HydraConfig
from hydra.utils import instantiate
from pytorch_lightning.callbacks import LearningRateMonitor
from lightning_fabric.utilities.rank_zero import _get_rank
from callbacks import EMACallback, FixNANinGrad, IncreaseDataEpoch
from models.module import DiffGeolocalizer

torch.set_float32_matmul_precision("high")  # TODO do we need that?

# Registering the "eval" resolver allows for advanced config
# interpolation with arithmetic operations in hydra:
# https://omegaconf.readthedocs.io/en/2.3_branch/how_to_guides.html
OmegaConf.register_new_resolver("eval", eval)


def wandb_init(cfg):
    directory = cfg.checkpoints.dirpath
    if isfile(join(directory, "wandb_id.txt")) and cfg.logger_suffix == "":
        with open(join(directory, "wandb_id.txt"), "r") as f:
            wandb_id = f.readline()
    else:
        rank = _get_rank()
        wandb_id = wandb.util.generate_id()
        print(f"Generated wandb id: {wandb_id}")
        if rank == 0 or rank is None:
            with open(join(directory, "wandb_id.txt"), "w") as f:
                f.write(str(wandb_id))

    return wandb_id


def load_model(cfg, dict_config, wandb_id, callbacks):
    directory = cfg.checkpoints.dirpath
    if isfile(join(directory, "last.ckpt")):
        checkpoint_path = join(directory, "last.ckpt")
        logger = instantiate(cfg.logger, id=wandb_id, resume="allow")
        model = DiffGeolocalizer.load_from_checkpoint(checkpoint_path, cfg=cfg.model)
        ckpt_path = join(directory, "last.ckpt")
        print(f"Loading form checkpoint ... {ckpt_path}")
    else:
        ckpt_path = None
        logger = instantiate(cfg.logger, id=wandb_id, resume="allow")
        log_dict = {"model": dict_config["model"], "dataset": dict_config["dataset"]}
        logger._wandb_init.update({"config": log_dict})
        model = DiffGeolocalizer(cfg.model)

    trainer, strategy = cfg.trainer, cfg.trainer.strategy
    # from pytorch_lightning.profilers import PyTorchProfiler

    trainer = instantiate(
        trainer,
        strategy=strategy,
        logger=logger,
        callbacks=callbacks,
        # profiler=PyTorchProfiler(
        #     dirpath="logs",
        #     schedule=torch.profiler.schedule(wait=1, warmup=3, active=3, repeat=1),
        #     on_trace_ready=torch.profiler.tensorboard_trace_handler("./logs"),
        #     record_shapes=True,
        #     with_stack=True,
        #     with_flops=True,
        #     with_modules=True,
        # ),
    )
    return trainer, model, ckpt_path


def project_init(cfg):
    print("Working directory set to {}".format(os.getcwd()))
    directory = cfg.checkpoints.dirpath
    os.makedirs(directory, exist_ok=True)
    copyfile(".hydra/config.yaml", join(directory, "config.yaml"))


def callback_init(cfg):
    checkpoint_callback = instantiate(cfg.checkpoints)
    progress_bar = instantiate(cfg.progress_bar)
    lr_monitor = LearningRateMonitor()
    ema_callback = EMACallback(
        "network",
        "ema_network",
        decay=cfg.model.ema_decay,
        start_ema_step=cfg.model.start_ema_step,
        init_ema_random=False,
    )
    fix_nan_callback = FixNANinGrad(
        monitor=["train/loss"],
    )
    increase_data_epoch_callback = IncreaseDataEpoch()
    callbacks = [
        checkpoint_callback,
        progress_bar,
        lr_monitor,
        ema_callback,
        fix_nan_callback,
        increase_data_epoch_callback,
    ]
    return callbacks


def init_datamodule(cfg):
    datamodule = instantiate(cfg.datamodule)
    return datamodule


def hydra_boilerplate(cfg):
    dict_config = OmegaConf.to_container(cfg, resolve=True)
    callbacks = callback_init(cfg)
    datamodule = init_datamodule(cfg)
    project_init(cfg)
    wandb_id = wandb_init(cfg)
    trainer, model, ckpt_path = load_model(cfg, dict_config, wandb_id, callbacks)
    return trainer, model, datamodule, ckpt_path


@hydra.main(config_path="configs", config_name="config", version_base=None)
def main(cfg):
    if "stage" in cfg and cfg.stage == "debug":
        import lovely_tensors as lt

        lt.monkey_patch()
    trainer, model, datamodule, ckpt_path = hydra_boilerplate(cfg)
    model.datamodule = datamodule
    # model = torch.compile(model)
    if cfg.mode == "train":
        trainer.fit(model, datamodule=datamodule, ckpt_path=ckpt_path)
    elif cfg.mode == "eval":
        trainer.test(model, datamodule=datamodule)
    elif cfg.mode == "traineval":
        cfg.mode = "train"
        trainer.fit(model, datamodule=datamodule, ckpt_path=ckpt_path)
        cfg.mode = "test"
        trainer.test(model, datamodule=datamodule)


if __name__ == "__main__":
    main()