Spaces:
Running
Running
from typing import Callable, List, Optional, Union | |
import torch | |
from pytorch_lightning import LightningDataModule | |
from torch.utils.data import DataLoader, Dataset | |
class SampleDataset(Dataset): | |
def __init__(self, | |
x: Union[List, torch.Tensor], | |
y: Union[List, torch.Tensor], | |
transforms: Optional[Callable] = None) -> None: | |
super(SampleDataset, self).__init__() | |
self.x = x | |
self.y = y | |
if transforms is None: | |
# Replace None with some default transforms | |
# If image, could be an Resize and ToTensor | |
self.transforms = lambda x: x | |
else: | |
self.transforms = transforms | |
def __len__(self): | |
return len(self.x) | |
def __getitem__(self, index: int): | |
x = self.x[index] | |
y = self.y[index] | |
x = self.transforms(x) | |
return x, y | |
class SampleDataModule(LightningDataModule): | |
def __init__(self, | |
x: Union[List, torch.Tensor], | |
y: Union[List, torch.Tensor], | |
transforms: Optional[Callable] = None, | |
val_ratio: float = 0, | |
batch_size: int = 32) -> None: | |
super(SampleDataModule, self).__init__() | |
assert 0 <= val_ratio < 1 | |
assert isinstance(batch_size, int) | |
self.x = x | |
self.y = y | |
self.transforms = transforms | |
self.val_ratio = val_ratio | |
self.batch_size = batch_size | |
self.setup() | |
self.prepare_data() | |
def setup(self, stage: Optional[str] = None) -> None: | |
pass | |
def prepare_data(self) -> None: | |
n_samples: int = len(self.x) | |
train_size: int = n_samples - int(n_samples * self.val_ratio) | |
self.train_dataset = SampleDataset(x=self.x[:train_size], | |
y=self.y[:train_size], | |
transforms=self.transforms) | |
if train_size < n_samples: | |
self.val_dataset = SampleDataset(x=self.x[train_size:], | |
y=self.y[train_size:], | |
transforms=self.transforms) | |
else: | |
self.val_dataset = SampleDataset(x=self.x[-self.batch_size:], | |
y=self.y[-self.batch_size:], | |
transforms=self.transforms) | |
def train_dataloader(self) -> DataLoader: | |
return DataLoader(dataset=self.train_dataset, | |
batch_size=self.batch_size) | |
def val_dataloader(self) -> DataLoader: | |
return DataLoader(dataset=self.val_dataset, batch_size=self.batch_size) | |