Spaces:
Running
Running
# Copyright 2018 The TensorFlow Authors All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# ============================================================================== | |
"""AutoAugment Train/Eval module. | |
""" | |
from __future__ import absolute_import | |
from __future__ import division | |
from __future__ import print_function | |
import contextlib | |
import os | |
import time | |
import custom_ops as ops | |
import data_utils | |
import helper_utils | |
import numpy as np | |
from shake_drop import build_shake_drop_model | |
from shake_shake import build_shake_shake_model | |
import tensorflow as tf | |
from wrn import build_wrn_model | |
tf.flags.DEFINE_string('model_name', 'wrn', | |
'wrn, shake_shake_32, shake_shake_96, shake_shake_112, ' | |
'pyramid_net') | |
tf.flags.DEFINE_string('checkpoint_dir', '/tmp/training', 'Training Directory.') | |
tf.flags.DEFINE_string('data_path', '/tmp/data', | |
'Directory where dataset is located.') | |
tf.flags.DEFINE_string('dataset', 'cifar10', | |
'Dataset to train with. Either cifar10 or cifar100') | |
tf.flags.DEFINE_integer('use_cpu', 1, '1 if use CPU, else GPU.') | |
FLAGS = tf.flags.FLAGS | |
arg_scope = tf.contrib.framework.arg_scope | |
def setup_arg_scopes(is_training): | |
"""Sets up the argscopes that will be used when building an image model. | |
Args: | |
is_training: Is the model training or not. | |
Returns: | |
Arg scopes to be put around the model being constructed. | |
""" | |
batch_norm_decay = 0.9 | |
batch_norm_epsilon = 1e-5 | |
batch_norm_params = { | |
# Decay for the moving averages. | |
'decay': batch_norm_decay, | |
# epsilon to prevent 0s in variance. | |
'epsilon': batch_norm_epsilon, | |
'scale': True, | |
# collection containing the moving mean and moving variance. | |
'is_training': is_training, | |
} | |
scopes = [] | |
scopes.append(arg_scope([ops.batch_norm], **batch_norm_params)) | |
return scopes | |
def build_model(inputs, num_classes, is_training, hparams): | |
"""Constructs the vision model being trained/evaled. | |
Args: | |
inputs: input features/images being fed to the image model build built. | |
num_classes: number of output classes being predicted. | |
is_training: is the model training or not. | |
hparams: additional hyperparameters associated with the image model. | |
Returns: | |
The logits of the image model. | |
""" | |
scopes = setup_arg_scopes(is_training) | |
with contextlib.nested(*scopes): | |
if hparams.model_name == 'pyramid_net': | |
logits = build_shake_drop_model( | |
inputs, num_classes, is_training) | |
elif hparams.model_name == 'wrn': | |
logits = build_wrn_model( | |
inputs, num_classes, hparams.wrn_size) | |
elif hparams.model_name == 'shake_shake': | |
logits = build_shake_shake_model( | |
inputs, num_classes, hparams, is_training) | |
return logits | |
class CifarModel(object): | |
"""Builds an image model for Cifar10/Cifar100.""" | |
def __init__(self, hparams): | |
self.hparams = hparams | |
def build(self, mode): | |
"""Construct the cifar model.""" | |
assert mode in ['train', 'eval'] | |
self.mode = mode | |
self._setup_misc(mode) | |
self._setup_images_and_labels() | |
self._build_graph(self.images, self.labels, mode) | |
self.init = tf.group(tf.global_variables_initializer(), | |
tf.local_variables_initializer()) | |
def _setup_misc(self, mode): | |
"""Sets up miscellaneous in the cifar model constructor.""" | |
self.lr_rate_ph = tf.Variable(0.0, name='lrn_rate', trainable=False) | |
self.reuse = None if (mode == 'train') else True | |
self.batch_size = self.hparams.batch_size | |
if mode == 'eval': | |
self.batch_size = 25 | |
def _setup_images_and_labels(self): | |
"""Sets up image and label placeholders for the cifar model.""" | |
if FLAGS.dataset == 'cifar10': | |
self.num_classes = 10 | |
else: | |
self.num_classes = 100 | |
self.images = tf.placeholder(tf.float32, [self.batch_size, 32, 32, 3]) | |
self.labels = tf.placeholder(tf.float32, | |
[self.batch_size, self.num_classes]) | |
def assign_epoch(self, session, epoch_value): | |
session.run(self._epoch_update, feed_dict={self._new_epoch: epoch_value}) | |
def _build_graph(self, images, labels, mode): | |
"""Constructs the TF graph for the cifar model. | |
Args: | |
images: A 4-D image Tensor | |
labels: A 2-D labels Tensor. | |
mode: string indicating training mode ( e.g., 'train', 'valid', 'test'). | |
""" | |
is_training = 'train' in mode | |
if is_training: | |
self.global_step = tf.train.get_or_create_global_step() | |
logits = build_model( | |
images, | |
self.num_classes, | |
is_training, | |
self.hparams) | |
self.predictions, self.cost = helper_utils.setup_loss( | |
logits, labels) | |
self.accuracy, self.eval_op = tf.metrics.accuracy( | |
tf.argmax(labels, 1), tf.argmax(self.predictions, 1)) | |
self._calc_num_trainable_params() | |
# Adds L2 weight decay to the cost | |
self.cost = helper_utils.decay_weights(self.cost, | |
self.hparams.weight_decay_rate) | |
if is_training: | |
self._build_train_op() | |
# Setup checkpointing for this child model | |
# Keep 2 or more checkpoints around during training. | |
with tf.device('/cpu:0'): | |
self.saver = tf.train.Saver(max_to_keep=2) | |
self.init = tf.group(tf.global_variables_initializer(), | |
tf.local_variables_initializer()) | |
def _calc_num_trainable_params(self): | |
self.num_trainable_params = np.sum([ | |
np.prod(var.get_shape().as_list()) for var in tf.trainable_variables() | |
]) | |
tf.logging.info('number of trainable params: {}'.format( | |
self.num_trainable_params)) | |
def _build_train_op(self): | |
"""Builds the train op for the cifar model.""" | |
hparams = self.hparams | |
tvars = tf.trainable_variables() | |
grads = tf.gradients(self.cost, tvars) | |
if hparams.gradient_clipping_by_global_norm > 0.0: | |
grads, norm = tf.clip_by_global_norm( | |
grads, hparams.gradient_clipping_by_global_norm) | |
tf.summary.scalar('grad_norm', norm) | |
# Setup the initial learning rate | |
initial_lr = self.lr_rate_ph | |
optimizer = tf.train.MomentumOptimizer( | |
initial_lr, | |
0.9, | |
use_nesterov=True) | |
self.optimizer = optimizer | |
apply_op = optimizer.apply_gradients( | |
zip(grads, tvars), global_step=self.global_step, name='train_step') | |
train_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) | |
with tf.control_dependencies([apply_op]): | |
self.train_op = tf.group(*train_ops) | |
class CifarModelTrainer(object): | |
"""Trains an instance of the CifarModel class.""" | |
def __init__(self, hparams): | |
self._session = None | |
self.hparams = hparams | |
self.model_dir = os.path.join(FLAGS.checkpoint_dir, 'model') | |
self.log_dir = os.path.join(FLAGS.checkpoint_dir, 'log') | |
# Set the random seed to be sure the same validation set | |
# is used for each model | |
np.random.seed(0) | |
self.data_loader = data_utils.DataSet(hparams) | |
np.random.seed() # Put the random seed back to random | |
self.data_loader.reset() | |
def save_model(self, step=None): | |
"""Dumps model into the backup_dir. | |
Args: | |
step: If provided, creates a checkpoint with the given step | |
number, instead of overwriting the existing checkpoints. | |
""" | |
model_save_name = os.path.join(self.model_dir, 'model.ckpt') | |
if not tf.gfile.IsDirectory(self.model_dir): | |
tf.gfile.MakeDirs(self.model_dir) | |
self.saver.save(self.session, model_save_name, global_step=step) | |
tf.logging.info('Saved child model') | |
def extract_model_spec(self): | |
"""Loads a checkpoint with the architecture structure stored in the name.""" | |
checkpoint_path = tf.train.latest_checkpoint(self.model_dir) | |
if checkpoint_path is not None: | |
self.saver.restore(self.session, checkpoint_path) | |
tf.logging.info('Loaded child model checkpoint from %s', | |
checkpoint_path) | |
else: | |
self.save_model(step=0) | |
def eval_child_model(self, model, data_loader, mode): | |
"""Evaluate the child model. | |
Args: | |
model: image model that will be evaluated. | |
data_loader: dataset object to extract eval data from. | |
mode: will the model be evalled on train, val or test. | |
Returns: | |
Accuracy of the model on the specified dataset. | |
""" | |
tf.logging.info('Evaluating child model in mode %s', mode) | |
while True: | |
try: | |
with self._new_session(model): | |
accuracy = helper_utils.eval_child_model( | |
self.session, | |
model, | |
data_loader, | |
mode) | |
tf.logging.info('Eval child model accuracy: {}'.format(accuracy)) | |
# If epoch trained without raising the below errors, break | |
# from loop. | |
break | |
except (tf.errors.AbortedError, tf.errors.UnavailableError) as e: | |
tf.logging.info('Retryable error caught: %s. Retrying.', e) | |
return accuracy | |
def _new_session(self, m): | |
"""Creates a new session for model m.""" | |
# Create a new session for this model, initialize | |
# variables, and save / restore from | |
# checkpoint. | |
self._session = tf.Session( | |
'', | |
config=tf.ConfigProto( | |
allow_soft_placement=True, log_device_placement=False)) | |
self.session.run(m.init) | |
# Load in a previous checkpoint, or save this one | |
self.extract_model_spec() | |
try: | |
yield | |
finally: | |
tf.Session.reset('') | |
self._session = None | |
def _build_models(self): | |
"""Builds the image models for train and eval.""" | |
# Determine if we should build the train and eval model. When using | |
# distributed training we only want to build one or the other and not both. | |
with tf.variable_scope('model', use_resource=False): | |
m = CifarModel(self.hparams) | |
m.build('train') | |
self._num_trainable_params = m.num_trainable_params | |
self._saver = m.saver | |
with tf.variable_scope('model', reuse=True, use_resource=False): | |
meval = CifarModel(self.hparams) | |
meval.build('eval') | |
return m, meval | |
def _calc_starting_epoch(self, m): | |
"""Calculates the starting epoch for model m based on global step.""" | |
hparams = self.hparams | |
batch_size = hparams.batch_size | |
steps_per_epoch = int(hparams.train_size / batch_size) | |
with self._new_session(m): | |
curr_step = self.session.run(m.global_step) | |
total_steps = steps_per_epoch * hparams.num_epochs | |
epochs_left = (total_steps - curr_step) // steps_per_epoch | |
starting_epoch = hparams.num_epochs - epochs_left | |
return starting_epoch | |
def _run_training_loop(self, m, curr_epoch): | |
"""Trains the cifar model `m` for one epoch.""" | |
start_time = time.time() | |
while True: | |
try: | |
with self._new_session(m): | |
train_accuracy = helper_utils.run_epoch_training( | |
self.session, m, self.data_loader, curr_epoch) | |
tf.logging.info('Saving model after epoch') | |
self.save_model(step=curr_epoch) | |
break | |
except (tf.errors.AbortedError, tf.errors.UnavailableError) as e: | |
tf.logging.info('Retryable error caught: %s. Retrying.', e) | |
tf.logging.info('Finished epoch: {}'.format(curr_epoch)) | |
tf.logging.info('Epoch time(min): {}'.format( | |
(time.time() - start_time) / 60.0)) | |
return train_accuracy | |
def _compute_final_accuracies(self, meval): | |
"""Run once training is finished to compute final val/test accuracies.""" | |
valid_accuracy = self.eval_child_model(meval, self.data_loader, 'val') | |
if self.hparams.eval_test: | |
test_accuracy = self.eval_child_model(meval, self.data_loader, 'test') | |
else: | |
test_accuracy = 0 | |
tf.logging.info('Test Accuracy: {}'.format(test_accuracy)) | |
return valid_accuracy, test_accuracy | |
def run_model(self): | |
"""Trains and evalutes the image model.""" | |
hparams = self.hparams | |
# Build the child graph | |
with tf.Graph().as_default(), tf.device( | |
'/cpu:0' if FLAGS.use_cpu else '/gpu:0'): | |
m, meval = self._build_models() | |
# Figure out what epoch we are on | |
starting_epoch = self._calc_starting_epoch(m) | |
# Run the validation error right at the beginning | |
valid_accuracy = self.eval_child_model( | |
meval, self.data_loader, 'val') | |
tf.logging.info('Before Training Epoch: {} Val Acc: {}'.format( | |
starting_epoch, valid_accuracy)) | |
training_accuracy = None | |
for curr_epoch in xrange(starting_epoch, hparams.num_epochs): | |
# Run one training epoch | |
training_accuracy = self._run_training_loop(m, curr_epoch) | |
valid_accuracy = self.eval_child_model( | |
meval, self.data_loader, 'val') | |
tf.logging.info('Epoch: {} Valid Acc: {}'.format( | |
curr_epoch, valid_accuracy)) | |
valid_accuracy, test_accuracy = self._compute_final_accuracies( | |
meval) | |
tf.logging.info( | |
'Train Acc: {} Valid Acc: {} Test Acc: {}'.format( | |
training_accuracy, valid_accuracy, test_accuracy)) | |
def saver(self): | |
return self._saver | |
def session(self): | |
return self._session | |
def num_trainable_params(self): | |
return self._num_trainable_params | |
def main(_): | |
if FLAGS.dataset not in ['cifar10', 'cifar100']: | |
raise ValueError('Invalid dataset: %s' % FLAGS.dataset) | |
hparams = tf.contrib.training.HParams( | |
train_size=50000, | |
validation_size=0, | |
eval_test=1, | |
dataset=FLAGS.dataset, | |
data_path=FLAGS.data_path, | |
batch_size=128, | |
gradient_clipping_by_global_norm=5.0) | |
if FLAGS.model_name == 'wrn': | |
hparams.add_hparam('model_name', 'wrn') | |
hparams.add_hparam('num_epochs', 200) | |
hparams.add_hparam('wrn_size', 160) | |
hparams.add_hparam('lr', 0.1) | |
hparams.add_hparam('weight_decay_rate', 5e-4) | |
elif FLAGS.model_name == 'shake_shake_32': | |
hparams.add_hparam('model_name', 'shake_shake') | |
hparams.add_hparam('num_epochs', 1800) | |
hparams.add_hparam('shake_shake_widen_factor', 2) | |
hparams.add_hparam('lr', 0.01) | |
hparams.add_hparam('weight_decay_rate', 0.001) | |
elif FLAGS.model_name == 'shake_shake_96': | |
hparams.add_hparam('model_name', 'shake_shake') | |
hparams.add_hparam('num_epochs', 1800) | |
hparams.add_hparam('shake_shake_widen_factor', 6) | |
hparams.add_hparam('lr', 0.01) | |
hparams.add_hparam('weight_decay_rate', 0.001) | |
elif FLAGS.model_name == 'shake_shake_112': | |
hparams.add_hparam('model_name', 'shake_shake') | |
hparams.add_hparam('num_epochs', 1800) | |
hparams.add_hparam('shake_shake_widen_factor', 7) | |
hparams.add_hparam('lr', 0.01) | |
hparams.add_hparam('weight_decay_rate', 0.001) | |
elif FLAGS.model_name == 'pyramid_net': | |
hparams.add_hparam('model_name', 'pyramid_net') | |
hparams.add_hparam('num_epochs', 1800) | |
hparams.add_hparam('lr', 0.05) | |
hparams.add_hparam('weight_decay_rate', 5e-5) | |
hparams.batch_size = 64 | |
else: | |
raise ValueError('Not Valid Model Name: %s' % FLAGS.model_name) | |
cifar_trainer = CifarModelTrainer(hparams) | |
cifar_trainer.run_model() | |
if __name__ == '__main__': | |
tf.logging.set_verbosity(tf.logging.INFO) | |
tf.app.run() | |