# Copyright 2017 The TensorFlow Authors All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Generate samples from the MaskGAN. Launch command: python generate_samples.py --data_dir=/tmp/data/imdb --data_set=imdb --batch_size=256 --sequence_length=20 --base_directory=/tmp/imdb --hparams="gen_rnn_size=650,dis_rnn_size=650,gen_num_layers=2, gen_vd_keep_prob=1.0" --generator_model=seq2seq_vd --discriminator_model=seq2seq_vd --is_present_rate=0.5 --maskgan_ckpt=/tmp/model.ckpt-45494 --seq2seq_share_embedding=True --dis_share_embedding=True --attention_option=luong --mask_strategy=contiguous --baseline_method=critic --number_epochs=4 """ from __future__ import absolute_import from __future__ import division from __future__ import print_function from functools import partial import os # Dependency imports import numpy as np from six.moves import xrange import tensorflow as tf import train_mask_gan from data import imdb_loader from data import ptb_loader # Data. from model_utils import helper from model_utils import model_utils SAMPLE_TRAIN = 'TRAIN' SAMPLE_VALIDATION = 'VALIDATION' ## Sample Generation. ## Binary and setup FLAGS. tf.app.flags.DEFINE_enum('sample_mode', 'TRAIN', [SAMPLE_TRAIN, SAMPLE_VALIDATION], 'Dataset to sample from.') tf.app.flags.DEFINE_string('output_path', '/tmp', 'Model output directory.') tf.app.flags.DEFINE_boolean( 'output_masked_logs', False, 'Whether to display for human evaluation (show masking).') tf.app.flags.DEFINE_integer('number_epochs', 1, 'The number of epochs to produce.') FLAGS = tf.app.flags.FLAGS def get_iterator(data): """Return the data iterator.""" if FLAGS.data_set == 'ptb': iterator = ptb_loader.ptb_iterator(data, FLAGS.batch_size, FLAGS.sequence_length, FLAGS.epoch_size_override) elif FLAGS.data_set == 'imdb': iterator = imdb_loader.imdb_iterator(data, FLAGS.batch_size, FLAGS.sequence_length) return iterator def convert_to_human_readable(id_to_word, arr, p, max_num_to_print): """Convert a np.array of indices into words using id_to_word dictionary. Return max_num_to_print results. """ assert arr.ndim == 2 samples = [] for sequence_id in xrange(min(len(arr), max_num_to_print)): sample = [] for i, index in enumerate(arr[sequence_id, :]): if p[sequence_id, i] == 1: sample.append(str(id_to_word[index])) else: sample.append('*' + str(id_to_word[index])) buffer_str = ' '.join(sample) samples.append(buffer_str) return samples def write_unmasked_log(log, id_to_word, sequence_eval): """Helper function for logging evaluated sequences without mask.""" indices_arr = np.asarray(sequence_eval) samples = helper.convert_to_human_readable(id_to_word, indices_arr, FLAGS.batch_size) for sample in samples: log.write(sample + '\n') log.flush() return samples def write_masked_log(log, id_to_word, sequence_eval, present_eval): indices_arr = np.asarray(sequence_eval) samples = convert_to_human_readable(id_to_word, indices_arr, present_eval, FLAGS.batch_size) for sample in samples: log.write(sample + '\n') log.flush() return samples def generate_logs(sess, model, log, id_to_word, feed): """Impute Sequences using the model for a particular feed and send it to logs. """ # Impute Sequences. [p, inputs_eval, sequence_eval] = sess.run( [model.present, model.inputs, model.fake_sequence], feed_dict=feed) # Add the 0th time-step for coherence. first_token = np.expand_dims(inputs_eval[:, 0], axis=1) sequence_eval = np.concatenate((first_token, sequence_eval), axis=1) # 0th token always present. p = np.concatenate((np.ones((FLAGS.batch_size, 1)), p), axis=1) if FLAGS.output_masked_logs: samples = write_masked_log(log, id_to_word, sequence_eval, p) else: samples = write_unmasked_log(log, id_to_word, sequence_eval) return samples def generate_samples(hparams, data, id_to_word, log_dir, output_file): """"Generate samples. Args: hparams: Hyperparameters for the MaskGAN. data: Data to evaluate. id_to_word: Dictionary of indices to words. log_dir: Log directory. output_file: Output file for the samples. """ # Boolean indicating operational mode. is_training = False # Set a random seed to keep fixed mask. np.random.seed(0) with tf.Graph().as_default(): # Construct the model. model = train_mask_gan.create_MaskGAN(hparams, is_training) ## Retrieve the initial savers. init_savers = model_utils.retrieve_init_savers(hparams) ## Initial saver function to supervisor. init_fn = partial(model_utils.init_fn, init_savers) is_chief = FLAGS.task == 0 # Create the supervisor. It will take care of initialization, summaries, # checkpoints, and recovery. sv = tf.Supervisor( logdir=log_dir, is_chief=is_chief, saver=model.saver, global_step=model.global_step, recovery_wait_secs=30, summary_op=None, init_fn=init_fn) # Get an initialized, and possibly recovered session. Launch the # services: Checkpointing, Summaries, step counting. # # When multiple replicas of this program are running the services are # only launched by the 'chief' replica. with sv.managed_session( FLAGS.master, start_standard_services=False) as sess: # Generator statefulness over the epoch. [gen_initial_state_eval, fake_gen_initial_state_eval] = sess.run( [model.eval_initial_state, model.fake_gen_initial_state]) for n in xrange(FLAGS.number_epochs): print('Epoch number: %d' % n) # print('Percent done: %.2f' % float(n) / float(FLAGS.number_epochs)) iterator = get_iterator(data) for x, y, _ in iterator: if FLAGS.eval_language_model: is_present_rate = 0. else: is_present_rate = FLAGS.is_present_rate tf.logging.info( 'Evaluating on is_present_rate=%.3f.' % is_present_rate) model_utils.assign_percent_real(sess, model.percent_real_update, model.new_rate, is_present_rate) # Randomly mask out tokens. p = model_utils.generate_mask() eval_feed = {model.inputs: x, model.targets: y, model.present: p} if FLAGS.data_set == 'ptb': # Statefulness for *evaluation* Generator. for i, (c, h) in enumerate(model.eval_initial_state): eval_feed[c] = gen_initial_state_eval[i].c eval_feed[h] = gen_initial_state_eval[i].h # Statefulness for the Generator. for i, (c, h) in enumerate(model.fake_gen_initial_state): eval_feed[c] = fake_gen_initial_state_eval[i].c eval_feed[h] = fake_gen_initial_state_eval[i].h [gen_initial_state_eval, fake_gen_initial_state_eval, _] = sess.run( [ model.eval_final_state, model.fake_gen_final_state, model.global_step ], feed_dict=eval_feed) generate_logs(sess, model, output_file, id_to_word, eval_feed) output_file.close() print('Closing output_file.') return def main(_): hparams = train_mask_gan.create_hparams() log_dir = FLAGS.base_directory tf.gfile.MakeDirs(FLAGS.output_path) output_file = tf.gfile.GFile( os.path.join(FLAGS.output_path, 'reviews.txt'), mode='w') # Load data set. if FLAGS.data_set == 'ptb': raw_data = ptb_loader.ptb_raw_data(FLAGS.data_dir) train_data, valid_data, _, _ = raw_data elif FLAGS.data_set == 'imdb': raw_data = imdb_loader.imdb_raw_data(FLAGS.data_dir) train_data, valid_data = raw_data else: raise NotImplementedError # Generating more data on train set. if FLAGS.sample_mode == SAMPLE_TRAIN: data_set = train_data elif FLAGS.sample_mode == SAMPLE_VALIDATION: data_set = valid_data else: raise NotImplementedError # Dictionary and reverse dictionry. if FLAGS.data_set == 'ptb': word_to_id = ptb_loader.build_vocab( os.path.join(FLAGS.data_dir, 'ptb.train.txt')) elif FLAGS.data_set == 'imdb': word_to_id = imdb_loader.build_vocab( os.path.join(FLAGS.data_dir, 'vocab.txt')) id_to_word = {v: k for k, v in word_to_id.iteritems()} FLAGS.vocab_size = len(id_to_word) print('Vocab size: %d' % FLAGS.vocab_size) generate_samples(hparams, data_set, id_to_word, log_dir, output_file) if __name__ == '__main__': tf.app.run()