CarlosMalaga's picture
Upload 201 files
2f044c1 verified
raw
history blame
2.49 kB
import os
from pathlib import Path
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union
import torch
from torch.utils.data import Dataset, IterableDataset
from relik.common.log import get_logger
logger = get_logger(__name__)
class BaseDataset(Dataset):
def __init__(
self,
name: str,
path: Optional[Union[str, os.PathLike, List[str], List[os.PathLike]]] = None,
data: Any = None,
**kwargs,
):
super().__init__()
self.name = name
if path is None and data is None:
raise ValueError("Either `path` or `data` must be provided")
self.path = path
self.project_folder = Path(__file__).parent.parent.parent
self.data = data
def __len__(self) -> int:
return len(self.data)
def __getitem__(
self, index
) -> Union[Dict[str, torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]:
return self.data[index]
def __repr__(self) -> str:
return f"Dataset({self.name=}, {self.path=})"
def load(
self,
paths: Union[str, os.PathLike, List[str], List[os.PathLike]],
*args,
**kwargs,
) -> Any:
# load data from single or multiple paths in one single dataset
raise NotImplementedError
@staticmethod
def collate_fn(batch: Any, *args, **kwargs) -> Any:
raise NotImplementedError
class IterableBaseDataset(IterableDataset):
def __init__(
self,
name: str,
path: Optional[Union[str, Path, List[str], List[Path]]] = None,
data: Any = None,
*args,
**kwargs,
):
super().__init__()
self.name = name
if path is None and data is None:
raise ValueError("Either `path` or `data` must be provided")
self.path = path
self.project_folder = Path(__file__).parent.parent.parent
self.data = data
def __iter__(self) -> Iterator[Dict[str, torch.Tensor]]:
for sample in self.data:
yield sample
def __repr__(self) -> str:
return f"Dataset({self.name=}, {self.path=})"
def load(
self,
paths: Union[str, os.PathLike, List[str], List[os.PathLike]],
*args,
**kwargs,
) -> Any:
# load data from single or multiple paths in one single dataset
raise NotImplementedError
@staticmethod
def collate_fn(batch: Any, *args, **kwargs) -> Any:
raise NotImplementedError