sinder / data.py
haoqiwang's picture
add files
9ae1b1e
raw
history blame
4.06 kB
from glob import glob
import torchvision.transforms as transforms
from PIL import Image
from torch.utils.data import Dataset
def make_transform(
smaller_edge_size: int, patch_size, center_crop=False, max_edge_size=812
) -> transforms.Compose:
IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
interpolation_mode = transforms.InterpolationMode.BICUBIC
assert smaller_edge_size > 0
if center_crop:
return transforms.Compose(
[
transforms.Resize(
size=smaller_edge_size,
interpolation=interpolation_mode,
antialias=True,
),
transforms.CenterCrop(smaller_edge_size),
transforms.ToTensor(),
transforms.Normalize(
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD
),
transforms.Lambda(
lambda img: img[
:,
: min(
max_edge_size,
(img.shape[1] - img.shape[1] % patch_size),
),
: min(
max_edge_size,
(img.shape[2] - img.shape[2] % patch_size),
),
]
),
]
)
else:
return transforms.Compose(
[
transforms.Resize(
size=smaller_edge_size,
interpolation=interpolation_mode,
antialias=True,
),
transforms.ToTensor(),
transforms.Normalize(
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD
),
transforms.Lambda(
lambda img: img[
:,
: min(
max_edge_size,
(img.shape[1] - img.shape[1] % patch_size),
),
: min(
max_edge_size,
(img.shape[2] - img.shape[2] % patch_size),
),
]
),
]
)
class VisualDataset(Dataset):
def __init__(self, transform, imgs=None):
self.transform = transform
if imgs is None:
self.files = [
'resources/example.jpg',
'resources/villa.png',
'resources/000000037740.jpg',
'resources/000000064359.jpg',
'resources/000000066635.jpg',
'resources/000000078420.jpg',
]
else:
self.files = imgs
def __len__(self):
return len(self.files)
def __getitem__(self, index):
img = self.files[index]
img = Image.open(img).convert('RGB')
if self.transform:
img = self.transform(img)
return img
class ImageNetDataset(Dataset):
def __init__(self, transform, num_train_max=1000000):
self.transform = transform
self.files = glob('data/imagenet/train/*/*.JPEG')
step = len(self.files) // num_train_max
self.files = self.files[::step]
def __len__(self):
return len(self.files)
def __getitem__(self, index):
img = Image.open(self.files[index]).convert('RGB')
img = self.transform(img)
return img
def load_data(args, model):
transform = make_transform(
args.resolution, model.patch_size, center_crop=True
)
dataset = ImageNetDataset(
transform=transform, num_train_max=args.num_train_max
)
return dataset
def load_visual_data(args, model):
transform = make_transform(
args.visual_size, model.patch_size, max_edge_size=1792
)
dataset = VisualDataset(transform=transform, imgs=vars(args).get('imgs'))
return dataset