import torchvision class CIFAR10Dataset(torchvision.datasets.CIFAR10): def __init__(self, root="~/data/cifar10", train=True, download=True, transform=None): super().__init__(root=root, train=train, download=download, transform=transform) def __getitem__(self, index): image, label = self.data[index], self.targets[index] if self.transform is not None: transformed = self.transform(image=image) image = transformed["image"] return image, label