|
from glob import glob |
|
|
|
import torchvision.transforms as transforms |
|
from PIL import Image |
|
from torch.utils.data import Dataset |
|
|
|
|
|
def make_transform( |
|
smaller_edge_size: int, patch_size, center_crop=False, max_edge_size=812 |
|
) -> transforms.Compose: |
|
IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406) |
|
IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225) |
|
interpolation_mode = transforms.InterpolationMode.BICUBIC |
|
assert smaller_edge_size > 0 |
|
|
|
if center_crop: |
|
return transforms.Compose( |
|
[ |
|
transforms.Resize( |
|
size=smaller_edge_size, |
|
interpolation=interpolation_mode, |
|
antialias=True, |
|
), |
|
transforms.CenterCrop(smaller_edge_size), |
|
transforms.ToTensor(), |
|
transforms.Normalize( |
|
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD |
|
), |
|
transforms.Lambda( |
|
lambda img: img[ |
|
:, |
|
: min( |
|
max_edge_size, |
|
(img.shape[1] - img.shape[1] % patch_size), |
|
), |
|
: min( |
|
max_edge_size, |
|
(img.shape[2] - img.shape[2] % patch_size), |
|
), |
|
] |
|
), |
|
] |
|
) |
|
else: |
|
return transforms.Compose( |
|
[ |
|
transforms.Resize( |
|
size=smaller_edge_size, |
|
interpolation=interpolation_mode, |
|
antialias=True, |
|
), |
|
transforms.ToTensor(), |
|
transforms.Normalize( |
|
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD |
|
), |
|
transforms.Lambda( |
|
lambda img: img[ |
|
:, |
|
: min( |
|
max_edge_size, |
|
(img.shape[1] - img.shape[1] % patch_size), |
|
), |
|
: min( |
|
max_edge_size, |
|
(img.shape[2] - img.shape[2] % patch_size), |
|
), |
|
] |
|
), |
|
] |
|
) |
|
|
|
|
|
class VisualDataset(Dataset): |
|
def __init__(self, transform, imgs=None): |
|
self.transform = transform |
|
if imgs is None: |
|
self.files = [ |
|
'resources/example.jpg', |
|
'resources/villa.png', |
|
'resources/000000037740.jpg', |
|
'resources/000000064359.jpg', |
|
'resources/000000066635.jpg', |
|
'resources/000000078420.jpg', |
|
] |
|
else: |
|
self.files = imgs |
|
|
|
def __len__(self): |
|
return len(self.files) |
|
|
|
def __getitem__(self, index): |
|
img = self.files[index] |
|
img = Image.open(img).convert('RGB') |
|
if self.transform: |
|
img = self.transform(img) |
|
return img |
|
|
|
|
|
class ImageNetDataset(Dataset): |
|
def __init__(self, transform, num_train_max=1000000): |
|
self.transform = transform |
|
self.files = glob('data/imagenet/train/*/*.JPEG') |
|
step = len(self.files) // num_train_max |
|
self.files = self.files[::step] |
|
|
|
def __len__(self): |
|
return len(self.files) |
|
|
|
def __getitem__(self, index): |
|
img = Image.open(self.files[index]).convert('RGB') |
|
img = self.transform(img) |
|
return img |
|
|
|
|
|
def load_data(args, model): |
|
transform = make_transform( |
|
args.resolution, model.patch_size, center_crop=True |
|
) |
|
dataset = ImageNetDataset( |
|
transform=transform, num_train_max=args.num_train_max |
|
) |
|
return dataset |
|
|
|
|
|
def load_visual_data(args, model): |
|
transform = make_transform( |
|
args.visual_size, model.patch_size, max_edge_size=1792 |
|
) |
|
dataset = VisualDataset(transform=transform, imgs=vars(args).get('imgs')) |
|
return dataset |
|
|