import torchvision | |
class CIFAR10Dataset(torchvision.datasets.CIFAR10): | |
def __init__(self, root="~/data/cifar10", train=True, download=True, transform=None): | |
super().__init__(root=root, train=train, download=download, transform=transform) | |
def __getitem__(self, index): | |
image, label = self.data[index], self.targets[index] | |
if self.transform is not None: | |
transformed = self.transform(image=image) | |
image = transformed["image"] | |
return image, label |