|
import pytorch_lightning as pl |
|
import torchvision |
|
from torch.utils.data import DataLoader, Dataset |
|
from torchvision import transforms |
|
|
|
|
|
class MNISTDataDictWrapper(Dataset): |
|
def __init__(self, dset): |
|
super().__init__() |
|
self.dset = dset |
|
|
|
def __getitem__(self, i): |
|
x, y = self.dset[i] |
|
return {"jpg": x, "cls": y} |
|
|
|
def __len__(self): |
|
return len(self.dset) |
|
|
|
|
|
class MNISTLoader(pl.LightningDataModule): |
|
def __init__(self, batch_size, num_workers=0, prefetch_factor=2, shuffle=True): |
|
super().__init__() |
|
|
|
transform = transforms.Compose( |
|
[transforms.ToTensor(), transforms.Lambda(lambda x: x * 2.0 - 1.0)] |
|
) |
|
|
|
self.batch_size = batch_size |
|
self.num_workers = num_workers |
|
self.prefetch_factor = prefetch_factor if num_workers > 0 else 0 |
|
self.shuffle = shuffle |
|
self.train_dataset = MNISTDataDictWrapper( |
|
torchvision.datasets.MNIST( |
|
root=".data/", train=True, download=True, transform=transform |
|
) |
|
) |
|
self.test_dataset = MNISTDataDictWrapper( |
|
torchvision.datasets.MNIST( |
|
root=".data/", train=False, download=True, transform=transform |
|
) |
|
) |
|
|
|
def prepare_data(self): |
|
pass |
|
|
|
def train_dataloader(self): |
|
return DataLoader( |
|
self.train_dataset, |
|
batch_size=self.batch_size, |
|
shuffle=self.shuffle, |
|
num_workers=self.num_workers, |
|
prefetch_factor=self.prefetch_factor, |
|
) |
|
|
|
def test_dataloader(self): |
|
return DataLoader( |
|
self.test_dataset, |
|
batch_size=self.batch_size, |
|
shuffle=self.shuffle, |
|
num_workers=self.num_workers, |
|
prefetch_factor=self.prefetch_factor, |
|
) |
|
|
|
def val_dataloader(self): |
|
return DataLoader( |
|
self.test_dataset, |
|
batch_size=self.batch_size, |
|
shuffle=self.shuffle, |
|
num_workers=self.num_workers, |
|
prefetch_factor=self.prefetch_factor, |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
dset = MNISTDataDictWrapper( |
|
torchvision.datasets.MNIST( |
|
root=".data/", |
|
train=False, |
|
download=True, |
|
transform=transforms.Compose( |
|
[transforms.ToTensor(), transforms.Lambda(lambda x: x * 2.0 - 1.0)] |
|
), |
|
) |
|
) |
|
ex = dset[0] |
|
|