Spaces:
Running
Running
# Copyright 2017 Google Inc. All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# ============================================================================== | |
"""Tests for graphs.""" | |
from __future__ import absolute_import | |
from __future__ import division | |
from __future__ import print_function | |
from collections import defaultdict | |
import operator | |
import os | |
import random | |
import shutil | |
import string | |
import tempfile | |
# Dependency imports | |
import tensorflow as tf | |
import graphs | |
from data import data_utils | |
flags = tf.app.flags | |
FLAGS = flags.FLAGS | |
data = data_utils | |
flags.DEFINE_integer('task', 0, 'Task id; needed for SyncReplicas test') | |
def _build_random_vocabulary(vocab_size=100): | |
"""Builds and returns a dict<term, id>.""" | |
vocab = set() | |
while len(vocab) < (vocab_size - 1): | |
rand_word = ''.join( | |
random.choice(string.ascii_lowercase) | |
for _ in range(random.randint(1, 10))) | |
vocab.add(rand_word) | |
vocab_ids = dict([(word, i) for i, word in enumerate(vocab)]) | |
vocab_ids[data.EOS_TOKEN] = vocab_size - 1 | |
return vocab_ids | |
def _build_random_sequence(vocab_ids): | |
seq_len = random.randint(10, 200) | |
ids = vocab_ids.values() | |
seq = data.SequenceWrapper() | |
for token_id in [random.choice(ids) for _ in range(seq_len)]: | |
seq.add_timestep().set_token(token_id) | |
return seq | |
def _build_vocab_frequencies(seqs, vocab_ids): | |
vocab_freqs = defaultdict(int) | |
ids_to_words = dict([(i, word) for word, i in vocab_ids.iteritems()]) | |
for seq in seqs: | |
for timestep in seq: | |
vocab_freqs[ids_to_words[timestep.token]] += 1 | |
vocab_freqs[data.EOS_TOKEN] = 0 | |
return vocab_freqs | |
class GraphsTest(tf.test.TestCase): | |
"""Test graph construction methods.""" | |
def setUpClass(cls): | |
# Make model small | |
FLAGS.batch_size = 2 | |
FLAGS.num_timesteps = 3 | |
FLAGS.embedding_dims = 4 | |
FLAGS.rnn_num_layers = 2 | |
FLAGS.rnn_cell_size = 4 | |
FLAGS.cl_num_layers = 2 | |
FLAGS.cl_hidden_size = 4 | |
FLAGS.vocab_size = 10 | |
# Set input/output flags | |
FLAGS.data_dir = tempfile.mkdtemp() | |
# Build and write sequence files. | |
vocab_ids = _build_random_vocabulary(FLAGS.vocab_size) | |
seqs = [_build_random_sequence(vocab_ids) for _ in range(5)] | |
seqs_label = [ | |
data.build_labeled_sequence(seq, random.choice([True, False])) | |
for seq in seqs | |
] | |
seqs_lm = [data.build_lm_sequence(seq) for seq in seqs] | |
seqs_ae = [data.build_seq_ae_sequence(seq) for seq in seqs] | |
seqs_rev = [data.build_reverse_sequence(seq) for seq in seqs] | |
seqs_bidir = [ | |
data.build_bidirectional_seq(seq, rev) | |
for seq, rev in zip(seqs, seqs_rev) | |
] | |
seqs_bidir_label = [ | |
data.build_labeled_sequence(bd_seq, random.choice([True, False])) | |
for bd_seq in seqs_bidir | |
] | |
filenames = [ | |
data.TRAIN_CLASS, data.TRAIN_LM, data.TRAIN_SA, data.TEST_CLASS, | |
data.TRAIN_REV_LM, data.TRAIN_BD_CLASS, data.TEST_BD_CLASS | |
] | |
seq_lists = [ | |
seqs_label, seqs_lm, seqs_ae, seqs_label, seqs_rev, seqs_bidir, | |
seqs_bidir_label | |
] | |
for fname, seq_list in zip(filenames, seq_lists): | |
with tf.python_io.TFRecordWriter( | |
os.path.join(FLAGS.data_dir, fname)) as writer: | |
for seq in seq_list: | |
writer.write(seq.seq.SerializeToString()) | |
# Write vocab.txt and vocab_freq.txt | |
vocab_freqs = _build_vocab_frequencies(seqs, vocab_ids) | |
ordered_vocab_freqs = sorted( | |
vocab_freqs.items(), key=operator.itemgetter(1), reverse=True) | |
with open(os.path.join(FLAGS.data_dir, 'vocab.txt'), 'w') as vocab_f: | |
with open(os.path.join(FLAGS.data_dir, 'vocab_freq.txt'), 'w') as freq_f: | |
for word, freq in ordered_vocab_freqs: | |
vocab_f.write('{}\n'.format(word)) | |
freq_f.write('{}\n'.format(freq)) | |
def tearDownClass(cls): | |
shutil.rmtree(FLAGS.data_dir) | |
def setUp(self): | |
# Reset FLAGS | |
FLAGS.rnn_num_layers = 1 | |
FLAGS.sync_replicas = False | |
FLAGS.adv_training_method = None | |
FLAGS.num_candidate_samples = -1 | |
FLAGS.num_classes = 2 | |
FLAGS.use_seq2seq_autoencoder = False | |
# Reset Graph | |
tf.reset_default_graph() | |
def testClassifierGraph(self): | |
FLAGS.rnn_num_layers = 2 | |
model = graphs.VatxtModel() | |
train_op, _, _ = model.classifier_training() | |
# Pretrained vars: embedding + LSTM layers | |
self.assertEqual( | |
len(model.pretrained_variables), 1 + 2 * FLAGS.rnn_num_layers) | |
with self.test_session() as sess: | |
sess.run(tf.global_variables_initializer()) | |
tf.train.start_queue_runners(sess) | |
sess.run(train_op) | |
def testLanguageModelGraph(self): | |
train_op, _, _ = graphs.VatxtModel().language_model_training() | |
with self.test_session() as sess: | |
sess.run(tf.global_variables_initializer()) | |
tf.train.start_queue_runners(sess) | |
sess.run(train_op) | |
def testMulticlass(self): | |
FLAGS.num_classes = 10 | |
graphs.VatxtModel().classifier_graph() | |
def testATMethods(self): | |
at_methods = [None, 'rp', 'at', 'vat', 'atvat'] | |
for method in at_methods: | |
FLAGS.adv_training_method = method | |
with tf.Graph().as_default(): | |
graphs.VatxtModel().classifier_graph() | |
# Ensure variables have been reused | |
# Embedding + LSTM layers + hidden layers + logits layer | |
expected_num_vars = 1 + 2 * FLAGS.rnn_num_layers + 2 * ( | |
FLAGS.cl_num_layers) + 2 | |
self.assertEqual(len(tf.trainable_variables()), expected_num_vars) | |
def testSyncReplicas(self): | |
FLAGS.sync_replicas = True | |
graphs.VatxtModel().language_model_training() | |
def testCandidateSampling(self): | |
FLAGS.num_candidate_samples = 10 | |
graphs.VatxtModel().language_model_training() | |
def testSeqAE(self): | |
FLAGS.use_seq2seq_autoencoder = True | |
graphs.VatxtModel().language_model_training() | |
def testBidirLM(self): | |
graphs.VatxtBidirModel().language_model_graph() | |
def testBidirClassifier(self): | |
at_methods = [None, 'rp', 'at', 'vat', 'atvat'] | |
for method in at_methods: | |
FLAGS.adv_training_method = method | |
with tf.Graph().as_default(): | |
graphs.VatxtBidirModel().classifier_graph() | |
# Ensure variables have been reused | |
# Embedding + 2 LSTM layers + hidden layers + logits layer | |
expected_num_vars = 1 + 2 * 2 * FLAGS.rnn_num_layers + 2 * ( | |
FLAGS.cl_num_layers) + 2 | |
self.assertEqual(len(tf.trainable_variables()), expected_num_vars) | |
def testEvalGraph(self): | |
_, _ = graphs.VatxtModel().eval_graph() | |
def testBidirEvalGraph(self): | |
_, _ = graphs.VatxtBidirModel().eval_graph() | |
if __name__ == '__main__': | |
tf.test.main() | |