NCTCMumbai's picture
Upload 2571 files
0b8359d
raw
history blame
12.5 kB
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import tensorflow as tf
from util.nn import fc_layer as fc
def _get_lstm_cell(num_layers, lstm_dim):
cell_list = [
tf.contrib.rnn.BasicLSTMCell(lstm_dim, state_is_tuple=True)
for _ in range(num_layers)
]
cell = tf.contrib.rnn.MultiRNNCell(cell_list, state_is_tuple=True)
return cell
class AttentionSeq2Seq:
def __init__(self,
config,
text_seq_batch,
seq_length_batch,
num_vocab_txt,
num_vocab_nmn,
EOS_token,
decoder_sampling,
embedding_mat,
use_gt_layout=None,
gt_layout_batch=None,
scope='encoder_decoder',
reuse=None):
self.T_decoder = config.T_decoder
self.encoder_num_vocab = num_vocab_txt
self.encoder_embed_dim = config.embed_dim_txt
self.decoder_num_vocab = num_vocab_nmn
self.decoder_embed_dim = config.embed_dim_nmn
self.lstm_dim = config.lstm_dim
self.num_layers = config.num_layers
self.EOS_token = EOS_token
self.decoder_sampling = decoder_sampling
self.embedding_mat = embedding_mat
with tf.variable_scope(scope, reuse=reuse):
self._build_encoder(text_seq_batch, seq_length_batch)
self._build_decoder(use_gt_layout, gt_layout_batch)
def _build_encoder(self,
text_seq_batch,
seq_length_batch,
scope='encoder',
reuse=None):
lstm_dim = self.lstm_dim
num_layers = self.num_layers
with tf.variable_scope(scope, reuse=reuse):
T = tf.shape(text_seq_batch)[0]
N = tf.shape(text_seq_batch)[1]
self.T_encoder = T
self.N = N
# text_seq has shape [T, N] and embedded_seq has shape [T, N, D]
embedded_seq = tf.nn.embedding_lookup(self.embedding_mat, text_seq_batch)
self.embedded_input_seq = embedded_seq
# The RNN
cell = _get_lstm_cell(num_layers, lstm_dim)
# encoder_outputs has shape [T, N, lstm_dim]
encoder_outputs, encoder_states = tf.nn.dynamic_rnn(
cell,
embedded_seq,
seq_length_batch,
dtype=tf.float32,
time_major=True,
scope='lstm')
self.encoder_outputs = encoder_outputs
self.encoder_states = encoder_states
# transform the encoder outputs for further attention alignments
# encoder_outputs_flat has shape [T, N, lstm_dim]
encoder_h_transformed = fc(
'encoder_h_transform',
tf.reshape(encoder_outputs, [-1, lstm_dim]),
output_dim=lstm_dim)
encoder_h_transformed = tf.reshape(encoder_h_transformed,
[T, N, lstm_dim])
self.encoder_h_transformed = encoder_h_transformed
# seq_not_finished is a shape [T, N, 1] tensor,
# where seq_not_finished[t, n]
# is 1 iff sequence n is not finished at time t, and 0 otherwise
seq_not_finished = tf.less(
tf.range(T)[:, tf.newaxis, tf.newaxis],
seq_length_batch[:, tf.newaxis])
seq_not_finished = tf.cast(seq_not_finished, tf.float32)
self.seq_not_finished = seq_not_finished
def _build_decoder(self,
use_gt_layout,
gt_layout_batch,
scope='decoder',
reuse=None):
# The main difference from before is that the decoders now takes another
# input (the attention) when computing the next step
# T_max is the maximum length of decoded sequence (including <eos>)
#
# This function is for decoding only. It performs greedy search or sampling.
# the first input is <go> (its embedding vector) and the subsequent inputs
# are the outputs from previous time step
# num_vocab does not include <go>
#
# use_gt_layout is None or a bool tensor, and gt_layout_batch is a tensor
# with shape [T_max, N].
# If use_gt_layout is not None, then when use_gt_layout is true, predict
# exactly the tokens in gt_layout_batch, regardless of actual probability.
# Otherwise, if sampling is True, sample from the token probability
# If sampling is False, do greedy decoding (beam size 1)
N = self.N
encoder_states = self.encoder_states
T_max = self.T_decoder
lstm_dim = self.lstm_dim
num_layers = self.num_layers
EOS_token = self.EOS_token
sampling = self.decoder_sampling
with tf.variable_scope(scope, reuse=reuse):
embedding_mat = tf.get_variable(
'embedding_mat', [self.decoder_num_vocab, self.decoder_embed_dim])
# we use a separate embedding for <go>, as it is only used in the
# beginning of the sequence
go_embedding = tf.get_variable('go_embedding',
[1, self.decoder_embed_dim])
with tf.variable_scope('att_prediction'):
v = tf.get_variable('v', [lstm_dim])
W_a = tf.get_variable(
'weights', [lstm_dim, lstm_dim],
initializer=tf.contrib.layers.xavier_initializer())
b_a = tf.get_variable(
'biases', lstm_dim, initializer=tf.constant_initializer(0.))
# The parameters to predict the next token
with tf.variable_scope('token_prediction'):
W_y = tf.get_variable(
'weights', [lstm_dim * 2, self.decoder_num_vocab],
initializer=tf.contrib.layers.xavier_initializer())
b_y = tf.get_variable(
'biases',
self.decoder_num_vocab,
initializer=tf.constant_initializer(0.))
# Attentional decoding
# Loop function is called at time t BEFORE the cell execution at time t,
# and its next_input is used as the input at time t (not t+1)
# c.f. https://www.tensorflow.org/api_docs/python/tf/nn/raw_rnn
mask_range = tf.reshape(
tf.range(self.decoder_num_vocab, dtype=tf.int32), [1, -1])
all_eos_pred = EOS_token * tf.ones([N], tf.int32)
all_one_prob = tf.ones([N], tf.float32)
all_zero_entropy = tf.zeros([N], tf.float32)
if use_gt_layout is not None:
gt_layout_mult = tf.cast(use_gt_layout, tf.int32)
pred_layout_mult = 1 - gt_layout_mult
def loop_fn(time, cell_output, cell_state, loop_state):
if cell_output is None: # time == 0
next_cell_state = encoder_states
next_input = tf.tile(go_embedding, [N, 1])
else: # time > 0
next_cell_state = cell_state
# compute the attention map over the input sequence
# a_raw has shape [T, N, 1]
att_raw = tf.reduce_sum(
tf.tanh(
tf.nn.xw_plus_b(cell_output, W_a, b_a) +
self.encoder_h_transformed) * v,
axis=2,
keep_dims=True)
# softmax along the first dimension (T) over not finished examples
# att has shape [T, N, 1]
att = tf.nn.softmax(att_raw, dim=0) * self.seq_not_finished
att = att / tf.reduce_sum(att, axis=0, keep_dims=True)
# d has shape [N, lstm_dim]
d2 = tf.reduce_sum(att * self.encoder_outputs, axis=0)
# token_scores has shape [N, num_vocab]
token_scores = tf.nn.xw_plus_b(
tf.concat([cell_output, d2], axis=1), W_y, b_y)
# predict the next token (behavior depending on parameters)
if sampling:
# predicted_token has shape [N]
logits = token_scores
predicted_token = tf.cast(
tf.reshape(tf.multinomial(token_scores, 1), [-1]), tf.int32)
else:
# predicted_token has shape [N]
predicted_token = tf.cast(tf.argmax(token_scores, 1), tf.int32)
if use_gt_layout is not None:
predicted_token = (gt_layout_batch[time - 1] * gt_layout_mult +
predicted_token * pred_layout_mult)
# token_prob has shape [N], the probability of the predicted token
# although token_prob is not needed for predicting the next token
# it is needed in output (for policy gradient training)
# [N, num_vocab]
# mask has shape [N, num_vocab]
mask = tf.equal(mask_range, tf.reshape(predicted_token, [-1, 1]))
all_token_probs = tf.nn.softmax(token_scores)
token_prob = tf.reduce_sum(
all_token_probs * tf.cast(mask, tf.float32), axis=1)
neg_entropy = tf.reduce_sum(
all_token_probs * tf.log(all_token_probs), axis=1)
# is_eos_predicted is a [N] bool tensor, indicating whether
# <eos> has already been predicted previously in each sequence
is_eos_predicted = loop_state[2]
predicted_token_old = predicted_token
# if <eos> has already been predicted, now predict <eos> with
# prob 1
predicted_token = tf.where(is_eos_predicted, all_eos_pred,
predicted_token)
token_prob = tf.where(is_eos_predicted, all_one_prob, token_prob)
neg_entropy = tf.where(is_eos_predicted, all_zero_entropy,
neg_entropy)
is_eos_predicted = tf.logical_or(is_eos_predicted,
tf.equal(predicted_token_old,
EOS_token))
# the prediction is from the cell output of the last step
# timestep (t-1), feed it as input into timestep t
next_input = tf.nn.embedding_lookup(embedding_mat, predicted_token)
elements_finished = tf.greater_equal(time, T_max)
# loop_state is a 5-tuple, representing
# 1) the predicted_tokens
# 2) the prob of predicted_tokens
# 3) whether <eos> has already been predicted
# 4) the negative entropy of policy (accumulated across timesteps)
# 5) the attention
if loop_state is None: # time == 0
# Write the predicted token into the output
predicted_token_array = tf.TensorArray(
dtype=tf.int32, size=T_max, infer_shape=False)
token_prob_array = tf.TensorArray(
dtype=tf.float32, size=T_max, infer_shape=False)
att_array = tf.TensorArray(
dtype=tf.float32, size=T_max, infer_shape=False)
next_loop_state = (predicted_token_array, token_prob_array, tf.zeros(
[N], dtype=tf.bool), tf.zeros([N], dtype=tf.float32), att_array)
else: # time > 0
t_write = time - 1
next_loop_state = (
loop_state[0].write(t_write, predicted_token),
loop_state[1].write(t_write, token_prob),
is_eos_predicted,
loop_state[3] + neg_entropy,
loop_state[4].write(t_write, att))
return (elements_finished, next_input, next_cell_state, cell_output,
next_loop_state)
# The RNN
cell = _get_lstm_cell(num_layers, lstm_dim)
_, _, decodes_ta = tf.nn.raw_rnn(cell, loop_fn, scope='lstm')
predicted_tokens = decodes_ta[0].stack()
token_probs = decodes_ta[1].stack()
neg_entropy = decodes_ta[3]
# atts has shape [T_decoder, T_encoder, N, 1]
atts = decodes_ta[4].stack()
self.atts = atts
# word_vec has shape [T_decoder, N, D]
word_vecs = tf.reduce_sum(atts * self.embedded_input_seq, axis=1)
predicted_tokens.set_shape([None, None])
token_probs.set_shape([None, None])
neg_entropy.set_shape([None])
word_vecs.set_shape([None, None, self.encoder_embed_dim])
self.predicted_tokens = predicted_tokens
self.token_probs = token_probs
self.neg_entropy = neg_entropy
self.word_vecs = word_vecs