3morrrrr's picture
Upload 14 files
569596a verified
from __future__ import print_function
import os
import numpy as np
import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()
import drawing
from data_frame import DataFrame
from rnn_cell import LSTMAttentionCell
from rnn_ops import rnn_free_run
from tf_base_model import TFBaseModel
from tf_utils import time_distributed_dense_layer
class DataReader(object):
def __init__(self, data_dir):
data_cols = ['x', 'x_len', 'c', 'c_len']
data = [np.load(os.path.join(data_dir, '{}.npy'.format(i))) for i in data_cols]
self.test_df = DataFrame(columns=data_cols, data=data)
self.train_df, self.val_df = self.test_df.train_test_split(train_size=0.95, random_state=2018)
print('train size', len(self.train_df))
print('val size', len(self.val_df))
print('test size', len(self.test_df))
def train_batch_generator(self, batch_size):
return self.batch_generator(
batch_size=batch_size,
df=self.train_df,
shuffle=True,
num_epochs=10000,
mode='train'
)
def val_batch_generator(self, batch_size):
return self.batch_generator(
batch_size=batch_size,
df=self.val_df,
shuffle=True,
num_epochs=10000,
mode='val'
)
def test_batch_generator(self, batch_size):
return self.batch_generator(
batch_size=batch_size,
df=self.test_df,
shuffle=False,
num_epochs=1,
mode='test'
)
def batch_generator(self, batch_size, df, shuffle=True, num_epochs=10000, mode='train'):
gen = df.batch_generator(
batch_size=batch_size,
shuffle=shuffle,
num_epochs=num_epochs,
allow_smaller_final_batch=(mode == 'test')
)
for batch in gen:
batch['x_len'] = batch['x_len'] - 1
max_x_len = np.max(batch['x_len'])
max_c_len = np.max(batch['c_len'])
batch['y'] = batch['x'][:, 1:max_x_len + 1, :]
batch['x'] = batch['x'][:, :max_x_len, :]
batch['c'] = batch['c'][:, :max_c_len]
yield batch
class rnn(TFBaseModel):
def __init__(
self,
lstm_size,
output_mixture_components,
attention_mixture_components,
**kwargs
):
self.lstm_size = lstm_size
self.output_mixture_components = output_mixture_components
self.output_units = self.output_mixture_components*6 + 1
self.attention_mixture_components = attention_mixture_components
super(rnn, self).__init__(**kwargs)
def parse_parameters(self, z, eps=1e-8, sigma_eps=1e-4):
pis, sigmas, rhos, mus, es = tf.split(
z,
[
1*self.output_mixture_components,
2*self.output_mixture_components,
1*self.output_mixture_components,
2*self.output_mixture_components,
1
],
axis=-1
)
pis = tf.nn.softmax(pis, axis=-1)
sigmas = tf.clip_by_value(tf.exp(sigmas), sigma_eps, np.inf)
rhos = tf.clip_by_value(tf.tanh(rhos), eps - 1.0, 1.0 - eps)
es = tf.clip_by_value(tf.nn.sigmoid(es), eps, 1.0 - eps)
return pis, mus, sigmas, rhos, es
def NLL(self, y, lengths, pis, mus, sigmas, rho, es, eps=1e-8):
sigma_1, sigma_2 = tf.split(sigmas, 2, axis=2)
y_1, y_2, y_3 = tf.split(y, 3, axis=2)
mu_1, mu_2 = tf.split(mus, 2, axis=2)
norm = 1.0 / (2*np.pi*sigma_1*sigma_2 * tf.sqrt(1 - tf.square(rho)))
Z = tf.square((y_1 - mu_1) / (sigma_1)) + \
tf.square((y_2 - mu_2) / (sigma_2)) - \
2*rho*(y_1 - mu_1)*(y_2 - mu_2) / (sigma_1*sigma_2)
exp = -1.0*Z / (2*(1 - tf.square(rho)))
gaussian_likelihoods = tf.exp(exp) * norm
gmm_likelihood = tf.reduce_sum(pis * gaussian_likelihoods, 2)
gmm_likelihood = tf.clip_by_value(gmm_likelihood, eps, np.inf)
bernoulli_likelihood = tf.squeeze(tf.where(tf.equal(tf.ones_like(y_3), y_3), es, 1 - es))
nll = -(tf.log(gmm_likelihood) + tf.log(bernoulli_likelihood))
sequence_mask = tf.logical_and(
tf.sequence_mask(lengths, maxlen=tf.shape(y)[1]),
tf.logical_not(tf.is_nan(nll)),
)
nll = tf.where(sequence_mask, nll, tf.zeros_like(nll))
num_valid = tf.reduce_sum(tf.cast(sequence_mask, tf.float32), axis=1)
sequence_loss = tf.reduce_sum(nll, axis=1) / tf.maximum(num_valid, 1.0)
element_loss = tf.reduce_sum(nll) / tf.maximum(tf.reduce_sum(num_valid), 1.0)
return sequence_loss, element_loss
def sample(self, cell):
initial_state = cell.zero_state(self.num_samples, dtype=tf.float32)
initial_input = tf.concat([
tf.zeros([self.num_samples, 2]),
tf.ones([self.num_samples, 1]),
], axis=1)
return rnn_free_run(
cell=cell,
sequence_length=self.sample_tsteps,
initial_state=initial_state,
initial_input=initial_input,
scope='rnn'
)[1]
def primed_sample(self, cell):
initial_state = cell.zero_state(self.num_samples, dtype=tf.float32)
primed_state = tf.nn.dynamic_rnn(
inputs=self.x_prime,
cell=cell,
sequence_length=self.x_prime_len,
dtype=tf.float32,
initial_state=initial_state,
scope='rnn'
)[1]
return rnn_free_run(
cell=cell,
sequence_length=self.sample_tsteps,
initial_state=primed_state,
scope='rnn'
)[1]
def calculate_loss(self):
self.x = tf.placeholder(tf.float32, [None, None, 3])
self.y = tf.placeholder(tf.float32, [None, None, 3])
self.x_len = tf.placeholder(tf.int32, [None])
self.c = tf.placeholder(tf.int32, [None, None])
self.c_len = tf.placeholder(tf.int32, [None])
self.sample_tsteps = tf.placeholder(tf.int32, [])
self.num_samples = tf.placeholder(tf.int32, [])
self.prime = tf.placeholder(tf.bool, [])
self.x_prime = tf.placeholder(tf.float32, [None, None, 3])
self.x_prime_len = tf.placeholder(tf.int32, [None])
self.bias = tf.placeholder_with_default(
tf.zeros([self.num_samples], dtype=tf.float32), [None])
cell = LSTMAttentionCell(
lstm_size=self.lstm_size,
num_attn_mixture_components=self.attention_mixture_components,
attention_values=tf.one_hot(self.c, len(drawing.alphabet)),
attention_values_lengths=self.c_len,
num_output_mixture_components=self.output_mixture_components,
bias=self.bias
)
self.initial_state = cell.zero_state(tf.shape(self.x)[0], dtype=tf.float32)
outputs, self.final_state = tf.nn.dynamic_rnn(
inputs=self.x,
cell=cell,
sequence_length=self.x_len,
dtype=tf.float32,
initial_state=self.initial_state,
scope='rnn'
)
params = time_distributed_dense_layer(outputs, self.output_units, scope='rnn/gmm')
pis, mus, sigmas, rhos, es = self.parse_parameters(params)
sequence_loss, self.loss = self.NLL(self.y, self.x_len, pis, mus, sigmas, rhos, es)
self.sampled_sequence = tf.cond(
self.prime,
lambda: self.primed_sample(cell),
lambda: self.sample(cell)
)
return self.loss
if __name__ == '__main__':
dr = DataReader(data_dir='data/processed/')
nn = rnn(
reader=dr,
log_dir='logs',
checkpoint_dir='checkpoints',
prediction_dir='predictions',
learning_rates=[.0001, .00005, .00002],
batch_sizes=[32, 64, 64],
patiences=[1500, 1000, 500],
beta1_decays=[.9, .9, .9],
validation_batch_size=32,
optimizer='rms',
num_training_steps=100000,
warm_start_init_step=0,
regularization_constant=0.0,
keep_prob=1.0,
enable_parameter_averaging=False,
min_steps_to_checkpoint=2000,
log_interval=20,
grad_clip=10,
lstm_size=400,
output_mixture_components=20,
attention_mixture_components=10
)
nn.fit()