NCTC / models /research /maskgan /sample_shuffler.py
NCTCMumbai's picture
Upload 2571 files
0b8359d
raw
history blame
3.16 kB
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Shuffle samples for human evaluation.
Local launch command:
python sample_shuffler.py
--input_ml_path=/tmp/ptb/seq2seq_vd_shareemb_forreal_55_3
--input_gan_path=/tmp/ptb/MaskGAN_PTB_ari_avg_56.29_v2.0.0
--output_file_name=/tmp/ptb/shuffled_output.txt
python sample_shuffler.py
--input_ml_path=/tmp/generate_samples/MaskGAN_IMDB_Benchmark_87.1_v0.3.0
--input_gan_path=/tmp/generate_samples/MaskGAN_IMDB_v1.0.1
--output_file_name=/tmp/imdb/shuffled_output.txt
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
# Dependency imports
import numpy as np
import tensorflow as tf
tf.app.flags.DEFINE_string('input_ml_path', '/tmp', 'Model output directory.')
tf.app.flags.DEFINE_string('input_gan_path', '/tmp', 'Model output directory.')
tf.app.flags.DEFINE_string('output_file_name', '/tmp/ptb/shuffled_output.txt',
'Model output file.')
tf.app.flags.DEFINE_boolean(
'output_masked_logs', False,
'Whether to display for human evaluation (show masking).')
tf.app.flags.DEFINE_integer('number_epochs', 1,
'The number of epochs to produce.')
FLAGS = tf.app.flags.FLAGS
def shuffle_samples(input_file_1, input_file_2):
"""Shuffle the examples."""
shuffled = []
# Set a random seed to keep fixed mask.
np.random.seed(0)
for line_1, line_2 in zip(input_file_1, input_file_2):
rand = np.random.randint(1, 3)
if rand == 1:
shuffled.append((rand, line_1, line_2))
else:
shuffled.append((rand, line_2, line_1))
input_file_1.close()
input_file_2.close()
return shuffled
def generate_output(shuffled_tuples, output_file_name):
output_file = tf.gfile.GFile(output_file_name, mode='w')
for tup in shuffled_tuples:
formatted_tuple = ('\n{:<1}, {:<1}, {:<1}').format(tup[0], tup[1].rstrip(),
tup[2].rstrip())
output_file.write(formatted_tuple)
output_file.close()
def main(_):
ml_samples_file = tf.gfile.GFile(
os.path.join(FLAGS.input_ml_path, 'reviews.txt'), mode='r')
gan_samples_file = tf.gfile.GFile(
os.path.join(FLAGS.input_gan_path, 'reviews.txt'), mode='r')
# Generate shuffled tuples.
shuffled_tuples = shuffle_samples(ml_samples_file, gan_samples_file)
# Output to file.
generate_output(shuffled_tuples, FLAGS.output_file_name)
if __name__ == '__main__':
tf.app.run()