Spaces:
Running
Running
# Copyright 2017 Google Inc. All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# ============================================================================== | |
"""Create TFRecord files of SequenceExample protos from dataset. | |
Constructs 3 datasets: | |
1. Labeled data for the LSTM classification model, optionally with label gain. | |
"*_classification.tfrecords" (for both unidirectional and bidirectional | |
models). | |
2. Data for the unsupervised LM-LSTM model that predicts the next token. | |
"*_lm.tfrecords" (generates forward and reverse data). | |
3. Data for the unsupervised SA-LSTM model that uses Seq2Seq. | |
"*_sa.tfrecords". | |
""" | |
from __future__ import absolute_import | |
from __future__ import division | |
from __future__ import print_function | |
import os | |
import string | |
# Dependency imports | |
import tensorflow as tf | |
from data import data_utils | |
from data import document_generators | |
data = data_utils | |
flags = tf.app.flags | |
FLAGS = flags.FLAGS | |
# Flags for input data are in document_generators.py | |
flags.DEFINE_string('vocab_file', '', 'Path to the vocabulary file. Defaults ' | |
'to FLAGS.output_dir/vocab.txt.') | |
flags.DEFINE_string('output_dir', '', 'Path to save tfrecords.') | |
# Config | |
flags.DEFINE_boolean('label_gain', False, | |
'Enable linear label gain. If True, sentiment label will ' | |
'be included at each timestep with linear weight ' | |
'increase.') | |
def build_shuffling_tf_record_writer(fname): | |
return data.ShufflingTFRecordWriter(os.path.join(FLAGS.output_dir, fname)) | |
def build_tf_record_writer(fname): | |
return tf.python_io.TFRecordWriter(os.path.join(FLAGS.output_dir, fname)) | |
def build_input_sequence(doc, vocab_ids): | |
"""Builds input sequence from file. | |
Splits lines on whitespace. Treats punctuation as whitespace. For word-level | |
sequences, only keeps terms that are in the vocab. | |
Terms are added as token in the SequenceExample. The EOS_TOKEN is also | |
appended. Label and weight features are set to 0. | |
Args: | |
doc: Document (defined in `document_generators`) from which to build the | |
sequence. | |
vocab_ids: dict<term, id>. | |
Returns: | |
SequenceExampleWrapper. | |
""" | |
seq = data.SequenceWrapper() | |
for token in document_generators.tokens(doc): | |
if token in vocab_ids: | |
seq.add_timestep().set_token(vocab_ids[token]) | |
# Add EOS token to end | |
seq.add_timestep().set_token(vocab_ids[data.EOS_TOKEN]) | |
return seq | |
def make_vocab_ids(vocab_filename): | |
if FLAGS.output_char: | |
ret = dict([(char, i) for i, char in enumerate(string.printable)]) | |
ret[data.EOS_TOKEN] = len(string.printable) | |
return ret | |
else: | |
with open(vocab_filename, encoding='utf-8') as vocab_f: | |
return dict([(line.strip(), i) for i, line in enumerate(vocab_f)]) | |
def generate_training_data(vocab_ids, writer_lm_all, writer_seq_ae_all): | |
"""Generates training data.""" | |
# Construct training data writers | |
writer_lm = build_shuffling_tf_record_writer(data.TRAIN_LM) | |
writer_seq_ae = build_shuffling_tf_record_writer(data.TRAIN_SA) | |
writer_class = build_shuffling_tf_record_writer(data.TRAIN_CLASS) | |
writer_valid_class = build_tf_record_writer(data.VALID_CLASS) | |
writer_rev_lm = build_shuffling_tf_record_writer(data.TRAIN_REV_LM) | |
writer_bd_class = build_shuffling_tf_record_writer(data.TRAIN_BD_CLASS) | |
writer_bd_valid_class = build_shuffling_tf_record_writer(data.VALID_BD_CLASS) | |
for doc in document_generators.documents( | |
dataset='train', include_unlabeled=True, include_validation=True): | |
input_seq = build_input_sequence(doc, vocab_ids) | |
if len(input_seq) < 2: | |
continue | |
rev_seq = data.build_reverse_sequence(input_seq) | |
lm_seq = data.build_lm_sequence(input_seq) | |
rev_lm_seq = data.build_lm_sequence(rev_seq) | |
seq_ae_seq = data.build_seq_ae_sequence(input_seq) | |
if doc.label is not None: | |
# Used for sentiment classification. | |
label_seq = data.build_labeled_sequence( | |
input_seq, | |
doc.label, | |
label_gain=(FLAGS.label_gain and not doc.is_validation)) | |
bd_label_seq = data.build_labeled_sequence( | |
data.build_bidirectional_seq(input_seq, rev_seq), | |
doc.label, | |
label_gain=(FLAGS.label_gain and not doc.is_validation)) | |
class_writer = writer_valid_class if doc.is_validation else writer_class | |
bd_class_writer = (writer_bd_valid_class | |
if doc.is_validation else writer_bd_class) | |
class_writer.write(label_seq.seq.SerializeToString()) | |
bd_class_writer.write(bd_label_seq.seq.SerializeToString()) | |
# Write | |
lm_seq_ser = lm_seq.seq.SerializeToString() | |
seq_ae_seq_ser = seq_ae_seq.seq.SerializeToString() | |
writer_lm_all.write(lm_seq_ser) | |
writer_seq_ae_all.write(seq_ae_seq_ser) | |
if not doc.is_validation: | |
writer_lm.write(lm_seq_ser) | |
writer_rev_lm.write(rev_lm_seq.seq.SerializeToString()) | |
writer_seq_ae.write(seq_ae_seq_ser) | |
# Close writers | |
writer_lm.close() | |
writer_seq_ae.close() | |
writer_class.close() | |
writer_valid_class.close() | |
writer_rev_lm.close() | |
writer_bd_class.close() | |
writer_bd_valid_class.close() | |
def generate_test_data(vocab_ids, writer_lm_all, writer_seq_ae_all): | |
"""Generates test data.""" | |
# Construct test data writers | |
writer_lm = build_shuffling_tf_record_writer(data.TEST_LM) | |
writer_rev_lm = build_shuffling_tf_record_writer(data.TEST_REV_LM) | |
writer_seq_ae = build_shuffling_tf_record_writer(data.TEST_SA) | |
writer_class = build_tf_record_writer(data.TEST_CLASS) | |
writer_bd_class = build_shuffling_tf_record_writer(data.TEST_BD_CLASS) | |
for doc in document_generators.documents( | |
dataset='test', include_unlabeled=False, include_validation=True): | |
input_seq = build_input_sequence(doc, vocab_ids) | |
if len(input_seq) < 2: | |
continue | |
rev_seq = data.build_reverse_sequence(input_seq) | |
lm_seq = data.build_lm_sequence(input_seq) | |
rev_lm_seq = data.build_lm_sequence(rev_seq) | |
seq_ae_seq = data.build_seq_ae_sequence(input_seq) | |
label_seq = data.build_labeled_sequence(input_seq, doc.label) | |
bd_label_seq = data.build_labeled_sequence( | |
data.build_bidirectional_seq(input_seq, rev_seq), doc.label) | |
# Write | |
writer_class.write(label_seq.seq.SerializeToString()) | |
writer_bd_class.write(bd_label_seq.seq.SerializeToString()) | |
lm_seq_ser = lm_seq.seq.SerializeToString() | |
seq_ae_seq_ser = seq_ae_seq.seq.SerializeToString() | |
writer_lm.write(lm_seq_ser) | |
writer_rev_lm.write(rev_lm_seq.seq.SerializeToString()) | |
writer_seq_ae.write(seq_ae_seq_ser) | |
writer_lm_all.write(lm_seq_ser) | |
writer_seq_ae_all.write(seq_ae_seq_ser) | |
# Close test writers | |
writer_lm.close() | |
writer_rev_lm.close() | |
writer_seq_ae.close() | |
writer_class.close() | |
writer_bd_class.close() | |
def main(_): | |
tf.logging.set_verbosity(tf.logging.INFO) | |
tf.logging.info('Assigning vocabulary ids...') | |
vocab_ids = make_vocab_ids( | |
FLAGS.vocab_file or os.path.join(FLAGS.output_dir, 'vocab.txt')) | |
with build_shuffling_tf_record_writer(data.ALL_LM) as writer_lm_all: | |
with build_shuffling_tf_record_writer(data.ALL_SA) as writer_seq_ae_all: | |
tf.logging.info('Generating training data...') | |
generate_training_data(vocab_ids, writer_lm_all, writer_seq_ae_all) | |
tf.logging.info('Generating test data...') | |
generate_test_data(vocab_ids, writer_lm_all, writer_seq_ae_all) | |
if __name__ == '__main__': | |
tf.app.run() | |