S12 / dataset /dataset.py
Sijuade's picture
Create dataset/dataset.py
fae2821
raw
history blame
510 Bytes
import torchvision
class CIFAR10Dataset(torchvision.datasets.CIFAR10):
def __init__(self, root="~/data/cifar10", train=True, download=True, transform=None):
super().__init__(root=root, train=train, download=download, transform=transform)
def __getitem__(self, index):
image, label = self.data[index], self.targets[index]
if self.transform is not None:
transformed = self.transform(image=image)
image = transformed["image"]
return image, label