NCTC / models /research /maskgan /data /imdb_loader.py
NCTCMumbai's picture
Upload 2571 files
0b8359d
raw
history blame
4.13 kB
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""IMDB data loader and helpers."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
# Dependency imports
import numpy as np
import tensorflow as tf
FLAGS = tf.app.flags.FLAGS
tf.app.flags.DEFINE_boolean('prefix_label', True, 'Vocabulary file.')
np.set_printoptions(precision=3)
np.set_printoptions(suppress=True)
EOS_INDEX = 88892
def _read_words(filename, use_prefix=True):
all_words = []
sequence_example = tf.train.SequenceExample()
for r in tf.python_io.tf_record_iterator(filename):
sequence_example.ParseFromString(r)
if FLAGS.prefix_label and use_prefix:
label = sequence_example.context.feature['class'].int64_list.value[0]
review_words = [EOS_INDEX + 1 + label]
else:
review_words = []
review_words.extend([
f.int64_list.value[0]
for f in sequence_example.feature_lists.feature_list['token_id'].feature
])
all_words.append(review_words)
return all_words
def build_vocab(vocab_file):
word_to_id = {}
with tf.gfile.GFile(vocab_file, 'r') as f:
index = 0
for word in f:
word_to_id[word.strip()] = index
index += 1
word_to_id['<eos>'] = EOS_INDEX
return word_to_id
def imdb_raw_data(data_path=None):
"""Load IMDB raw data from data directory "data_path".
Reads IMDB tf record files containing integer ids,
and performs mini-batching of the inputs.
Args:
data_path: string path to the directory where simple-examples.tgz has
been extracted.
Returns:
tuple (train_data, valid_data)
where each of the data objects can be passed to IMDBIterator.
"""
train_path = os.path.join(data_path, 'train_lm.tfrecords')
valid_path = os.path.join(data_path, 'test_lm.tfrecords')
train_data = _read_words(train_path)
valid_data = _read_words(valid_path)
return train_data, valid_data
def imdb_iterator(raw_data, batch_size, num_steps, epoch_size_override=None):
"""Iterate on the raw IMDB data.
This generates batch_size pointers into the raw IMDB data, and allows
minibatch iteration along these pointers.
Args:
raw_data: one of the raw data outputs from imdb_raw_data.
batch_size: int, the batch size.
num_steps: int, the number of unrolls.
Yields:
Pairs of the batched data, each a matrix of shape [batch_size, num_steps].
The second element of the tuple is the same data time-shifted to the
right by one. The third is a set of weights with 1 indicating a word was
present and 0 not.
Raises:
ValueError: if batch_size or num_steps are too high.
"""
del epoch_size_override
data_len = len(raw_data)
num_batches = data_len // batch_size - 1
for batch in range(num_batches):
x = np.zeros([batch_size, num_steps], dtype=np.int32)
y = np.zeros([batch_size, num_steps], dtype=np.int32)
w = np.zeros([batch_size, num_steps], dtype=np.float)
for i in range(batch_size):
data_index = batch * batch_size + i
example = raw_data[data_index]
if len(example) > num_steps:
final_x = example[:num_steps]
final_y = example[1:(num_steps + 1)]
w[i] = 1
else:
to_fill_in = num_steps - len(example)
final_x = example + [EOS_INDEX] * to_fill_in
final_y = final_x[1:] + [EOS_INDEX]
w[i] = [1] * len(example) + [0] * to_fill_in
x[i] = final_x
y[i] = final_y
yield (x, y, w)