|
import os |
|
from pathlib import Path |
|
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union |
|
|
|
import torch |
|
from torch.utils.data import Dataset, IterableDataset |
|
|
|
from relik.common.log import get_logger |
|
|
|
logger = get_logger(__name__) |
|
|
|
|
|
class BaseDataset(Dataset): |
|
def __init__( |
|
self, |
|
name: str, |
|
path: Optional[Union[str, os.PathLike, List[str], List[os.PathLike]]] = None, |
|
data: Any = None, |
|
**kwargs, |
|
): |
|
super().__init__() |
|
self.name = name |
|
if path is None and data is None: |
|
raise ValueError("Either `path` or `data` must be provided") |
|
self.path = path |
|
self.project_folder = Path(__file__).parent.parent.parent |
|
self.data = data |
|
|
|
def __len__(self) -> int: |
|
return len(self.data) |
|
|
|
def __getitem__( |
|
self, index |
|
) -> Union[Dict[str, torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]: |
|
return self.data[index] |
|
|
|
def __repr__(self) -> str: |
|
return f"Dataset({self.name=}, {self.path=})" |
|
|
|
def load( |
|
self, |
|
paths: Union[str, os.PathLike, List[str], List[os.PathLike]], |
|
*args, |
|
**kwargs, |
|
) -> Any: |
|
|
|
raise NotImplementedError |
|
|
|
@staticmethod |
|
def collate_fn(batch: Any, *args, **kwargs) -> Any: |
|
raise NotImplementedError |
|
|
|
|
|
class IterableBaseDataset(IterableDataset): |
|
def __init__( |
|
self, |
|
name: str, |
|
path: Optional[Union[str, Path, List[str], List[Path]]] = None, |
|
data: Any = None, |
|
*args, |
|
**kwargs, |
|
): |
|
super().__init__() |
|
self.name = name |
|
if path is None and data is None: |
|
raise ValueError("Either `path` or `data` must be provided") |
|
self.path = path |
|
self.project_folder = Path(__file__).parent.parent.parent |
|
self.data = data |
|
|
|
def __iter__(self) -> Iterator[Dict[str, torch.Tensor]]: |
|
for sample in self.data: |
|
yield sample |
|
|
|
def __repr__(self) -> str: |
|
return f"Dataset({self.name=}, {self.path=})" |
|
|
|
def load( |
|
self, |
|
paths: Union[str, os.PathLike, List[str], List[os.PathLike]], |
|
*args, |
|
**kwargs, |
|
) -> Any: |
|
|
|
raise NotImplementedError |
|
|
|
@staticmethod |
|
def collate_fn(batch: Any, *args, **kwargs) -> Any: |
|
raise NotImplementedError |
|
|