|
import torch |
|
from torchvision import datasets |
|
import torchvision.transforms as transforms |
|
|
|
batch_size = 128 |
|
|
|
def data_transform(): |
|
transform_train = transforms.Compose([ |
|
transforms.RandomHorizontalFlip(), |
|
transforms.RandomRotation(10), |
|
transforms.RandomCrop(32, padding=4), |
|
transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1), |
|
transforms.ToTensor(), |
|
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) |
|
]) |
|
|
|
transform_test = transforms.Compose([ |
|
transforms.ToTensor(), |
|
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) |
|
]) |
|
|
|
return transform_train, transform_test |
|
|
|
def data_loader(transform_train, transform_test): |
|
train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train) |
|
test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test) |
|
|
|
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2) |
|
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=2) |
|
return train_loader, test_loader |
|
|