NCTCMumbai's picture
Upload 2571 files
0b8359d
# Copyright 2017 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Create TFRecord files of SequenceExample protos from dataset.
Constructs 3 datasets:
1. Labeled data for the LSTM classification model, optionally with label gain.
"*_classification.tfrecords" (for both unidirectional and bidirectional
models).
2. Data for the unsupervised LM-LSTM model that predicts the next token.
"*_lm.tfrecords" (generates forward and reverse data).
3. Data for the unsupervised SA-LSTM model that uses Seq2Seq.
"*_sa.tfrecords".
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import string
# Dependency imports
import tensorflow as tf
from data import data_utils
from data import document_generators
data = data_utils
flags = tf.app.flags
FLAGS = flags.FLAGS
# Flags for input data are in document_generators.py
flags.DEFINE_string('vocab_file', '', 'Path to the vocabulary file. Defaults '
'to FLAGS.output_dir/vocab.txt.')
flags.DEFINE_string('output_dir', '', 'Path to save tfrecords.')
# Config
flags.DEFINE_boolean('label_gain', False,
'Enable linear label gain. If True, sentiment label will '
'be included at each timestep with linear weight '
'increase.')
def build_shuffling_tf_record_writer(fname):
return data.ShufflingTFRecordWriter(os.path.join(FLAGS.output_dir, fname))
def build_tf_record_writer(fname):
return tf.python_io.TFRecordWriter(os.path.join(FLAGS.output_dir, fname))
def build_input_sequence(doc, vocab_ids):
"""Builds input sequence from file.
Splits lines on whitespace. Treats punctuation as whitespace. For word-level
sequences, only keeps terms that are in the vocab.
Terms are added as token in the SequenceExample. The EOS_TOKEN is also
appended. Label and weight features are set to 0.
Args:
doc: Document (defined in `document_generators`) from which to build the
sequence.
vocab_ids: dict<term, id>.
Returns:
SequenceExampleWrapper.
"""
seq = data.SequenceWrapper()
for token in document_generators.tokens(doc):
if token in vocab_ids:
seq.add_timestep().set_token(vocab_ids[token])
# Add EOS token to end
seq.add_timestep().set_token(vocab_ids[data.EOS_TOKEN])
return seq
def make_vocab_ids(vocab_filename):
if FLAGS.output_char:
ret = dict([(char, i) for i, char in enumerate(string.printable)])
ret[data.EOS_TOKEN] = len(string.printable)
return ret
else:
with open(vocab_filename, encoding='utf-8') as vocab_f:
return dict([(line.strip(), i) for i, line in enumerate(vocab_f)])
def generate_training_data(vocab_ids, writer_lm_all, writer_seq_ae_all):
"""Generates training data."""
# Construct training data writers
writer_lm = build_shuffling_tf_record_writer(data.TRAIN_LM)
writer_seq_ae = build_shuffling_tf_record_writer(data.TRAIN_SA)
writer_class = build_shuffling_tf_record_writer(data.TRAIN_CLASS)
writer_valid_class = build_tf_record_writer(data.VALID_CLASS)
writer_rev_lm = build_shuffling_tf_record_writer(data.TRAIN_REV_LM)
writer_bd_class = build_shuffling_tf_record_writer(data.TRAIN_BD_CLASS)
writer_bd_valid_class = build_shuffling_tf_record_writer(data.VALID_BD_CLASS)
for doc in document_generators.documents(
dataset='train', include_unlabeled=True, include_validation=True):
input_seq = build_input_sequence(doc, vocab_ids)
if len(input_seq) < 2:
continue
rev_seq = data.build_reverse_sequence(input_seq)
lm_seq = data.build_lm_sequence(input_seq)
rev_lm_seq = data.build_lm_sequence(rev_seq)
seq_ae_seq = data.build_seq_ae_sequence(input_seq)
if doc.label is not None:
# Used for sentiment classification.
label_seq = data.build_labeled_sequence(
input_seq,
doc.label,
label_gain=(FLAGS.label_gain and not doc.is_validation))
bd_label_seq = data.build_labeled_sequence(
data.build_bidirectional_seq(input_seq, rev_seq),
doc.label,
label_gain=(FLAGS.label_gain and not doc.is_validation))
class_writer = writer_valid_class if doc.is_validation else writer_class
bd_class_writer = (writer_bd_valid_class
if doc.is_validation else writer_bd_class)
class_writer.write(label_seq.seq.SerializeToString())
bd_class_writer.write(bd_label_seq.seq.SerializeToString())
# Write
lm_seq_ser = lm_seq.seq.SerializeToString()
seq_ae_seq_ser = seq_ae_seq.seq.SerializeToString()
writer_lm_all.write(lm_seq_ser)
writer_seq_ae_all.write(seq_ae_seq_ser)
if not doc.is_validation:
writer_lm.write(lm_seq_ser)
writer_rev_lm.write(rev_lm_seq.seq.SerializeToString())
writer_seq_ae.write(seq_ae_seq_ser)
# Close writers
writer_lm.close()
writer_seq_ae.close()
writer_class.close()
writer_valid_class.close()
writer_rev_lm.close()
writer_bd_class.close()
writer_bd_valid_class.close()
def generate_test_data(vocab_ids, writer_lm_all, writer_seq_ae_all):
"""Generates test data."""
# Construct test data writers
writer_lm = build_shuffling_tf_record_writer(data.TEST_LM)
writer_rev_lm = build_shuffling_tf_record_writer(data.TEST_REV_LM)
writer_seq_ae = build_shuffling_tf_record_writer(data.TEST_SA)
writer_class = build_tf_record_writer(data.TEST_CLASS)
writer_bd_class = build_shuffling_tf_record_writer(data.TEST_BD_CLASS)
for doc in document_generators.documents(
dataset='test', include_unlabeled=False, include_validation=True):
input_seq = build_input_sequence(doc, vocab_ids)
if len(input_seq) < 2:
continue
rev_seq = data.build_reverse_sequence(input_seq)
lm_seq = data.build_lm_sequence(input_seq)
rev_lm_seq = data.build_lm_sequence(rev_seq)
seq_ae_seq = data.build_seq_ae_sequence(input_seq)
label_seq = data.build_labeled_sequence(input_seq, doc.label)
bd_label_seq = data.build_labeled_sequence(
data.build_bidirectional_seq(input_seq, rev_seq), doc.label)
# Write
writer_class.write(label_seq.seq.SerializeToString())
writer_bd_class.write(bd_label_seq.seq.SerializeToString())
lm_seq_ser = lm_seq.seq.SerializeToString()
seq_ae_seq_ser = seq_ae_seq.seq.SerializeToString()
writer_lm.write(lm_seq_ser)
writer_rev_lm.write(rev_lm_seq.seq.SerializeToString())
writer_seq_ae.write(seq_ae_seq_ser)
writer_lm_all.write(lm_seq_ser)
writer_seq_ae_all.write(seq_ae_seq_ser)
# Close test writers
writer_lm.close()
writer_rev_lm.close()
writer_seq_ae.close()
writer_class.close()
writer_bd_class.close()
def main(_):
tf.logging.set_verbosity(tf.logging.INFO)
tf.logging.info('Assigning vocabulary ids...')
vocab_ids = make_vocab_ids(
FLAGS.vocab_file or os.path.join(FLAGS.output_dir, 'vocab.txt'))
with build_shuffling_tf_record_writer(data.ALL_LM) as writer_lm_all:
with build_shuffling_tf_record_writer(data.ALL_SA) as writer_seq_ae_all:
tf.logging.info('Generating training data...')
generate_training_data(vocab_ids, writer_lm_all, writer_seq_ae_all)
tf.logging.info('Generating test data...')
generate_test_data(vocab_ids, writer_lm_all, writer_seq_ae_all)
if __name__ == '__main__':
tf.app.run()