NCTC / models /research /maskgan /train_mask_gan.py
NCTCMumbai's picture
Upload 2571 files
0b8359d
raw
history blame
45.9 kB
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Launch example:
[IMDB]
python train_mask_gan.py --data_dir
/tmp/imdb --data_set imdb --batch_size 128
--sequence_length 20 --base_directory /tmp/maskGAN_v0.01
--hparams="gen_rnn_size=650,gen_num_layers=2,dis_rnn_size=650,dis_num_layers=2
,critic_learning_rate=0.0009756,dis_learning_rate=0.0000585,
dis_train_iterations=8,gen_learning_rate=0.0016624,
gen_full_learning_rate_steps=1e9,gen_learning_rate_decay=0.999999,
rl_discount_rate=0.8835659" --mode TRAIN --max_steps 1000000
--generator_model seq2seq_vd --discriminator_model seq2seq_vd
--is_present_rate 0.5 --summaries_every 25 --print_every 25
--max_num_to_print=3 --generator_optimizer=adam
--seq2seq_share_embedding=True --baseline_method=critic
--attention_option=luong --n_gram_eval=4 --mask_strategy=contiguous
--gen_training_strategy=reinforce --dis_pretrain_steps=100
--perplexity_threshold=1000000
--dis_share_embedding=True --maskgan_ckpt
/tmp/model.ckpt-171091
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
from functools import partial
import os
import time
# Dependency imports
import numpy as np
from six.moves import xrange
import tensorflow as tf
import pretrain_mask_gan
from data import imdb_loader
from data import ptb_loader
from model_utils import helper
from model_utils import model_construction
from model_utils import model_losses
from model_utils import model_optimization
# Data.
from model_utils import model_utils
from model_utils import n_gram
from models import evaluation_utils
from models import rollout
np.set_printoptions(precision=3)
np.set_printoptions(suppress=True)
MODE_TRAIN = 'TRAIN'
MODE_TRAIN_EVAL = 'TRAIN_EVAL'
MODE_VALIDATION = 'VALIDATION'
MODE_TEST = 'TEST'
## Binary and setup FLAGS.
tf.app.flags.DEFINE_enum(
'mode', 'TRAIN', [MODE_TRAIN, MODE_VALIDATION, MODE_TEST, MODE_TRAIN_EVAL],
'What this binary will do.')
tf.app.flags.DEFINE_string('master', '',
"""Name of the TensorFlow master to use.""")
tf.app.flags.DEFINE_string('eval_master', '',
"""Name prefix of the Tensorflow eval master.""")
tf.app.flags.DEFINE_integer('task', 0,
"""Task id of the replica running the training.""")
tf.app.flags.DEFINE_integer('ps_tasks', 0, """Number of tasks in the ps job.
If 0 no ps job is used.""")
## General FLAGS.
tf.app.flags.DEFINE_string(
'hparams', '', 'Comma separated list of name=value hyperparameter pairs.')
tf.app.flags.DEFINE_integer('batch_size', 20, 'The batch size.')
tf.app.flags.DEFINE_integer('vocab_size', 10000, 'The vocabulary size.')
tf.app.flags.DEFINE_integer('sequence_length', 20, 'The sequence length.')
tf.app.flags.DEFINE_integer('max_steps', 1000000,
'Maximum number of steps to run.')
tf.app.flags.DEFINE_string(
'mask_strategy', 'random', 'Strategy for masking the words. Determine the '
'characterisitics of how the words are dropped out. One of '
"['contiguous', 'random'].")
tf.app.flags.DEFINE_float('is_present_rate', 0.5,
'Percent of tokens present in the forward sequence.')
tf.app.flags.DEFINE_float('is_present_rate_decay', None, 'Decay rate for the '
'percent of words that are real (are present).')
tf.app.flags.DEFINE_string(
'generator_model', 'seq2seq',
"Type of Generator model. One of ['rnn', 'seq2seq', 'seq2seq_zaremba',"
"'rnn_zaremba', 'rnn_nas', 'seq2seq_nas']")
tf.app.flags.DEFINE_string(
'attention_option', None,
"Attention mechanism. One of [None, 'luong', 'bahdanau']")
tf.app.flags.DEFINE_string(
'discriminator_model', 'bidirectional',
"Type of Discriminator model. One of ['cnn', 'rnn', 'bidirectional', "
"'rnn_zaremba', 'bidirectional_zaremba', 'rnn_nas', 'rnn_vd', 'seq2seq_vd']"
)
tf.app.flags.DEFINE_boolean('seq2seq_share_embedding', False,
'Whether to share the '
'embeddings between the encoder and decoder.')
tf.app.flags.DEFINE_boolean(
'dis_share_embedding', False, 'Whether to share the '
'embeddings between the generator and discriminator.')
tf.app.flags.DEFINE_boolean('dis_update_share_embedding', False, 'Whether the '
'discriminator should update the shared embedding.')
tf.app.flags.DEFINE_boolean('use_gen_mode', False,
'Use the mode of the generator '
'to produce samples.')
tf.app.flags.DEFINE_boolean('critic_update_dis_vars', False,
'Whether the critic '
'updates the discriminator variables.')
## Training FLAGS.
tf.app.flags.DEFINE_string(
'gen_training_strategy', 'reinforce',
"Method for training the Generator. One of ['cross_entropy', 'reinforce']")
tf.app.flags.DEFINE_string(
'generator_optimizer', 'adam',
"Type of Generator optimizer. One of ['sgd', 'adam']")
tf.app.flags.DEFINE_float('grad_clipping', 10., 'Norm for gradient clipping.')
tf.app.flags.DEFINE_float('advantage_clipping', 5., 'Clipping for advantages.')
tf.app.flags.DEFINE_string(
'baseline_method', None,
"Approach for baseline. One of ['critic', 'dis_batch', 'ema', None]")
tf.app.flags.DEFINE_float('perplexity_threshold', 15000,
'Limit for perplexity before terminating job.')
tf.app.flags.DEFINE_float('zoneout_drop_prob', 0.1,
'Probability for dropping parameter for zoneout.')
tf.app.flags.DEFINE_float('keep_prob', 0.5,
'Probability for keeping parameter for dropout.')
## Logging and evaluation FLAGS.
tf.app.flags.DEFINE_integer('print_every', 250,
'Frequency to print and log the '
'outputs of the model.')
tf.app.flags.DEFINE_integer('max_num_to_print', 5,
'Number of samples to log/print.')
tf.app.flags.DEFINE_boolean('print_verbose', False, 'Whether to print in full.')
tf.app.flags.DEFINE_integer('summaries_every', 100,
'Frequency to compute summaries.')
tf.app.flags.DEFINE_boolean('eval_language_model', False,
'Whether to evaluate on '
'all words as in language modeling.')
tf.app.flags.DEFINE_float('eval_interval_secs', 60,
'Delay for evaluating model.')
tf.app.flags.DEFINE_integer(
'n_gram_eval', 4, """The degree of the n-grams to use for evaluation.""")
tf.app.flags.DEFINE_integer(
'epoch_size_override', None,
'If an integer, this dictates the size of the epochs and will potentially '
'not iterate over all the data.')
tf.app.flags.DEFINE_integer('eval_epoch_size_override', None,
'Number of evaluation steps.')
## Directories and checkpoints.
tf.app.flags.DEFINE_string('base_directory', '/tmp/maskGAN_v0.00',
'Base directory for the logging, events and graph.')
tf.app.flags.DEFINE_string('data_set', 'ptb', 'Data set to operate on. One of'
"['ptb', 'imdb']")
tf.app.flags.DEFINE_string('data_dir', '/tmp/data/ptb',
'Directory for the training data.')
tf.app.flags.DEFINE_string(
'language_model_ckpt_dir', None,
'Directory storing checkpoints to initialize the model. Pretrained models'
'are stored at /tmp/maskGAN/pretrained/')
tf.app.flags.DEFINE_string(
'language_model_ckpt_dir_reversed', None,
'Directory storing checkpoints of reversed models to initialize the model.'
'Pretrained models stored at'
'are stored at /tmp/PTB/pretrained_reversed')
tf.app.flags.DEFINE_string(
'maskgan_ckpt', None,
'Override which checkpoint file to use to restore the '
'model. A pretrained seq2seq_zaremba model is stored at '
'/tmp/maskGAN/pretrain/seq2seq_zaremba/train/model.ckpt-64912')
tf.app.flags.DEFINE_boolean('wasserstein_objective', False,
'(DEPRECATED) Whether to use the WGAN training.')
tf.app.flags.DEFINE_integer('num_rollouts', 1,
'The number of rolled out predictions to make.')
tf.app.flags.DEFINE_float('c_lower', -0.01, 'Lower bound for weights.')
tf.app.flags.DEFINE_float('c_upper', 0.01, 'Upper bound for weights.')
FLAGS = tf.app.flags.FLAGS
def create_hparams():
"""Create the hparams object for generic training hyperparameters."""
hparams = tf.contrib.training.HParams(
gen_num_layers=2,
dis_num_layers=2,
gen_rnn_size=740,
dis_rnn_size=740,
gen_learning_rate=5e-4,
dis_learning_rate=5e-3,
critic_learning_rate=5e-3,
dis_train_iterations=1,
gen_learning_rate_decay=1.0,
gen_full_learning_rate_steps=1e7,
baseline_decay=0.999999,
rl_discount_rate=0.9,
gen_vd_keep_prob=0.5,
dis_vd_keep_prob=0.5,
dis_pretrain_learning_rate=5e-3,
dis_num_filters=128,
dis_hidden_dim=128,
gen_nas_keep_prob_0=0.85,
gen_nas_keep_prob_1=0.55,
dis_nas_keep_prob_0=0.85,
dis_nas_keep_prob_1=0.55)
# Command line flags override any of the preceding hyperparameter values.
if FLAGS.hparams:
hparams = hparams.parse(FLAGS.hparams)
return hparams
def create_MaskGAN(hparams, is_training):
"""Create the MaskGAN model.
Args:
hparams: Hyperparameters for the MaskGAN.
is_training: Boolean indicating operational mode (train/inference).
evaluated with a teacher forcing regime.
Return:
model: Namedtuple for specifying the MaskGAN.
"""
global_step = tf.Variable(0, name='global_step', trainable=False)
new_learning_rate = tf.placeholder(tf.float32, [], name='new_learning_rate')
learning_rate = tf.Variable(0.0, name='learning_rate', trainable=False)
learning_rate_update = tf.assign(learning_rate, new_learning_rate)
new_rate = tf.placeholder(tf.float32, [], name='new_rate')
percent_real_var = tf.Variable(0.0, trainable=False)
percent_real_update = tf.assign(percent_real_var, new_rate)
## Placeholders.
inputs = tf.placeholder(
tf.int32, shape=[FLAGS.batch_size, FLAGS.sequence_length])
targets = tf.placeholder(
tf.int32, shape=[FLAGS.batch_size, FLAGS.sequence_length])
present = tf.placeholder(
tf.bool, shape=[FLAGS.batch_size, FLAGS.sequence_length])
# TODO(adai): Placeholder for IMDB label.
## Real Sequence is the targets.
real_sequence = targets
## Fakse Sequence from the Generator.
# TODO(adai): Generator must have IMDB labels placeholder.
(fake_sequence, fake_logits, fake_log_probs, fake_gen_initial_state,
fake_gen_final_state, _) = model_construction.create_generator(
hparams,
inputs,
targets,
present,
is_training=is_training,
is_validating=False)
(_, eval_logits, _, eval_initial_state, eval_final_state,
_) = model_construction.create_generator(
hparams,
inputs,
targets,
present,
is_training=False,
is_validating=True,
reuse=True)
## Discriminator.
fake_predictions = model_construction.create_discriminator(
hparams,
fake_sequence,
is_training=is_training,
inputs=inputs,
present=present)
real_predictions = model_construction.create_discriminator(
hparams,
real_sequence,
is_training=is_training,
reuse=True,
inputs=inputs,
present=present)
## Critic.
# The critic will be used to estimate the forward rewards to the Generator.
if FLAGS.baseline_method == 'critic':
est_state_values = model_construction.create_critic(
hparams, fake_sequence, is_training=is_training)
else:
est_state_values = None
## Discriminator Loss.
[dis_loss, dis_loss_fake, dis_loss_real] = model_losses.create_dis_loss(
fake_predictions, real_predictions, present)
## Average log-perplexity for only missing words. However, to do this,
# the logits are still computed using teacher forcing, that is, the ground
# truth tokens are fed in at each time point to be valid.
avg_log_perplexity = model_losses.calculate_log_perplexity(
eval_logits, targets, present)
## Generator Objective.
# 1. Cross Entropy losses on missing tokens.
fake_cross_entropy_losses = model_losses.create_masked_cross_entropy_loss(
targets, present, fake_logits)
# 2. GAN REINFORCE losses.
[
fake_RL_loss, fake_log_probs, fake_rewards, fake_advantages,
fake_baselines, fake_averages_op, critic_loss, cumulative_rewards
] = model_losses.calculate_reinforce_objective(
hparams, fake_log_probs, fake_predictions, present, est_state_values)
## Pre-training.
if FLAGS.gen_pretrain_steps:
raise NotImplementedError
# # TODO(liamfedus): Rewrite this.
# fwd_cross_entropy_loss = tf.reduce_mean(fwd_cross_entropy_losses)
# gen_pretrain_op = model_optimization.create_gen_pretrain_op(
# hparams, fwd_cross_entropy_loss, global_step)
else:
gen_pretrain_op = None
if FLAGS.dis_pretrain_steps:
dis_pretrain_op = model_optimization.create_dis_pretrain_op(
hparams, dis_loss, global_step)
else:
dis_pretrain_op = None
## Generator Train Op.
# 1. Cross-Entropy.
if FLAGS.gen_training_strategy == 'cross_entropy':
gen_loss = tf.reduce_mean(fake_cross_entropy_losses)
[gen_train_op, gen_grads,
gen_vars] = model_optimization.create_gen_train_op(
hparams, learning_rate, gen_loss, global_step, mode='MINIMIZE')
# 2. GAN (REINFORCE)
elif FLAGS.gen_training_strategy == 'reinforce':
gen_loss = fake_RL_loss
[gen_train_op, gen_grads,
gen_vars] = model_optimization.create_reinforce_gen_train_op(
hparams, learning_rate, gen_loss, fake_averages_op, global_step)
else:
raise NotImplementedError
## Discriminator Train Op.
dis_train_op, dis_grads, dis_vars = model_optimization.create_dis_train_op(
hparams, dis_loss, global_step)
## Critic Train Op.
if critic_loss is not None:
[critic_train_op, _, _] = model_optimization.create_critic_train_op(
hparams, critic_loss, global_step)
dis_train_op = tf.group(dis_train_op, critic_train_op)
## Summaries.
with tf.name_scope('general'):
tf.summary.scalar('percent_real', percent_real_var)
tf.summary.scalar('learning_rate', learning_rate)
with tf.name_scope('generator_objectives'):
tf.summary.scalar('gen_objective', tf.reduce_mean(gen_loss))
tf.summary.scalar('gen_loss_cross_entropy',
tf.reduce_mean(fake_cross_entropy_losses))
with tf.name_scope('REINFORCE'):
with tf.name_scope('objective'):
tf.summary.scalar('fake_RL_loss', tf.reduce_mean(fake_RL_loss))
with tf.name_scope('rewards'):
helper.variable_summaries(cumulative_rewards, 'rewards')
with tf.name_scope('advantages'):
helper.variable_summaries(fake_advantages, 'advantages')
with tf.name_scope('baselines'):
helper.variable_summaries(fake_baselines, 'baselines')
with tf.name_scope('log_probs'):
helper.variable_summaries(fake_log_probs, 'log_probs')
with tf.name_scope('discriminator_losses'):
tf.summary.scalar('dis_loss', dis_loss)
tf.summary.scalar('dis_loss_fake_sequence', dis_loss_fake)
tf.summary.scalar('dis_loss_prob_fake_sequence', tf.exp(-dis_loss_fake))
tf.summary.scalar('dis_loss_real_sequence', dis_loss_real)
tf.summary.scalar('dis_loss_prob_real_sequence', tf.exp(-dis_loss_real))
if critic_loss is not None:
with tf.name_scope('critic_losses'):
tf.summary.scalar('critic_loss', critic_loss)
with tf.name_scope('logits'):
helper.variable_summaries(fake_logits, 'fake_logits')
for v, g in zip(gen_vars, gen_grads):
helper.variable_summaries(v, v.op.name)
helper.variable_summaries(g, 'grad/' + v.op.name)
for v, g in zip(dis_vars, dis_grads):
helper.variable_summaries(v, v.op.name)
helper.variable_summaries(g, 'grad/' + v.op.name)
merge_summaries_op = tf.summary.merge_all()
text_summary_placeholder = tf.placeholder(tf.string)
text_summary_op = tf.summary.text('Samples', text_summary_placeholder)
# Model saver.
saver = tf.train.Saver(keep_checkpoint_every_n_hours=1, max_to_keep=5)
# Named tuple that captures elements of the MaskGAN model.
Model = collections.namedtuple('Model', [
'inputs', 'targets', 'present', 'percent_real_update', 'new_rate',
'fake_sequence', 'fake_logits', 'fake_rewards', 'fake_baselines',
'fake_advantages', 'fake_log_probs', 'fake_predictions',
'real_predictions', 'fake_cross_entropy_losses', 'fake_gen_initial_state',
'fake_gen_final_state', 'eval_initial_state', 'eval_final_state',
'avg_log_perplexity', 'dis_loss', 'gen_loss', 'critic_loss',
'cumulative_rewards', 'dis_train_op', 'gen_train_op', 'gen_pretrain_op',
'dis_pretrain_op', 'merge_summaries_op', 'global_step',
'new_learning_rate', 'learning_rate_update', 'saver', 'text_summary_op',
'text_summary_placeholder'
])
model = Model(
inputs, targets, present, percent_real_update, new_rate, fake_sequence,
fake_logits, fake_rewards, fake_baselines, fake_advantages,
fake_log_probs, fake_predictions, real_predictions,
fake_cross_entropy_losses, fake_gen_initial_state, fake_gen_final_state,
eval_initial_state, eval_final_state, avg_log_perplexity, dis_loss,
gen_loss, critic_loss, cumulative_rewards, dis_train_op, gen_train_op,
gen_pretrain_op, dis_pretrain_op, merge_summaries_op, global_step,
new_learning_rate, learning_rate_update, saver, text_summary_op,
text_summary_placeholder)
return model
def compute_geometric_average(percent_captured):
"""Compute the geometric average of the n-gram metrics."""
res = 1.
for _, n_gram_percent in percent_captured.iteritems():
res *= n_gram_percent
return np.power(res, 1. / float(len(percent_captured)))
def compute_arithmetic_average(percent_captured):
"""Compute the arithmetic average of the n-gram metrics."""
N = len(percent_captured)
res = 0.
for _, n_gram_percent in percent_captured.iteritems():
res += n_gram_percent
return res / float(N)
def get_iterator(data):
"""Return the data iterator."""
if FLAGS.data_set == 'ptb':
iterator = ptb_loader.ptb_iterator(data, FLAGS.batch_size,
FLAGS.sequence_length,
FLAGS.epoch_size_override)
elif FLAGS.data_set == 'imdb':
iterator = imdb_loader.imdb_iterator(data, FLAGS.batch_size,
FLAGS.sequence_length)
return iterator
def train_model(hparams, data, log_dir, log, id_to_word, data_ngram_counts):
"""Train model.
Args:
hparams: Hyperparameters for the MaskGAN.
data: Data to evaluate.
log_dir: Directory to save checkpoints.
log: Readable log for the experiment.
id_to_word: Dictionary of indices to words.
data_ngram_counts: Dictionary of hashed(n-gram tuples) to counts in the
data_set.
"""
print('Training model.')
tf.logging.info('Training model.')
# Boolean indicating operational mode.
is_training = True
# Write all the information to the logs.
log.write('hparams\n')
log.write(str(hparams))
log.flush()
is_chief = FLAGS.task == 0
with tf.Graph().as_default():
with tf.device(tf.train.replica_device_setter(FLAGS.ps_tasks)):
container_name = ''
with tf.container(container_name):
# Construct the model.
if FLAGS.num_rollouts == 1:
model = create_MaskGAN(hparams, is_training)
elif FLAGS.num_rollouts > 1:
model = rollout.create_rollout_MaskGAN(hparams, is_training)
else:
raise ValueError
print('\nTrainable Variables in Graph:')
for v in tf.trainable_variables():
print(v)
## Retrieve the initial savers.
init_savers = model_utils.retrieve_init_savers(hparams)
## Initial saver function to supervisor.
init_fn = partial(model_utils.init_fn, init_savers)
# Create the supervisor. It will take care of initialization,
# summaries, checkpoints, and recovery.
sv = tf.train.Supervisor(
logdir=log_dir,
is_chief=is_chief,
saver=model.saver,
global_step=model.global_step,
save_model_secs=60,
recovery_wait_secs=30,
summary_op=None,
init_fn=init_fn)
# Get an initialized, and possibly recovered session. Launch the
# services: Checkpointing, Summaries, step counting.
#
# When multiple replicas of this program are running the services are
# only launched by the 'chief' replica.
with sv.managed_session(FLAGS.master) as sess:
## Pretrain the generator.
if FLAGS.gen_pretrain_steps:
pretrain_mask_gan.pretrain_generator(sv, sess, model, data, log,
id_to_word, data_ngram_counts,
is_chief)
## Pretrain the discriminator.
if FLAGS.dis_pretrain_steps:
pretrain_mask_gan.pretrain_discriminator(
sv, sess, model, data, log, id_to_word, data_ngram_counts,
is_chief)
# Initial indicators for printing and summarizing.
print_step_division = -1
summary_step_division = -1
# Run iterative computation in a loop.
while not sv.ShouldStop():
is_present_rate = FLAGS.is_present_rate
if FLAGS.is_present_rate_decay is not None:
is_present_rate *= (1. - FLAGS.is_present_rate_decay)
model_utils.assign_percent_real(sess, model.percent_real_update,
model.new_rate, is_present_rate)
# GAN training.
avg_epoch_gen_loss, avg_epoch_dis_loss = [], []
cumulative_costs = 0.
gen_iters = 0
# Generator and Discriminator statefulness initial evaluation.
# TODO(liamfedus): Throughout the code I am implicitly assuming
# that the Generator and Discriminator are equal sized.
[gen_initial_state_eval, fake_gen_initial_state_eval] = sess.run(
[model.eval_initial_state, model.fake_gen_initial_state])
dis_initial_state_eval = fake_gen_initial_state_eval
# Save zeros state to reset later.
zeros_state = fake_gen_initial_state_eval
## Offset Discriminator.
if FLAGS.ps_tasks == 0:
dis_offset = 1
else:
dis_offset = FLAGS.task * 1000 + 1
dis_iterator = get_iterator(data)
for i in range(dis_offset):
try:
dis_x, dis_y, _ = next(dis_iterator)
except StopIteration:
dis_iterator = get_iterator(data)
dis_initial_state_eval = zeros_state
dis_x, dis_y, _ = next(dis_iterator)
p = model_utils.generate_mask()
# Construct the train feed.
train_feed = {
model.inputs: dis_x,
model.targets: dis_y,
model.present: p
}
if FLAGS.data_set == 'ptb':
# Statefulness of the Generator being used for Discriminator.
for i, (c, h) in enumerate(model.fake_gen_initial_state):
train_feed[c] = dis_initial_state_eval[i].c
train_feed[h] = dis_initial_state_eval[i].h
# Determine the state had the Generator run over real data. We
# use this state for the Discriminator.
[dis_initial_state_eval] = sess.run(
[model.fake_gen_final_state], train_feed)
## Training loop.
iterator = get_iterator(data)
gen_initial_state_eval = zeros_state
if FLAGS.ps_tasks > 0:
gen_offset = FLAGS.task * 1000 + 1
for i in range(gen_offset):
try:
next(iterator)
except StopIteration:
dis_iterator = get_iterator(data)
dis_initial_state_eval = zeros_state
next(dis_iterator)
for x, y, _ in iterator:
for _ in xrange(hparams.dis_train_iterations):
try:
dis_x, dis_y, _ = next(dis_iterator)
except StopIteration:
dis_iterator = get_iterator(data)
dis_initial_state_eval = zeros_state
dis_x, dis_y, _ = next(dis_iterator)
if FLAGS.data_set == 'ptb':
[dis_initial_state_eval] = sess.run(
[model.fake_gen_initial_state])
p = model_utils.generate_mask()
# Construct the train feed.
train_feed = {
model.inputs: dis_x,
model.targets: dis_y,
model.present: p
}
# Statefulness for the Discriminator.
if FLAGS.data_set == 'ptb':
for i, (c, h) in enumerate(model.fake_gen_initial_state):
train_feed[c] = dis_initial_state_eval[i].c
train_feed[h] = dis_initial_state_eval[i].h
_, dis_loss_eval, step = sess.run(
[model.dis_train_op, model.dis_loss, model.global_step],
feed_dict=train_feed)
# Determine the state had the Generator run over real data.
# Use this state for the Discriminator.
[dis_initial_state_eval] = sess.run(
[model.fake_gen_final_state], train_feed)
# Randomly mask out tokens.
p = model_utils.generate_mask()
# Construct the train feed.
train_feed = {model.inputs: x, model.targets: y, model.present: p}
# Statefulness for Generator.
if FLAGS.data_set == 'ptb':
tf.logging.info('Generator is stateful.')
print('Generator is stateful.')
# Statefulness for *evaluation* Generator.
for i, (c, h) in enumerate(model.eval_initial_state):
train_feed[c] = gen_initial_state_eval[i].c
train_feed[h] = gen_initial_state_eval[i].h
# Statefulness for Generator.
for i, (c, h) in enumerate(model.fake_gen_initial_state):
train_feed[c] = fake_gen_initial_state_eval[i].c
train_feed[h] = fake_gen_initial_state_eval[i].h
# Determine whether to decay learning rate.
lr_decay = hparams.gen_learning_rate_decay**max(
step + 1 - hparams.gen_full_learning_rate_steps, 0.0)
# Assign learning rate.
gen_learning_rate = hparams.gen_learning_rate * lr_decay
model_utils.assign_learning_rate(sess, model.learning_rate_update,
model.new_learning_rate,
gen_learning_rate)
[_, gen_loss_eval, gen_log_perplexity_eval, step] = sess.run(
[
model.gen_train_op, model.gen_loss,
model.avg_log_perplexity, model.global_step
],
feed_dict=train_feed)
cumulative_costs += gen_log_perplexity_eval
gen_iters += 1
# Determine the state had the Generator run over real data.
[gen_initial_state_eval, fake_gen_initial_state_eval] = sess.run(
[model.eval_final_state,
model.fake_gen_final_state], train_feed)
avg_epoch_dis_loss.append(dis_loss_eval)
avg_epoch_gen_loss.append(gen_loss_eval)
## Summaries.
# Calulate rolling perplexity.
perplexity = np.exp(cumulative_costs / gen_iters)
if is_chief and (step / FLAGS.summaries_every >
summary_step_division):
summary_step_division = step / FLAGS.summaries_every
# Confirm perplexity is not infinite.
if (not np.isfinite(perplexity) or
perplexity >= FLAGS.perplexity_threshold):
print('Training raising FloatingPoinError.')
raise FloatingPointError(
'Training infinite perplexity: %.3f' % perplexity)
# Graph summaries.
summary_str = sess.run(
model.merge_summaries_op, feed_dict=train_feed)
sv.SummaryComputed(sess, summary_str)
# Summary: n-gram
avg_percent_captured = {'2': 0., '3': 0., '4': 0.}
for n, data_ngram_count in data_ngram_counts.iteritems():
batch_percent_captured = evaluation_utils.sequence_ngram_evaluation(
sess, model.fake_sequence, log, train_feed,
data_ngram_count, int(n))
summary_percent_str = tf.Summary(value=[
tf.Summary.Value(
tag='general/%s-grams_percent_correct' % n,
simple_value=batch_percent_captured)
])
sv.SummaryComputed(
sess, summary_percent_str, global_step=step)
# Summary: geometric_avg
geometric_avg = compute_geometric_average(avg_percent_captured)
summary_geometric_avg_str = tf.Summary(value=[
tf.Summary.Value(
tag='general/geometric_avg', simple_value=geometric_avg)
])
sv.SummaryComputed(
sess, summary_geometric_avg_str, global_step=step)
# Summary: arithmetic_avg
arithmetic_avg = compute_arithmetic_average(
avg_percent_captured)
summary_arithmetic_avg_str = tf.Summary(value=[
tf.Summary.Value(
tag='general/arithmetic_avg',
simple_value=arithmetic_avg)
])
sv.SummaryComputed(
sess, summary_arithmetic_avg_str, global_step=step)
# Summary: perplexity
summary_perplexity_str = tf.Summary(value=[
tf.Summary.Value(
tag='general/perplexity', simple_value=perplexity)
])
sv.SummaryComputed(
sess, summary_perplexity_str, global_step=step)
## Printing and logging
if is_chief and (step / FLAGS.print_every > print_step_division):
print_step_division = (step / FLAGS.print_every)
print('global_step: %d' % step)
print(' perplexity: %.3f' % perplexity)
print(' gen_learning_rate: %.6f' % gen_learning_rate)
log.write('global_step: %d\n' % step)
log.write(' perplexity: %.3f\n' % perplexity)
log.write(' gen_learning_rate: %.6f' % gen_learning_rate)
# Average percent captured for each of the n-grams.
avg_percent_captured = {'2': 0., '3': 0., '4': 0.}
for n, data_ngram_count in data_ngram_counts.iteritems():
batch_percent_captured = evaluation_utils.sequence_ngram_evaluation(
sess, model.fake_sequence, log, train_feed,
data_ngram_count, int(n))
avg_percent_captured[n] = batch_percent_captured
print(' percent of %s-grams captured: %.3f.' %
(n, batch_percent_captured))
log.write(' percent of %s-grams captured: %.3f.\n' %
(n, batch_percent_captured))
geometric_avg = compute_geometric_average(avg_percent_captured)
print(' geometric_avg: %.3f.' % geometric_avg)
log.write(' geometric_avg: %.3f.' % geometric_avg)
arithmetic_avg = compute_arithmetic_average(
avg_percent_captured)
print(' arithmetic_avg: %.3f.' % arithmetic_avg)
log.write(' arithmetic_avg: %.3f.' % arithmetic_avg)
evaluation_utils.print_and_log_losses(
log, step, is_present_rate, avg_epoch_dis_loss,
avg_epoch_gen_loss)
if FLAGS.gen_training_strategy == 'reinforce':
evaluation_utils.generate_RL_logs(sess, model, log,
id_to_word, train_feed)
else:
evaluation_utils.generate_logs(sess, model, log, id_to_word,
train_feed)
log.flush()
log.close()
def evaluate_once(data, sv, model, sess, train_dir, log, id_to_word,
data_ngram_counts, eval_saver):
"""Evaluate model for a number of steps.
Args:
data: Dataset.
sv: Supervisor.
model: The GAN model we have just built.
sess: A session to use.
train_dir: Path to a directory containing checkpoints.
log: Evaluation log for evaluation.
id_to_word: Dictionary of indices to words.
data_ngram_counts: Dictionary of hashed(n-gram tuples) to counts in the
data_set.
eval_saver: Evaluation saver.r.
"""
tf.logging.info('Evaluate Once.')
# Load the last model checkpoint, or initialize the graph.
model_save_path = tf.latest_checkpoint(train_dir)
if not model_save_path:
tf.logging.warning('No checkpoint yet in: %s', train_dir)
return
tf.logging.info('Starting eval of: %s' % model_save_path)
tf.logging.info('Only restoring trainable variables.')
eval_saver.restore(sess, model_save_path)
# Run the requested number of evaluation steps
avg_epoch_gen_loss, avg_epoch_dis_loss = [], []
cumulative_costs = 0.
# Average percent captured for each of the n-grams.
avg_percent_captured = {'2': 0., '3': 0., '4': 0.}
# Set a random seed to keep fixed mask.
np.random.seed(0)
gen_iters = 0
# Generator statefulness over the epoch.
# TODO(liamfedus): Check this.
[gen_initial_state_eval, fake_gen_initial_state_eval] = sess.run(
[model.eval_initial_state, model.fake_gen_initial_state])
if FLAGS.eval_language_model:
is_present_rate = 0.
tf.logging.info('Overriding is_present_rate=0. for evaluation.')
print('Overriding is_present_rate=0. for evaluation.')
iterator = get_iterator(data)
for x, y, _ in iterator:
if FLAGS.eval_language_model:
is_present_rate = 0.
else:
is_present_rate = FLAGS.is_present_rate
tf.logging.info('Evaluating on is_present_rate=%.3f.' % is_present_rate)
model_utils.assign_percent_real(sess, model.percent_real_update,
model.new_rate, is_present_rate)
# Randomly mask out tokens.
p = model_utils.generate_mask()
eval_feed = {model.inputs: x, model.targets: y, model.present: p}
if FLAGS.data_set == 'ptb':
# Statefulness for *evaluation* Generator.
for i, (c, h) in enumerate(model.eval_initial_state):
eval_feed[c] = gen_initial_state_eval[i].c
eval_feed[h] = gen_initial_state_eval[i].h
# Statefulness for the Generator.
for i, (c, h) in enumerate(model.fake_gen_initial_state):
eval_feed[c] = fake_gen_initial_state_eval[i].c
eval_feed[h] = fake_gen_initial_state_eval[i].h
[
gen_log_perplexity_eval, dis_loss_eval, gen_loss_eval,
gen_initial_state_eval, fake_gen_initial_state_eval, step
] = sess.run(
[
model.avg_log_perplexity, model.dis_loss, model.gen_loss,
model.eval_final_state, model.fake_gen_final_state,
model.global_step
],
feed_dict=eval_feed)
for n, data_ngram_count in data_ngram_counts.iteritems():
batch_percent_captured = evaluation_utils.sequence_ngram_evaluation(
sess, model.fake_sequence, log, eval_feed, data_ngram_count, int(n))
avg_percent_captured[n] += batch_percent_captured
cumulative_costs += gen_log_perplexity_eval
avg_epoch_dis_loss.append(dis_loss_eval)
avg_epoch_gen_loss.append(gen_loss_eval)
gen_iters += 1
# Calulate rolling metrics.
perplexity = np.exp(cumulative_costs / gen_iters)
for n, _ in avg_percent_captured.iteritems():
avg_percent_captured[n] /= gen_iters
# Confirm perplexity is not infinite.
if not np.isfinite(perplexity) or perplexity >= FLAGS.perplexity_threshold:
print('Evaluation raising FloatingPointError.')
raise FloatingPointError(
'Evaluation infinite perplexity: %.3f' % perplexity)
## Printing and logging.
evaluation_utils.print_and_log_losses(log, step, is_present_rate,
avg_epoch_dis_loss, avg_epoch_gen_loss)
print(' perplexity: %.3f' % perplexity)
log.write(' perplexity: %.3f\n' % perplexity)
for n, n_gram_percent in avg_percent_captured.iteritems():
n = int(n)
print(' percent of %d-grams captured: %.3f.' % (n, n_gram_percent))
log.write(' percent of %d-grams captured: %.3f.\n' % (n, n_gram_percent))
samples = evaluation_utils.generate_logs(sess, model, log, id_to_word,
eval_feed)
## Summaries.
summary_str = sess.run(model.merge_summaries_op, feed_dict=eval_feed)
sv.SummaryComputed(sess, summary_str)
# Summary: text
summary_str = sess.run(model.text_summary_op,
{model.text_summary_placeholder: '\n\n'.join(samples)})
sv.SummaryComputed(sess, summary_str, global_step=step)
# Summary: n-gram
for n, n_gram_percent in avg_percent_captured.iteritems():
n = int(n)
summary_percent_str = tf.Summary(value=[
tf.Summary.Value(
tag='general/%d-grams_percent_correct' % n,
simple_value=n_gram_percent)
])
sv.SummaryComputed(sess, summary_percent_str, global_step=step)
# Summary: geometric_avg
geometric_avg = compute_geometric_average(avg_percent_captured)
summary_geometric_avg_str = tf.Summary(value=[
tf.Summary.Value(tag='general/geometric_avg', simple_value=geometric_avg)
])
sv.SummaryComputed(sess, summary_geometric_avg_str, global_step=step)
# Summary: arithmetic_avg
arithmetic_avg = compute_arithmetic_average(avg_percent_captured)
summary_arithmetic_avg_str = tf.Summary(value=[
tf.Summary.Value(
tag='general/arithmetic_avg', simple_value=arithmetic_avg)
])
sv.SummaryComputed(sess, summary_arithmetic_avg_str, global_step=step)
# Summary: perplexity
summary_perplexity_str = tf.Summary(value=[
tf.Summary.Value(tag='general/perplexity', simple_value=perplexity)
])
sv.SummaryComputed(sess, summary_perplexity_str, global_step=step)
def evaluate_model(hparams, data, train_dir, log, id_to_word,
data_ngram_counts):
"""Evaluate MaskGAN model.
Args:
hparams: Hyperparameters for the MaskGAN.
data: Data to evaluate.
train_dir: Path to a directory containing checkpoints.
id_to_word: Dictionary of indices to words.
data_ngram_counts: Dictionary of hashed(n-gram tuples) to counts in the
data_set.
"""
tf.logging.error('Evaluate model.')
# Boolean indicating operational mode.
is_training = False
if FLAGS.mode == MODE_VALIDATION:
logdir = FLAGS.base_directory + '/validation'
elif FLAGS.mode == MODE_TRAIN_EVAL:
logdir = FLAGS.base_directory + '/train_eval'
elif FLAGS.mode == MODE_TEST:
logdir = FLAGS.base_directory + '/test'
else:
raise NotImplementedError
# Wait for a checkpoint to exist.
print(train_dir)
print(tf.train.latest_checkpoint(train_dir))
while not tf.train.latest_checkpoint(train_dir):
tf.logging.error('Waiting for checkpoint...')
print('Waiting for checkpoint...')
time.sleep(10)
with tf.Graph().as_default():
# Use a separate container for each trial
container_name = ''
with tf.container(container_name):
# Construct the model.
if FLAGS.num_rollouts == 1:
model = create_MaskGAN(hparams, is_training)
elif FLAGS.num_rollouts > 1:
model = rollout.create_rollout_MaskGAN(hparams, is_training)
else:
raise ValueError
# Create the supervisor. It will take care of initialization, summaries,
# checkpoints, and recovery. We only pass the trainable variables
# to load since things like baselines keep batch_size which may not
# match between training and evaluation.
evaluation_variables = tf.trainable_variables()
evaluation_variables.append(model.global_step)
eval_saver = tf.train.Saver(var_list=evaluation_variables)
sv = tf.Supervisor(logdir=logdir)
sess = sv.PrepareSession(FLAGS.eval_master, start_standard_services=False)
tf.logging.info('Before sv.Loop.')
sv.Loop(FLAGS.eval_interval_secs, evaluate_once,
(data, sv, model, sess, train_dir, log, id_to_word,
data_ngram_counts, eval_saver))
sv.WaitForStop()
tf.logging.info('sv.Stop().')
sv.Stop()
def main(_):
hparams = create_hparams()
train_dir = FLAGS.base_directory + '/train'
# Load data set.
if FLAGS.data_set == 'ptb':
raw_data = ptb_loader.ptb_raw_data(FLAGS.data_dir)
train_data, valid_data, test_data, _ = raw_data
valid_data_flat = valid_data
elif FLAGS.data_set == 'imdb':
raw_data = imdb_loader.imdb_raw_data(FLAGS.data_dir)
# TODO(liamfedus): Get an IMDB test partition.
train_data, valid_data = raw_data
valid_data_flat = [word for review in valid_data for word in review]
else:
raise NotImplementedError
if FLAGS.mode == MODE_TRAIN or FLAGS.mode == MODE_TRAIN_EVAL:
data_set = train_data
elif FLAGS.mode == MODE_VALIDATION:
data_set = valid_data
elif FLAGS.mode == MODE_TEST:
data_set = test_data
else:
raise NotImplementedError
# Dictionary and reverse dictionry.
if FLAGS.data_set == 'ptb':
word_to_id = ptb_loader.build_vocab(
os.path.join(FLAGS.data_dir, 'ptb.train.txt'))
elif FLAGS.data_set == 'imdb':
word_to_id = imdb_loader.build_vocab(
os.path.join(FLAGS.data_dir, 'vocab.txt'))
id_to_word = {v: k for k, v in word_to_id.iteritems()}
# Dictionary of Training Set n-gram counts.
bigram_tuples = n_gram.find_all_ngrams(valid_data_flat, n=2)
trigram_tuples = n_gram.find_all_ngrams(valid_data_flat, n=3)
fourgram_tuples = n_gram.find_all_ngrams(valid_data_flat, n=4)
bigram_counts = n_gram.construct_ngrams_dict(bigram_tuples)
trigram_counts = n_gram.construct_ngrams_dict(trigram_tuples)
fourgram_counts = n_gram.construct_ngrams_dict(fourgram_tuples)
print('Unique %d-grams: %d' % (2, len(bigram_counts)))
print('Unique %d-grams: %d' % (3, len(trigram_counts)))
print('Unique %d-grams: %d' % (4, len(fourgram_counts)))
data_ngram_counts = {
'2': bigram_counts,
'3': trigram_counts,
'4': fourgram_counts
}
# TODO(liamfedus): This was necessary because there was a problem with our
# originally trained IMDB models. The EOS_INDEX was off by one, which means,
# two words were mapping to index 86933. The presence of '</s>' is going
# to throw and out of vocabulary error.
FLAGS.vocab_size = len(id_to_word)
print('Vocab size: %d' % FLAGS.vocab_size)
tf.gfile.MakeDirs(FLAGS.base_directory)
if FLAGS.mode == MODE_TRAIN:
log = tf.gfile.GFile(
os.path.join(FLAGS.base_directory, 'train-log.txt'), mode='w')
elif FLAGS.mode == MODE_VALIDATION:
log = tf.gfile.GFile(
os.path.join(FLAGS.base_directory, 'validation-log.txt'), mode='w')
elif FLAGS.mode == MODE_TRAIN_EVAL:
log = tf.gfile.GFile(
os.path.join(FLAGS.base_directory, 'train_eval-log.txt'), mode='w')
else:
log = tf.gfile.GFile(
os.path.join(FLAGS.base_directory, 'test-log.txt'), mode='w')
if FLAGS.mode == MODE_TRAIN:
train_model(hparams, data_set, train_dir, log, id_to_word,
data_ngram_counts)
elif FLAGS.mode == MODE_VALIDATION:
evaluate_model(hparams, data_set, train_dir, log, id_to_word,
data_ngram_counts)
elif FLAGS.mode == MODE_TRAIN_EVAL:
evaluate_model(hparams, data_set, train_dir, log, id_to_word,
data_ngram_counts)
elif FLAGS.mode == MODE_TEST:
evaluate_model(hparams, data_set, train_dir, log, id_to_word,
data_ngram_counts)
else:
raise NotImplementedError
if __name__ == '__main__':
tf.app.run()