import abc import numpy as np import pytorch_lightning as pl from pytorch_lightning.utilities.types import ( EVAL_DATALOADERS, TRAIN_DATALOADERS, ) from torch import nn from torch.utils.data import DataLoader import torch from torchvision import transforms from src.dataset import DATASET_REGISTRY class AbstractModel(pl.LightningModule): def __init__(self, cfg): super().__init__() self.cfg = cfg self.train_dataset = None self.val_dataset = None self.metric_evaluator = None self.init_model() def setup(self, stage): if stage in ["fit", "validate", "test"]: self.train_dataset = DATASET_REGISTRY.get("BlenderDataset")( **self.cfg["dataset"]["train"]["params"], ) self.val_dataset = DATASET_REGISTRY.get("BlenderDataset")( **self.cfg["dataset"]["val"]["params"], ) # self.metric_evaluator = SHRECMetricEvaluator( # embed_dim=self.cfg["model"]["embed_dim"] # ) @abc.abstractmethod def init_model(self): """ Function to initialize model """ raise NotImplementedError @abc.abstractmethod def forward(self, batch): raise NotImplementedError @abc.abstractmethod def compute_loss(self, forwarded_batch, input_batch): """ Function to compute loss Args: forwarded_batch: output of `forward` method input_batch: input of batch method Returns: loss: computed loss """ raise NotImplementedError def training_step(self, batch, batch_idx): # 1. get embeddings from model forwarded_batch = self.forward(batch) # 2. Calculate loss loss = self.compute_loss(forwarded_batch=forwarded_batch, input_batch=batch) # 3. Update monitor self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True) return {"loss": loss} def validation_step(self, batch, batch_idx): # 1. Get embeddings from model forwarded_batch = self.forward(batch) # 2. Calculate loss loss = self.compute_loss(forwarded_batch=forwarded_batch, input_batch=batch) # 3. Update metric for each batch self.log("val_loss", loss, on_step=True, on_epoch=True, prog_bar=True) self.metric_evaluator.append( g_emb=forwarded_batch["pc_embedding_feats"].float().clone().detach(), q_emb=forwarded_batch["query_embedding_feats"].float().clone().detach(), query_ids=batch["query_ids"], gallery_ids=batch["point_cloud_ids"], target_ids=batch["point_cloud_ids"], ) return {"loss": loss} def validation_epoch_end(self, outputs) -> None: """ Callback at validation epoch end to do additional works with output of validation step, note that this is called before `training_epoch_end()` Args: outputs: output of validation step """ self.log_dict( self.metric_evaluator.evaluate(), prog_bar=True, on_step=False, on_epoch=True, ) self.metric_evaluator.reset() def train_dataloader(self) -> TRAIN_DATALOADERS: train_loader = DataLoader( dataset=self.train_dataset, collate_fn=self.train_dataset.collate_fn, **self.cfg["data_loader"]["train"]["params"], ) return train_loader def val_dataloader(self) -> EVAL_DATALOADERS: val_loader = DataLoader( dataset=self.val_dataset, collate_fn=self.val_dataset.collate_fn, **self.cfg["data_loader"]["val"]["params"], ) return val_loader def configure_optimizers(self): pass