Plonk / train_von_fisher.py
nicolas-dufour's picture
squash: merge all unpushed commits
c4c7cee
raw
history blame
4.85 kB
import os
import hydra
import wandb
from os.path import isfile, join
from shutil import copyfile
import torch
from omegaconf import OmegaConf
from hydra.core.hydra_config import HydraConfig
from hydra.utils import instantiate
from pytorch_lightning.callbacks import LearningRateMonitor
from lightning_fabric.utilities.rank_zero import _get_rank
from callbacks import EMACallback, FixNANinGrad, IncreaseDataEpoch
from models.module import VonFisherGeolocalizer
torch.set_float32_matmul_precision("high") # TODO do we need that?
# Registering the "eval" resolver allows for advanced config
# interpolation with arithmetic operations in hydra:
# https://omegaconf.readthedocs.io/en/2.3_branch/how_to_guides.html
OmegaConf.register_new_resolver("eval", eval)
def wandb_init(cfg):
directory = cfg.checkpoints.dirpath
if isfile(join(directory, "wandb_id.txt")):
with open(join(directory, "wandb_id.txt"), "r") as f:
wandb_id = f.readline()
else:
rank = _get_rank()
wandb_id = wandb.util.generate_id()
print(f"Generated wandb id: {wandb_id}")
if rank == 0 or rank is None:
with open(join(directory, "wandb_id.txt"), "w") as f:
f.write(str(wandb_id))
return wandb_id
def load_model(cfg, dict_config, wandb_id, callbacks):
directory = cfg.checkpoints.dirpath
if isfile(join(directory, "last.ckpt")):
checkpoint_path = join(directory, "last.ckpt")
logger = instantiate(cfg.logger, id=wandb_id, resume="allow")
model = VonFisherGeolocalizer.load_from_checkpoint(
checkpoint_path, cfg=cfg.model
)
ckpt_path = join(directory, "last.ckpt")
print(f"Loading form checkpoint ... {ckpt_path}")
else:
ckpt_path = None
logger = instantiate(cfg.logger, id=wandb_id, resume="allow")
log_dict = {"model": dict_config["model"], "dataset": dict_config["dataset"]}
logger._wandb_init.update({"config": log_dict})
model = VonFisherGeolocalizer(cfg.model)
trainer, strategy = cfg.trainer, cfg.trainer.strategy
# from pytorch_lightning.profilers import PyTorchProfiler
trainer = instantiate(
trainer,
strategy=strategy,
logger=logger,
callbacks=callbacks,
# profiler=PyTorchProfiler(
# dirpath="logs",
# schedule=torch.profiler.schedule(wait=1, warmup=3, active=3, repeat=1),
# on_trace_ready=torch.profiler.tensorboard_trace_handler("./logs"),
# record_shapes=True,
# with_stack=True,
# with_flops=True,
# with_modules=True,
# ),
)
return trainer, model, ckpt_path
def project_init(cfg):
print("Working directory set to {}".format(os.getcwd()))
directory = cfg.checkpoints.dirpath
os.makedirs(directory, exist_ok=True)
copyfile(".hydra/config.yaml", join(directory, "config.yaml"))
def callback_init(cfg):
checkpoint_callback = instantiate(cfg.checkpoints)
progress_bar = instantiate(cfg.progress_bar)
lr_monitor = LearningRateMonitor()
ema_callback = EMACallback(
"network",
"ema_network",
decay=cfg.model.ema_decay,
start_ema_step=cfg.model.start_ema_step,
init_ema_random=False,
)
fix_nan_callback = FixNANinGrad(
monitor=["train/loss"],
)
increase_data_epoch_callback = IncreaseDataEpoch()
callbacks = [
checkpoint_callback,
progress_bar,
lr_monitor,
ema_callback,
fix_nan_callback,
increase_data_epoch_callback,
]
return callbacks
def init_datamodule(cfg):
datamodule = instantiate(cfg.datamodule)
return datamodule
def hydra_boilerplate(cfg):
dict_config = OmegaConf.to_container(cfg, resolve=True)
callbacks = callback_init(cfg)
datamodule = init_datamodule(cfg)
project_init(cfg)
wandb_id = wandb_init(cfg)
trainer, model, ckpt_path = load_model(cfg, dict_config, wandb_id, callbacks)
return trainer, model, datamodule, ckpt_path
@hydra.main(config_path="configs", config_name="config", version_base=None)
def main(cfg):
if "stage" in cfg and cfg.stage == "debug":
import lovely_tensors as lt
lt.monkey_patch()
trainer, model, datamodule, ckpt_path = hydra_boilerplate(cfg)
model.datamodule = datamodule
# model = torch.compile(model)
if cfg.mode == "train":
trainer.fit(model, datamodule=datamodule, ckpt_path=ckpt_path)
elif cfg.mode == "eval":
trainer.test(model, datamodule=datamodule)
elif cfg.mode == "traineval":
cfg.mode = "train"
trainer.fit(model, datamodule=datamodule, ckpt_path=ckpt_path)
cfg.mode = "test"
trainer.test(model, datamodule=datamodule)
if __name__ == "__main__":
main()