import torch from torchvision import datasets import torchvision.transforms as transforms batch_size = 128 def data_transform(): transform_train = transforms.Compose([ transforms.RandomHorizontalFlip(), transforms.RandomRotation(10), transforms.RandomCrop(32, padding=4), transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ]) transform_test = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ]) return transform_train, transform_test def data_loader(transform_train, transform_test): train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train) test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test) train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2) test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=2) return train_loader, test_loader