import os from pathlib import Path from typing import Any, Dict, Iterator, List, Optional, Tuple, Union import torch from torch.utils.data import Dataset, IterableDataset from relik.common.log import get_logger logger = get_logger(__name__) class BaseDataset(Dataset): def __init__( self, name: str, path: Optional[Union[str, os.PathLike, List[str], List[os.PathLike]]] = None, data: Any = None, **kwargs, ): super().__init__() self.name = name if path is None and data is None: raise ValueError("Either `path` or `data` must be provided") self.path = path self.project_folder = Path(__file__).parent.parent.parent self.data = data def __len__(self) -> int: return len(self.data) def __getitem__( self, index ) -> Union[Dict[str, torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]: return self.data[index] def __repr__(self) -> str: return f"Dataset({self.name=}, {self.path=})" def load( self, paths: Union[str, os.PathLike, List[str], List[os.PathLike]], *args, **kwargs, ) -> Any: # load data from single or multiple paths in one single dataset raise NotImplementedError @staticmethod def collate_fn(batch: Any, *args, **kwargs) -> Any: raise NotImplementedError class IterableBaseDataset(IterableDataset): def __init__( self, name: str, path: Optional[Union[str, Path, List[str], List[Path]]] = None, data: Any = None, *args, **kwargs, ): super().__init__() self.name = name if path is None and data is None: raise ValueError("Either `path` or `data` must be provided") self.path = path self.project_folder = Path(__file__).parent.parent.parent self.data = data def __iter__(self) -> Iterator[Dict[str, torch.Tensor]]: for sample in self.data: yield sample def __repr__(self) -> str: return f"Dataset({self.name=}, {self.path=})" def load( self, paths: Union[str, os.PathLike, List[str], List[os.PathLike]], *args, **kwargs, ) -> Any: # load data from single or multiple paths in one single dataset raise NotImplementedError @staticmethod def collate_fn(batch: Any, *args, **kwargs) -> Any: raise NotImplementedError