# Copyright 2017 The TensorFlow Authors All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Launch example: [IMDB] python train_mask_gan.py --data_dir /tmp/imdb --data_set imdb --batch_size 128 --sequence_length 20 --base_directory /tmp/maskGAN_v0.01 --hparams="gen_rnn_size=650,gen_num_layers=2,dis_rnn_size=650,dis_num_layers=2 ,critic_learning_rate=0.0009756,dis_learning_rate=0.0000585, dis_train_iterations=8,gen_learning_rate=0.0016624, gen_full_learning_rate_steps=1e9,gen_learning_rate_decay=0.999999, rl_discount_rate=0.8835659" --mode TRAIN --max_steps 1000000 --generator_model seq2seq_vd --discriminator_model seq2seq_vd --is_present_rate 0.5 --summaries_every 25 --print_every 25 --max_num_to_print=3 --generator_optimizer=adam --seq2seq_share_embedding=True --baseline_method=critic --attention_option=luong --n_gram_eval=4 --mask_strategy=contiguous --gen_training_strategy=reinforce --dis_pretrain_steps=100 --perplexity_threshold=1000000 --dis_share_embedding=True --maskgan_ckpt /tmp/model.ckpt-171091 """ from __future__ import absolute_import from __future__ import division from __future__ import print_function import collections from functools import partial import os import time # Dependency imports import numpy as np from six.moves import xrange import tensorflow as tf import pretrain_mask_gan from data import imdb_loader from data import ptb_loader from model_utils import helper from model_utils import model_construction from model_utils import model_losses from model_utils import model_optimization # Data. from model_utils import model_utils from model_utils import n_gram from models import evaluation_utils from models import rollout np.set_printoptions(precision=3) np.set_printoptions(suppress=True) MODE_TRAIN = 'TRAIN' MODE_TRAIN_EVAL = 'TRAIN_EVAL' MODE_VALIDATION = 'VALIDATION' MODE_TEST = 'TEST' ## Binary and setup FLAGS. tf.app.flags.DEFINE_enum( 'mode', 'TRAIN', [MODE_TRAIN, MODE_VALIDATION, MODE_TEST, MODE_TRAIN_EVAL], 'What this binary will do.') tf.app.flags.DEFINE_string('master', '', """Name of the TensorFlow master to use.""") tf.app.flags.DEFINE_string('eval_master', '', """Name prefix of the Tensorflow eval master.""") tf.app.flags.DEFINE_integer('task', 0, """Task id of the replica running the training.""") tf.app.flags.DEFINE_integer('ps_tasks', 0, """Number of tasks in the ps job. If 0 no ps job is used.""") ## General FLAGS. tf.app.flags.DEFINE_string( 'hparams', '', 'Comma separated list of name=value hyperparameter pairs.') tf.app.flags.DEFINE_integer('batch_size', 20, 'The batch size.') tf.app.flags.DEFINE_integer('vocab_size', 10000, 'The vocabulary size.') tf.app.flags.DEFINE_integer('sequence_length', 20, 'The sequence length.') tf.app.flags.DEFINE_integer('max_steps', 1000000, 'Maximum number of steps to run.') tf.app.flags.DEFINE_string( 'mask_strategy', 'random', 'Strategy for masking the words. Determine the ' 'characterisitics of how the words are dropped out. One of ' "['contiguous', 'random'].") tf.app.flags.DEFINE_float('is_present_rate', 0.5, 'Percent of tokens present in the forward sequence.') tf.app.flags.DEFINE_float('is_present_rate_decay', None, 'Decay rate for the ' 'percent of words that are real (are present).') tf.app.flags.DEFINE_string( 'generator_model', 'seq2seq', "Type of Generator model. One of ['rnn', 'seq2seq', 'seq2seq_zaremba'," "'rnn_zaremba', 'rnn_nas', 'seq2seq_nas']") tf.app.flags.DEFINE_string( 'attention_option', None, "Attention mechanism. One of [None, 'luong', 'bahdanau']") tf.app.flags.DEFINE_string( 'discriminator_model', 'bidirectional', "Type of Discriminator model. One of ['cnn', 'rnn', 'bidirectional', " "'rnn_zaremba', 'bidirectional_zaremba', 'rnn_nas', 'rnn_vd', 'seq2seq_vd']" ) tf.app.flags.DEFINE_boolean('seq2seq_share_embedding', False, 'Whether to share the ' 'embeddings between the encoder and decoder.') tf.app.flags.DEFINE_boolean( 'dis_share_embedding', False, 'Whether to share the ' 'embeddings between the generator and discriminator.') tf.app.flags.DEFINE_boolean('dis_update_share_embedding', False, 'Whether the ' 'discriminator should update the shared embedding.') tf.app.flags.DEFINE_boolean('use_gen_mode', False, 'Use the mode of the generator ' 'to produce samples.') tf.app.flags.DEFINE_boolean('critic_update_dis_vars', False, 'Whether the critic ' 'updates the discriminator variables.') ## Training FLAGS. tf.app.flags.DEFINE_string( 'gen_training_strategy', 'reinforce', "Method for training the Generator. One of ['cross_entropy', 'reinforce']") tf.app.flags.DEFINE_string( 'generator_optimizer', 'adam', "Type of Generator optimizer. One of ['sgd', 'adam']") tf.app.flags.DEFINE_float('grad_clipping', 10., 'Norm for gradient clipping.') tf.app.flags.DEFINE_float('advantage_clipping', 5., 'Clipping for advantages.') tf.app.flags.DEFINE_string( 'baseline_method', None, "Approach for baseline. One of ['critic', 'dis_batch', 'ema', None]") tf.app.flags.DEFINE_float('perplexity_threshold', 15000, 'Limit for perplexity before terminating job.') tf.app.flags.DEFINE_float('zoneout_drop_prob', 0.1, 'Probability for dropping parameter for zoneout.') tf.app.flags.DEFINE_float('keep_prob', 0.5, 'Probability for keeping parameter for dropout.') ## Logging and evaluation FLAGS. tf.app.flags.DEFINE_integer('print_every', 250, 'Frequency to print and log the ' 'outputs of the model.') tf.app.flags.DEFINE_integer('max_num_to_print', 5, 'Number of samples to log/print.') tf.app.flags.DEFINE_boolean('print_verbose', False, 'Whether to print in full.') tf.app.flags.DEFINE_integer('summaries_every', 100, 'Frequency to compute summaries.') tf.app.flags.DEFINE_boolean('eval_language_model', False, 'Whether to evaluate on ' 'all words as in language modeling.') tf.app.flags.DEFINE_float('eval_interval_secs', 60, 'Delay for evaluating model.') tf.app.flags.DEFINE_integer( 'n_gram_eval', 4, """The degree of the n-grams to use for evaluation.""") tf.app.flags.DEFINE_integer( 'epoch_size_override', None, 'If an integer, this dictates the size of the epochs and will potentially ' 'not iterate over all the data.') tf.app.flags.DEFINE_integer('eval_epoch_size_override', None, 'Number of evaluation steps.') ## Directories and checkpoints. tf.app.flags.DEFINE_string('base_directory', '/tmp/maskGAN_v0.00', 'Base directory for the logging, events and graph.') tf.app.flags.DEFINE_string('data_set', 'ptb', 'Data set to operate on. One of' "['ptb', 'imdb']") tf.app.flags.DEFINE_string('data_dir', '/tmp/data/ptb', 'Directory for the training data.') tf.app.flags.DEFINE_string( 'language_model_ckpt_dir', None, 'Directory storing checkpoints to initialize the model. Pretrained models' 'are stored at /tmp/maskGAN/pretrained/') tf.app.flags.DEFINE_string( 'language_model_ckpt_dir_reversed', None, 'Directory storing checkpoints of reversed models to initialize the model.' 'Pretrained models stored at' 'are stored at /tmp/PTB/pretrained_reversed') tf.app.flags.DEFINE_string( 'maskgan_ckpt', None, 'Override which checkpoint file to use to restore the ' 'model. A pretrained seq2seq_zaremba model is stored at ' '/tmp/maskGAN/pretrain/seq2seq_zaremba/train/model.ckpt-64912') tf.app.flags.DEFINE_boolean('wasserstein_objective', False, '(DEPRECATED) Whether to use the WGAN training.') tf.app.flags.DEFINE_integer('num_rollouts', 1, 'The number of rolled out predictions to make.') tf.app.flags.DEFINE_float('c_lower', -0.01, 'Lower bound for weights.') tf.app.flags.DEFINE_float('c_upper', 0.01, 'Upper bound for weights.') FLAGS = tf.app.flags.FLAGS def create_hparams(): """Create the hparams object for generic training hyperparameters.""" hparams = tf.contrib.training.HParams( gen_num_layers=2, dis_num_layers=2, gen_rnn_size=740, dis_rnn_size=740, gen_learning_rate=5e-4, dis_learning_rate=5e-3, critic_learning_rate=5e-3, dis_train_iterations=1, gen_learning_rate_decay=1.0, gen_full_learning_rate_steps=1e7, baseline_decay=0.999999, rl_discount_rate=0.9, gen_vd_keep_prob=0.5, dis_vd_keep_prob=0.5, dis_pretrain_learning_rate=5e-3, dis_num_filters=128, dis_hidden_dim=128, gen_nas_keep_prob_0=0.85, gen_nas_keep_prob_1=0.55, dis_nas_keep_prob_0=0.85, dis_nas_keep_prob_1=0.55) # Command line flags override any of the preceding hyperparameter values. if FLAGS.hparams: hparams = hparams.parse(FLAGS.hparams) return hparams def create_MaskGAN(hparams, is_training): """Create the MaskGAN model. Args: hparams: Hyperparameters for the MaskGAN. is_training: Boolean indicating operational mode (train/inference). evaluated with a teacher forcing regime. Return: model: Namedtuple for specifying the MaskGAN. """ global_step = tf.Variable(0, name='global_step', trainable=False) new_learning_rate = tf.placeholder(tf.float32, [], name='new_learning_rate') learning_rate = tf.Variable(0.0, name='learning_rate', trainable=False) learning_rate_update = tf.assign(learning_rate, new_learning_rate) new_rate = tf.placeholder(tf.float32, [], name='new_rate') percent_real_var = tf.Variable(0.0, trainable=False) percent_real_update = tf.assign(percent_real_var, new_rate) ## Placeholders. inputs = tf.placeholder( tf.int32, shape=[FLAGS.batch_size, FLAGS.sequence_length]) targets = tf.placeholder( tf.int32, shape=[FLAGS.batch_size, FLAGS.sequence_length]) present = tf.placeholder( tf.bool, shape=[FLAGS.batch_size, FLAGS.sequence_length]) # TODO(adai): Placeholder for IMDB label. ## Real Sequence is the targets. real_sequence = targets ## Fakse Sequence from the Generator. # TODO(adai): Generator must have IMDB labels placeholder. (fake_sequence, fake_logits, fake_log_probs, fake_gen_initial_state, fake_gen_final_state, _) = model_construction.create_generator( hparams, inputs, targets, present, is_training=is_training, is_validating=False) (_, eval_logits, _, eval_initial_state, eval_final_state, _) = model_construction.create_generator( hparams, inputs, targets, present, is_training=False, is_validating=True, reuse=True) ## Discriminator. fake_predictions = model_construction.create_discriminator( hparams, fake_sequence, is_training=is_training, inputs=inputs, present=present) real_predictions = model_construction.create_discriminator( hparams, real_sequence, is_training=is_training, reuse=True, inputs=inputs, present=present) ## Critic. # The critic will be used to estimate the forward rewards to the Generator. if FLAGS.baseline_method == 'critic': est_state_values = model_construction.create_critic( hparams, fake_sequence, is_training=is_training) else: est_state_values = None ## Discriminator Loss. [dis_loss, dis_loss_fake, dis_loss_real] = model_losses.create_dis_loss( fake_predictions, real_predictions, present) ## Average log-perplexity for only missing words. However, to do this, # the logits are still computed using teacher forcing, that is, the ground # truth tokens are fed in at each time point to be valid. avg_log_perplexity = model_losses.calculate_log_perplexity( eval_logits, targets, present) ## Generator Objective. # 1. Cross Entropy losses on missing tokens. fake_cross_entropy_losses = model_losses.create_masked_cross_entropy_loss( targets, present, fake_logits) # 2. GAN REINFORCE losses. [ fake_RL_loss, fake_log_probs, fake_rewards, fake_advantages, fake_baselines, fake_averages_op, critic_loss, cumulative_rewards ] = model_losses.calculate_reinforce_objective( hparams, fake_log_probs, fake_predictions, present, est_state_values) ## Pre-training. if FLAGS.gen_pretrain_steps: raise NotImplementedError # # TODO(liamfedus): Rewrite this. # fwd_cross_entropy_loss = tf.reduce_mean(fwd_cross_entropy_losses) # gen_pretrain_op = model_optimization.create_gen_pretrain_op( # hparams, fwd_cross_entropy_loss, global_step) else: gen_pretrain_op = None if FLAGS.dis_pretrain_steps: dis_pretrain_op = model_optimization.create_dis_pretrain_op( hparams, dis_loss, global_step) else: dis_pretrain_op = None ## Generator Train Op. # 1. Cross-Entropy. if FLAGS.gen_training_strategy == 'cross_entropy': gen_loss = tf.reduce_mean(fake_cross_entropy_losses) [gen_train_op, gen_grads, gen_vars] = model_optimization.create_gen_train_op( hparams, learning_rate, gen_loss, global_step, mode='MINIMIZE') # 2. GAN (REINFORCE) elif FLAGS.gen_training_strategy == 'reinforce': gen_loss = fake_RL_loss [gen_train_op, gen_grads, gen_vars] = model_optimization.create_reinforce_gen_train_op( hparams, learning_rate, gen_loss, fake_averages_op, global_step) else: raise NotImplementedError ## Discriminator Train Op. dis_train_op, dis_grads, dis_vars = model_optimization.create_dis_train_op( hparams, dis_loss, global_step) ## Critic Train Op. if critic_loss is not None: [critic_train_op, _, _] = model_optimization.create_critic_train_op( hparams, critic_loss, global_step) dis_train_op = tf.group(dis_train_op, critic_train_op) ## Summaries. with tf.name_scope('general'): tf.summary.scalar('percent_real', percent_real_var) tf.summary.scalar('learning_rate', learning_rate) with tf.name_scope('generator_objectives'): tf.summary.scalar('gen_objective', tf.reduce_mean(gen_loss)) tf.summary.scalar('gen_loss_cross_entropy', tf.reduce_mean(fake_cross_entropy_losses)) with tf.name_scope('REINFORCE'): with tf.name_scope('objective'): tf.summary.scalar('fake_RL_loss', tf.reduce_mean(fake_RL_loss)) with tf.name_scope('rewards'): helper.variable_summaries(cumulative_rewards, 'rewards') with tf.name_scope('advantages'): helper.variable_summaries(fake_advantages, 'advantages') with tf.name_scope('baselines'): helper.variable_summaries(fake_baselines, 'baselines') with tf.name_scope('log_probs'): helper.variable_summaries(fake_log_probs, 'log_probs') with tf.name_scope('discriminator_losses'): tf.summary.scalar('dis_loss', dis_loss) tf.summary.scalar('dis_loss_fake_sequence', dis_loss_fake) tf.summary.scalar('dis_loss_prob_fake_sequence', tf.exp(-dis_loss_fake)) tf.summary.scalar('dis_loss_real_sequence', dis_loss_real) tf.summary.scalar('dis_loss_prob_real_sequence', tf.exp(-dis_loss_real)) if critic_loss is not None: with tf.name_scope('critic_losses'): tf.summary.scalar('critic_loss', critic_loss) with tf.name_scope('logits'): helper.variable_summaries(fake_logits, 'fake_logits') for v, g in zip(gen_vars, gen_grads): helper.variable_summaries(v, v.op.name) helper.variable_summaries(g, 'grad/' + v.op.name) for v, g in zip(dis_vars, dis_grads): helper.variable_summaries(v, v.op.name) helper.variable_summaries(g, 'grad/' + v.op.name) merge_summaries_op = tf.summary.merge_all() text_summary_placeholder = tf.placeholder(tf.string) text_summary_op = tf.summary.text('Samples', text_summary_placeholder) # Model saver. saver = tf.train.Saver(keep_checkpoint_every_n_hours=1, max_to_keep=5) # Named tuple that captures elements of the MaskGAN model. Model = collections.namedtuple('Model', [ 'inputs', 'targets', 'present', 'percent_real_update', 'new_rate', 'fake_sequence', 'fake_logits', 'fake_rewards', 'fake_baselines', 'fake_advantages', 'fake_log_probs', 'fake_predictions', 'real_predictions', 'fake_cross_entropy_losses', 'fake_gen_initial_state', 'fake_gen_final_state', 'eval_initial_state', 'eval_final_state', 'avg_log_perplexity', 'dis_loss', 'gen_loss', 'critic_loss', 'cumulative_rewards', 'dis_train_op', 'gen_train_op', 'gen_pretrain_op', 'dis_pretrain_op', 'merge_summaries_op', 'global_step', 'new_learning_rate', 'learning_rate_update', 'saver', 'text_summary_op', 'text_summary_placeholder' ]) model = Model( inputs, targets, present, percent_real_update, new_rate, fake_sequence, fake_logits, fake_rewards, fake_baselines, fake_advantages, fake_log_probs, fake_predictions, real_predictions, fake_cross_entropy_losses, fake_gen_initial_state, fake_gen_final_state, eval_initial_state, eval_final_state, avg_log_perplexity, dis_loss, gen_loss, critic_loss, cumulative_rewards, dis_train_op, gen_train_op, gen_pretrain_op, dis_pretrain_op, merge_summaries_op, global_step, new_learning_rate, learning_rate_update, saver, text_summary_op, text_summary_placeholder) return model def compute_geometric_average(percent_captured): """Compute the geometric average of the n-gram metrics.""" res = 1. for _, n_gram_percent in percent_captured.iteritems(): res *= n_gram_percent return np.power(res, 1. / float(len(percent_captured))) def compute_arithmetic_average(percent_captured): """Compute the arithmetic average of the n-gram metrics.""" N = len(percent_captured) res = 0. for _, n_gram_percent in percent_captured.iteritems(): res += n_gram_percent return res / float(N) def get_iterator(data): """Return the data iterator.""" if FLAGS.data_set == 'ptb': iterator = ptb_loader.ptb_iterator(data, FLAGS.batch_size, FLAGS.sequence_length, FLAGS.epoch_size_override) elif FLAGS.data_set == 'imdb': iterator = imdb_loader.imdb_iterator(data, FLAGS.batch_size, FLAGS.sequence_length) return iterator def train_model(hparams, data, log_dir, log, id_to_word, data_ngram_counts): """Train model. Args: hparams: Hyperparameters for the MaskGAN. data: Data to evaluate. log_dir: Directory to save checkpoints. log: Readable log for the experiment. id_to_word: Dictionary of indices to words. data_ngram_counts: Dictionary of hashed(n-gram tuples) to counts in the data_set. """ print('Training model.') tf.logging.info('Training model.') # Boolean indicating operational mode. is_training = True # Write all the information to the logs. log.write('hparams\n') log.write(str(hparams)) log.flush() is_chief = FLAGS.task == 0 with tf.Graph().as_default(): with tf.device(tf.train.replica_device_setter(FLAGS.ps_tasks)): container_name = '' with tf.container(container_name): # Construct the model. if FLAGS.num_rollouts == 1: model = create_MaskGAN(hparams, is_training) elif FLAGS.num_rollouts > 1: model = rollout.create_rollout_MaskGAN(hparams, is_training) else: raise ValueError print('\nTrainable Variables in Graph:') for v in tf.trainable_variables(): print(v) ## Retrieve the initial savers. init_savers = model_utils.retrieve_init_savers(hparams) ## Initial saver function to supervisor. init_fn = partial(model_utils.init_fn, init_savers) # Create the supervisor. It will take care of initialization, # summaries, checkpoints, and recovery. sv = tf.train.Supervisor( logdir=log_dir, is_chief=is_chief, saver=model.saver, global_step=model.global_step, save_model_secs=60, recovery_wait_secs=30, summary_op=None, init_fn=init_fn) # Get an initialized, and possibly recovered session. Launch the # services: Checkpointing, Summaries, step counting. # # When multiple replicas of this program are running the services are # only launched by the 'chief' replica. with sv.managed_session(FLAGS.master) as sess: ## Pretrain the generator. if FLAGS.gen_pretrain_steps: pretrain_mask_gan.pretrain_generator(sv, sess, model, data, log, id_to_word, data_ngram_counts, is_chief) ## Pretrain the discriminator. if FLAGS.dis_pretrain_steps: pretrain_mask_gan.pretrain_discriminator( sv, sess, model, data, log, id_to_word, data_ngram_counts, is_chief) # Initial indicators for printing and summarizing. print_step_division = -1 summary_step_division = -1 # Run iterative computation in a loop. while not sv.ShouldStop(): is_present_rate = FLAGS.is_present_rate if FLAGS.is_present_rate_decay is not None: is_present_rate *= (1. - FLAGS.is_present_rate_decay) model_utils.assign_percent_real(sess, model.percent_real_update, model.new_rate, is_present_rate) # GAN training. avg_epoch_gen_loss, avg_epoch_dis_loss = [], [] cumulative_costs = 0. gen_iters = 0 # Generator and Discriminator statefulness initial evaluation. # TODO(liamfedus): Throughout the code I am implicitly assuming # that the Generator and Discriminator are equal sized. [gen_initial_state_eval, fake_gen_initial_state_eval] = sess.run( [model.eval_initial_state, model.fake_gen_initial_state]) dis_initial_state_eval = fake_gen_initial_state_eval # Save zeros state to reset later. zeros_state = fake_gen_initial_state_eval ## Offset Discriminator. if FLAGS.ps_tasks == 0: dis_offset = 1 else: dis_offset = FLAGS.task * 1000 + 1 dis_iterator = get_iterator(data) for i in range(dis_offset): try: dis_x, dis_y, _ = next(dis_iterator) except StopIteration: dis_iterator = get_iterator(data) dis_initial_state_eval = zeros_state dis_x, dis_y, _ = next(dis_iterator) p = model_utils.generate_mask() # Construct the train feed. train_feed = { model.inputs: dis_x, model.targets: dis_y, model.present: p } if FLAGS.data_set == 'ptb': # Statefulness of the Generator being used for Discriminator. for i, (c, h) in enumerate(model.fake_gen_initial_state): train_feed[c] = dis_initial_state_eval[i].c train_feed[h] = dis_initial_state_eval[i].h # Determine the state had the Generator run over real data. We # use this state for the Discriminator. [dis_initial_state_eval] = sess.run( [model.fake_gen_final_state], train_feed) ## Training loop. iterator = get_iterator(data) gen_initial_state_eval = zeros_state if FLAGS.ps_tasks > 0: gen_offset = FLAGS.task * 1000 + 1 for i in range(gen_offset): try: next(iterator) except StopIteration: dis_iterator = get_iterator(data) dis_initial_state_eval = zeros_state next(dis_iterator) for x, y, _ in iterator: for _ in xrange(hparams.dis_train_iterations): try: dis_x, dis_y, _ = next(dis_iterator) except StopIteration: dis_iterator = get_iterator(data) dis_initial_state_eval = zeros_state dis_x, dis_y, _ = next(dis_iterator) if FLAGS.data_set == 'ptb': [dis_initial_state_eval] = sess.run( [model.fake_gen_initial_state]) p = model_utils.generate_mask() # Construct the train feed. train_feed = { model.inputs: dis_x, model.targets: dis_y, model.present: p } # Statefulness for the Discriminator. if FLAGS.data_set == 'ptb': for i, (c, h) in enumerate(model.fake_gen_initial_state): train_feed[c] = dis_initial_state_eval[i].c train_feed[h] = dis_initial_state_eval[i].h _, dis_loss_eval, step = sess.run( [model.dis_train_op, model.dis_loss, model.global_step], feed_dict=train_feed) # Determine the state had the Generator run over real data. # Use this state for the Discriminator. [dis_initial_state_eval] = sess.run( [model.fake_gen_final_state], train_feed) # Randomly mask out tokens. p = model_utils.generate_mask() # Construct the train feed. train_feed = {model.inputs: x, model.targets: y, model.present: p} # Statefulness for Generator. if FLAGS.data_set == 'ptb': tf.logging.info('Generator is stateful.') print('Generator is stateful.') # Statefulness for *evaluation* Generator. for i, (c, h) in enumerate(model.eval_initial_state): train_feed[c] = gen_initial_state_eval[i].c train_feed[h] = gen_initial_state_eval[i].h # Statefulness for Generator. for i, (c, h) in enumerate(model.fake_gen_initial_state): train_feed[c] = fake_gen_initial_state_eval[i].c train_feed[h] = fake_gen_initial_state_eval[i].h # Determine whether to decay learning rate. lr_decay = hparams.gen_learning_rate_decay**max( step + 1 - hparams.gen_full_learning_rate_steps, 0.0) # Assign learning rate. gen_learning_rate = hparams.gen_learning_rate * lr_decay model_utils.assign_learning_rate(sess, model.learning_rate_update, model.new_learning_rate, gen_learning_rate) [_, gen_loss_eval, gen_log_perplexity_eval, step] = sess.run( [ model.gen_train_op, model.gen_loss, model.avg_log_perplexity, model.global_step ], feed_dict=train_feed) cumulative_costs += gen_log_perplexity_eval gen_iters += 1 # Determine the state had the Generator run over real data. [gen_initial_state_eval, fake_gen_initial_state_eval] = sess.run( [model.eval_final_state, model.fake_gen_final_state], train_feed) avg_epoch_dis_loss.append(dis_loss_eval) avg_epoch_gen_loss.append(gen_loss_eval) ## Summaries. # Calulate rolling perplexity. perplexity = np.exp(cumulative_costs / gen_iters) if is_chief and (step / FLAGS.summaries_every > summary_step_division): summary_step_division = step / FLAGS.summaries_every # Confirm perplexity is not infinite. if (not np.isfinite(perplexity) or perplexity >= FLAGS.perplexity_threshold): print('Training raising FloatingPoinError.') raise FloatingPointError( 'Training infinite perplexity: %.3f' % perplexity) # Graph summaries. summary_str = sess.run( model.merge_summaries_op, feed_dict=train_feed) sv.SummaryComputed(sess, summary_str) # Summary: n-gram avg_percent_captured = {'2': 0., '3': 0., '4': 0.} for n, data_ngram_count in data_ngram_counts.iteritems(): batch_percent_captured = evaluation_utils.sequence_ngram_evaluation( sess, model.fake_sequence, log, train_feed, data_ngram_count, int(n)) summary_percent_str = tf.Summary(value=[ tf.Summary.Value( tag='general/%s-grams_percent_correct' % n, simple_value=batch_percent_captured) ]) sv.SummaryComputed( sess, summary_percent_str, global_step=step) # Summary: geometric_avg geometric_avg = compute_geometric_average(avg_percent_captured) summary_geometric_avg_str = tf.Summary(value=[ tf.Summary.Value( tag='general/geometric_avg', simple_value=geometric_avg) ]) sv.SummaryComputed( sess, summary_geometric_avg_str, global_step=step) # Summary: arithmetic_avg arithmetic_avg = compute_arithmetic_average( avg_percent_captured) summary_arithmetic_avg_str = tf.Summary(value=[ tf.Summary.Value( tag='general/arithmetic_avg', simple_value=arithmetic_avg) ]) sv.SummaryComputed( sess, summary_arithmetic_avg_str, global_step=step) # Summary: perplexity summary_perplexity_str = tf.Summary(value=[ tf.Summary.Value( tag='general/perplexity', simple_value=perplexity) ]) sv.SummaryComputed( sess, summary_perplexity_str, global_step=step) ## Printing and logging if is_chief and (step / FLAGS.print_every > print_step_division): print_step_division = (step / FLAGS.print_every) print('global_step: %d' % step) print(' perplexity: %.3f' % perplexity) print(' gen_learning_rate: %.6f' % gen_learning_rate) log.write('global_step: %d\n' % step) log.write(' perplexity: %.3f\n' % perplexity) log.write(' gen_learning_rate: %.6f' % gen_learning_rate) # Average percent captured for each of the n-grams. avg_percent_captured = {'2': 0., '3': 0., '4': 0.} for n, data_ngram_count in data_ngram_counts.iteritems(): batch_percent_captured = evaluation_utils.sequence_ngram_evaluation( sess, model.fake_sequence, log, train_feed, data_ngram_count, int(n)) avg_percent_captured[n] = batch_percent_captured print(' percent of %s-grams captured: %.3f.' % (n, batch_percent_captured)) log.write(' percent of %s-grams captured: %.3f.\n' % (n, batch_percent_captured)) geometric_avg = compute_geometric_average(avg_percent_captured) print(' geometric_avg: %.3f.' % geometric_avg) log.write(' geometric_avg: %.3f.' % geometric_avg) arithmetic_avg = compute_arithmetic_average( avg_percent_captured) print(' arithmetic_avg: %.3f.' % arithmetic_avg) log.write(' arithmetic_avg: %.3f.' % arithmetic_avg) evaluation_utils.print_and_log_losses( log, step, is_present_rate, avg_epoch_dis_loss, avg_epoch_gen_loss) if FLAGS.gen_training_strategy == 'reinforce': evaluation_utils.generate_RL_logs(sess, model, log, id_to_word, train_feed) else: evaluation_utils.generate_logs(sess, model, log, id_to_word, train_feed) log.flush() log.close() def evaluate_once(data, sv, model, sess, train_dir, log, id_to_word, data_ngram_counts, eval_saver): """Evaluate model for a number of steps. Args: data: Dataset. sv: Supervisor. model: The GAN model we have just built. sess: A session to use. train_dir: Path to a directory containing checkpoints. log: Evaluation log for evaluation. id_to_word: Dictionary of indices to words. data_ngram_counts: Dictionary of hashed(n-gram tuples) to counts in the data_set. eval_saver: Evaluation saver.r. """ tf.logging.info('Evaluate Once.') # Load the last model checkpoint, or initialize the graph. model_save_path = tf.latest_checkpoint(train_dir) if not model_save_path: tf.logging.warning('No checkpoint yet in: %s', train_dir) return tf.logging.info('Starting eval of: %s' % model_save_path) tf.logging.info('Only restoring trainable variables.') eval_saver.restore(sess, model_save_path) # Run the requested number of evaluation steps avg_epoch_gen_loss, avg_epoch_dis_loss = [], [] cumulative_costs = 0. # Average percent captured for each of the n-grams. avg_percent_captured = {'2': 0., '3': 0., '4': 0.} # Set a random seed to keep fixed mask. np.random.seed(0) gen_iters = 0 # Generator statefulness over the epoch. # TODO(liamfedus): Check this. [gen_initial_state_eval, fake_gen_initial_state_eval] = sess.run( [model.eval_initial_state, model.fake_gen_initial_state]) if FLAGS.eval_language_model: is_present_rate = 0. tf.logging.info('Overriding is_present_rate=0. for evaluation.') print('Overriding is_present_rate=0. for evaluation.') iterator = get_iterator(data) for x, y, _ in iterator: if FLAGS.eval_language_model: is_present_rate = 0. else: is_present_rate = FLAGS.is_present_rate tf.logging.info('Evaluating on is_present_rate=%.3f.' % is_present_rate) model_utils.assign_percent_real(sess, model.percent_real_update, model.new_rate, is_present_rate) # Randomly mask out tokens. p = model_utils.generate_mask() eval_feed = {model.inputs: x, model.targets: y, model.present: p} if FLAGS.data_set == 'ptb': # Statefulness for *evaluation* Generator. for i, (c, h) in enumerate(model.eval_initial_state): eval_feed[c] = gen_initial_state_eval[i].c eval_feed[h] = gen_initial_state_eval[i].h # Statefulness for the Generator. for i, (c, h) in enumerate(model.fake_gen_initial_state): eval_feed[c] = fake_gen_initial_state_eval[i].c eval_feed[h] = fake_gen_initial_state_eval[i].h [ gen_log_perplexity_eval, dis_loss_eval, gen_loss_eval, gen_initial_state_eval, fake_gen_initial_state_eval, step ] = sess.run( [ model.avg_log_perplexity, model.dis_loss, model.gen_loss, model.eval_final_state, model.fake_gen_final_state, model.global_step ], feed_dict=eval_feed) for n, data_ngram_count in data_ngram_counts.iteritems(): batch_percent_captured = evaluation_utils.sequence_ngram_evaluation( sess, model.fake_sequence, log, eval_feed, data_ngram_count, int(n)) avg_percent_captured[n] += batch_percent_captured cumulative_costs += gen_log_perplexity_eval avg_epoch_dis_loss.append(dis_loss_eval) avg_epoch_gen_loss.append(gen_loss_eval) gen_iters += 1 # Calulate rolling metrics. perplexity = np.exp(cumulative_costs / gen_iters) for n, _ in avg_percent_captured.iteritems(): avg_percent_captured[n] /= gen_iters # Confirm perplexity is not infinite. if not np.isfinite(perplexity) or perplexity >= FLAGS.perplexity_threshold: print('Evaluation raising FloatingPointError.') raise FloatingPointError( 'Evaluation infinite perplexity: %.3f' % perplexity) ## Printing and logging. evaluation_utils.print_and_log_losses(log, step, is_present_rate, avg_epoch_dis_loss, avg_epoch_gen_loss) print(' perplexity: %.3f' % perplexity) log.write(' perplexity: %.3f\n' % perplexity) for n, n_gram_percent in avg_percent_captured.iteritems(): n = int(n) print(' percent of %d-grams captured: %.3f.' % (n, n_gram_percent)) log.write(' percent of %d-grams captured: %.3f.\n' % (n, n_gram_percent)) samples = evaluation_utils.generate_logs(sess, model, log, id_to_word, eval_feed) ## Summaries. summary_str = sess.run(model.merge_summaries_op, feed_dict=eval_feed) sv.SummaryComputed(sess, summary_str) # Summary: text summary_str = sess.run(model.text_summary_op, {model.text_summary_placeholder: '\n\n'.join(samples)}) sv.SummaryComputed(sess, summary_str, global_step=step) # Summary: n-gram for n, n_gram_percent in avg_percent_captured.iteritems(): n = int(n) summary_percent_str = tf.Summary(value=[ tf.Summary.Value( tag='general/%d-grams_percent_correct' % n, simple_value=n_gram_percent) ]) sv.SummaryComputed(sess, summary_percent_str, global_step=step) # Summary: geometric_avg geometric_avg = compute_geometric_average(avg_percent_captured) summary_geometric_avg_str = tf.Summary(value=[ tf.Summary.Value(tag='general/geometric_avg', simple_value=geometric_avg) ]) sv.SummaryComputed(sess, summary_geometric_avg_str, global_step=step) # Summary: arithmetic_avg arithmetic_avg = compute_arithmetic_average(avg_percent_captured) summary_arithmetic_avg_str = tf.Summary(value=[ tf.Summary.Value( tag='general/arithmetic_avg', simple_value=arithmetic_avg) ]) sv.SummaryComputed(sess, summary_arithmetic_avg_str, global_step=step) # Summary: perplexity summary_perplexity_str = tf.Summary(value=[ tf.Summary.Value(tag='general/perplexity', simple_value=perplexity) ]) sv.SummaryComputed(sess, summary_perplexity_str, global_step=step) def evaluate_model(hparams, data, train_dir, log, id_to_word, data_ngram_counts): """Evaluate MaskGAN model. Args: hparams: Hyperparameters for the MaskGAN. data: Data to evaluate. train_dir: Path to a directory containing checkpoints. id_to_word: Dictionary of indices to words. data_ngram_counts: Dictionary of hashed(n-gram tuples) to counts in the data_set. """ tf.logging.error('Evaluate model.') # Boolean indicating operational mode. is_training = False if FLAGS.mode == MODE_VALIDATION: logdir = FLAGS.base_directory + '/validation' elif FLAGS.mode == MODE_TRAIN_EVAL: logdir = FLAGS.base_directory + '/train_eval' elif FLAGS.mode == MODE_TEST: logdir = FLAGS.base_directory + '/test' else: raise NotImplementedError # Wait for a checkpoint to exist. print(train_dir) print(tf.train.latest_checkpoint(train_dir)) while not tf.train.latest_checkpoint(train_dir): tf.logging.error('Waiting for checkpoint...') print('Waiting for checkpoint...') time.sleep(10) with tf.Graph().as_default(): # Use a separate container for each trial container_name = '' with tf.container(container_name): # Construct the model. if FLAGS.num_rollouts == 1: model = create_MaskGAN(hparams, is_training) elif FLAGS.num_rollouts > 1: model = rollout.create_rollout_MaskGAN(hparams, is_training) else: raise ValueError # Create the supervisor. It will take care of initialization, summaries, # checkpoints, and recovery. We only pass the trainable variables # to load since things like baselines keep batch_size which may not # match between training and evaluation. evaluation_variables = tf.trainable_variables() evaluation_variables.append(model.global_step) eval_saver = tf.train.Saver(var_list=evaluation_variables) sv = tf.Supervisor(logdir=logdir) sess = sv.PrepareSession(FLAGS.eval_master, start_standard_services=False) tf.logging.info('Before sv.Loop.') sv.Loop(FLAGS.eval_interval_secs, evaluate_once, (data, sv, model, sess, train_dir, log, id_to_word, data_ngram_counts, eval_saver)) sv.WaitForStop() tf.logging.info('sv.Stop().') sv.Stop() def main(_): hparams = create_hparams() train_dir = FLAGS.base_directory + '/train' # Load data set. if FLAGS.data_set == 'ptb': raw_data = ptb_loader.ptb_raw_data(FLAGS.data_dir) train_data, valid_data, test_data, _ = raw_data valid_data_flat = valid_data elif FLAGS.data_set == 'imdb': raw_data = imdb_loader.imdb_raw_data(FLAGS.data_dir) # TODO(liamfedus): Get an IMDB test partition. train_data, valid_data = raw_data valid_data_flat = [word for review in valid_data for word in review] else: raise NotImplementedError if FLAGS.mode == MODE_TRAIN or FLAGS.mode == MODE_TRAIN_EVAL: data_set = train_data elif FLAGS.mode == MODE_VALIDATION: data_set = valid_data elif FLAGS.mode == MODE_TEST: data_set = test_data else: raise NotImplementedError # Dictionary and reverse dictionry. if FLAGS.data_set == 'ptb': word_to_id = ptb_loader.build_vocab( os.path.join(FLAGS.data_dir, 'ptb.train.txt')) elif FLAGS.data_set == 'imdb': word_to_id = imdb_loader.build_vocab( os.path.join(FLAGS.data_dir, 'vocab.txt')) id_to_word = {v: k for k, v in word_to_id.iteritems()} # Dictionary of Training Set n-gram counts. bigram_tuples = n_gram.find_all_ngrams(valid_data_flat, n=2) trigram_tuples = n_gram.find_all_ngrams(valid_data_flat, n=3) fourgram_tuples = n_gram.find_all_ngrams(valid_data_flat, n=4) bigram_counts = n_gram.construct_ngrams_dict(bigram_tuples) trigram_counts = n_gram.construct_ngrams_dict(trigram_tuples) fourgram_counts = n_gram.construct_ngrams_dict(fourgram_tuples) print('Unique %d-grams: %d' % (2, len(bigram_counts))) print('Unique %d-grams: %d' % (3, len(trigram_counts))) print('Unique %d-grams: %d' % (4, len(fourgram_counts))) data_ngram_counts = { '2': bigram_counts, '3': trigram_counts, '4': fourgram_counts } # TODO(liamfedus): This was necessary because there was a problem with our # originally trained IMDB models. The EOS_INDEX was off by one, which means, # two words were mapping to index 86933. The presence of '' is going # to throw and out of vocabulary error. FLAGS.vocab_size = len(id_to_word) print('Vocab size: %d' % FLAGS.vocab_size) tf.gfile.MakeDirs(FLAGS.base_directory) if FLAGS.mode == MODE_TRAIN: log = tf.gfile.GFile( os.path.join(FLAGS.base_directory, 'train-log.txt'), mode='w') elif FLAGS.mode == MODE_VALIDATION: log = tf.gfile.GFile( os.path.join(FLAGS.base_directory, 'validation-log.txt'), mode='w') elif FLAGS.mode == MODE_TRAIN_EVAL: log = tf.gfile.GFile( os.path.join(FLAGS.base_directory, 'train_eval-log.txt'), mode='w') else: log = tf.gfile.GFile( os.path.join(FLAGS.base_directory, 'test-log.txt'), mode='w') if FLAGS.mode == MODE_TRAIN: train_model(hparams, data_set, train_dir, log, id_to_word, data_ngram_counts) elif FLAGS.mode == MODE_VALIDATION: evaluate_model(hparams, data_set, train_dir, log, id_to_word, data_ngram_counts) elif FLAGS.mode == MODE_TRAIN_EVAL: evaluate_model(hparams, data_set, train_dir, log, id_to_word, data_ngram_counts) elif FLAGS.mode == MODE_TEST: evaluate_model(hparams, data_set, train_dir, log, id_to_word, data_ngram_counts) else: raise NotImplementedError if __name__ == '__main__': tf.app.run()