Plonk / evaluation.py
nicolas-dufour's picture
squash: merge all unpushed commits
c4c7cee
raw
history blame
1.96 kB
import os
from models.module import DiffGeolocalizer
import hydra
from os.path import join
import torch
from omegaconf import OmegaConf
from omegaconf import open_dict
from hydra.utils import instantiate
from models.eval_best_model import EvalModule
torch.set_float32_matmul_precision("high")
# Registering the "eval" resolver allows for advanced config
# interpolation with arithmetic operations in hydra:
# https://omegaconf.readthedocs.io/en/2.3_branch/how_to_guides.html
OmegaConf.register_new_resolver("eval", eval)
def load_model(cfg, dict_config, wandb_id):
logger = instantiate(cfg.logger, id=wandb_id, resume="allow")
log_dict = {"model": dict_config["model"], "dataset": dict_config["dataset"]}
logger._wandb_init.update({"config": log_dict})
model = EvalModule(cfg.model)
trainer = instantiate(
cfg.trainer, strategy=cfg.trainer.strategy
) # , logger=logger)
return trainer, model
def hydra_boilerplate(cfg):
dict_config = OmegaConf.to_container(cfg, resolve=True)
trainer, model = load_model(cfg, dict_config, cfg.wandb_id)
return trainer, model
import copy
def init_datamodule(cfg):
datamodule = instantiate(cfg.datamodule)
return datamodule
if __name__ == "__main__":
import sys
sys.argv = (
[sys.argv[0]]
+ ["+pt_model_path=${hydra:runtime.config_sources}"]
+ sys.argv[1:]
)
@hydra.main(config_path="configs", config_name="config", version_base=None)
def main(cfg):
# print(hydra.runtime.config_sources)
with open_dict(cfg):
path = cfg.pt_model_path[1]["path"]
cfg.wandb_id = join(path, "wandb_id.txt")
cfg.checkpoint = join(path, "last.ckpt")
cfg.computer.devices = 1
(
trainer,
model,
) = hydra_boilerplate(cfg)
datamodule = init_datamodule(cfg)
trainer.test(model, datamodule=datamodule)
main()