|
from typing import Any, List, Optional, Sequence, Union |
|
|
|
import hydra |
|
import lightning as pl |
|
import torch |
|
from lightning.pytorch.utilities.types import EVAL_DATALOADERS |
|
from omegaconf import DictConfig |
|
from torch.utils.data import DataLoader |
|
|
|
from relik.common.log import get_logger |
|
from relik.retriever.data.datasets import GoldenRetrieverDataset |
|
|
|
logger = get_logger(__name__) |
|
|
|
|
|
class GoldenRetrieverPLDataModule(pl.LightningDataModule): |
|
def __init__( |
|
self, |
|
train_dataset: Optional[GoldenRetrieverDataset] = None, |
|
val_datasets: Optional[Sequence[GoldenRetrieverDataset]] = None, |
|
test_datasets: Optional[Sequence[GoldenRetrieverDataset]] = None, |
|
num_workers: Optional[Union[DictConfig, int]] = None, |
|
datasets: Optional[DictConfig] = None, |
|
*args, |
|
**kwargs, |
|
): |
|
super().__init__() |
|
self.datasets = datasets |
|
if num_workers is None: |
|
num_workers = 0 |
|
if isinstance(num_workers, int): |
|
num_workers = DictConfig( |
|
{"train": num_workers, "val": num_workers, "test": num_workers} |
|
) |
|
self.num_workers = num_workers |
|
|
|
self.train_dataset: Optional[GoldenRetrieverDataset] = train_dataset |
|
self.val_datasets: Optional[Sequence[GoldenRetrieverDataset]] = val_datasets |
|
self.test_datasets: Optional[Sequence[GoldenRetrieverDataset]] = test_datasets |
|
|
|
def prepare_data(self, *args, **kwargs): |
|
""" |
|
Method for preparing the data before the training. This method is called only once. |
|
It is used to download the data, tokenize the data, etc. |
|
""" |
|
pass |
|
|
|
def setup(self, stage: Optional[str] = None): |
|
if stage == "fit" or stage is None: |
|
|
|
|
|
|
|
if self.train_dataset is None: |
|
self.train_dataset = hydra.utils.instantiate(self.datasets.train) |
|
self.val_datasets = [ |
|
hydra.utils.instantiate(dataset_cfg) |
|
for dataset_cfg in self.datasets.val |
|
] |
|
if stage == "test": |
|
if self.test_datasets is None: |
|
self.test_datasets = [ |
|
hydra.utils.instantiate(dataset_cfg) |
|
for dataset_cfg in self.datasets.test |
|
] |
|
|
|
def train_dataloader(self, *args, **kwargs) -> DataLoader: |
|
torch_dataset = self.train_dataset.to_torch_dataset() |
|
return DataLoader( |
|
|
|
torch_dataset, |
|
shuffle=False, |
|
batch_size=None, |
|
num_workers=self.num_workers.train, |
|
pin_memory=True, |
|
collate_fn=lambda x: x, |
|
) |
|
|
|
def val_dataloader(self, *args, **kwargs) -> Union[DataLoader, List[DataLoader]]: |
|
dataloaders = [] |
|
for dataset in self.val_datasets: |
|
torch_dataset = dataset.to_torch_dataset() |
|
dataloaders.append( |
|
DataLoader( |
|
torch_dataset, |
|
shuffle=False, |
|
batch_size=None, |
|
num_workers=self.num_workers.val, |
|
pin_memory=True, |
|
collate_fn=lambda x: x, |
|
) |
|
) |
|
return dataloaders |
|
|
|
def test_dataloader(self, *args, **kwargs) -> Union[DataLoader, List[DataLoader]]: |
|
dataloaders = [] |
|
for dataset in self.test_datasets: |
|
torch_dataset = dataset.to_torch_dataset() |
|
dataloaders.append( |
|
DataLoader( |
|
torch_dataset, |
|
shuffle=False, |
|
batch_size=None, |
|
num_workers=self.num_workers.test, |
|
pin_memory=True, |
|
collate_fn=lambda x: x, |
|
) |
|
) |
|
return dataloaders |
|
|
|
def predict_dataloader(self) -> EVAL_DATALOADERS: |
|
raise NotImplementedError |
|
|
|
def transfer_batch_to_device( |
|
self, batch: Any, device: torch.device, dataloader_idx: int |
|
) -> Any: |
|
return super().transfer_batch_to_device(batch, device, dataloader_idx) |
|
|
|
def __repr__(self) -> str: |
|
return ( |
|
f"{self.__class__.__name__}(" f"{self.datasets=}, " f"{self.num_workers=}, " |
|
) |
|
|