|
|
|
|
|
from pathlib import Path
|
|
|
|
import pytorch_lightning as pl
|
|
import torch
|
|
from omegaconf import DictConfig, OmegaConf, open_dict
|
|
from torchmetrics import MeanMetric, MetricCollection
|
|
|
|
from . import logger
|
|
from .models import get_model
|
|
|
|
|
|
class AverageKeyMeter(MeanMetric):
|
|
def __init__(self, key, *args, **kwargs):
|
|
self.key = key
|
|
super().__init__(*args, **kwargs)
|
|
|
|
def update(self, dict):
|
|
value = dict[self.key]
|
|
value = value[torch.isfinite(value)]
|
|
return super().update(value)
|
|
|
|
|
|
class GenericModule(pl.LightningModule):
|
|
def __init__(self, cfg):
|
|
super().__init__()
|
|
name = cfg.model.get("name")
|
|
name = "map_perception_net" if name is None else name
|
|
self.model = get_model(name)(cfg.model)
|
|
self.cfg = cfg
|
|
self.save_hyperparameters(cfg)
|
|
self.metrics_val = MetricCollection(
|
|
self.model.metrics(), prefix="val/")
|
|
self.losses_val = None
|
|
|
|
def forward(self, batch):
|
|
return self.model(batch)
|
|
|
|
def training_step(self, batch):
|
|
pred = self(batch)
|
|
losses = self.model.loss(pred, batch)
|
|
self.log_dict(
|
|
{f"train/loss/{k}": v.mean() for k, v in losses.items()},
|
|
prog_bar=True,
|
|
rank_zero_only=True,
|
|
on_epoch=True,
|
|
sync_dist=True
|
|
)
|
|
return losses["total"].mean()
|
|
|
|
def validation_step(self, batch, batch_idx):
|
|
pred = self(batch)
|
|
losses = self.model.loss(pred, batch)
|
|
if self.losses_val is None:
|
|
self.losses_val = MetricCollection(
|
|
{k: AverageKeyMeter(k).to(self.device) for k in losses},
|
|
prefix="val/",
|
|
postfix="/loss",
|
|
)
|
|
self.metrics_val(pred, batch)
|
|
self.log_dict(self.metrics_val, on_epoch=True)
|
|
self.losses_val.update(losses)
|
|
self.log_dict(self.losses_val, on_epoch=True)
|
|
|
|
return pred
|
|
|
|
def test_step(self, batch, batch_idx):
|
|
pred = self(batch)
|
|
|
|
return pred
|
|
|
|
def validation_epoch_start(self, batch):
|
|
self.losses_val = None
|
|
|
|
def configure_optimizers(self):
|
|
optimizer = torch.optim.Adam(
|
|
self.parameters(), lr=self.cfg.training.lr)
|
|
ret = {"optimizer": optimizer}
|
|
cfg_scheduler = self.cfg.training.get("lr_scheduler")
|
|
if cfg_scheduler is not None:
|
|
scheduler_args = cfg_scheduler.get("args", {})
|
|
for key in scheduler_args:
|
|
if scheduler_args[key] == "$total_epochs":
|
|
scheduler_args[key] = int(self.trainer.max_epochs)
|
|
scheduler = getattr(torch.optim.lr_scheduler, cfg_scheduler.name)(
|
|
optimizer=optimizer, **scheduler_args
|
|
)
|
|
ret["lr_scheduler"] = {
|
|
"scheduler": scheduler,
|
|
"interval": "epoch",
|
|
"frequency": 1,
|
|
"monitor": "loss/total/val",
|
|
"strict": True,
|
|
"name": "learning_rate",
|
|
}
|
|
return ret
|
|
|
|
@classmethod
|
|
def load_from_checkpoint(
|
|
cls,
|
|
checkpoint_path,
|
|
map_location=None,
|
|
hparams_file=None,
|
|
strict=True,
|
|
cfg=None,
|
|
find_best=False,
|
|
):
|
|
assert hparams_file is None, "hparams are not supported."
|
|
|
|
checkpoint = torch.load(
|
|
checkpoint_path, map_location=map_location or (
|
|
lambda storage, loc: storage)
|
|
)
|
|
if find_best:
|
|
best_score, best_name = None, None
|
|
modes = {"min": torch.lt, "max": torch.gt}
|
|
for key, state in checkpoint["callbacks"].items():
|
|
if not key.startswith("ModelCheckpoint"):
|
|
continue
|
|
mode = eval(key.replace("ModelCheckpoint", ""))["mode"]
|
|
if best_score is None or modes[mode](
|
|
state["best_model_score"], best_score
|
|
):
|
|
best_score = state["best_model_score"]
|
|
best_name = Path(state["best_model_path"]).name
|
|
logger.info("Loading best checkpoint %s", best_name)
|
|
if best_name != checkpoint_path:
|
|
return cls.load_from_checkpoint(
|
|
Path(checkpoint_path).parent / best_name,
|
|
map_location,
|
|
hparams_file,
|
|
strict,
|
|
cfg,
|
|
find_best=False,
|
|
)
|
|
|
|
logger.info(
|
|
"Using checkpoint %s from epoch %d and step %d.",
|
|
checkpoint_path,
|
|
checkpoint["epoch"],
|
|
checkpoint["global_step"],
|
|
)
|
|
cfg_ckpt = checkpoint[cls.CHECKPOINT_HYPER_PARAMS_KEY]
|
|
if list(cfg_ckpt.keys()) == ["cfg"]:
|
|
cfg_ckpt = cfg_ckpt["cfg"]
|
|
cfg_ckpt = OmegaConf.create(cfg_ckpt)
|
|
|
|
if cfg is None:
|
|
cfg = {}
|
|
if not isinstance(cfg, DictConfig):
|
|
cfg = OmegaConf.create(cfg)
|
|
with open_dict(cfg_ckpt):
|
|
cfg = OmegaConf.merge(cfg_ckpt, cfg)
|
|
|
|
return pl.core.saving._load_state(cls, checkpoint, strict=strict, cfg=cfg)
|
|
|