Spaces:
Runtime error
Runtime error
# --------------------------------------------------------------- | |
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. | |
# | |
# This work is licensed under the NVIDIA Source Code License | |
# for Denoising Diffusion GAN. To view a copy of this license, see the LICENSE file. | |
# --------------------------------------------------------------- | |
import numpy as np | |
from PIL import Image | |
import torchvision.datasets as dset | |
import torchvision.transforms as transforms | |
class StackedMNIST(dset.MNIST): | |
def __init__(self, root, train=True, transform=None, target_transform=None, | |
download=False): | |
super(StackedMNIST, self).__init__(root=root, train=train, transform=transform, | |
target_transform=target_transform, download=download) | |
index1 = np.hstack([np.random.permutation(len(self.data)), np.random.permutation(len(self.data))]) | |
index2 = np.hstack([np.random.permutation(len(self.data)), np.random.permutation(len(self.data))]) | |
index3 = np.hstack([np.random.permutation(len(self.data)), np.random.permutation(len(self.data))]) | |
self.num_images = 2 * len(self.data) | |
self.index = [] | |
for i in range(self.num_images): | |
self.index.append((index1[i], index2[i], index3[i])) | |
def __len__(self): | |
return self.num_images | |
def __getitem__(self, index): | |
img = np.zeros((28, 28, 3), dtype=np.uint8) | |
target = 0 | |
for i in range(3): | |
img_, target_ = self.data[self.index[index][i]], int(self.targets[self.index[index][i]]) | |
img[:, :, i] = img_ | |
target += target_ * 10 ** (2 - i) | |
img = Image.fromarray(img, mode="RGB") | |
if self.transform is not None: | |
img = self.transform(img) | |
if self.target_transform is not None: | |
target = self.target_transform(target) | |
return img, target | |
def _data_transforms_stacked_mnist(): | |
"""Get data transforms for cifar10.""" | |
train_transform = transforms.Compose([ | |
transforms.Pad(padding=2), | |
transforms.ToTensor(), | |
transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5)) | |
]) | |
valid_transform = transforms.Compose([ | |
transforms.Pad(padding=2), | |
transforms.ToTensor(), | |
transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5)) | |
]) | |
return train_transform, valid_transform | |