|
import torch
|
|
import pytorch_lightning as pl
|
|
from pathlib import Path
|
|
from typing import Any
|
|
import torchvision
|
|
import wandb
|
|
|
|
|
|
class EvalSaveCallback(pl.Callback):
|
|
|
|
def __init__(self, save_dir: Path) -> None:
|
|
super().__init__()
|
|
self.save_dir = save_dir
|
|
|
|
def save(self, outputs, batch, batch_idx):
|
|
name = batch['name']
|
|
|
|
filename = self.save_dir / f"{batch_idx:06d}_{name[0]}.pt"
|
|
torch.save({
|
|
"fpv": batch['image'],
|
|
"seg_masks": batch['seg_masks'],
|
|
'name': name,
|
|
"output": outputs["output"],
|
|
"valid_bev": outputs["valid_bev"],
|
|
}, filename)
|
|
|
|
def on_test_batch_end(self, trainer: pl.Trainer,
|
|
pl_module: pl.LightningModule,
|
|
outputs: torch.Tensor | Any | None,
|
|
batch: Any,
|
|
batch_idx: int,
|
|
dataloader_idx: int = 0) -> None:
|
|
if not outputs:
|
|
return
|
|
|
|
self.save(outputs, batch, batch_idx)
|
|
|
|
def on_validation_batch_end(self, trainer: pl.Trainer,
|
|
pl_module: pl.LightningModule,
|
|
outputs: torch.Tensor | Any | None,
|
|
batch: Any,
|
|
batch_idx: int,
|
|
dataloader_idx: int = 0) -> None:
|
|
if not outputs:
|
|
|
|
return
|
|
|
|
self.save(outputs, batch, batch_idx)
|
|
|
|
|
|
class ImageLoggerCallback(pl.Callback):
|
|
def __init__(self, num_classes):
|
|
super().__init__()
|
|
self.num_classes = num_classes
|
|
|
|
def log_image(self, trainer, pl_module, outputs, batch, batch_idx, mode="train"):
|
|
fpv_rgb = batch["image"]
|
|
fpv_grid = torchvision.utils.make_grid(
|
|
fpv_rgb, nrow=8, normalize=False)
|
|
images = [
|
|
wandb.Image(fpv_grid, caption="fpv")
|
|
]
|
|
|
|
pred = outputs['output'].permute(0, 2, 3, 1)
|
|
pred[outputs["valid_bev"][..., :-1] == 0] = 0
|
|
pred = (pred > 0.5).float()
|
|
pred = pred.permute(0, 3, 1, 2)
|
|
|
|
for i in range(self.num_classes):
|
|
gt_class_i = batch['seg_masks'][..., i]
|
|
gt_class_i_grid = torchvision.utils.make_grid(
|
|
gt_class_i.unsqueeze(1), nrow=8, normalize=False, pad_value=0)
|
|
pred_class_i = pred[:, i]
|
|
pred_class_i_grid = torchvision.utils.make_grid(
|
|
pred_class_i.unsqueeze(1), nrow=8, normalize=False, pad_value=0)
|
|
|
|
images += [
|
|
wandb.Image(gt_class_i_grid, caption=f"gt_class_{i}"),
|
|
wandb.Image(pred_class_i_grid, caption=f"pred_class_{i}")
|
|
]
|
|
|
|
trainer.logger.experiment.log(
|
|
{
|
|
"{}/images".format(mode): images
|
|
}
|
|
)
|
|
|
|
def on_validation_batch_end(self, trainer, pl_module: pl.LightningModule, outputs, batch, batch_idx):
|
|
if batch_idx == 0:
|
|
with torch.no_grad():
|
|
outputs = pl_module(batch)
|
|
self.log_image(trainer, pl_module, outputs,
|
|
batch, batch_idx, mode="val")
|
|
|
|
def on_train_batch_end(self, trainer, pl_module: pl.LightningModule, outputs, batch, batch_idx):
|
|
if batch_idx == 0:
|
|
pl_module.eval()
|
|
|
|
with torch.no_grad():
|
|
outputs = pl_module(batch)
|
|
|
|
self.log_image(trainer, pl_module, outputs,
|
|
batch, batch_idx, mode="train")
|
|
|
|
pl_module.train()
|
|
|