Spaces:
Running
Running
# Copyright 2018 The TensorFlow Authors All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# ============================================================================== | |
"""Data utils for CIFAR-10 and CIFAR-100.""" | |
from __future__ import absolute_import | |
from __future__ import division | |
from __future__ import print_function | |
import copy | |
import cPickle | |
import os | |
import augmentation_transforms | |
import numpy as np | |
import policies as found_policies | |
import tensorflow as tf | |
# pylint:disable=logging-format-interpolation | |
class DataSet(object): | |
"""Dataset object that produces augmented training and eval data.""" | |
def __init__(self, hparams): | |
self.hparams = hparams | |
self.epochs = 0 | |
self.curr_train_index = 0 | |
all_labels = [] | |
self.good_policies = found_policies.good_policies() | |
# Determine how many databatched to load | |
num_data_batches_to_load = 5 | |
total_batches_to_load = num_data_batches_to_load | |
train_batches_to_load = total_batches_to_load | |
assert hparams.train_size + hparams.validation_size <= 50000 | |
if hparams.eval_test: | |
total_batches_to_load += 1 | |
# Determine how many images we have loaded | |
total_dataset_size = 10000 * num_data_batches_to_load | |
train_dataset_size = total_dataset_size | |
if hparams.eval_test: | |
total_dataset_size += 10000 | |
if hparams.dataset == 'cifar10': | |
all_data = np.empty((total_batches_to_load, 10000, 3072), dtype=np.uint8) | |
elif hparams.dataset == 'cifar100': | |
assert num_data_batches_to_load == 5 | |
all_data = np.empty((1, 50000, 3072), dtype=np.uint8) | |
if hparams.eval_test: | |
test_data = np.empty((1, 10000, 3072), dtype=np.uint8) | |
if hparams.dataset == 'cifar10': | |
tf.logging.info('Cifar10') | |
datafiles = [ | |
'data_batch_1', 'data_batch_2', 'data_batch_3', 'data_batch_4', | |
'data_batch_5'] | |
datafiles = datafiles[:train_batches_to_load] | |
if hparams.eval_test: | |
datafiles.append('test_batch') | |
num_classes = 10 | |
elif hparams.dataset == 'cifar100': | |
datafiles = ['train'] | |
if hparams.eval_test: | |
datafiles.append('test') | |
num_classes = 100 | |
else: | |
raise NotImplementedError('Unimplemented dataset: ', hparams.dataset) | |
if hparams.dataset != 'test': | |
for file_num, f in enumerate(datafiles): | |
d = unpickle(os.path.join(hparams.data_path, f)) | |
if f == 'test': | |
test_data[0] = copy.deepcopy(d['data']) | |
all_data = np.concatenate([all_data, test_data], axis=1) | |
else: | |
all_data[file_num] = copy.deepcopy(d['data']) | |
if hparams.dataset == 'cifar10': | |
labels = np.array(d['labels']) | |
else: | |
labels = np.array(d['fine_labels']) | |
nsamples = len(labels) | |
for idx in range(nsamples): | |
all_labels.append(labels[idx]) | |
all_data = all_data.reshape(total_dataset_size, 3072) | |
all_data = all_data.reshape(-1, 3, 32, 32) | |
all_data = all_data.transpose(0, 2, 3, 1).copy() | |
all_data = all_data / 255.0 | |
mean = augmentation_transforms.MEANS | |
std = augmentation_transforms.STDS | |
tf.logging.info('mean:{} std: {}'.format(mean, std)) | |
all_data = (all_data - mean) / std | |
all_labels = np.eye(num_classes)[np.array(all_labels, dtype=np.int32)] | |
assert len(all_data) == len(all_labels) | |
tf.logging.info( | |
'In CIFAR10 loader, number of images: {}'.format(len(all_data))) | |
# Break off test data | |
if hparams.eval_test: | |
self.test_images = all_data[train_dataset_size:] | |
self.test_labels = all_labels[train_dataset_size:] | |
# Shuffle the rest of the data | |
all_data = all_data[:train_dataset_size] | |
all_labels = all_labels[:train_dataset_size] | |
np.random.seed(0) | |
perm = np.arange(len(all_data)) | |
np.random.shuffle(perm) | |
all_data = all_data[perm] | |
all_labels = all_labels[perm] | |
# Break into train and val | |
train_size, val_size = hparams.train_size, hparams.validation_size | |
assert 50000 >= train_size + val_size | |
self.train_images = all_data[:train_size] | |
self.train_labels = all_labels[:train_size] | |
self.val_images = all_data[train_size:train_size + val_size] | |
self.val_labels = all_labels[train_size:train_size + val_size] | |
self.num_train = self.train_images.shape[0] | |
def next_batch(self): | |
"""Return the next minibatch of augmented data.""" | |
next_train_index = self.curr_train_index + self.hparams.batch_size | |
if next_train_index > self.num_train: | |
# Increase epoch number | |
epoch = self.epochs + 1 | |
self.reset() | |
self.epochs = epoch | |
batched_data = ( | |
self.train_images[self.curr_train_index: | |
self.curr_train_index + self.hparams.batch_size], | |
self.train_labels[self.curr_train_index: | |
self.curr_train_index + self.hparams.batch_size]) | |
final_imgs = [] | |
images, labels = batched_data | |
for data in images: | |
epoch_policy = self.good_policies[np.random.choice( | |
len(self.good_policies))] | |
final_img = augmentation_transforms.apply_policy( | |
epoch_policy, data) | |
final_img = augmentation_transforms.random_flip( | |
augmentation_transforms.zero_pad_and_crop(final_img, 4)) | |
# Apply cutout | |
final_img = augmentation_transforms.cutout_numpy(final_img) | |
final_imgs.append(final_img) | |
batched_data = (np.array(final_imgs, np.float32), labels) | |
self.curr_train_index += self.hparams.batch_size | |
return batched_data | |
def reset(self): | |
"""Reset training data and index into the training data.""" | |
self.epochs = 0 | |
# Shuffle the training data | |
perm = np.arange(self.num_train) | |
np.random.shuffle(perm) | |
assert self.num_train == self.train_images.shape[ | |
0], 'Error incorrect shuffling mask' | |
self.train_images = self.train_images[perm] | |
self.train_labels = self.train_labels[perm] | |
self.curr_train_index = 0 | |
def unpickle(f): | |
tf.logging.info('loading file: {}'.format(f)) | |
fo = tf.gfile.Open(f, 'r') | |
d = cPickle.load(fo) | |
fo.close() | |
return d | |