Spaces:
Running
Running
# Copyright 2017 The TensorFlow Authors All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# ============================================================================== | |
"""PTB data loader and helpers.""" | |
from __future__ import absolute_import | |
from __future__ import division | |
from __future__ import print_function | |
import collections | |
import os | |
# Dependency imports | |
import numpy as np | |
import tensorflow as tf | |
EOS_INDEX = 0 | |
def _read_words(filename): | |
with tf.gfile.GFile(filename, "r") as f: | |
return f.read().decode("utf-8").replace("\n", "<eos>").split() | |
def build_vocab(filename): | |
data = _read_words(filename) | |
counter = collections.Counter(data) | |
count_pairs = sorted(counter.items(), key=lambda x: (-x[1], x[0])) | |
words, _ = list(zip(*count_pairs)) | |
word_to_id = dict(zip(words, range(len(words)))) | |
print("<eos>:", word_to_id["<eos>"]) | |
global EOS_INDEX | |
EOS_INDEX = word_to_id["<eos>"] | |
return word_to_id | |
def _file_to_word_ids(filename, word_to_id): | |
data = _read_words(filename) | |
return [word_to_id[word] for word in data if word in word_to_id] | |
def ptb_raw_data(data_path=None): | |
"""Load PTB raw data from data directory "data_path". | |
Reads PTB text files, converts strings to integer ids, | |
and performs mini-batching of the inputs. | |
The PTB dataset comes from Tomas Mikolov's webpage: | |
http://www.fit.vutbr.cz/~imikolov/rnnlm/simple-examples.tgz | |
Args: | |
data_path: string path to the directory where simple-examples.tgz has | |
been extracted. | |
Returns: | |
tuple (train_data, valid_data, test_data, vocabulary) | |
where each of the data objects can be passed to PTBIterator. | |
""" | |
train_path = os.path.join(data_path, "ptb.train.txt") | |
valid_path = os.path.join(data_path, "ptb.valid.txt") | |
test_path = os.path.join(data_path, "ptb.test.txt") | |
word_to_id = build_vocab(train_path) | |
train_data = _file_to_word_ids(train_path, word_to_id) | |
valid_data = _file_to_word_ids(valid_path, word_to_id) | |
test_data = _file_to_word_ids(test_path, word_to_id) | |
vocabulary = len(word_to_id) | |
return train_data, valid_data, test_data, vocabulary | |
def ptb_iterator(raw_data, batch_size, num_steps, epoch_size_override=None): | |
"""Iterate on the raw PTB data. | |
This generates batch_size pointers into the raw PTB data, and allows | |
minibatch iteration along these pointers. | |
Args: | |
raw_data: one of the raw data outputs from ptb_raw_data. | |
batch_size: int, the batch size. | |
num_steps: int, the number of unrolls. | |
Yields: | |
Pairs of the batched data, each a matrix of shape [batch_size, num_steps]. | |
The second element of the tuple is the same data time-shifted to the | |
right by one. | |
Raises: | |
ValueError: if batch_size or num_steps are too high. | |
""" | |
raw_data = np.array(raw_data, dtype=np.int32) | |
data_len = len(raw_data) | |
batch_len = data_len // batch_size | |
data = np.full([batch_size, batch_len], EOS_INDEX, dtype=np.int32) | |
for i in range(batch_size): | |
data[i] = raw_data[batch_len * i:batch_len * (i + 1)] | |
if epoch_size_override: | |
epoch_size = epoch_size_override | |
else: | |
epoch_size = (batch_len - 1) // num_steps | |
if epoch_size == 0: | |
raise ValueError("epoch_size == 0, decrease batch_size or num_steps") | |
# print("Number of batches per epoch: %d" % epoch_size) | |
for i in range(epoch_size): | |
x = data[:, i * num_steps:(i + 1) * num_steps] | |
y = data[:, i * num_steps + 1:(i + 1) * num_steps + 1] | |
w = np.ones_like(x) | |
yield (x, y, w) | |