Spaces:
Runtime error
Runtime error
''' Datasets | |
This file contains definitions for our CIFAR, ImageFolder, and HDF5 datasets | |
''' | |
import os | |
import os.path | |
import sys | |
from PIL import Image | |
import numpy as np | |
from tqdm import tqdm, trange | |
import torchvision.datasets as dset | |
import torchvision.transforms as transforms | |
from torchvision.datasets.utils import download_url, check_integrity | |
import torch.utils.data as data | |
from torch.utils.data import DataLoader | |
IMG_EXTENSIONS = ['.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm'] | |
def is_image_file(filename): | |
"""Checks if a file is an image. | |
Args: | |
filename (string): path to a file | |
Returns: | |
bool: True if the filename ends with a known image extension | |
""" | |
filename_lower = filename.lower() | |
return any(filename_lower.endswith(ext) for ext in IMG_EXTENSIONS) | |
def find_classes(dir): | |
classes = [d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d))] | |
classes.sort() | |
class_to_idx = {classes[i]: i for i in range(len(classes))} | |
return classes, class_to_idx | |
def make_dataset(dir, class_to_idx): | |
images = [] | |
dir = os.path.expanduser(dir) | |
for target in tqdm(sorted(os.listdir(dir))): | |
d = os.path.join(dir, target) | |
if not os.path.isdir(d): | |
continue | |
for root, _, fnames in sorted(os.walk(d)): | |
for fname in sorted(fnames): | |
if is_image_file(fname): | |
path = os.path.join(root, fname) | |
item = (path, class_to_idx[target]) | |
images.append(item) | |
return images | |
def pil_loader(path): | |
# open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835) | |
with open(path, 'rb') as f: | |
img = Image.open(f) | |
return img.convert('RGB') | |
def accimage_loader(path): | |
import accimage | |
try: | |
return accimage.Image(path) | |
except IOError: | |
# Potentially a decoding problem, fall back to PIL.Image | |
return pil_loader(path) | |
def default_loader(path): | |
from torchvision import get_image_backend | |
if get_image_backend() == 'accimage': | |
return accimage_loader(path) | |
else: | |
return pil_loader(path) | |
class ImageFolder(data.Dataset): | |
"""A generic data loader where the images are arranged in this way: :: | |
root/dogball/xxx.png | |
root/dogball/xxy.png | |
root/dogball/xxz.png | |
root/cat/123.png | |
root/cat/nsdf3.png | |
root/cat/asd932_.png | |
Args: | |
root (string): Root directory path. | |
transform (callable, optional): A function/transform that takes in an PIL image | |
and returns a transformed version. E.g, ``transforms.RandomCrop`` | |
target_transform (callable, optional): A function/transform that takes in the | |
target and transforms it. | |
loader (callable, optional): A function to load an image given its path. | |
Attributes: | |
classes (list): List of the class names. | |
class_to_idx (dict): Dict with items (class_name, class_index). | |
imgs (list): List of (image path, class_index) tuples | |
""" | |
def __init__(self, root, transform=None, target_transform=None, | |
loader=default_loader, load_in_mem=False, | |
index_filename='imagenet_imgs.npz', **kwargs): | |
classes, class_to_idx = find_classes(root) | |
# Load pre-computed image directory walk | |
if os.path.exists(index_filename): | |
print('Loading pre-saved Index file %s...' % index_filename) | |
imgs = np.load(index_filename)['imgs'] | |
# If first time, walk the folder directory and save the | |
# results to a pre-computed file. | |
else: | |
print('Generating Index file %s...' % index_filename) | |
imgs = make_dataset(root, class_to_idx) | |
np.savez_compressed(index_filename, **{'imgs' : imgs}) | |
if len(imgs) == 0: | |
raise(RuntimeError("Found 0 images in subfolders of: " + root + "\n" | |
"Supported image extensions are: " + ",".join(IMG_EXTENSIONS))) | |
self.root = root | |
self.imgs = imgs | |
self.classes = classes | |
self.class_to_idx = class_to_idx | |
self.transform = transform | |
self.target_transform = target_transform | |
self.loader = loader | |
self.load_in_mem = load_in_mem | |
if self.load_in_mem: | |
print('Loading all images into memory...') | |
self.data, self.labels = [], [] | |
for index in tqdm(range(len(self.imgs))): | |
path, target = imgs[index][0], imgs[index][1] | |
self.data.append(self.transform(self.loader(path))) | |
self.labels.append(target) | |
def __getitem__(self, index): | |
""" | |
Args: | |
index (int): Index | |
Returns: | |
tuple: (image, target) where target is class_index of the target class. | |
""" | |
if self.load_in_mem: | |
img = self.data[index] | |
target = self.labels[index] | |
else: | |
path, target = self.imgs[index] | |
img = self.loader(str(path)) | |
if self.transform is not None: | |
img = self.transform(img) | |
if self.target_transform is not None: | |
target = self.target_transform(target) | |
# print(img.size(), target) | |
return img, int(target) | |
def __len__(self): | |
return len(self.imgs) | |
def __repr__(self): | |
fmt_str = 'Dataset ' + self.__class__.__name__ + '\n' | |
fmt_str += ' Number of datapoints: {}\n'.format(self.__len__()) | |
fmt_str += ' Root Location: {}\n'.format(self.root) | |
tmp = ' Transforms (if any): ' | |
fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp))) | |
tmp = ' Target Transforms (if any): ' | |
fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp))) | |
return fmt_str | |
''' ILSVRC_HDF5: A dataset to support I/O from an HDF5 to avoid | |
having to load individual images all the time. ''' | |
import h5py as h5 | |
import torch | |
class ILSVRC_HDF5(data.Dataset): | |
def __init__(self, root, transform=None, target_transform=None, | |
load_in_mem=False, train=True,download=False, validate_seed=0, | |
val_split=0, **kwargs): # last four are dummies | |
self.root = root | |
self.num_imgs = len(h5.File(root, 'r')['labels']) | |
# self.transform = transform | |
self.target_transform = target_transform | |
# Set the transform here | |
self.transform = transform | |
# load the entire dataset into memory? | |
self.load_in_mem = load_in_mem | |
# If loading into memory, do so now | |
if self.load_in_mem: | |
print('Loading %s into memory...' % root) | |
with h5.File(root,'r') as f: | |
self.data = f['imgs'][:] | |
self.labels = f['labels'][:] | |
def __getitem__(self, index): | |
""" | |
Args: | |
index (int): Index | |
Returns: | |
tuple: (image, target) where target is class_index of the target class. | |
""" | |
# If loaded the entire dataset in RAM, get image from memory | |
if self.load_in_mem: | |
img = self.data[index] | |
target = self.labels[index] | |
# Else load it from disk | |
else: | |
with h5.File(self.root,'r') as f: | |
img = f['imgs'][index] | |
target = f['labels'][index] | |
# if self.transform is not None: | |
# img = self.transform(img) | |
# Apply my own transform | |
img = ((torch.from_numpy(img).float() / 255) - 0.5) * 2 | |
if self.target_transform is not None: | |
target = self.target_transform(target) | |
return img, int(target) | |
def __len__(self): | |
return self.num_imgs | |
# return len(self.f['imgs']) | |
import pickle | |
class CIFAR10(dset.CIFAR10): | |
def __init__(self, root, train=True, | |
transform=None, target_transform=None, | |
download=True, validate_seed=0, | |
val_split=0, load_in_mem=True, **kwargs): | |
self.root = os.path.expanduser(root) | |
self.transform = transform | |
self.target_transform = target_transform | |
self.train = train # training set or test set | |
self.val_split = val_split | |
if download: | |
self.download() | |
if not self._check_integrity(): | |
raise RuntimeError('Dataset not found or corrupted.' + | |
' You can use download=True to download it') | |
# now load the picked numpy arrays | |
self.data = [] | |
self.labels= [] | |
for fentry in self.train_list: | |
f = fentry[0] | |
file = os.path.join(self.root, self.base_folder, f) | |
fo = open(file, 'rb') | |
if sys.version_info[0] == 2: | |
entry = pickle.load(fo) | |
else: | |
entry = pickle.load(fo, encoding='latin1') | |
self.data.append(entry['data']) | |
if 'labels' in entry: | |
self.labels += entry['labels'] | |
else: | |
self.labels += entry['fine_labels'] | |
fo.close() | |
self.data = np.concatenate(self.data) | |
# Randomly select indices for validation | |
if self.val_split > 0: | |
label_indices = [[] for _ in range(max(self.labels)+1)] | |
for i,l in enumerate(self.labels): | |
label_indices[l] += [i] | |
label_indices = np.asarray(label_indices) | |
# randomly grab 500 elements of each class | |
np.random.seed(validate_seed) | |
self.val_indices = [] | |
for l_i in label_indices: | |
self.val_indices += list(l_i[np.random.choice(len(l_i), int(len(self.data) * val_split) // (max(self.labels) + 1) ,replace=False)]) | |
if self.train=='validate': | |
self.data = self.data[self.val_indices] | |
self.labels = list(np.asarray(self.labels)[self.val_indices]) | |
self.data = self.data.reshape((int(50e3 * self.val_split), 3, 32, 32)) | |
self.data = self.data.transpose((0, 2, 3, 1)) # convert to HWC | |
elif self.train: | |
print(np.shape(self.data)) | |
if self.val_split > 0: | |
self.data = np.delete(self.data,self.val_indices,axis=0) | |
self.labels = list(np.delete(np.asarray(self.labels),self.val_indices,axis=0)) | |
self.data = self.data.reshape((int(50e3 * (1.-self.val_split)), 3, 32, 32)) | |
self.data = self.data.transpose((0, 2, 3, 1)) # convert to HWC | |
else: | |
f = self.test_list[0][0] | |
file = os.path.join(self.root, self.base_folder, f) | |
fo = open(file, 'rb') | |
if sys.version_info[0] == 2: | |
entry = pickle.load(fo) | |
else: | |
entry = pickle.load(fo, encoding='latin1') | |
self.data = entry['data'] | |
if 'labels' in entry: | |
self.labels = entry['labels'] | |
else: | |
self.labels = entry['fine_labels'] | |
fo.close() | |
self.data = self.data.reshape((10000, 3, 32, 32)) | |
self.data = self.data.transpose((0, 2, 3, 1)) # convert to HWC | |
def __getitem__(self, index): | |
""" | |
Args: | |
index (int): Index | |
Returns: | |
tuple: (image, target) where target is index of the target class. | |
""" | |
img, target = self.data[index], self.labels[index] | |
# doing this so that it is consistent with all other datasets | |
# to return a PIL Image | |
img = Image.fromarray(img) | |
if self.transform is not None: | |
img = self.transform(img) | |
if self.target_transform is not None: | |
target = self.target_transform(target) | |
return img, target | |
def __len__(self): | |
return len(self.data) | |
class CIFAR100(CIFAR10): | |
base_folder = 'cifar-100-python' | |
url = "http://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz" | |
filename = "cifar-100-python.tar.gz" | |
tgz_md5 = 'eb9058c3a382ffc7106e4002c42a8d85' | |
train_list = [ | |
['train', '16019d7e3df5f24257cddd939b257f8d'], | |
] | |
test_list = [ | |
['test', 'f0ef6b0ae62326f3e7ffdfab6717acfc'], | |
] | |