Spaces:
Running
Running
# Copyright 2016 Google Inc. All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
"""RNN model with embeddings""" | |
from __future__ import absolute_import | |
from __future__ import division | |
from __future__ import print_function | |
import tensorflow as tf | |
class NamignizerModel(object): | |
"""The Namignizer model ~ strongly based on PTB""" | |
def __init__(self, is_training, config): | |
self.batch_size = batch_size = config.batch_size | |
self.num_steps = num_steps = config.num_steps | |
size = config.hidden_size | |
# will always be 27 | |
vocab_size = config.vocab_size | |
# placeholders for inputs | |
self._input_data = tf.placeholder(tf.int32, [batch_size, num_steps]) | |
self._targets = tf.placeholder(tf.int32, [batch_size, num_steps]) | |
# weights for the loss function | |
self._weights = tf.placeholder(tf.float32, [batch_size * num_steps]) | |
# lstm for our RNN cell (GRU supported too) | |
lstm_cells = [] | |
for layer in range(config.num_layers): | |
lstm_cell = tf.contrib.rnn.BasicLSTMCell(size, forget_bias=0.0) | |
if is_training and config.keep_prob < 1: | |
lstm_cell = tf.contrib.rnn.DropoutWrapper( | |
lstm_cell, output_keep_prob=config.keep_prob) | |
lstm_cells.append(lstm_cell) | |
cell = tf.contrib.rnn.MultiRNNCell(lstm_cells) | |
self._initial_state = cell.zero_state(batch_size, tf.float32) | |
with tf.device("/cpu:0"): | |
embedding = tf.get_variable("embedding", [vocab_size, size]) | |
inputs = tf.nn.embedding_lookup(embedding, self._input_data) | |
if is_training and config.keep_prob < 1: | |
inputs = tf.nn.dropout(inputs, config.keep_prob) | |
outputs = [] | |
state = self._initial_state | |
with tf.variable_scope("RNN"): | |
for time_step in range(num_steps): | |
if time_step > 0: | |
tf.get_variable_scope().reuse_variables() | |
(cell_output, state) = cell(inputs[:, time_step, :], state) | |
outputs.append(cell_output) | |
output = tf.reshape(tf.concat(axis=1, values=outputs), [-1, size]) | |
softmax_w = tf.get_variable("softmax_w", [size, vocab_size]) | |
softmax_b = tf.get_variable("softmax_b", [vocab_size]) | |
logits = tf.matmul(output, softmax_w) + softmax_b | |
loss = tf.contrib.legacy_seq2seq.sequence_loss_by_example( | |
[logits], | |
[tf.reshape(self._targets, [-1])], | |
[self._weights]) | |
self._loss = loss | |
self._cost = cost = tf.reduce_sum(loss) / batch_size | |
self._final_state = state | |
# probabilities of each letter | |
self._activations = tf.nn.softmax(logits) | |
# ability to save the model | |
self.saver = tf.train.Saver(tf.global_variables()) | |
if not is_training: | |
return | |
self._lr = tf.Variable(0.0, trainable=False) | |
tvars = tf.trainable_variables() | |
grads, _ = tf.clip_by_global_norm(tf.gradients(cost, tvars), | |
config.max_grad_norm) | |
optimizer = tf.train.GradientDescentOptimizer(self.lr) | |
self._train_op = optimizer.apply_gradients(zip(grads, tvars)) | |
def assign_lr(self, session, lr_value): | |
session.run(tf.assign(self.lr, lr_value)) | |
def input_data(self): | |
return self._input_data | |
def targets(self): | |
return self._targets | |
def activations(self): | |
return self._activations | |
def weights(self): | |
return self._weights | |
def initial_state(self): | |
return self._initial_state | |
def cost(self): | |
return self._cost | |
def loss(self): | |
return self._loss | |
def final_state(self): | |
return self._final_state | |
def lr(self): | |
return self._lr | |
def train_op(self): | |
return self._train_op | |