Spaces:
Running
Running
from __future__ import print_function | |
import os | |
import numpy as np | |
import tensorflow.compat.v1 as tf | |
tf.disable_v2_behavior() | |
import drawing | |
from data_frame import DataFrame | |
from rnn_cell import LSTMAttentionCell | |
from rnn_ops import rnn_free_run | |
from tf_base_model import TFBaseModel | |
from tf_utils import time_distributed_dense_layer | |
class DataReader(object): | |
def __init__(self, data_dir): | |
data_cols = ['x', 'x_len', 'c', 'c_len'] | |
data = [np.load(os.path.join(data_dir, '{}.npy'.format(i))) for i in data_cols] | |
self.test_df = DataFrame(columns=data_cols, data=data) | |
self.train_df, self.val_df = self.test_df.train_test_split(train_size=0.95, random_state=2018) | |
print('train size', len(self.train_df)) | |
print('val size', len(self.val_df)) | |
print('test size', len(self.test_df)) | |
def train_batch_generator(self, batch_size): | |
return self.batch_generator( | |
batch_size=batch_size, | |
df=self.train_df, | |
shuffle=True, | |
num_epochs=10000, | |
mode='train' | |
) | |
def val_batch_generator(self, batch_size): | |
return self.batch_generator( | |
batch_size=batch_size, | |
df=self.val_df, | |
shuffle=True, | |
num_epochs=10000, | |
mode='val' | |
) | |
def test_batch_generator(self, batch_size): | |
return self.batch_generator( | |
batch_size=batch_size, | |
df=self.test_df, | |
shuffle=False, | |
num_epochs=1, | |
mode='test' | |
) | |
def batch_generator(self, batch_size, df, shuffle=True, num_epochs=10000, mode='train'): | |
gen = df.batch_generator( | |
batch_size=batch_size, | |
shuffle=shuffle, | |
num_epochs=num_epochs, | |
allow_smaller_final_batch=(mode == 'test') | |
) | |
for batch in gen: | |
batch['x_len'] = batch['x_len'] - 1 | |
max_x_len = np.max(batch['x_len']) | |
max_c_len = np.max(batch['c_len']) | |
batch['y'] = batch['x'][:, 1:max_x_len + 1, :] | |
batch['x'] = batch['x'][:, :max_x_len, :] | |
batch['c'] = batch['c'][:, :max_c_len] | |
yield batch | |
class rnn(TFBaseModel): | |
def __init__( | |
self, | |
lstm_size, | |
output_mixture_components, | |
attention_mixture_components, | |
**kwargs | |
): | |
self.lstm_size = lstm_size | |
self.output_mixture_components = output_mixture_components | |
self.output_units = self.output_mixture_components*6 + 1 | |
self.attention_mixture_components = attention_mixture_components | |
super(rnn, self).__init__(**kwargs) | |
def parse_parameters(self, z, eps=1e-8, sigma_eps=1e-4): | |
pis, sigmas, rhos, mus, es = tf.split( | |
z, | |
[ | |
1*self.output_mixture_components, | |
2*self.output_mixture_components, | |
1*self.output_mixture_components, | |
2*self.output_mixture_components, | |
1 | |
], | |
axis=-1 | |
) | |
pis = tf.nn.softmax(pis, axis=-1) | |
sigmas = tf.clip_by_value(tf.exp(sigmas), sigma_eps, np.inf) | |
rhos = tf.clip_by_value(tf.tanh(rhos), eps - 1.0, 1.0 - eps) | |
es = tf.clip_by_value(tf.nn.sigmoid(es), eps, 1.0 - eps) | |
return pis, mus, sigmas, rhos, es | |
def NLL(self, y, lengths, pis, mus, sigmas, rho, es, eps=1e-8): | |
sigma_1, sigma_2 = tf.split(sigmas, 2, axis=2) | |
y_1, y_2, y_3 = tf.split(y, 3, axis=2) | |
mu_1, mu_2 = tf.split(mus, 2, axis=2) | |
norm = 1.0 / (2*np.pi*sigma_1*sigma_2 * tf.sqrt(1 - tf.square(rho))) | |
Z = tf.square((y_1 - mu_1) / (sigma_1)) + \ | |
tf.square((y_2 - mu_2) / (sigma_2)) - \ | |
2*rho*(y_1 - mu_1)*(y_2 - mu_2) / (sigma_1*sigma_2) | |
exp = -1.0*Z / (2*(1 - tf.square(rho))) | |
gaussian_likelihoods = tf.exp(exp) * norm | |
gmm_likelihood = tf.reduce_sum(pis * gaussian_likelihoods, 2) | |
gmm_likelihood = tf.clip_by_value(gmm_likelihood, eps, np.inf) | |
bernoulli_likelihood = tf.squeeze(tf.where(tf.equal(tf.ones_like(y_3), y_3), es, 1 - es)) | |
nll = -(tf.log(gmm_likelihood) + tf.log(bernoulli_likelihood)) | |
sequence_mask = tf.logical_and( | |
tf.sequence_mask(lengths, maxlen=tf.shape(y)[1]), | |
tf.logical_not(tf.is_nan(nll)), | |
) | |
nll = tf.where(sequence_mask, nll, tf.zeros_like(nll)) | |
num_valid = tf.reduce_sum(tf.cast(sequence_mask, tf.float32), axis=1) | |
sequence_loss = tf.reduce_sum(nll, axis=1) / tf.maximum(num_valid, 1.0) | |
element_loss = tf.reduce_sum(nll) / tf.maximum(tf.reduce_sum(num_valid), 1.0) | |
return sequence_loss, element_loss | |
def sample(self, cell): | |
initial_state = cell.zero_state(self.num_samples, dtype=tf.float32) | |
initial_input = tf.concat([ | |
tf.zeros([self.num_samples, 2]), | |
tf.ones([self.num_samples, 1]), | |
], axis=1) | |
return rnn_free_run( | |
cell=cell, | |
sequence_length=self.sample_tsteps, | |
initial_state=initial_state, | |
initial_input=initial_input, | |
scope='rnn' | |
)[1] | |
def primed_sample(self, cell): | |
initial_state = cell.zero_state(self.num_samples, dtype=tf.float32) | |
primed_state = tf.nn.dynamic_rnn( | |
inputs=self.x_prime, | |
cell=cell, | |
sequence_length=self.x_prime_len, | |
dtype=tf.float32, | |
initial_state=initial_state, | |
scope='rnn' | |
)[1] | |
return rnn_free_run( | |
cell=cell, | |
sequence_length=self.sample_tsteps, | |
initial_state=primed_state, | |
scope='rnn' | |
)[1] | |
def calculate_loss(self): | |
self.x = tf.placeholder(tf.float32, [None, None, 3]) | |
self.y = tf.placeholder(tf.float32, [None, None, 3]) | |
self.x_len = tf.placeholder(tf.int32, [None]) | |
self.c = tf.placeholder(tf.int32, [None, None]) | |
self.c_len = tf.placeholder(tf.int32, [None]) | |
self.sample_tsteps = tf.placeholder(tf.int32, []) | |
self.num_samples = tf.placeholder(tf.int32, []) | |
self.prime = tf.placeholder(tf.bool, []) | |
self.x_prime = tf.placeholder(tf.float32, [None, None, 3]) | |
self.x_prime_len = tf.placeholder(tf.int32, [None]) | |
self.bias = tf.placeholder_with_default( | |
tf.zeros([self.num_samples], dtype=tf.float32), [None]) | |
cell = LSTMAttentionCell( | |
lstm_size=self.lstm_size, | |
num_attn_mixture_components=self.attention_mixture_components, | |
attention_values=tf.one_hot(self.c, len(drawing.alphabet)), | |
attention_values_lengths=self.c_len, | |
num_output_mixture_components=self.output_mixture_components, | |
bias=self.bias | |
) | |
self.initial_state = cell.zero_state(tf.shape(self.x)[0], dtype=tf.float32) | |
outputs, self.final_state = tf.nn.dynamic_rnn( | |
inputs=self.x, | |
cell=cell, | |
sequence_length=self.x_len, | |
dtype=tf.float32, | |
initial_state=self.initial_state, | |
scope='rnn' | |
) | |
params = time_distributed_dense_layer(outputs, self.output_units, scope='rnn/gmm') | |
pis, mus, sigmas, rhos, es = self.parse_parameters(params) | |
sequence_loss, self.loss = self.NLL(self.y, self.x_len, pis, mus, sigmas, rhos, es) | |
self.sampled_sequence = tf.cond( | |
self.prime, | |
lambda: self.primed_sample(cell), | |
lambda: self.sample(cell) | |
) | |
return self.loss | |
if __name__ == '__main__': | |
dr = DataReader(data_dir='data/processed/') | |
nn = rnn( | |
reader=dr, | |
log_dir='logs', | |
checkpoint_dir='checkpoints', | |
prediction_dir='predictions', | |
learning_rates=[.0001, .00005, .00002], | |
batch_sizes=[32, 64, 64], | |
patiences=[1500, 1000, 500], | |
beta1_decays=[.9, .9, .9], | |
validation_batch_size=32, | |
optimizer='rms', | |
num_training_steps=100000, | |
warm_start_init_step=0, | |
regularization_constant=0.0, | |
keep_prob=1.0, | |
enable_parameter_averaging=False, | |
min_steps_to_checkpoint=2000, | |
log_interval=20, | |
grad_clip=10, | |
lstm_size=400, | |
output_mixture_components=20, | |
attention_mixture_components=10 | |
) | |
nn.fit() | |