CarlosMalaga's picture
Upload 201 files
2f044c1 verified
raw
history blame
5.81 kB
from functools import partial
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
import hydra
import lightning as pl
import torch
from lightning.pytorch.trainer.states import RunningStage
from omegaconf import DictConfig
from torch.utils.data import DataLoader, Dataset
from relik.common.log import get_logger
from relik.retriever.data.base.datasets import BaseDataset
logger = get_logger(__name__)
STAGES_COMPATIBILITY_MAP = {
"train": RunningStage.TRAINING,
"val": RunningStage.VALIDATING,
"test": RunningStage.TESTING,
}
DEFAULT_STAGES = {
RunningStage.VALIDATING,
RunningStage.TESTING,
RunningStage.SANITY_CHECKING,
RunningStage.PREDICTING,
}
class PredictionCallback(pl.Callback):
def __init__(
self,
batch_size: int = 32,
stages: Optional[Set[Union[str, RunningStage]]] = None,
other_callbacks: Optional[
Union[List[DictConfig], List["NLPTemplateCallback"]]
] = None,
datasets: Optional[Union[DictConfig, BaseDataset]] = None,
dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None,
*args,
**kwargs,
):
super().__init__()
# parameters
self.batch_size = batch_size
self.datasets = datasets
self.dataloaders = dataloaders
# callback initialization
if stages is None:
stages = DEFAULT_STAGES
# compatibily stuff
stages = {STAGES_COMPATIBILITY_MAP.get(stage, stage) for stage in stages}
self.stages = [RunningStage(stage) for stage in stages]
self.other_callbacks = other_callbacks or []
for i, callback in enumerate(self.other_callbacks):
if isinstance(callback, DictConfig):
self.other_callbacks[i] = hydra.utils.instantiate(
callback, _recursive_=False
)
@torch.no_grad()
def __call__(
self,
trainer: pl.Trainer,
pl_module: pl.LightningModule,
*args,
**kwargs,
) -> Any:
# it should return the predictions
raise NotImplementedError
def on_validation_epoch_end(
self, trainer: pl.Trainer, pl_module: pl.LightningModule
):
predictions = self(trainer, pl_module)
for callback in self.other_callbacks:
callback(
trainer=trainer,
pl_module=pl_module,
callback=self,
predictions=predictions,
)
def on_test_epoch_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule):
predictions = self(trainer, pl_module)
for callback in self.other_callbacks:
callback(
trainer=trainer,
pl_module=pl_module,
callback=self,
predictions=predictions,
)
@staticmethod
def _get_datasets_and_dataloaders(
dataset: Optional[Union[Dataset, DictConfig]],
dataloader: Optional[DataLoader],
trainer: pl.Trainer,
dataloader_kwargs: Optional[Dict[str, Any]] = None,
collate_fn: Optional[Callable] = None,
collate_fn_kwargs: Optional[Dict[str, Any]] = None,
) -> Tuple[List[Dataset], List[DataLoader]]:
"""
Get the datasets and dataloaders from the datamodule or from the dataset provided.
Args:
dataset (`Optional[Union[Dataset, DictConfig]]`):
The dataset to use. If `None`, the datamodule is used.
dataloader (`Optional[DataLoader]`):
The dataloader to use. If `None`, the datamodule is used.
trainer (`pl.Trainer`):
The trainer that contains the datamodule.
dataloader_kwargs (`Optional[Dict[str, Any]]`):
The kwargs to pass to the dataloader.
collate_fn (`Optional[Callable]`):
The collate function to use.
collate_fn_kwargs (`Optional[Dict[str, Any]]`):
The kwargs to pass to the collate function.
Returns:
`Tuple[List[Dataset], List[DataLoader]]`: The datasets and dataloaders.
"""
# if a dataset is provided, use it
if dataset is not None:
dataloader_kwargs = dataloader_kwargs or {}
# get dataset
if isinstance(dataset, DictConfig):
dataset = hydra.utils.instantiate(dataset, _recursive_=False)
datasets = [dataset] if not isinstance(dataset, list) else dataset
if dataloader is not None:
dataloaders = (
[dataloader] if isinstance(dataloader, DataLoader) else dataloader
)
else:
collate_fn = collate_fn or partial(
datasets[0].collate_fn, **collate_fn_kwargs
)
dataloader_kwargs["collate_fn"] = collate_fn
dataloaders = [DataLoader(datasets[0], **dataloader_kwargs)]
else:
# get the dataloaders and datasets from the datamodule
datasets = (
trainer.datamodule.test_datasets
if trainer.state.stage == RunningStage.TESTING
else trainer.datamodule.val_datasets
)
dataloaders = (
trainer.test_dataloaders
if trainer.state.stage == RunningStage.TESTING
else trainer.val_dataloaders
)
return datasets, dataloaders
class NLPTemplateCallback:
def __call__(
self,
trainer: pl.Trainer,
pl_module: pl.LightningModule,
callback: PredictionCallback,
predictions: Dict[str, Any],
*args,
**kwargs,
) -> Any:
raise NotImplementedError