CarlosMalaga's picture
Upload 201 files
2f044c1 verified
raw
history blame
4.46 kB
from typing import Any, List, Optional, Sequence, Union
import hydra
import lightning as pl
import torch
from lightning.pytorch.utilities.types import EVAL_DATALOADERS
from omegaconf import DictConfig
from torch.utils.data import DataLoader
from relik.common.log import get_logger
from relik.retriever.data.datasets import GoldenRetrieverDataset
logger = get_logger(__name__)
class GoldenRetrieverPLDataModule(pl.LightningDataModule):
def __init__(
self,
train_dataset: Optional[GoldenRetrieverDataset] = None,
val_datasets: Optional[Sequence[GoldenRetrieverDataset]] = None,
test_datasets: Optional[Sequence[GoldenRetrieverDataset]] = None,
num_workers: Optional[Union[DictConfig, int]] = None,
datasets: Optional[DictConfig] = None,
*args,
**kwargs,
):
super().__init__()
self.datasets = datasets
if num_workers is None:
num_workers = 0
if isinstance(num_workers, int):
num_workers = DictConfig(
{"train": num_workers, "val": num_workers, "test": num_workers}
)
self.num_workers = num_workers
# data
self.train_dataset: Optional[GoldenRetrieverDataset] = train_dataset
self.val_datasets: Optional[Sequence[GoldenRetrieverDataset]] = val_datasets
self.test_datasets: Optional[Sequence[GoldenRetrieverDataset]] = test_datasets
def prepare_data(self, *args, **kwargs):
"""
Method for preparing the data before the training. This method is called only once.
It is used to download the data, tokenize the data, etc.
"""
pass
def setup(self, stage: Optional[str] = None):
if stage == "fit" or stage is None:
# usually there is only one dataset for train
# if you need more train loader, you can follow
# the same logic as val and test datasets
if self.train_dataset is None:
self.train_dataset = hydra.utils.instantiate(self.datasets.train)
self.val_datasets = [
hydra.utils.instantiate(dataset_cfg)
for dataset_cfg in self.datasets.val
]
if stage == "test":
if self.test_datasets is None:
self.test_datasets = [
hydra.utils.instantiate(dataset_cfg)
for dataset_cfg in self.datasets.test
]
def train_dataloader(self, *args, **kwargs) -> DataLoader:
torch_dataset = self.train_dataset.to_torch_dataset()
return DataLoader(
# self.train_dataset.to_torch_dataset(),
torch_dataset,
shuffle=False,
batch_size=None,
num_workers=self.num_workers.train,
pin_memory=True,
collate_fn=lambda x: x,
)
def val_dataloader(self, *args, **kwargs) -> Union[DataLoader, List[DataLoader]]:
dataloaders = []
for dataset in self.val_datasets:
torch_dataset = dataset.to_torch_dataset()
dataloaders.append(
DataLoader(
torch_dataset,
shuffle=False,
batch_size=None,
num_workers=self.num_workers.val,
pin_memory=True,
collate_fn=lambda x: x,
)
)
return dataloaders
def test_dataloader(self, *args, **kwargs) -> Union[DataLoader, List[DataLoader]]:
dataloaders = []
for dataset in self.test_datasets:
torch_dataset = dataset.to_torch_dataset()
dataloaders.append(
DataLoader(
torch_dataset,
shuffle=False,
batch_size=None,
num_workers=self.num_workers.test,
pin_memory=True,
collate_fn=lambda x: x,
)
)
return dataloaders
def predict_dataloader(self) -> EVAL_DATALOADERS:
raise NotImplementedError
def transfer_batch_to_device(
self, batch: Any, device: torch.device, dataloader_idx: int
) -> Any:
return super().transfer_batch_to_device(batch, device, dataloader_idx)
def __repr__(self) -> str:
return (
f"{self.__class__.__name__}(" f"{self.datasets=}, " f"{self.num_workers=}, "
)