Handwriting_Model_Inf / rnn_cell.py
3morrrrr's picture
Upload 14 files
569596a verified
from collections import namedtuple
import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()
import tensorflow_probability as tfp
tfd = tfp.distributions
import numpy as np
from tf_utils import dense_layer, shape
LSTMAttentionCellState = namedtuple(
'LSTMAttentionCellState',
['h1', 'c1', 'h2', 'c2', 'h3', 'c3', 'alpha', 'beta', 'kappa', 'w', 'phi']
)
class LSTMAttentionCell(tf.nn.rnn_cell.RNNCell):
def __init__(
self,
lstm_size,
num_attn_mixture_components,
attention_values,
attention_values_lengths,
num_output_mixture_components,
bias,
reuse=None,
):
self.reuse = reuse
self.lstm_size = lstm_size
self.num_attn_mixture_components = num_attn_mixture_components
self.attention_values = attention_values
self.attention_values_lengths = attention_values_lengths
self.window_size = shape(self.attention_values, 2)
self.char_len = tf.shape(attention_values)[1]
self.batch_size = tf.shape(attention_values)[0]
self.num_output_mixture_components = num_output_mixture_components
self.output_units = 6*self.num_output_mixture_components + 1
self.bias = bias
@property
def state_size(self):
return LSTMAttentionCellState(
self.lstm_size,
self.lstm_size,
self.lstm_size,
self.lstm_size,
self.lstm_size,
self.lstm_size,
self.num_attn_mixture_components,
self.num_attn_mixture_components,
self.num_attn_mixture_components,
self.window_size,
self.char_len,
)
@property
def output_size(self):
return self.lstm_size
def zero_state(self, batch_size, dtype):
return LSTMAttentionCellState(
tf.zeros([batch_size, self.lstm_size]),
tf.zeros([batch_size, self.lstm_size]),
tf.zeros([batch_size, self.lstm_size]),
tf.zeros([batch_size, self.lstm_size]),
tf.zeros([batch_size, self.lstm_size]),
tf.zeros([batch_size, self.lstm_size]),
tf.zeros([batch_size, self.num_attn_mixture_components]),
tf.zeros([batch_size, self.num_attn_mixture_components]),
tf.zeros([batch_size, self.num_attn_mixture_components]),
tf.zeros([batch_size, self.window_size]),
tf.zeros([batch_size, self.char_len]),
)
def __call__(self, inputs, state, scope=None):
with tf.variable_scope(scope or type(self).__name__, reuse=tf.AUTO_REUSE):
# lstm 1
s1_in = tf.concat([state.w, inputs], axis=1)
cell1 = tf.compat.v1.nn.rnn_cell.LSTMCell(self.lstm_size)
s1_out, s1_state = cell1(s1_in, state=(state.c1, state.h1))
# attention
attention_inputs = tf.concat([state.w, inputs, s1_out], axis=1)
attention_params = dense_layer(attention_inputs, 3*self.num_attn_mixture_components, scope='attention')
alpha, beta, kappa = tf.split(tf.nn.softplus(attention_params), 3, axis=1)
kappa = state.kappa + kappa / 25.0
beta = tf.clip_by_value(beta, .01, np.inf)
kappa_flat, alpha_flat, beta_flat = kappa, alpha, beta
kappa, alpha, beta = tf.expand_dims(kappa, 2), tf.expand_dims(alpha, 2), tf.expand_dims(beta, 2)
enum = tf.reshape(tf.range(self.char_len), (1, 1, self.char_len))
u = tf.cast(tf.tile(enum, (self.batch_size, self.num_attn_mixture_components, 1)), tf.float32)
phi_flat = tf.reduce_sum(alpha*tf.exp(-tf.square(kappa - u) / beta), axis=1)
phi = tf.expand_dims(phi_flat, 2)
sequence_mask = tf.cast(tf.sequence_mask(self.attention_values_lengths, maxlen=self.char_len), tf.float32)
sequence_mask = tf.expand_dims(sequence_mask, 2)
w = tf.reduce_sum(phi*self.attention_values*sequence_mask, axis=1)
# lstm 2
s2_in = tf.concat([inputs, s1_out, w], axis=1)
cell2 = tf.compat.v1.nn.rnn_cell.LSTMCell(self.lstm_size)
s2_out, s2_state = cell2(s2_in, state=(state.c2, state.h2))
# lstm 3
s3_in = tf.concat([inputs, s2_out, w], axis=1)
cell3 = tf.compat.v1.nn.rnn_cell.LSTMCell(self.lstm_size)
s3_out, s3_state = cell3(s3_in, state=(state.c3, state.h3))
new_state = LSTMAttentionCellState(
s1_state.h,
s1_state.c,
s2_state.h,
s2_state.c,
s3_state.h,
s3_state.c,
alpha_flat,
beta_flat,
kappa_flat,
w,
phi_flat,
)
return s3_out, new_state
def output_function(self, state):
params = dense_layer(state.h3, self.output_units, scope='gmm', reuse=tf.AUTO_REUSE)
pis, mus, sigmas, rhos, es = self._parse_parameters(params)
mu1, mu2 = tf.split(mus, 2, axis=1)
mus = tf.stack([mu1, mu2], axis=2)
sigma1, sigma2 = tf.split(sigmas, 2, axis=1)
covar_matrix = [tf.square(sigma1), rhos*sigma1*sigma2,
rhos*sigma1*sigma2, tf.square(sigma2)]
covar_matrix = tf.stack(covar_matrix, axis=2)
covar_matrix = tf.reshape(covar_matrix, (self.batch_size, self.num_output_mixture_components, 2, 2))
mvn = tfd.MultivariateNormalFullCovariance(loc=mus, covariance_matrix=covar_matrix)
b = tfd.Bernoulli(probs=es)
c = tfd.Categorical(probs=pis)
sampled_e = b.sample()
sampled_coords = mvn.sample()
sampled_idx = c.sample()
idx = tf.stack([tf.range(self.batch_size), sampled_idx], axis=1)
coords = tf.gather_nd(sampled_coords, idx)
return tf.concat([coords, tf.cast(sampled_e, tf.float32)], axis=1)
def termination_condition(self, state):
char_idx = tf.cast(tf.argmax(state.phi, axis=1), tf.int32)
final_char = char_idx >= self.attention_values_lengths - 1
past_final_char = char_idx >= self.attention_values_lengths
output = self.output_function(state)
es = tf.cast(output[:, 2], tf.int32)
is_eos = tf.equal(es, tf.ones_like(es))
return tf.logical_or(tf.logical_and(final_char, is_eos), past_final_char)
def _parse_parameters(self, gmm_params, eps=1e-8, sigma_eps=1e-4):
pis, sigmas, rhos, mus, es = tf.split(
gmm_params,
[
1*self.num_output_mixture_components,
2*self.num_output_mixture_components,
1*self.num_output_mixture_components,
2*self.num_output_mixture_components,
1
],
axis=-1
)
pis = pis*(1 + tf.expand_dims(self.bias, 1))
sigmas = sigmas - tf.expand_dims(self.bias, 1)
pis = tf.nn.softmax(pis, axis=-1)
pis = tf.where(pis < .01, tf.zeros_like(pis), pis)
sigmas = tf.clip_by_value(tf.exp(sigmas), sigma_eps, np.inf)
rhos = tf.clip_by_value(tf.tanh(rhos), eps - 1.0, 1.0 - eps)
es = tf.clip_by_value(tf.nn.sigmoid(es), eps, 1.0 - eps)
es = tf.where(es < .01, tf.zeros_like(es), es)
return pis, mus, sigmas, rhos, es