Spaces:
Running
Running
# Copyright 2017 The TensorFlow Authors All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# ============================================================================== | |
"""Shuffle samples for human evaluation. | |
Local launch command: | |
python sample_shuffler.py | |
--input_ml_path=/tmp/ptb/seq2seq_vd_shareemb_forreal_55_3 | |
--input_gan_path=/tmp/ptb/MaskGAN_PTB_ari_avg_56.29_v2.0.0 | |
--output_file_name=/tmp/ptb/shuffled_output.txt | |
python sample_shuffler.py | |
--input_ml_path=/tmp/generate_samples/MaskGAN_IMDB_Benchmark_87.1_v0.3.0 | |
--input_gan_path=/tmp/generate_samples/MaskGAN_IMDB_v1.0.1 | |
--output_file_name=/tmp/imdb/shuffled_output.txt | |
""" | |
from __future__ import absolute_import | |
from __future__ import division | |
from __future__ import print_function | |
import os | |
# Dependency imports | |
import numpy as np | |
import tensorflow as tf | |
tf.app.flags.DEFINE_string('input_ml_path', '/tmp', 'Model output directory.') | |
tf.app.flags.DEFINE_string('input_gan_path', '/tmp', 'Model output directory.') | |
tf.app.flags.DEFINE_string('output_file_name', '/tmp/ptb/shuffled_output.txt', | |
'Model output file.') | |
tf.app.flags.DEFINE_boolean( | |
'output_masked_logs', False, | |
'Whether to display for human evaluation (show masking).') | |
tf.app.flags.DEFINE_integer('number_epochs', 1, | |
'The number of epochs to produce.') | |
FLAGS = tf.app.flags.FLAGS | |
def shuffle_samples(input_file_1, input_file_2): | |
"""Shuffle the examples.""" | |
shuffled = [] | |
# Set a random seed to keep fixed mask. | |
np.random.seed(0) | |
for line_1, line_2 in zip(input_file_1, input_file_2): | |
rand = np.random.randint(1, 3) | |
if rand == 1: | |
shuffled.append((rand, line_1, line_2)) | |
else: | |
shuffled.append((rand, line_2, line_1)) | |
input_file_1.close() | |
input_file_2.close() | |
return shuffled | |
def generate_output(shuffled_tuples, output_file_name): | |
output_file = tf.gfile.GFile(output_file_name, mode='w') | |
for tup in shuffled_tuples: | |
formatted_tuple = ('\n{:<1}, {:<1}, {:<1}').format(tup[0], tup[1].rstrip(), | |
tup[2].rstrip()) | |
output_file.write(formatted_tuple) | |
output_file.close() | |
def main(_): | |
ml_samples_file = tf.gfile.GFile( | |
os.path.join(FLAGS.input_ml_path, 'reviews.txt'), mode='r') | |
gan_samples_file = tf.gfile.GFile( | |
os.path.join(FLAGS.input_gan_path, 'reviews.txt'), mode='r') | |
# Generate shuffled tuples. | |
shuffled_tuples = shuffle_samples(ml_samples_file, gan_samples_file) | |
# Output to file. | |
generate_output(shuffled_tuples, FLAGS.output_file_name) | |
if __name__ == '__main__': | |
tf.app.run() | |