Spaces:
Running
Running
import os | |
import hydra | |
import wandb | |
from os.path import isfile, join | |
from shutil import copyfile | |
import torch | |
from omegaconf import OmegaConf | |
from hydra.core.hydra_config import HydraConfig | |
from hydra.utils import instantiate | |
from pytorch_lightning.callbacks import LearningRateMonitor | |
from lightning_fabric.utilities.rank_zero import _get_rank | |
from callbacks import EMACallback, FixNANinGrad, IncreaseDataEpoch | |
from models.module import VonFisherGeolocalizer | |
torch.set_float32_matmul_precision("high") # TODO do we need that? | |
# Registering the "eval" resolver allows for advanced config | |
# interpolation with arithmetic operations in hydra: | |
# https://omegaconf.readthedocs.io/en/2.3_branch/how_to_guides.html | |
OmegaConf.register_new_resolver("eval", eval) | |
def wandb_init(cfg): | |
directory = cfg.checkpoints.dirpath | |
if isfile(join(directory, "wandb_id.txt")): | |
with open(join(directory, "wandb_id.txt"), "r") as f: | |
wandb_id = f.readline() | |
else: | |
rank = _get_rank() | |
wandb_id = wandb.util.generate_id() | |
print(f"Generated wandb id: {wandb_id}") | |
if rank == 0 or rank is None: | |
with open(join(directory, "wandb_id.txt"), "w") as f: | |
f.write(str(wandb_id)) | |
return wandb_id | |
def load_model(cfg, dict_config, wandb_id, callbacks): | |
directory = cfg.checkpoints.dirpath | |
if isfile(join(directory, "last.ckpt")): | |
checkpoint_path = join(directory, "last.ckpt") | |
logger = instantiate(cfg.logger, id=wandb_id, resume="allow") | |
model = VonFisherGeolocalizer.load_from_checkpoint( | |
checkpoint_path, cfg=cfg.model | |
) | |
ckpt_path = join(directory, "last.ckpt") | |
print(f"Loading form checkpoint ... {ckpt_path}") | |
else: | |
ckpt_path = None | |
logger = instantiate(cfg.logger, id=wandb_id, resume="allow") | |
log_dict = {"model": dict_config["model"], "dataset": dict_config["dataset"]} | |
logger._wandb_init.update({"config": log_dict}) | |
model = VonFisherGeolocalizer(cfg.model) | |
trainer, strategy = cfg.trainer, cfg.trainer.strategy | |
# from pytorch_lightning.profilers import PyTorchProfiler | |
trainer = instantiate( | |
trainer, | |
strategy=strategy, | |
logger=logger, | |
callbacks=callbacks, | |
# profiler=PyTorchProfiler( | |
# dirpath="logs", | |
# schedule=torch.profiler.schedule(wait=1, warmup=3, active=3, repeat=1), | |
# on_trace_ready=torch.profiler.tensorboard_trace_handler("./logs"), | |
# record_shapes=True, | |
# with_stack=True, | |
# with_flops=True, | |
# with_modules=True, | |
# ), | |
) | |
return trainer, model, ckpt_path | |
def project_init(cfg): | |
print("Working directory set to {}".format(os.getcwd())) | |
directory = cfg.checkpoints.dirpath | |
os.makedirs(directory, exist_ok=True) | |
copyfile(".hydra/config.yaml", join(directory, "config.yaml")) | |
def callback_init(cfg): | |
checkpoint_callback = instantiate(cfg.checkpoints) | |
progress_bar = instantiate(cfg.progress_bar) | |
lr_monitor = LearningRateMonitor() | |
ema_callback = EMACallback( | |
"network", | |
"ema_network", | |
decay=cfg.model.ema_decay, | |
start_ema_step=cfg.model.start_ema_step, | |
init_ema_random=False, | |
) | |
fix_nan_callback = FixNANinGrad( | |
monitor=["train/loss"], | |
) | |
increase_data_epoch_callback = IncreaseDataEpoch() | |
callbacks = [ | |
checkpoint_callback, | |
progress_bar, | |
lr_monitor, | |
ema_callback, | |
fix_nan_callback, | |
increase_data_epoch_callback, | |
] | |
return callbacks | |
def init_datamodule(cfg): | |
datamodule = instantiate(cfg.datamodule) | |
return datamodule | |
def hydra_boilerplate(cfg): | |
dict_config = OmegaConf.to_container(cfg, resolve=True) | |
callbacks = callback_init(cfg) | |
datamodule = init_datamodule(cfg) | |
project_init(cfg) | |
wandb_id = wandb_init(cfg) | |
trainer, model, ckpt_path = load_model(cfg, dict_config, wandb_id, callbacks) | |
return trainer, model, datamodule, ckpt_path | |
def main(cfg): | |
if "stage" in cfg and cfg.stage == "debug": | |
import lovely_tensors as lt | |
lt.monkey_patch() | |
trainer, model, datamodule, ckpt_path = hydra_boilerplate(cfg) | |
model.datamodule = datamodule | |
# model = torch.compile(model) | |
if cfg.mode == "train": | |
trainer.fit(model, datamodule=datamodule, ckpt_path=ckpt_path) | |
elif cfg.mode == "eval": | |
trainer.test(model, datamodule=datamodule) | |
elif cfg.mode == "traineval": | |
cfg.mode = "train" | |
trainer.fit(model, datamodule=datamodule, ckpt_path=ckpt_path) | |
cfg.mode = "test" | |
trainer.test(model, datamodule=datamodule) | |
if __name__ == "__main__": | |
main() | |