# Copyright 2018 Google Inc. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Program which train models.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function from absl import app from absl import flags import tensorflow as tf import adversarial_attack import model_lib from datasets import dataset_factory FLAGS = flags.FLAGS flags.DEFINE_integer('max_steps', -1, 'Number of steps to stop at.') flags.DEFINE_string('output_dir', None, 'Training directory where checkpoints will be saved.') flags.DEFINE_integer('ps_tasks', 0, 'Number of parameter servers.') flags.DEFINE_integer('task', 0, 'Task ID for running distributed training.') flags.DEFINE_string('master', '', 'Tensorflow master.') flags.DEFINE_string('model_name', 'resnet_v2_50', 'Name of the model.') flags.DEFINE_string('dataset', 'imagenet', 'Dataset: "tiny_imagenet" or "imagenet".') flags.DEFINE_integer('dataset_image_size', 64, 'Size of the images in the dataset.') flags.DEFINE_integer('num_summary_images', 3, 'Number of images to display in Tensorboard.') flags.DEFINE_integer( 'save_summaries_steps', 100, 'The frequency with which summaries are saved, in steps.') flags.DEFINE_integer( 'save_summaries_secs', None, 'The frequency with which summaries are saved, in seconds.') flags.DEFINE_integer( 'save_model_steps', 500, 'The frequency with which the model is saved, in steps.') flags.DEFINE_string('hparams', '', 'Hyper parameters.') flags.DEFINE_integer('replicas_to_aggregate', 1, 'Number of gradients to collect before param updates.') flags.DEFINE_integer('worker_replicas', 1, 'Number of worker replicas.') flags.DEFINE_float('moving_average_decay', 0.9999, 'The decay to use for the moving average.') # Flags to control fine tuning flags.DEFINE_string('finetune_checkpoint_path', None, 'Path to checkpoint for fine tuning. ' 'If None then no fine tuning is done.') flags.DEFINE_string('finetune_exclude_pretrained_scopes', '', 'Variable scopes to exclude when loading checkpoint for ' 'fine tuning.') flags.DEFINE_string('finetune_trainable_scopes', None, 'If set then it defines list of variable scopes for ' 'trainable variables.') def _get_finetuning_init_fn(variable_averages): """Returns an init functions, used for fine tuning.""" if not FLAGS.finetune_checkpoint_path: return None if tf.train.latest_checkpoint(FLAGS.output_dir): return None if tf.gfile.IsDirectory(FLAGS.finetune_checkpoint_path): checkpoint_path = tf.train.latest_checkpoint(FLAGS.finetune_checkpoint_path) else: checkpoint_path = FLAGS.finetune_checkpoint_path if not checkpoint_path: tf.logging.warning('Not doing fine tuning, can not find checkpoint in %s', FLAGS.finetune_checkpoint_path) return None tf.logging.info('Fine-tuning from %s', checkpoint_path) if FLAGS.finetune_exclude_pretrained_scopes: exclusions = { scope.strip() for scope in FLAGS.finetune_exclude_pretrained_scopes.split(',') } else: exclusions = set() filtered_model_variables = [ v for v in tf.contrib.framework.get_model_variables() if not any([v.op.name.startswith(e) for e in exclusions]) ] if variable_averages: variables_to_restore = {} for v in filtered_model_variables: # variables_to_restore[variable_averages.average_name(v)] = v if v in tf.trainable_variables(): variables_to_restore[variable_averages.average_name(v)] = v else: variables_to_restore[v.op.name] = v else: variables_to_restore = {v.op.name: v for v in filtered_model_variables} assign_fn = tf.contrib.framework.assign_from_checkpoint_fn( checkpoint_path, variables_to_restore) if assign_fn: return lambda _, sess: assign_fn(sess) else: return None def main(_): assert FLAGS.output_dir, '--output_dir has to be provided' if not tf.gfile.Exists(FLAGS.output_dir): tf.gfile.MakeDirs(FLAGS.output_dir) params = model_lib.default_hparams() params.parse(FLAGS.hparams) tf.logging.info('User provided hparams: %s', FLAGS.hparams) tf.logging.info('All hyper parameters: %s', params) batch_size = params.batch_size graph = tf.Graph() with graph.as_default(): with tf.device(tf.train.replica_device_setter(ps_tasks=FLAGS.ps_tasks)): # dataset dataset, examples_per_epoch, num_classes, bounds = ( dataset_factory.get_dataset( FLAGS.dataset, 'train', batch_size, FLAGS.dataset_image_size, is_training=True)) dataset_iterator = dataset.make_one_shot_iterator() images, labels = dataset_iterator.get_next() one_hot_labels = tf.one_hot(labels, num_classes) # set up model global_step = tf.train.get_or_create_global_step() model_fn = model_lib.get_model(FLAGS.model_name, num_classes) if params.train_adv_method == 'clean': logits = model_fn(images, is_training=True) adv_examples = None else: model_fn_eval_mode = lambda x: model_fn(x, is_training=False) adv_examples = adversarial_attack.generate_adversarial_examples( images, bounds, model_fn_eval_mode, params.train_adv_method) all_examples = tf.concat([images, adv_examples], axis=0) logits = model_fn(all_examples, is_training=True) one_hot_labels = tf.concat([one_hot_labels, one_hot_labels], axis=0) # update trainable variables if fine tuning is used model_lib.filter_trainable_variables( FLAGS.finetune_trainable_scopes) # set up losses total_loss = tf.losses.softmax_cross_entropy( onehot_labels=one_hot_labels, logits=logits, label_smoothing=params.label_smoothing) tf.summary.scalar('loss_xent', total_loss) if params.train_lp_weight > 0: images1, images2 = tf.split(logits, 2) loss_lp = tf.losses.mean_squared_error( images1, images2, weights=params.train_lp_weight) tf.summary.scalar('loss_lp', loss_lp) total_loss += loss_lp if params.weight_decay > 0: loss_wd = ( params.weight_decay * tf.add_n([tf.nn.l2_loss(v) for v in tf.trainable_variables()]) ) tf.summary.scalar('loss_wd', loss_wd) total_loss += loss_wd # Setup the moving averages: if FLAGS.moving_average_decay and (FLAGS.moving_average_decay > 0): with tf.name_scope('moving_average'): moving_average_variables = tf.contrib.framework.get_model_variables() variable_averages = tf.train.ExponentialMovingAverage( FLAGS.moving_average_decay, global_step) else: moving_average_variables = None variable_averages = None # set up optimizer and training op learning_rate, steps_per_epoch = model_lib.get_lr_schedule( params, examples_per_epoch, FLAGS.replicas_to_aggregate) optimizer = model_lib.get_optimizer(params, learning_rate) optimizer = tf.train.SyncReplicasOptimizer( opt=optimizer, replicas_to_aggregate=FLAGS.replicas_to_aggregate, total_num_replicas=FLAGS.worker_replicas, variable_averages=variable_averages, variables_to_average=moving_average_variables) train_op = tf.contrib.training.create_train_op( total_loss, optimizer, update_ops=tf.get_collection(tf.GraphKeys.UPDATE_OPS)) tf.summary.image('images', images[0:FLAGS.num_summary_images]) if adv_examples is not None: tf.summary.image('adv_images', adv_examples[0:FLAGS.num_summary_images]) tf.summary.scalar('total_loss', total_loss) tf.summary.scalar('learning_rate', learning_rate) tf.summary.scalar('current_epoch', tf.to_double(global_step) / steps_per_epoch) # Training is_chief = FLAGS.task == 0 scaffold = tf.train.Scaffold( init_fn=_get_finetuning_init_fn(variable_averages)) hooks = [ tf.train.LoggingTensorHook({'total_loss': total_loss, 'global_step': global_step}, every_n_iter=1), tf.train.NanTensorHook(total_loss), ] chief_only_hooks = [ tf.train.SummarySaverHook(save_steps=FLAGS.save_summaries_steps, save_secs=FLAGS.save_summaries_secs, output_dir=FLAGS.output_dir, scaffold=scaffold), tf.train.CheckpointSaverHook(FLAGS.output_dir, save_steps=FLAGS.save_model_steps, scaffold=scaffold), ] if FLAGS.max_steps > 0: hooks.append( tf.train.StopAtStepHook(last_step=FLAGS.max_steps)) # hook for sync replica training hooks.append(optimizer.make_session_run_hook(is_chief)) with tf.train.MonitoredTrainingSession( master=FLAGS.master, is_chief=is_chief, checkpoint_dir=FLAGS.output_dir, scaffold=scaffold, hooks=hooks, chief_only_hooks=chief_only_hooks, save_checkpoint_secs=None, save_summaries_steps=None, save_summaries_secs=None) as session: while not session.should_stop(): session.run([train_op]) if __name__ == '__main__': app.run(main)