File size: 510 Bytes
fae2821 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 |
import torchvision
class CIFAR10Dataset(torchvision.datasets.CIFAR10):
def __init__(self, root="~/data/cifar10", train=True, download=True, transform=None):
super().__init__(root=root, train=train, download=download, transform=transform)
def __getitem__(self, index):
image, label = self.data[index], self.targets[index]
if self.transform is not None:
transformed = self.transform(image=image)
image = transformed["image"]
return image, label |