NCTCMumbai's picture
Upload 2571 files
0b8359d
# Copyright 2017 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# ==============================================================================
"""
LFADS - Latent Factor Analysis via Dynamical Systems.
LFADS is an unsupervised method to decompose time series data into
various factors, such as an initial condition, a generative
dynamical system, control inputs to that generator, and a low
dimensional description of the observed data, called the factors.
Additionally, the observations have a noise model (in this case
Poisson), so a denoised version of the observations is also created
(e.g. underlying rates of a Poisson distribution given the observed
event counts).
The main data structure being passed around is a dataset. This is a dictionary
of data dictionaries.
DATASET: The top level dictionary is simply name (string -> dictionary).
The nested dictionary is the DATA DICTIONARY, which has the following keys:
'train_data' and 'valid_data', whose values are the corresponding training
and validation data with shape
ExTxD, E - # examples, T - # time steps, D - # dimensions in data.
The data dictionary also has a few more keys:
'train_ext_input' and 'valid_ext_input', if there are know external inputs
to the system being modeled, these take on dimensions:
ExTxI, E - # examples, T - # time steps, I = # dimensions in input.
'alignment_matrix_cxf' - If you are using multiple days data, it's possible
that one can align the channels (see manuscript). If so each dataset will
contain this matrix, which will be used for both the input adapter and the
output adapter for each dataset. These matrices, if provided, must be of
size [data_dim x factors] where data_dim is the number of neurons recorded
on that day, and factors is chosen and set through the '--factors' flag.
'alignment_bias_c' - See alignment_matrix_cxf. This bias will used to
the offset for the alignment transformation. It will *subtract* off the
bias from the data, so pca style inits can align factors across sessions.
If one runs LFADS on data where the true rates are known for some trials,
(say simulated, testing data, as in the example shipped with the paper), then
one can add three more fields for plotting purposes. These are 'train_truth'
and 'valid_truth', and 'conversion_factor'. These have the same dimensions as
'train_data', and 'valid_data' but represent the underlying rates of the
observations. Finally, if one needs to convert scale for plotting the true
underlying firing rates, there is the 'conversion_factor' key.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import os
import tensorflow as tf
from distributions import LearnableDiagonalGaussian, DiagonalGaussianFromInput
from distributions import diag_gaussian_log_likelihood
from distributions import KLCost_GaussianGaussian, Poisson
from distributions import LearnableAutoRegressive1Prior
from distributions import KLCost_GaussianGaussianProcessSampled
from utils import init_linear, linear, list_t_bxn_to_tensor_bxtxn, write_data
from utils import log_sum_exp, flatten
from plot_lfads import plot_lfads
class GRU(object):
"""Gated Recurrent Unit cell (cf. http://arxiv.org/abs/1406.1078).
"""
def __init__(self, num_units, forget_bias=1.0, weight_scale=1.0,
clip_value=np.inf, collections=None):
"""Create a GRU object.
Args:
num_units: Number of units in the GRU
forget_bias (optional): Hack to help learning.
weight_scale (optional): weights are scaled by ws/sqrt(#inputs), with
ws being the weight scale.
clip_value (optional): if the recurrent values grow above this value,
clip them.
collections (optional): List of additonal collections variables should
belong to.
"""
self._num_units = num_units
self._forget_bias = forget_bias
self._weight_scale = weight_scale
self._clip_value = clip_value
self._collections = collections
@property
def state_size(self):
return self._num_units
@property
def output_size(self):
return self._num_units
@property
def state_multiplier(self):
return 1
def output_from_state(self, state):
"""Return the output portion of the state."""
return state
def __call__(self, inputs, state, scope=None):
"""Gated recurrent unit (GRU) function.
Args:
inputs: A 2D batch x input_dim tensor of inputs.
state: The previous state from the last time step.
scope (optional): TF variable scope for defined GRU variables.
Returns:
A tuple (state, state), where state is the newly computed state at time t.
It is returned twice to respect an interface that works for LSTMs.
"""
x = inputs
h = state
if inputs is not None:
xh = tf.concat(axis=1, values=[x, h])
else:
xh = h
with tf.variable_scope(scope or type(self).__name__): # "GRU"
with tf.variable_scope("Gates"): # Reset gate and update gate.
# We start with bias of 1.0 to not reset and not update.
r, u = tf.split(axis=1, num_or_size_splits=2, value=linear(xh,
2 * self._num_units,
alpha=self._weight_scale,
name="xh_2_ru",
collections=self._collections))
r, u = tf.sigmoid(r), tf.sigmoid(u + self._forget_bias)
with tf.variable_scope("Candidate"):
xrh = tf.concat(axis=1, values=[x, r * h])
c = tf.tanh(linear(xrh, self._num_units, name="xrh_2_c",
collections=self._collections))
new_h = u * h + (1 - u) * c
new_h = tf.clip_by_value(new_h, -self._clip_value, self._clip_value)
return new_h, new_h
class GenGRU(object):
"""Gated Recurrent Unit cell (cf. http://arxiv.org/abs/1406.1078).
This version is specialized for the generator, but isn't as fast, so
we have two. Note this allows for l2 regularization on the recurrent
weights, but also implicitly rescales the inputs via the 1/sqrt(input)
scaling in the linear helper routine to be large magnitude, if there are
fewer inputs than recurrent state.
"""
def __init__(self, num_units, forget_bias=1.0,
input_weight_scale=1.0, rec_weight_scale=1.0, clip_value=np.inf,
input_collections=None, recurrent_collections=None):
"""Create a GRU object.
Args:
num_units: Number of units in the GRU
forget_bias (optional): Hack to help learning.
input_weight_scale (optional): weights are scaled ws/sqrt(#inputs), with
ws being the weight scale.
rec_weight_scale (optional): weights are scaled ws/sqrt(#inputs),
with ws being the weight scale.
clip_value (optional): if the recurrent values grow above this value,
clip them.
input_collections (optional): List of additonal collections variables
that input->rec weights should belong to.
recurrent_collections (optional): List of additonal collections variables
that rec->rec weights should belong to.
"""
self._num_units = num_units
self._forget_bias = forget_bias
self._input_weight_scale = input_weight_scale
self._rec_weight_scale = rec_weight_scale
self._clip_value = clip_value
self._input_collections = input_collections
self._rec_collections = recurrent_collections
@property
def state_size(self):
return self._num_units
@property
def output_size(self):
return self._num_units
@property
def state_multiplier(self):
return 1
def output_from_state(self, state):
"""Return the output portion of the state."""
return state
def __call__(self, inputs, state, scope=None):
"""Gated recurrent unit (GRU) function.
Args:
inputs: A 2D batch x input_dim tensor of inputs.
state: The previous state from the last time step.
scope (optional): TF variable scope for defined GRU variables.
Returns:
A tuple (state, state), where state is the newly computed state at time t.
It is returned twice to respect an interface that works for LSTMs.
"""
x = inputs
h = state
with tf.variable_scope(scope or type(self).__name__): # "GRU"
with tf.variable_scope("Gates"): # Reset gate and update gate.
# We start with bias of 1.0 to not reset and not update.
r_x = u_x = 0.0
if x is not None:
r_x, u_x = tf.split(axis=1, num_or_size_splits=2, value=linear(x,
2 * self._num_units,
alpha=self._input_weight_scale,
do_bias=False,
name="x_2_ru",
normalized=False,
collections=self._input_collections))
r_h, u_h = tf.split(axis=1, num_or_size_splits=2, value=linear(h,
2 * self._num_units,
do_bias=True,
alpha=self._rec_weight_scale,
name="h_2_ru",
collections=self._rec_collections))
r = r_x + r_h
u = u_x + u_h
r, u = tf.sigmoid(r), tf.sigmoid(u + self._forget_bias)
with tf.variable_scope("Candidate"):
c_x = 0.0
if x is not None:
c_x = linear(x, self._num_units, name="x_2_c", do_bias=False,
alpha=self._input_weight_scale,
normalized=False,
collections=self._input_collections)
c_rh = linear(r*h, self._num_units, name="rh_2_c", do_bias=True,
alpha=self._rec_weight_scale,
collections=self._rec_collections)
c = tf.tanh(c_x + c_rh)
new_h = u * h + (1 - u) * c
new_h = tf.clip_by_value(new_h, -self._clip_value, self._clip_value)
return new_h, new_h
class LFADS(object):
"""LFADS - Latent Factor Analysis via Dynamical Systems.
LFADS is an unsupervised method to decompose time series data into
various factors, such as an initial condition, a generative
dynamical system, inferred inputs to that generator, and a low
dimensional description of the observed data, called the factors.
Additoinally, the observations have a noise model (in this case
Poisson), so a denoised version of the observations is also created
(e.g. underlying rates of a Poisson distribution given the observed
event counts).
"""
def __init__(self, hps, kind="train", datasets=None):
"""Create an LFADS model.
train - a model for training, sampling of posteriors is used
posterior_sample_and_average - sample from the posterior, this is used
for evaluating the expected value of the outputs of LFADS, given a
specific input, by averaging over multiple samples from the approx
posterior. Also used for the lower bound on the negative
log-likelihood using IWAE error (Importance Weighed Auto-encoder).
This is the denoising operation.
prior_sample - a model for generation - sampling from priors is used
Args:
hps: The dictionary of hyper parameters.
kind: the type of model to build (see above).
datasets: a dictionary of named data_dictionaries, see top of lfads.py
"""
print("Building graph...")
all_kinds = ['train', 'posterior_sample_and_average', 'posterior_push_mean',
'prior_sample']
assert kind in all_kinds, 'Wrong kind'
if hps.feedback_factors_or_rates == "rates":
assert len(hps.dataset_names) == 1, \
"Multiple datasets not supported for rate feedback."
num_steps = hps.num_steps
ic_dim = hps.ic_dim
co_dim = hps.co_dim
ext_input_dim = hps.ext_input_dim
cell_class = GRU
gen_cell_class = GenGRU
def makelambda(v): # Used with tf.case
return lambda: v
# Define the data placeholder, and deal with all parts of the graph
# that are dataset dependent.
self.dataName = tf.placeholder(tf.string, shape=())
# The batch_size to be inferred from data, as normal.
# Additionally, the data_dim will be inferred as well, allowing for a
# single placeholder for all datasets, regardless of data dimension.
if hps.output_dist == 'poisson':
# Enforce correct dtype
assert np.issubdtype(
datasets[hps.dataset_names[0]]['train_data'].dtype, int), \
"Data dtype must be int for poisson output distribution"
data_dtype = tf.int32
elif hps.output_dist == 'gaussian':
assert np.issubdtype(
datasets[hps.dataset_names[0]]['train_data'].dtype, float), \
"Data dtype must be float for gaussian output dsitribution"
data_dtype = tf.float32
else:
assert False, "NIY"
self.dataset_ph = dataset_ph = tf.placeholder(data_dtype,
[None, num_steps, None],
name="data")
self.train_step = tf.get_variable("global_step", [], tf.int64,
tf.zeros_initializer(),
trainable=False)
self.hps = hps
ndatasets = hps.ndatasets
factors_dim = hps.factors_dim
self.preds = preds = [None] * ndatasets
self.fns_in_fac_Ws = fns_in_fac_Ws = [None] * ndatasets
self.fns_in_fatcor_bs = fns_in_fac_bs = [None] * ndatasets
self.fns_out_fac_Ws = fns_out_fac_Ws = [None] * ndatasets
self.fns_out_fac_bs = fns_out_fac_bs = [None] * ndatasets
self.datasetNames = dataset_names = hps.dataset_names
self.ext_inputs = ext_inputs = None
if len(dataset_names) == 1: # single session
if 'alignment_matrix_cxf' in datasets[dataset_names[0]].keys():
used_in_factors_dim = factors_dim
in_identity_if_poss = False
else:
used_in_factors_dim = hps.dataset_dims[dataset_names[0]]
in_identity_if_poss = True
else: # multisession
used_in_factors_dim = factors_dim
in_identity_if_poss = False
for d, name in enumerate(dataset_names):
data_dim = hps.dataset_dims[name]
in_mat_cxf = None
in_bias_1xf = None
align_bias_1xc = None
if datasets and 'alignment_matrix_cxf' in datasets[name].keys():
dataset = datasets[name]
if hps.do_train_readin:
print("Initializing trainable readin matrix with alignment matrix" \
" provided for dataset:", name)
else:
print("Setting non-trainable readin matrix to alignment matrix" \
" provided for dataset:", name)
in_mat_cxf = dataset['alignment_matrix_cxf'].astype(np.float32)
if in_mat_cxf.shape != (data_dim, factors_dim):
raise ValueError("""Alignment matrix must have dimensions %d x %d
(data_dim x factors_dim), but currently has %d x %d."""%
(data_dim, factors_dim, in_mat_cxf.shape[0],
in_mat_cxf.shape[1]))
if datasets and 'alignment_bias_c' in datasets[name].keys():
dataset = datasets[name]
if hps.do_train_readin:
print("Initializing trainable readin bias with alignment bias " \
"provided for dataset:", name)
else:
print("Setting non-trainable readin bias to alignment bias " \
"provided for dataset:", name)
align_bias_c = dataset['alignment_bias_c'].astype(np.float32)
align_bias_1xc = np.expand_dims(align_bias_c, axis=0)
if align_bias_1xc.shape[1] != data_dim:
raise ValueError("""Alignment bias must have dimensions %d
(data_dim), but currently has %d."""%
(data_dim, in_mat_cxf.shape[0]))
if in_mat_cxf is not None and align_bias_1xc is not None:
# (data - alignment_bias) * W_in
# data * W_in - alignment_bias * W_in
# So b = -alignment_bias * W_in to accommodate PCA style offset.
in_bias_1xf = -np.dot(align_bias_1xc, in_mat_cxf)
if hps.do_train_readin:
# only add to IO transformations collection only if we want it to be
# learnable, because IO_transformations collection will be trained
# when do_train_io_only
collections_readin=['IO_transformations']
else:
collections_readin=None
in_fac_lin = init_linear(data_dim, used_in_factors_dim,
do_bias=True,
mat_init_value=in_mat_cxf,
bias_init_value=in_bias_1xf,
identity_if_possible=in_identity_if_poss,
normalized=False, name="x_2_infac_"+name,
collections=collections_readin,
trainable=hps.do_train_readin)
in_fac_W, in_fac_b = in_fac_lin
fns_in_fac_Ws[d] = makelambda(in_fac_W)
fns_in_fac_bs[d] = makelambda(in_fac_b)
with tf.variable_scope("glm"):
out_identity_if_poss = False
if len(dataset_names) == 1 and \
factors_dim == hps.dataset_dims[dataset_names[0]]:
out_identity_if_poss = True
for d, name in enumerate(dataset_names):
data_dim = hps.dataset_dims[name]
in_mat_cxf = None
if datasets and 'alignment_matrix_cxf' in datasets[name].keys():
dataset = datasets[name]
in_mat_cxf = dataset['alignment_matrix_cxf'].astype(np.float32)
if datasets and 'alignment_bias_c' in datasets[name].keys():
dataset = datasets[name]
align_bias_c = dataset['alignment_bias_c'].astype(np.float32)
align_bias_1xc = np.expand_dims(align_bias_c, axis=0)
out_mat_fxc = None
out_bias_1xc = None
if in_mat_cxf is not None:
out_mat_fxc = in_mat_cxf.T
if align_bias_1xc is not None:
out_bias_1xc = align_bias_1xc
if hps.output_dist == 'poisson':
out_fac_lin = init_linear(factors_dim, data_dim, do_bias=True,
mat_init_value=out_mat_fxc,
bias_init_value=out_bias_1xc,
identity_if_possible=out_identity_if_poss,
normalized=False,
name="fac_2_logrates_"+name,
collections=['IO_transformations'])
out_fac_W, out_fac_b = out_fac_lin
elif hps.output_dist == 'gaussian':
out_fac_lin_mean = \
init_linear(factors_dim, data_dim, do_bias=True,
mat_init_value=out_mat_fxc,
bias_init_value=out_bias_1xc,
normalized=False,
name="fac_2_means_"+name,
collections=['IO_transformations'])
out_fac_W_mean, out_fac_b_mean = out_fac_lin_mean
mat_init_value = np.zeros([factors_dim, data_dim]).astype(np.float32)
bias_init_value = np.ones([1, data_dim]).astype(np.float32)
out_fac_lin_logvar = \
init_linear(factors_dim, data_dim, do_bias=True,
mat_init_value=mat_init_value,
bias_init_value=bias_init_value,
normalized=False,
name="fac_2_logvars_"+name,
collections=['IO_transformations'])
out_fac_W_mean, out_fac_b_mean = out_fac_lin_mean
out_fac_W_logvar, out_fac_b_logvar = out_fac_lin_logvar
out_fac_W = tf.concat(
axis=1, values=[out_fac_W_mean, out_fac_W_logvar])
out_fac_b = tf.concat(
axis=1, values=[out_fac_b_mean, out_fac_b_logvar])
else:
assert False, "NIY"
preds[d] = tf.equal(tf.constant(name), self.dataName)
data_dim = hps.dataset_dims[name]
fns_out_fac_Ws[d] = makelambda(out_fac_W)
fns_out_fac_bs[d] = makelambda(out_fac_b)
pf_pairs_in_fac_Ws = zip(preds, fns_in_fac_Ws)
pf_pairs_in_fac_bs = zip(preds, fns_in_fac_bs)
pf_pairs_out_fac_Ws = zip(preds, fns_out_fac_Ws)
pf_pairs_out_fac_bs = zip(preds, fns_out_fac_bs)
this_in_fac_W = tf.case(pf_pairs_in_fac_Ws, exclusive=True)
this_in_fac_b = tf.case(pf_pairs_in_fac_bs, exclusive=True)
this_out_fac_W = tf.case(pf_pairs_out_fac_Ws, exclusive=True)
this_out_fac_b = tf.case(pf_pairs_out_fac_bs, exclusive=True)
# External inputs (not changing by dataset, by definition).
if hps.ext_input_dim > 0:
self.ext_input = tf.placeholder(tf.float32,
[None, num_steps, ext_input_dim],
name="ext_input")
else:
self.ext_input = None
ext_input_bxtxi = self.ext_input
self.keep_prob = keep_prob = tf.placeholder(tf.float32, [], "keep_prob")
self.batch_size = batch_size = int(hps.batch_size)
self.learning_rate = tf.Variable(float(hps.learning_rate_init),
trainable=False, name="learning_rate")
self.learning_rate_decay_op = self.learning_rate.assign(
self.learning_rate * hps.learning_rate_decay_factor)
# Dropout the data.
dataset_do_bxtxd = tf.nn.dropout(tf.to_float(dataset_ph), keep_prob)
if hps.ext_input_dim > 0:
ext_input_do_bxtxi = tf.nn.dropout(ext_input_bxtxi, keep_prob)
else:
ext_input_do_bxtxi = None
# ENCODERS
def encode_data(dataset_bxtxd, enc_cell, name, forward_or_reverse,
num_steps_to_encode):
"""Encode data for LFADS
Args:
dataset_bxtxd - the data to encode, as a 3 tensor, with dims
time x batch x data dims.
enc_cell: encoder cell
name: name of encoder
forward_or_reverse: string, encode in forward or reverse direction
num_steps_to_encode: number of steps to encode, 0:num_steps_to_encode
Returns:
encoded data as a list with num_steps_to_encode items, in order
"""
if forward_or_reverse == "forward":
dstr = "_fwd"
time_fwd_or_rev = range(num_steps_to_encode)
else:
dstr = "_rev"
time_fwd_or_rev = reversed(range(num_steps_to_encode))
with tf.variable_scope(name+"_enc"+dstr, reuse=False):
enc_state = tf.tile(
tf.Variable(tf.zeros([1, enc_cell.state_size]),
name=name+"_enc_t0"+dstr), tf.stack([batch_size, 1]))
enc_state.set_shape([None, enc_cell.state_size]) # tile loses shape
enc_outs = [None] * num_steps_to_encode
for i, t in enumerate(time_fwd_or_rev):
with tf.variable_scope(name+"_enc"+dstr, reuse=True if i > 0 else None):
dataset_t_bxd = dataset_bxtxd[:,t,:]
in_fac_t_bxf = tf.matmul(dataset_t_bxd, this_in_fac_W) + this_in_fac_b
in_fac_t_bxf.set_shape([None, used_in_factors_dim])
if ext_input_dim > 0 and not hps.inject_ext_input_to_gen:
ext_input_t_bxi = ext_input_do_bxtxi[:,t,:]
enc_input_t_bxfpe = tf.concat(
axis=1, values=[in_fac_t_bxf, ext_input_t_bxi])
else:
enc_input_t_bxfpe = in_fac_t_bxf
enc_out, enc_state = enc_cell(enc_input_t_bxfpe, enc_state)
enc_outs[t] = enc_out
return enc_outs
# Encode initial condition means and variances
# ([x_T, x_T-1, ... x_0] and [x_0, x_1, ... x_T] -> g0/c0)
self.ic_enc_fwd = [None] * num_steps
self.ic_enc_rev = [None] * num_steps
if ic_dim > 0:
enc_ic_cell = cell_class(hps.ic_enc_dim,
weight_scale=hps.cell_weight_scale,
clip_value=hps.cell_clip_value)
ic_enc_fwd = encode_data(dataset_do_bxtxd, enc_ic_cell,
"ic", "forward",
hps.num_steps_for_gen_ic)
ic_enc_rev = encode_data(dataset_do_bxtxd, enc_ic_cell,
"ic", "reverse",
hps.num_steps_for_gen_ic)
self.ic_enc_fwd = ic_enc_fwd
self.ic_enc_rev = ic_enc_rev
# Encoder control input means and variances, bi-directional encoding so:
# ([x_T, x_T-1, ..., x_0] and [x_0, x_1 ... x_T] -> u_t)
self.ci_enc_fwd = [None] * num_steps
self.ci_enc_rev = [None] * num_steps
if co_dim > 0:
enc_ci_cell = cell_class(hps.ci_enc_dim,
weight_scale=hps.cell_weight_scale,
clip_value=hps.cell_clip_value)
ci_enc_fwd = encode_data(dataset_do_bxtxd, enc_ci_cell,
"ci", "forward",
hps.num_steps)
if hps.do_causal_controller:
ci_enc_rev = None
else:
ci_enc_rev = encode_data(dataset_do_bxtxd, enc_ci_cell,
"ci", "reverse",
hps.num_steps)
self.ci_enc_fwd = ci_enc_fwd
self.ci_enc_rev = ci_enc_rev
# STOCHASTIC LATENT VARIABLES, priors and posteriors
# (initial conditions g0, and control inputs, u_t)
# Note that zs represent all the stochastic latent variables.
with tf.variable_scope("z", reuse=False):
self.prior_zs_g0 = None
self.posterior_zs_g0 = None
self.g0s_val = None
if ic_dim > 0:
self.prior_zs_g0 = \
LearnableDiagonalGaussian(batch_size, ic_dim, name="prior_g0",
mean_init=0.0,
var_min=hps.ic_prior_var_min,
var_init=hps.ic_prior_var_scale,
var_max=hps.ic_prior_var_max)
ic_enc = tf.concat(axis=1, values=[ic_enc_fwd[-1], ic_enc_rev[0]])
ic_enc = tf.nn.dropout(ic_enc, keep_prob)
self.posterior_zs_g0 = \
DiagonalGaussianFromInput(ic_enc, ic_dim, "ic_enc_2_post_g0",
var_min=hps.ic_post_var_min)
if kind in ["train", "posterior_sample_and_average",
"posterior_push_mean"]:
zs_g0 = self.posterior_zs_g0
else:
zs_g0 = self.prior_zs_g0
if kind in ["train", "posterior_sample_and_average", "prior_sample"]:
self.g0s_val = zs_g0.sample
else:
self.g0s_val = zs_g0.mean
# Priors for controller, 'co' for controller output
self.prior_zs_co = prior_zs_co = [None] * num_steps
self.posterior_zs_co = posterior_zs_co = [None] * num_steps
self.zs_co = zs_co = [None] * num_steps
self.prior_zs_ar_con = None
if co_dim > 0:
# Controller outputs
autocorrelation_taus = [hps.prior_ar_atau for x in range(hps.co_dim)]
noise_variances = [hps.prior_ar_nvar for x in range(hps.co_dim)]
self.prior_zs_ar_con = prior_zs_ar_con = \
LearnableAutoRegressive1Prior(batch_size, hps.co_dim,
autocorrelation_taus,
noise_variances,
hps.do_train_prior_ar_atau,
hps.do_train_prior_ar_nvar,
num_steps, "u_prior_ar1")
# CONTROLLER -> GENERATOR -> RATES
# (u(t) -> gen(t) -> factors(t) -> rates(t) -> p(x_t|z_t) )
self.controller_outputs = u_t = [None] * num_steps
self.con_ics = con_state = None
self.con_states = con_states = [None] * num_steps
self.con_outs = con_outs = [None] * num_steps
self.gen_inputs = gen_inputs = [None] * num_steps
if co_dim > 0:
# gen_cell_class here for l2 penalty recurrent weights
# didn't split the cell_weight scale here, because I doubt it matters
con_cell = gen_cell_class(hps.con_dim,
input_weight_scale=hps.cell_weight_scale,
rec_weight_scale=hps.cell_weight_scale,
clip_value=hps.cell_clip_value,
recurrent_collections=['l2_con_reg'])
with tf.variable_scope("con", reuse=False):
self.con_ics = tf.tile(
tf.Variable(tf.zeros([1, hps.con_dim*con_cell.state_multiplier]),
name="c0"),
tf.stack([batch_size, 1]))
self.con_ics.set_shape([None, con_cell.state_size]) # tile loses shape
con_states[-1] = self.con_ics
gen_cell = gen_cell_class(hps.gen_dim,
input_weight_scale=hps.gen_cell_input_weight_scale,
rec_weight_scale=hps.gen_cell_rec_weight_scale,
clip_value=hps.cell_clip_value,
recurrent_collections=['l2_gen_reg'])
with tf.variable_scope("gen", reuse=False):
if ic_dim == 0:
self.gen_ics = tf.tile(
tf.Variable(tf.zeros([1, gen_cell.state_size]), name="g0"),
tf.stack([batch_size, 1]))
else:
self.gen_ics = linear(self.g0s_val, gen_cell.state_size,
identity_if_possible=True,
name="g0_2_gen_ic")
self.gen_states = gen_states = [None] * num_steps
self.gen_outs = gen_outs = [None] * num_steps
gen_states[-1] = self.gen_ics
gen_outs[-1] = gen_cell.output_from_state(gen_states[-1])
self.factors = factors = [None] * num_steps
factors[-1] = linear(gen_outs[-1], factors_dim, do_bias=False,
normalized=True, name="gen_2_fac")
self.rates = rates = [None] * num_steps
# rates[-1] is collected to potentially feed back to controller
with tf.variable_scope("glm", reuse=False):
if hps.output_dist == 'poisson':
log_rates_t0 = tf.matmul(factors[-1], this_out_fac_W) + this_out_fac_b
log_rates_t0.set_shape([None, None])
rates[-1] = tf.exp(log_rates_t0) # rate
rates[-1].set_shape([None, hps.dataset_dims[hps.dataset_names[0]]])
elif hps.output_dist == 'gaussian':
mean_n_logvars = tf.matmul(factors[-1],this_out_fac_W) + this_out_fac_b
mean_n_logvars.set_shape([None, None])
means_t_bxd, logvars_t_bxd = tf.split(axis=1, num_or_size_splits=2,
value=mean_n_logvars)
rates[-1] = means_t_bxd
else:
assert False, "NIY"
# We support multiple output distributions, for example Poisson, and also
# Gaussian. In these two cases respectively, there are one and two
# parameters (rates vs. mean and variance). So the output_dist_params
# tensor will variable sizes via tf.concat and tf.split, along the 1st
# dimension. So in the case of gaussian, for example, it'll be
# batch x (D+D), where each D dims is the mean, and then variances,
# respectively. For a distribution with 3 parameters, it would be
# batch x (D+D+D).
self.output_dist_params = dist_params = [None] * num_steps
self.log_p_xgz_b = log_p_xgz_b = 0.0 # log P(x|z)
for t in range(num_steps):
# Controller
if co_dim > 0:
# Build inputs for controller
tlag = t - hps.controller_input_lag
if tlag < 0:
con_in_f_t = tf.zeros_like(ci_enc_fwd[0])
else:
con_in_f_t = ci_enc_fwd[tlag]
if hps.do_causal_controller:
# If controller is causal (wrt to data generation process), then it
# cannot see future data. Thus, excluding ci_enc_rev[t] is obvious.
# Less obvious is the need to exclude factors[t-1]. This arises
# because information flows from g0 through factors to the controller
# input. The g0 encoding is backwards, so we must necessarily exclude
# the factors in order to keep the controller input purely from a
# forward encoding (however unlikely it is that
# g0->factors->controller channel might actually be used in this way).
con_in_list_t = [con_in_f_t]
else:
tlag_rev = t + hps.controller_input_lag
if tlag_rev >= num_steps:
# better than zeros
con_in_r_t = tf.zeros_like(ci_enc_rev[0])
else:
con_in_r_t = ci_enc_rev[tlag_rev]
con_in_list_t = [con_in_f_t, con_in_r_t]
if hps.do_feed_factors_to_controller:
if hps.feedback_factors_or_rates == "factors":
con_in_list_t.append(factors[t-1])
elif hps.feedback_factors_or_rates == "rates":
con_in_list_t.append(rates[t-1])
else:
assert False, "NIY"
con_in_t = tf.concat(axis=1, values=con_in_list_t)
con_in_t = tf.nn.dropout(con_in_t, keep_prob)
with tf.variable_scope("con", reuse=True if t > 0 else None):
con_outs[t], con_states[t] = con_cell(con_in_t, con_states[t-1])
posterior_zs_co[t] = \
DiagonalGaussianFromInput(con_outs[t], co_dim,
name="con_to_post_co")
if kind == "train":
u_t[t] = posterior_zs_co[t].sample
elif kind == "posterior_sample_and_average":
u_t[t] = posterior_zs_co[t].sample
elif kind == "posterior_push_mean":
u_t[t] = posterior_zs_co[t].mean
else:
u_t[t] = prior_zs_ar_con.samples_t[t]
# Inputs to the generator (controller output + external input)
if ext_input_dim > 0 and hps.inject_ext_input_to_gen:
ext_input_t_bxi = ext_input_do_bxtxi[:,t,:]
if co_dim > 0:
gen_inputs[t] = tf.concat(axis=1, values=[u_t[t], ext_input_t_bxi])
else:
gen_inputs[t] = ext_input_t_bxi
else:
gen_inputs[t] = u_t[t]
# Generator
data_t_bxd = dataset_ph[:,t,:]
with tf.variable_scope("gen", reuse=True if t > 0 else None):
gen_outs[t], gen_states[t] = gen_cell(gen_inputs[t], gen_states[t-1])
gen_outs[t] = tf.nn.dropout(gen_outs[t], keep_prob)
with tf.variable_scope("gen", reuse=True): # ic defined it above
factors[t] = linear(gen_outs[t], factors_dim, do_bias=False,
normalized=True, name="gen_2_fac")
with tf.variable_scope("glm", reuse=True if t > 0 else None):
if hps.output_dist == 'poisson':
log_rates_t = tf.matmul(factors[t], this_out_fac_W) + this_out_fac_b
log_rates_t.set_shape([None, None])
rates[t] = dist_params[t] = tf.exp(tf.clip_by_value(log_rates_t, -hps._clip_value, hps._clip_value)) # rates feed back
rates[t].set_shape([None, hps.dataset_dims[hps.dataset_names[0]]])
loglikelihood_t = Poisson(log_rates_t).logp(data_t_bxd)
elif hps.output_dist == 'gaussian':
mean_n_logvars = tf.matmul(factors[t],this_out_fac_W) + this_out_fac_b
mean_n_logvars.set_shape([None, None])
means_t_bxd, logvars_t_bxd = tf.split(axis=1, num_or_size_splits=2,
value=mean_n_logvars)
rates[t] = means_t_bxd # rates feed back to controller
dist_params[t] = tf.concat(
axis=1, values=[means_t_bxd, tf.exp(tf.clip_by_value(logvars_t_bxd, -hps._clip_value, hps._clip_value))])
loglikelihood_t = \
diag_gaussian_log_likelihood(data_t_bxd,
means_t_bxd, logvars_t_bxd)
else:
assert False, "NIY"
log_p_xgz_b += tf.reduce_sum(loglikelihood_t, [1])
# Correlation of inferred inputs cost.
self.corr_cost = tf.constant(0.0)
if hps.co_mean_corr_scale > 0.0:
all_sum_corr = []
for i in range(hps.co_dim):
for j in range(i+1, hps.co_dim):
sum_corr_ij = tf.constant(0.0)
for t in range(num_steps):
u_mean_t = posterior_zs_co[t].mean
sum_corr_ij += u_mean_t[:,i]*u_mean_t[:,j]
all_sum_corr.append(0.5 * tf.square(sum_corr_ij))
self.corr_cost = tf.reduce_mean(all_sum_corr) # div by batch and by n*(n-1)/2 pairs
# Variational Lower Bound on posterior, p(z|x), plus reconstruction cost.
# KL and reconstruction costs are normalized only by batch size, not by
# dimension, or by time steps.
kl_cost_g0_b = tf.zeros_like(batch_size, dtype=tf.float32)
kl_cost_co_b = tf.zeros_like(batch_size, dtype=tf.float32)
self.kl_cost = tf.constant(0.0) # VAE KL cost
self.recon_cost = tf.constant(0.0) # VAE reconstruction cost
self.nll_bound_vae = tf.constant(0.0)
self.nll_bound_iwae = tf.constant(0.0) # for eval with IWAE cost.
if kind in ["train", "posterior_sample_and_average", "posterior_push_mean"]:
kl_cost_g0_b = 0.0
kl_cost_co_b = 0.0
if ic_dim > 0:
g0_priors = [self.prior_zs_g0]
g0_posts = [self.posterior_zs_g0]
kl_cost_g0_b = KLCost_GaussianGaussian(g0_posts, g0_priors).kl_cost_b
kl_cost_g0_b = hps.kl_ic_weight * kl_cost_g0_b
if co_dim > 0:
kl_cost_co_b = \
KLCost_GaussianGaussianProcessSampled(
posterior_zs_co, prior_zs_ar_con).kl_cost_b
kl_cost_co_b = hps.kl_co_weight * kl_cost_co_b
# L = -KL + log p(x|z), to maximize bound on likelihood
# -L = KL - log p(x|z), to minimize bound on NLL
# so 'reconstruction cost' is negative log likelihood
self.recon_cost = - tf.reduce_mean(log_p_xgz_b)
self.kl_cost = tf.reduce_mean(kl_cost_g0_b + kl_cost_co_b)
lb_on_ll_b = log_p_xgz_b - kl_cost_g0_b - kl_cost_co_b
# VAE error averages outside the log
self.nll_bound_vae = -tf.reduce_mean(lb_on_ll_b)
# IWAE error averages inside the log
k = tf.cast(tf.shape(log_p_xgz_b)[0], tf.float32)
iwae_lb_on_ll = -tf.log(k) + log_sum_exp(lb_on_ll_b)
self.nll_bound_iwae = -iwae_lb_on_ll
# L2 regularization on the generator, normalized by number of parameters.
self.l2_cost = tf.constant(0.0)
if self.hps.l2_gen_scale > 0.0 or self.hps.l2_con_scale > 0.0:
l2_costs = []
l2_numels = []
l2_reg_var_lists = [tf.get_collection('l2_gen_reg'),
tf.get_collection('l2_con_reg')]
l2_reg_scales = [self.hps.l2_gen_scale, self.hps.l2_con_scale]
for l2_reg_vars, l2_scale in zip(l2_reg_var_lists, l2_reg_scales):
for v in l2_reg_vars:
numel = tf.reduce_prod(tf.concat(axis=0, values=tf.shape(v)))
numel_f = tf.cast(numel, tf.float32)
l2_numels.append(numel_f)
v_l2 = tf.reduce_sum(v*v)
l2_costs.append(0.5 * l2_scale * v_l2)
self.l2_cost = tf.add_n(l2_costs) / tf.add_n(l2_numels)
# Compute the cost for training, part of the graph regardless.
# The KL cost can be problematic at the beginning of optimization,
# so we allow an exponential increase in weighting the KL from 0
# to 1.
self.kl_decay_step = tf.maximum(self.train_step - hps.kl_start_step, 0)
self.l2_decay_step = tf.maximum(self.train_step - hps.l2_start_step, 0)
kl_decay_step_f = tf.cast(self.kl_decay_step, tf.float32)
l2_decay_step_f = tf.cast(self.l2_decay_step, tf.float32)
kl_increase_steps_f = tf.cast(hps.kl_increase_steps, tf.float32)
l2_increase_steps_f = tf.cast(hps.l2_increase_steps, tf.float32)
self.kl_weight = kl_weight = \
tf.minimum(kl_decay_step_f / kl_increase_steps_f, 1.0)
self.l2_weight = l2_weight = \
tf.minimum(l2_decay_step_f / l2_increase_steps_f, 1.0)
self.timed_kl_cost = kl_weight * self.kl_cost
self.timed_l2_cost = l2_weight * self.l2_cost
self.weight_corr_cost = hps.co_mean_corr_scale * self.corr_cost
self.cost = self.recon_cost + self.timed_kl_cost + \
self.timed_l2_cost + self.weight_corr_cost
if kind != "train":
# save every so often
self.seso_saver = tf.train.Saver(tf.global_variables(),
max_to_keep=hps.max_ckpt_to_keep)
# lowest validation error
self.lve_saver = tf.train.Saver(tf.global_variables(),
max_to_keep=hps.max_ckpt_to_keep_lve)
return
# OPTIMIZATION
# train the io matrices only
if self.hps.do_train_io_only:
self.train_vars = tvars = \
tf.get_collection('IO_transformations',
scope=tf.get_variable_scope().name)
# train the encoder only
elif self.hps.do_train_encoder_only:
tvars1 = \
tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
scope='LFADS/ic_enc_*')
tvars2 = \
tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
scope='LFADS/z/ic_enc_*')
self.train_vars = tvars = tvars1 + tvars2
# train all variables
else:
self.train_vars = tvars = \
tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
scope=tf.get_variable_scope().name)
print("done.")
print("Model Variables (to be optimized): ")
total_params = 0
for i in range(len(tvars)):
shape = tvars[i].get_shape().as_list()
print(" ", i, tvars[i].name, shape)
total_params += np.prod(shape)
print("Total model parameters: ", total_params)
grads = tf.gradients(self.cost, tvars)
grads, grad_global_norm = tf.clip_by_global_norm(grads, hps.max_grad_norm)
opt = tf.train.AdamOptimizer(self.learning_rate, beta1=0.9, beta2=0.999,
epsilon=1e-01)
self.grads = grads
self.grad_global_norm = grad_global_norm
self.train_op = opt.apply_gradients(
zip(grads, tvars), global_step=self.train_step)
self.seso_saver = tf.train.Saver(tf.global_variables(),
max_to_keep=hps.max_ckpt_to_keep)
# lowest validation error
self.lve_saver = tf.train.Saver(tf.global_variables(),
max_to_keep=hps.max_ckpt_to_keep)
# SUMMARIES, used only during training.
# example summary
self.example_image = tf.placeholder(tf.float32, shape=[1,None,None,3],
name='image_tensor')
self.example_summ = tf.summary.image("LFADS example", self.example_image,
collections=["example_summaries"])
# general training summaries
self.lr_summ = tf.summary.scalar("Learning rate", self.learning_rate)
self.kl_weight_summ = tf.summary.scalar("KL weight", self.kl_weight)
self.l2_weight_summ = tf.summary.scalar("L2 weight", self.l2_weight)
self.corr_cost_summ = tf.summary.scalar("Corr cost", self.weight_corr_cost)
self.grad_global_norm_summ = tf.summary.scalar("Gradient global norm",
self.grad_global_norm)
if hps.co_dim > 0:
self.atau_summ = [None] * hps.co_dim
self.pvar_summ = [None] * hps.co_dim
for c in range(hps.co_dim):
self.atau_summ[c] = \
tf.summary.scalar("AR Autocorrelation taus " + str(c),
tf.exp(self.prior_zs_ar_con.logataus_1xu[0,c]))
self.pvar_summ[c] = \
tf.summary.scalar("AR Variances " + str(c),
tf.exp(self.prior_zs_ar_con.logpvars_1xu[0,c]))
# cost summaries, separated into different collections for
# training vs validation. We make placeholders for these, because
# even though the graph computes these costs on a per-batch basis,
# we want to report the more reliable metric of per-epoch cost.
kl_cost_ph = tf.placeholder(tf.float32, shape=[], name='kl_cost_ph')
self.kl_t_cost_summ = tf.summary.scalar("KL cost (train)", kl_cost_ph,
collections=["train_summaries"])
self.kl_v_cost_summ = tf.summary.scalar("KL cost (valid)", kl_cost_ph,
collections=["valid_summaries"])
l2_cost_ph = tf.placeholder(tf.float32, shape=[], name='l2_cost_ph')
self.l2_cost_summ = tf.summary.scalar("L2 cost", l2_cost_ph,
collections=["train_summaries"])
recon_cost_ph = tf.placeholder(tf.float32, shape=[], name='recon_cost_ph')
self.recon_t_cost_summ = tf.summary.scalar("Reconstruction cost (train)",
recon_cost_ph,
collections=["train_summaries"])
self.recon_v_cost_summ = tf.summary.scalar("Reconstruction cost (valid)",
recon_cost_ph,
collections=["valid_summaries"])
total_cost_ph = tf.placeholder(tf.float32, shape=[], name='total_cost_ph')
self.cost_t_summ = tf.summary.scalar("Total cost (train)", total_cost_ph,
collections=["train_summaries"])
self.cost_v_summ = tf.summary.scalar("Total cost (valid)", total_cost_ph,
collections=["valid_summaries"])
self.kl_cost_ph = kl_cost_ph
self.l2_cost_ph = l2_cost_ph
self.recon_cost_ph = recon_cost_ph
self.total_cost_ph = total_cost_ph
# Merged summaries, for easy coding later.
self.merged_examples = tf.summary.merge_all(key="example_summaries")
self.merged_generic = tf.summary.merge_all() # default key is 'summaries'
self.merged_train = tf.summary.merge_all(key="train_summaries")
self.merged_valid = tf.summary.merge_all(key="valid_summaries")
session = tf.get_default_session()
self.logfile = os.path.join(hps.lfads_save_dir, "lfads_log")
self.writer = tf.summary.FileWriter(self.logfile)
def build_feed_dict(self, train_name, data_bxtxd, ext_input_bxtxi=None,
keep_prob=None):
"""Build the feed dictionary, handles cases where there is no value defined.
Args:
train_name: The key into the datasets, to set the tf.case statement for
the proper readin / readout matrices.
data_bxtxd: The data tensor
ext_input_bxtxi (optional): The external input tensor
keep_prob: The drop out keep probability.
Returns:
The feed dictionary with TF tensors as keys and data as values, for use
with tf.Session.run()
"""
feed_dict = {}
B, T, _ = data_bxtxd.shape
feed_dict[self.dataName] = train_name
feed_dict[self.dataset_ph] = data_bxtxd
if self.ext_input is not None and ext_input_bxtxi is not None:
feed_dict[self.ext_input] = ext_input_bxtxi
if keep_prob is None:
feed_dict[self.keep_prob] = self.hps.keep_prob
else:
feed_dict[self.keep_prob] = keep_prob
return feed_dict
@staticmethod
def get_batch(data_extxd, ext_input_extxi=None, batch_size=None,
example_idxs=None):
"""Get a batch of data, either randomly chosen, or specified directly.
Args:
data_extxd: The data to model, numpy tensors with shape:
# examples x # time steps x # dimensions
ext_input_extxi (optional): The external inputs, numpy tensor with shape:
# examples x # time steps x # external input dimensions
batch_size: The size of the batch to return
example_idxs (optional): The example indices used to select examples.
Returns:
A tuple with two parts:
1. Batched data numpy tensor with shape:
batch_size x # time steps x # dimensions
2. Batched external input numpy tensor with shape:
batch_size x # time steps x # external input dims
"""
assert batch_size is not None or example_idxs is not None, "Problems"
E, T, D = data_extxd.shape
if example_idxs is None:
example_idxs = np.random.choice(E, batch_size)
ext_input_bxtxi = None
if ext_input_extxi is not None:
ext_input_bxtxi = ext_input_extxi[example_idxs,:,:]
return data_extxd[example_idxs,:,:], ext_input_bxtxi
@staticmethod
def example_idxs_mod_batch_size(nexamples, batch_size):
"""Given a number of examples, E, and a batch_size, B, generate indices
[0, 1, 2, ... B-1;
[B, B+1, ... 2*B-1;
...
]
returning those indices as a 2-dim tensor shaped like E/B x B. Note that
shape is only correct if E % B == 0. If not, then an extra row is generated
so that the remainder of examples is included. The extra examples are
explicitly to to the zero index (see randomize_example_idxs_mod_batch_size)
for randomized behavior.
Args:
nexamples: The number of examples to batch up.
batch_size: The size of the batch.
Returns:
2-dim tensor as described above.
"""
bmrem = batch_size - (nexamples % batch_size)
bmrem_examples = []
if bmrem < batch_size:
#bmrem_examples = np.zeros(bmrem, dtype=np.int32)
ridxs = np.random.permutation(nexamples)[0:bmrem].astype(np.int32)
bmrem_examples = np.sort(ridxs)
example_idxs = range(nexamples) + list(bmrem_examples)
example_idxs_e_x_edivb = np.reshape(example_idxs, [-1, batch_size])
return example_idxs_e_x_edivb, bmrem
@staticmethod
def randomize_example_idxs_mod_batch_size(nexamples, batch_size):
"""Indices 1:nexamples, randomized, in 2D form of
shape = (nexamples / batch_size) x batch_size. The remainder
is managed by drawing randomly from 1:nexamples.
Args:
nexamples: number of examples to randomize
batch_size: number of elements in batch
Returns:
The randomized, properly shaped indicies.
"""
assert nexamples > batch_size, "Problems"
bmrem = batch_size - nexamples % batch_size
bmrem_examples = []
if bmrem < batch_size:
bmrem_examples = np.random.choice(range(nexamples),
size=bmrem, replace=False)
example_idxs = range(nexamples) + list(bmrem_examples)
mixed_example_idxs = np.random.permutation(example_idxs)
example_idxs_e_x_edivb = np.reshape(mixed_example_idxs, [-1, batch_size])
return example_idxs_e_x_edivb, bmrem
def shuffle_spikes_in_time(self, data_bxtxd):
"""Shuffle the spikes in the temporal dimension. This is useful to
help the LFADS system avoid overfitting to individual spikes or fast
oscillations found in the data that are irrelevant to behavior. A
pure 'tabula rasa' approach would avoid this, but LFADS is sensitive
enough to pick up dynamics that you may not want.
Args:
data_bxtxd: numpy array of spike count data to be shuffled.
Returns:
S_bxtxd, a numpy array with the same dimensions and contents as
data_bxtxd, but shuffled appropriately.
"""
B, T, N = data_bxtxd.shape
w = self.hps.temporal_spike_jitter_width
if w == 0:
return data_bxtxd
max_counts = np.max(data_bxtxd)
S_bxtxd = np.zeros([B,T,N])
# Intuitively, shuffle spike occurances, 0 or 1, but since we have counts,
# Do it over and over again up to the max count.
for mc in range(1,max_counts+1):
idxs = np.nonzero(data_bxtxd >= mc)
data_ones = np.zeros_like(data_bxtxd)
data_ones[data_bxtxd >= mc] = 1
nfound = len(idxs[0])
shuffles_incrs_in_time = np.random.randint(-w, w, size=nfound)
shuffle_tidxs = idxs[1].copy()
shuffle_tidxs += shuffles_incrs_in_time
# Reflect on the boundaries to not lose mass.
shuffle_tidxs[shuffle_tidxs < 0] = -shuffle_tidxs[shuffle_tidxs < 0]
shuffle_tidxs[shuffle_tidxs > T-1] = \
(T-1)-(shuffle_tidxs[shuffle_tidxs > T-1] -(T-1))
for iii in zip(idxs[0], shuffle_tidxs, idxs[2]):
S_bxtxd[iii] += 1
return S_bxtxd
def shuffle_and_flatten_datasets(self, datasets, kind='train'):
"""Since LFADS supports multiple datasets in the same dynamical model,
we have to be careful to use all the data in a single training epoch. But
since the datasets my have different data dimensionality, we cannot batch
examples from data dictionaries together. Instead, we generate random
batches within each data dictionary, and then randomize these batches
while holding onto the dataname, so that when it's time to feed
the graph, the correct in/out matrices can be selected, per batch.
Args:
datasets: A dict of data dicts. The dataset dict is simply a
name(string)-> data dictionary mapping (See top of lfads.py).
kind: 'train' or 'valid'
Returns:
A flat list, in which each element is a pair ('name', indices).
"""
batch_size = self.hps.batch_size
ndatasets = len(datasets)
random_example_idxs = {}
epoch_idxs = {}
all_name_example_idx_pairs = []
kind_data = kind + '_data'
for name, data_dict in datasets.items():
nexamples, ntime, data_dim = data_dict[kind_data].shape
epoch_idxs[name] = 0
random_example_idxs, _ = \
self.randomize_example_idxs_mod_batch_size(nexamples, batch_size)
epoch_size = random_example_idxs.shape[0]
names = [name] * epoch_size
all_name_example_idx_pairs += zip(names, random_example_idxs)
np.random.shuffle(all_name_example_idx_pairs) # shuffle in place
return all_name_example_idx_pairs
def train_epoch(self, datasets, batch_size=None, do_save_ckpt=True):
"""Train the model through the entire dataset once.
Args:
datasets: A dict of data dicts. The dataset dict is simply a
name(string)-> data dictionary mapping (See top of lfads.py).
batch_size (optional): The batch_size to use
do_save_ckpt (optional): Should the routine save a checkpoint on this
training epoch?
Returns:
A tuple with 6 float values:
(total cost of the epoch, epoch reconstruction cost,
epoch kl cost, KL weight used this training epoch,
total l2 cost on generator, and the corresponding weight).
"""
ops_to_eval = [self.cost, self.recon_cost,
self.kl_cost, self.kl_weight,
self.l2_cost, self.l2_weight,
self.train_op]
collected_op_values = self.run_epoch(datasets, ops_to_eval, kind="train")
total_cost = total_recon_cost = total_kl_cost = 0.0
# normalizing by batch done in distributions.py
epoch_size = len(collected_op_values)
for op_values in collected_op_values:
total_cost += op_values[0]
total_recon_cost += op_values[1]
total_kl_cost += op_values[2]
kl_weight = collected_op_values[-1][3]
l2_cost = collected_op_values[-1][4]
l2_weight = collected_op_values[-1][5]
epoch_total_cost = total_cost / epoch_size
epoch_recon_cost = total_recon_cost / epoch_size
epoch_kl_cost = total_kl_cost / epoch_size
if do_save_ckpt:
session = tf.get_default_session()
checkpoint_path = os.path.join(self.hps.lfads_save_dir,
self.hps.checkpoint_name + '.ckpt')
self.seso_saver.save(session, checkpoint_path,
global_step=self.train_step)
return epoch_total_cost, epoch_recon_cost, epoch_kl_cost, \
kl_weight, l2_cost, l2_weight
def run_epoch(self, datasets, ops_to_eval, kind="train", batch_size=None,
do_collect=True, keep_prob=None):
"""Run the model through the entire dataset once.
Args:
datasets: A dict of data dicts. The dataset dict is simply a
name(string)-> data dictionary mapping (See top of lfads.py).
ops_to_eval: A list of tensorflow operations that will be evaluated in
the tf.session.run() call.
batch_size (optional): The batch_size to use
do_collect (optional): Should the routine collect all session.run
output as a list, and return it?
keep_prob (optional): The dropout keep probability.
Returns:
A list of lists, the internal list is the return for the ops for each
session.run() call. The outer list collects over the epoch.
"""
hps = self.hps
all_name_example_idx_pairs = \
self.shuffle_and_flatten_datasets(datasets, kind)
kind_data = kind + '_data'
kind_ext_input = kind + '_ext_input'
total_cost = total_recon_cost = total_kl_cost = 0.0
session = tf.get_default_session()
epoch_size = len(all_name_example_idx_pairs)
evaled_ops_list = []
for name, example_idxs in all_name_example_idx_pairs:
data_dict = datasets[name]
data_extxd = data_dict[kind_data]
if hps.output_dist == 'poisson' and hps.temporal_spike_jitter_width > 0:
data_extxd = self.shuffle_spikes_in_time(data_extxd)
ext_input_extxi = data_dict[kind_ext_input]
data_bxtxd, ext_input_bxtxi = self.get_batch(data_extxd, ext_input_extxi,
example_idxs=example_idxs)
feed_dict = self.build_feed_dict(name, data_bxtxd, ext_input_bxtxi,
keep_prob=keep_prob)
evaled_ops_np = session.run(ops_to_eval, feed_dict=feed_dict)
if do_collect:
evaled_ops_list.append(evaled_ops_np)
return evaled_ops_list
def summarize_all(self, datasets, summary_values):
"""Plot and summarize stuff in tensorboard.
Note that everything done in the current function is otherwise done on
a single, randomly selected dataset (except for summary_values, which are
passed in.)
Args:
datasets, the dictionary of datasets used in the study.
summary_values: These summary values are created from the training loop,
and so summarize the entire set of datasets.
"""
hps = self.hps
tr_kl_cost = summary_values['tr_kl_cost']
tr_recon_cost = summary_values['tr_recon_cost']
tr_total_cost = summary_values['tr_total_cost']
kl_weight = summary_values['kl_weight']
l2_weight = summary_values['l2_weight']
l2_cost = summary_values['l2_cost']
has_any_valid_set = summary_values['has_any_valid_set']
i = summary_values['nepochs']
session = tf.get_default_session()
train_summ, train_step = session.run([self.merged_train,
self.train_step],
feed_dict={self.l2_cost_ph:l2_cost,
self.kl_cost_ph:tr_kl_cost,
self.recon_cost_ph:tr_recon_cost,
self.total_cost_ph:tr_total_cost})
self.writer.add_summary(train_summ, train_step)
if has_any_valid_set:
ev_kl_cost = summary_values['ev_kl_cost']
ev_recon_cost = summary_values['ev_recon_cost']
ev_total_cost = summary_values['ev_total_cost']
eval_summ = session.run(self.merged_valid,
feed_dict={self.kl_cost_ph:ev_kl_cost,
self.recon_cost_ph:ev_recon_cost,
self.total_cost_ph:ev_total_cost})
self.writer.add_summary(eval_summ, train_step)
print("Epoch:%d, step:%d (TRAIN, VALID): total: %.2f, %.2f\
recon: %.2f, %.2f, kl: %.2f, %.2f, l2: %.5f,\
kl weight: %.2f, l2 weight: %.2f" % \
(i, train_step, tr_total_cost, ev_total_cost,
tr_recon_cost, ev_recon_cost, tr_kl_cost, ev_kl_cost,
l2_cost, kl_weight, l2_weight))
csv_outstr = "epoch,%d, step,%d, total,%.2f,%.2f, \
recon,%.2f,%.2f, kl,%.2f,%.2f, l2,%.5f, \
klweight,%.2f, l2weight,%.2f\n"% \
(i, train_step, tr_total_cost, ev_total_cost,
tr_recon_cost, ev_recon_cost, tr_kl_cost, ev_kl_cost,
l2_cost, kl_weight, l2_weight)
else:
print("Epoch:%d, step:%d TRAIN: total: %.2f recon: %.2f, kl: %.2f,\
l2: %.5f, kl weight: %.2f, l2 weight: %.2f" % \
(i, train_step, tr_total_cost, tr_recon_cost, tr_kl_cost,
l2_cost, kl_weight, l2_weight))
csv_outstr = "epoch,%d, step,%d, total,%.2f, recon,%.2f, kl,%.2f, \
l2,%.5f, klweight,%.2f, l2weight,%.2f\n"% \
(i, train_step, tr_total_cost, tr_recon_cost,
tr_kl_cost, l2_cost, kl_weight, l2_weight)
if self.hps.csv_log:
csv_file = os.path.join(self.hps.lfads_save_dir, self.hps.csv_log+'.csv')
with open(csv_file, "a") as myfile:
myfile.write(csv_outstr)
def plot_single_example(self, datasets):
"""Plot an image relating to a randomly chosen, specific example. We use
posterior sample and average by taking one example, and filling a whole
batch with that example, sample from the posterior, and then average the
quantities.
"""
hps = self.hps
all_data_names = datasets.keys()
data_name = np.random.permutation(all_data_names)[0]
data_dict = datasets[data_name]
has_valid_set = True if data_dict['valid_data'] is not None else False
cf = 1.0 # plotting concern
# posterior sample and average here
E, _, _ = data_dict['train_data'].shape
eidx = np.random.choice(E)
example_idxs = eidx * np.ones(hps.batch_size, dtype=np.int32)
train_data_bxtxd, train_ext_input_bxtxi = \
self.get_batch(data_dict['train_data'], data_dict['train_ext_input'],
example_idxs=example_idxs)
truth_train_data_bxtxd = None
if 'train_truth' in data_dict and data_dict['train_truth'] is not None:
truth_train_data_bxtxd, _ = self.get_batch(data_dict['train_truth'],
example_idxs=example_idxs)
cf = data_dict['conversion_factor']
# plotter does averaging
train_model_values = self.eval_model_runs_batch(data_name,
train_data_bxtxd,
train_ext_input_bxtxi,
do_average_batch=False)
train_step = train_model_values['train_steps']
feed_dict = self.build_feed_dict(data_name, train_data_bxtxd,
train_ext_input_bxtxi, keep_prob=1.0)
session = tf.get_default_session()
generic_summ = session.run(self.merged_generic, feed_dict=feed_dict)
self.writer.add_summary(generic_summ, train_step)
valid_data_bxtxd = valid_model_values = valid_ext_input_bxtxi = None
truth_valid_data_bxtxd = None
if has_valid_set:
E, _, _ = data_dict['valid_data'].shape
eidx = np.random.choice(E)
example_idxs = eidx * np.ones(hps.batch_size, dtype=np.int32)
valid_data_bxtxd, valid_ext_input_bxtxi = \
self.get_batch(data_dict['valid_data'],
data_dict['valid_ext_input'],
example_idxs=example_idxs)
if 'valid_truth' in data_dict and data_dict['valid_truth'] is not None:
truth_valid_data_bxtxd, _ = self.get_batch(data_dict['valid_truth'],
example_idxs=example_idxs)
else:
truth_valid_data_bxtxd = None
# plotter does averaging
valid_model_values = self.eval_model_runs_batch(data_name,
valid_data_bxtxd,
valid_ext_input_bxtxi,
do_average_batch=False)
example_image = plot_lfads(train_bxtxd=train_data_bxtxd,
train_model_vals=train_model_values,
train_ext_input_bxtxi=train_ext_input_bxtxi,
train_truth_bxtxd=truth_train_data_bxtxd,
valid_bxtxd=valid_data_bxtxd,
valid_model_vals=valid_model_values,
valid_ext_input_bxtxi=valid_ext_input_bxtxi,
valid_truth_bxtxd=truth_valid_data_bxtxd,
bidx=None, cf=cf, output_dist=hps.output_dist)
example_image = np.expand_dims(example_image, axis=0)
example_summ = session.run(self.merged_examples,
feed_dict={self.example_image : example_image})
self.writer.add_summary(example_summ)
def train_model(self, datasets):
"""Train the model, print per-epoch information, and save checkpoints.
Loop over training epochs. The function that actually does the
training is train_epoch. This function iterates over the training
data, one epoch at a time. The learning rate schedule is such
that it will stay the same until the cost goes up in comparison to
the last few values, then it will drop.
Args:
datasets: A dict of data dicts. The dataset dict is simply a
name(string)-> data dictionary mapping (See top of lfads.py).
"""
hps = self.hps
has_any_valid_set = False
for data_dict in datasets.values():
if data_dict['valid_data'] is not None:
has_any_valid_set = True
break
session = tf.get_default_session()
lr = session.run(self.learning_rate)
lr_stop = hps.learning_rate_stop
i = -1
train_costs = []
valid_costs = []
ev_total_cost = ev_recon_cost = ev_kl_cost = 0.0
lowest_ev_cost = np.Inf
while True:
i += 1
do_save_ckpt = True if i % 10 ==0 else False
tr_total_cost, tr_recon_cost, tr_kl_cost, kl_weight, l2_cost, l2_weight = \
self.train_epoch(datasets, do_save_ckpt=do_save_ckpt)
# Evaluate the validation cost, and potentially save. Note that this
# routine will not save a validation checkpoint until the kl weight and
# l2 weights are equal to 1.0.
if has_any_valid_set:
ev_total_cost, ev_recon_cost, ev_kl_cost = \
self.eval_cost_epoch(datasets, kind='valid')
valid_costs.append(ev_total_cost)
# > 1 may give more consistent results, but not the actual lowest vae.
# == 1 gives the lowest vae seen so far.
n_lve = 1
run_avg_lve = np.mean(valid_costs[-n_lve:])
# conditions for saving checkpoints:
# KL weight must have finished stepping (>=1.0), AND
# L2 weight must have finished stepping OR L2 is not being used, AND
# the current run has a lower LVE than previous runs AND
# len(valid_costs > n_lve) (not sure what that does)
if kl_weight >= 1.0 and \
(l2_weight >= 1.0 or \
(self.hps.l2_gen_scale == 0.0 and self.hps.l2_con_scale == 0.0)) \
and (len(valid_costs) > n_lve and run_avg_lve < lowest_ev_cost):
lowest_ev_cost = run_avg_lve
checkpoint_path = os.path.join(self.hps.lfads_save_dir,
self.hps.checkpoint_name + '_lve.ckpt')
self.lve_saver.save(session, checkpoint_path,
global_step=self.train_step,
latest_filename='checkpoint_lve')
# Plot and summarize.
values = {'nepochs':i, 'has_any_valid_set': has_any_valid_set,
'tr_total_cost':tr_total_cost, 'ev_total_cost':ev_total_cost,
'tr_recon_cost':tr_recon_cost, 'ev_recon_cost':ev_recon_cost,
'tr_kl_cost':tr_kl_cost, 'ev_kl_cost':ev_kl_cost,
'l2_weight':l2_weight, 'kl_weight':kl_weight,
'l2_cost':l2_cost}
self.summarize_all(datasets, values)
self.plot_single_example(datasets)
# Manage learning rate.
train_res = tr_total_cost
n_lr = hps.learning_rate_n_to_compare
if len(train_costs) > n_lr and train_res > np.max(train_costs[-n_lr:]):
_ = session.run(self.learning_rate_decay_op)
lr = session.run(self.learning_rate)
print(" Decreasing learning rate to %f." % lr)
# Force the system to run n_lr times while at this lr.
train_costs.append(np.inf)
else:
train_costs.append(train_res)
if lr < lr_stop:
print("Stopping optimization based on learning rate criteria.")
break
def eval_cost_epoch(self, datasets, kind='train', ext_input_extxi=None,
batch_size=None):
"""Evaluate the cost of the epoch.
Args:
data_dict: The dictionary of data (training and validation) used for
training and evaluation of the model, respectively.
Returns:
a 3 tuple of costs:
(epoch total cost, epoch reconstruction cost, epoch KL cost)
"""
ops_to_eval = [self.cost, self.recon_cost, self.kl_cost]
collected_op_values = self.run_epoch(datasets, ops_to_eval, kind=kind,
keep_prob=1.0)
total_cost = total_recon_cost = total_kl_cost = 0.0
# normalizing by batch done in distributions.py
epoch_size = len(collected_op_values)
for op_values in collected_op_values:
total_cost += op_values[0]
total_recon_cost += op_values[1]
total_kl_cost += op_values[2]
epoch_total_cost = total_cost / epoch_size
epoch_recon_cost = total_recon_cost / epoch_size
epoch_kl_cost = total_kl_cost / epoch_size
return epoch_total_cost, epoch_recon_cost, epoch_kl_cost
def eval_model_runs_batch(self, data_name, data_bxtxd, ext_input_bxtxi=None,
do_eval_cost=False, do_average_batch=False):
"""Returns all the goodies for the entire model, per batch.
If data_bxtxd and ext_input_bxtxi can have fewer than batch_size along dim 1
in which case this handles the padding and truncating automatically
Args:
data_name: The name of the data dict, to select which in/out matrices
to use.
data_bxtxd: Numpy array training data with shape:
batch_size x # time steps x # dimensions
ext_input_bxtxi: Numpy array training external input with shape:
batch_size x # time steps x # external input dims
do_eval_cost (optional): If true, the IWAE (Importance Weighted
Autoencoder) log likeihood bound, instead of the VAE version.
do_average_batch (optional): average over the batch, useful for getting
good IWAE costs, and model outputs for a single data point.
Returns:
A dictionary with the outputs of the model decoder, namely:
prior g0 mean, prior g0 variance, approx. posterior mean, approx
posterior mean, the generator initial conditions, the control inputs (if
enabled), the state of the generator, the factors, and the rates.
"""
session = tf.get_default_session()
# if fewer than batch_size provided, pad to batch_size
hps = self.hps
batch_size = hps.batch_size
E, _, _ = data_bxtxd.shape
if E < hps.batch_size:
data_bxtxd = np.pad(data_bxtxd, ((0, hps.batch_size-E), (0, 0), (0, 0)),
mode='constant', constant_values=0)
if ext_input_bxtxi is not None:
ext_input_bxtxi = np.pad(ext_input_bxtxi,
((0, hps.batch_size-E), (0, 0), (0, 0)),
mode='constant', constant_values=0)
feed_dict = self.build_feed_dict(data_name, data_bxtxd,
ext_input_bxtxi, keep_prob=1.0)
# Non-temporal signals will be batch x dim.
# Temporal signals are list length T with elements batch x dim.
tf_vals = [self.gen_ics, self.gen_states, self.factors,
self.output_dist_params]
tf_vals.append(self.cost)
tf_vals.append(self.nll_bound_vae)
tf_vals.append(self.nll_bound_iwae)
tf_vals.append(self.train_step) # not train_op!
if self.hps.ic_dim > 0:
tf_vals += [self.prior_zs_g0.mean, self.prior_zs_g0.logvar,
self.posterior_zs_g0.mean, self.posterior_zs_g0.logvar]
if self.hps.co_dim > 0:
tf_vals.append(self.controller_outputs)
tf_vals_flat, fidxs = flatten(tf_vals)
np_vals_flat = session.run(tf_vals_flat, feed_dict=feed_dict)
ff = 0
gen_ics = [np_vals_flat[f] for f in fidxs[ff]]; ff += 1
gen_states = [np_vals_flat[f] for f in fidxs[ff]]; ff += 1
factors = [np_vals_flat[f] for f in fidxs[ff]]; ff += 1
out_dist_params = [np_vals_flat[f] for f in fidxs[ff]]; ff += 1
costs = [np_vals_flat[f] for f in fidxs[ff]]; ff += 1
nll_bound_vaes = [np_vals_flat[f] for f in fidxs[ff]]; ff += 1
nll_bound_iwaes = [np_vals_flat[f] for f in fidxs[ff]]; ff +=1
train_steps = [np_vals_flat[f] for f in fidxs[ff]]; ff +=1
if self.hps.ic_dim > 0:
prior_g0_mean = [np_vals_flat[f] for f in fidxs[ff]]; ff +=1
prior_g0_logvar = [np_vals_flat[f] for f in fidxs[ff]]; ff += 1
post_g0_mean = [np_vals_flat[f] for f in fidxs[ff]]; ff += 1
post_g0_logvar = [np_vals_flat[f] for f in fidxs[ff]]; ff += 1
if self.hps.co_dim > 0:
controller_outputs = [np_vals_flat[f] for f in fidxs[ff]]; ff += 1
# [0] are to take out the non-temporal items from lists
gen_ics = gen_ics[0]
costs = costs[0]
nll_bound_vaes = nll_bound_vaes[0]
nll_bound_iwaes = nll_bound_iwaes[0]
train_steps = train_steps[0]
# Convert to full tensors, not lists of tensors in time dim.
gen_states = list_t_bxn_to_tensor_bxtxn(gen_states)
factors = list_t_bxn_to_tensor_bxtxn(factors)
out_dist_params = list_t_bxn_to_tensor_bxtxn(out_dist_params)
if self.hps.ic_dim > 0:
# select first time point
prior_g0_mean = prior_g0_mean[0]
prior_g0_logvar = prior_g0_logvar[0]
post_g0_mean = post_g0_mean[0]
post_g0_logvar = post_g0_logvar[0]
if self.hps.co_dim > 0:
controller_outputs = list_t_bxn_to_tensor_bxtxn(controller_outputs)
# slice out the trials in case < batch_size provided
if E < hps.batch_size:
idx = np.arange(E)
gen_ics = gen_ics[idx, :]
gen_states = gen_states[idx, :]
factors = factors[idx, :, :]
out_dist_params = out_dist_params[idx, :, :]
if self.hps.ic_dim > 0:
prior_g0_mean = prior_g0_mean[idx, :]
prior_g0_logvar = prior_g0_logvar[idx, :]
post_g0_mean = post_g0_mean[idx, :]
post_g0_logvar = post_g0_logvar[idx, :]
if self.hps.co_dim > 0:
controller_outputs = controller_outputs[idx, :, :]
if do_average_batch:
gen_ics = np.mean(gen_ics, axis=0)
gen_states = np.mean(gen_states, axis=0)
factors = np.mean(factors, axis=0)
out_dist_params = np.mean(out_dist_params, axis=0)
if self.hps.ic_dim > 0:
prior_g0_mean = np.mean(prior_g0_mean, axis=0)
prior_g0_logvar = np.mean(prior_g0_logvar, axis=0)
post_g0_mean = np.mean(post_g0_mean, axis=0)
post_g0_logvar = np.mean(post_g0_logvar, axis=0)
if self.hps.co_dim > 0:
controller_outputs = np.mean(controller_outputs, axis=0)
model_vals = {}
model_vals['gen_ics'] = gen_ics
model_vals['gen_states'] = gen_states
model_vals['factors'] = factors
model_vals['output_dist_params'] = out_dist_params
model_vals['costs'] = costs
model_vals['nll_bound_vaes'] = nll_bound_vaes
model_vals['nll_bound_iwaes'] = nll_bound_iwaes
model_vals['train_steps'] = train_steps
if self.hps.ic_dim > 0:
model_vals['prior_g0_mean'] = prior_g0_mean
model_vals['prior_g0_logvar'] = prior_g0_logvar
model_vals['post_g0_mean'] = post_g0_mean
model_vals['post_g0_logvar'] = post_g0_logvar
if self.hps.co_dim > 0:
model_vals['controller_outputs'] = controller_outputs
return model_vals
def eval_model_runs_avg_epoch(self, data_name, data_extxd,
ext_input_extxi=None):
"""Returns all the expected value for goodies for the entire model.
The expected value is taken over hidden (z) variables, namely the initial
conditions and the control inputs. The expected value is approximate, and
accomplished via sampling (batch_size) samples for every examples.
Args:
data_name: The name of the data dict, to select which in/out matrices
to use.
data_extxd: Numpy array training data with shape:
# examples x # time steps x # dimensions
ext_input_extxi (optional): Numpy array training external input with
shape: # examples x # time steps x # external input dims
Returns:
A dictionary with the averaged outputs of the model decoder, namely:
prior g0 mean, prior g0 variance, approx. posterior mean, approx
posterior mean, the generator initial conditions, the control inputs (if
enabled), the state of the generator, the factors, and the output
distribution parameters, e.g. (rates or mean and variances).
"""
hps = self.hps
batch_size = hps.batch_size
E, T, D = data_extxd.shape
E_to_process = hps.ps_nexamples_to_process
if E_to_process > E:
E_to_process = E
if hps.ic_dim > 0:
prior_g0_mean = np.zeros([E_to_process, hps.ic_dim])
prior_g0_logvar = np.zeros([E_to_process, hps.ic_dim])
post_g0_mean = np.zeros([E_to_process, hps.ic_dim])
post_g0_logvar = np.zeros([E_to_process, hps.ic_dim])
if hps.co_dim > 0:
controller_outputs = np.zeros([E_to_process, T, hps.co_dim])
gen_ics = np.zeros([E_to_process, hps.gen_dim])
gen_states = np.zeros([E_to_process, T, hps.gen_dim])
factors = np.zeros([E_to_process, T, hps.factors_dim])
if hps.output_dist == 'poisson':
out_dist_params = np.zeros([E_to_process, T, D])
elif hps.output_dist == 'gaussian':
out_dist_params = np.zeros([E_to_process, T, D+D])
else:
assert False, "NIY"
costs = np.zeros(E_to_process)
nll_bound_vaes = np.zeros(E_to_process)
nll_bound_iwaes = np.zeros(E_to_process)
train_steps = np.zeros(E_to_process)
for es_idx in range(E_to_process):
print("Running %d of %d." % (es_idx+1, E_to_process))
example_idxs = es_idx * np.ones(batch_size, dtype=np.int32)
data_bxtxd, ext_input_bxtxi = self.get_batch(data_extxd,
ext_input_extxi,
batch_size=batch_size,
example_idxs=example_idxs)
model_values = self.eval_model_runs_batch(data_name, data_bxtxd,
ext_input_bxtxi,
do_eval_cost=True,
do_average_batch=True)
if self.hps.ic_dim > 0:
prior_g0_mean[es_idx,:] = model_values['prior_g0_mean']
prior_g0_logvar[es_idx,:] = model_values['prior_g0_logvar']
post_g0_mean[es_idx,:] = model_values['post_g0_mean']
post_g0_logvar[es_idx,:] = model_values['post_g0_logvar']
gen_ics[es_idx,:] = model_values['gen_ics']
if self.hps.co_dim > 0:
controller_outputs[es_idx,:,:] = model_values['controller_outputs']
gen_states[es_idx,:,:] = model_values['gen_states']
factors[es_idx,:,:] = model_values['factors']
out_dist_params[es_idx,:,:] = model_values['output_dist_params']
costs[es_idx] = model_values['costs']
nll_bound_vaes[es_idx] = model_values['nll_bound_vaes']
nll_bound_iwaes[es_idx] = model_values['nll_bound_iwaes']
train_steps[es_idx] = model_values['train_steps']
print('bound nll(vae): %.3f, bound nll(iwae): %.3f' \
% (nll_bound_vaes[es_idx], nll_bound_iwaes[es_idx]))
model_runs = {}
if self.hps.ic_dim > 0:
model_runs['prior_g0_mean'] = prior_g0_mean
model_runs['prior_g0_logvar'] = prior_g0_logvar
model_runs['post_g0_mean'] = post_g0_mean
model_runs['post_g0_logvar'] = post_g0_logvar
model_runs['gen_ics'] = gen_ics
if self.hps.co_dim > 0:
model_runs['controller_outputs'] = controller_outputs
model_runs['gen_states'] = gen_states
model_runs['factors'] = factors
model_runs['output_dist_params'] = out_dist_params
model_runs['costs'] = costs
model_runs['nll_bound_vaes'] = nll_bound_vaes
model_runs['nll_bound_iwaes'] = nll_bound_iwaes
model_runs['train_steps'] = train_steps
return model_runs
def eval_model_runs_push_mean(self, data_name, data_extxd,
ext_input_extxi=None):
"""Returns values of interest for the model by pushing the means through
The mean values for both initial conditions and the control inputs are
pushed through the model instead of sampling (as is done in
eval_model_runs_avg_epoch).
This is a quick and approximate version of estimating these values instead
of sampling from the posterior many times and then averaging those values of
interest.
Internally, a total of batch_size trials are run through the model at once.
Args:
data_name: The name of the data dict, to select which in/out matrices
to use.
data_extxd: Numpy array training data with shape:
# examples x # time steps x # dimensions
ext_input_extxi (optional): Numpy array training external input with
shape: # examples x # time steps x # external input dims
Returns:
A dictionary with the estimated outputs of the model decoder, namely:
prior g0 mean, prior g0 variance, approx. posterior mean, approx
posterior mean, the generator initial conditions, the control inputs (if
enabled), the state of the generator, the factors, and the output
distribution parameters, e.g. (rates or mean and variances).
"""
hps = self.hps
batch_size = hps.batch_size
E, T, D = data_extxd.shape
E_to_process = hps.ps_nexamples_to_process
if E_to_process > E:
print("Setting number of posterior samples to process to : ", E)
E_to_process = E
if hps.ic_dim > 0:
prior_g0_mean = np.zeros([E_to_process, hps.ic_dim])
prior_g0_logvar = np.zeros([E_to_process, hps.ic_dim])
post_g0_mean = np.zeros([E_to_process, hps.ic_dim])
post_g0_logvar = np.zeros([E_to_process, hps.ic_dim])
if hps.co_dim > 0:
controller_outputs = np.zeros([E_to_process, T, hps.co_dim])
gen_ics = np.zeros([E_to_process, hps.gen_dim])
gen_states = np.zeros([E_to_process, T, hps.gen_dim])
factors = np.zeros([E_to_process, T, hps.factors_dim])
if hps.output_dist == 'poisson':
out_dist_params = np.zeros([E_to_process, T, D])
elif hps.output_dist == 'gaussian':
out_dist_params = np.zeros([E_to_process, T, D+D])
else:
assert False, "NIY"
costs = np.zeros(E_to_process)
nll_bound_vaes = np.zeros(E_to_process)
nll_bound_iwaes = np.zeros(E_to_process)
train_steps = np.zeros(E_to_process)
# generator that will yield 0:N in groups of per items, e.g.
# (0:per-1), (per:2*per-1), ..., with the last group containing <= per items
# this will be used to feed per=batch_size trials into the model at a time
def trial_batches(N, per):
for i in range(0, N, per):
yield np.arange(i, min(i+per, N), dtype=np.int32)
for batch_idx, es_idx in enumerate(trial_batches(E_to_process,
hps.batch_size)):
print("Running trial batch %d with %d trials" % (batch_idx+1,
len(es_idx)))
data_bxtxd, ext_input_bxtxi = self.get_batch(data_extxd,
ext_input_extxi,
batch_size=batch_size,
example_idxs=es_idx)
model_values = self.eval_model_runs_batch(data_name, data_bxtxd,
ext_input_bxtxi,
do_eval_cost=True,
do_average_batch=False)
if self.hps.ic_dim > 0:
prior_g0_mean[es_idx,:] = model_values['prior_g0_mean']
prior_g0_logvar[es_idx,:] = model_values['prior_g0_logvar']
post_g0_mean[es_idx,:] = model_values['post_g0_mean']
post_g0_logvar[es_idx,:] = model_values['post_g0_logvar']
gen_ics[es_idx,:] = model_values['gen_ics']
if self.hps.co_dim > 0:
controller_outputs[es_idx,:,:] = model_values['controller_outputs']
gen_states[es_idx,:,:] = model_values['gen_states']
factors[es_idx,:,:] = model_values['factors']
out_dist_params[es_idx,:,:] = model_values['output_dist_params']
# TODO
# model_values['costs'] and other costs come out as scalars, summed over
# all the trials in the batch. what we want is the per-trial costs
costs[es_idx] = model_values['costs']
nll_bound_vaes[es_idx] = model_values['nll_bound_vaes']
nll_bound_iwaes[es_idx] = model_values['nll_bound_iwaes']
train_steps[es_idx] = model_values['train_steps']
model_runs = {}
if self.hps.ic_dim > 0:
model_runs['prior_g0_mean'] = prior_g0_mean
model_runs['prior_g0_logvar'] = prior_g0_logvar
model_runs['post_g0_mean'] = post_g0_mean
model_runs['post_g0_logvar'] = post_g0_logvar
model_runs['gen_ics'] = gen_ics
if self.hps.co_dim > 0:
model_runs['controller_outputs'] = controller_outputs
model_runs['gen_states'] = gen_states
model_runs['factors'] = factors
model_runs['output_dist_params'] = out_dist_params
# You probably do not want the LL associated values when pushing the mean
# instead of sampling.
model_runs['costs'] = costs
model_runs['nll_bound_vaes'] = nll_bound_vaes
model_runs['nll_bound_iwaes'] = nll_bound_iwaes
model_runs['train_steps'] = train_steps
return model_runs
def write_model_runs(self, datasets, output_fname=None, push_mean=False):
"""Run the model on the data in data_dict, and save the computed values.
LFADS generates a number of outputs for each examples, and these are all
saved. They are:
The mean and variance of the prior of g0.
The mean and variance of approximate posterior of g0.
The control inputs (if enabled)
The initial conditions, g0, for all examples.
The generator states for all time.
The factors for all time.
The output distribution parameters (e.g. rates) for all time.
Args:
datasets: a dictionary of named data_dictionaries, see top of lfads.py
output_fname: a file name stem for the output files.
push_mean: if False (default), generates batch_size samples for each trial
and averages the results. if True, runs each trial once without noise,
pushing the posterior mean initial conditions and control inputs through
the trained model. False is used for posterior_sample_and_average, True
is used for posterior_push_mean.
"""
hps = self.hps
kind = hps.kind
for data_name, data_dict in datasets.items():
data_tuple = [('train', data_dict['train_data'],
data_dict['train_ext_input']),
('valid', data_dict['valid_data'],
data_dict['valid_ext_input'])]
for data_kind, data_extxd, ext_input_extxi in data_tuple:
if not output_fname:
fname = "model_runs_" + data_name + '_' + data_kind + '_' + kind
else:
fname = output_fname + data_name + '_' + data_kind + '_' + kind
print("Writing data for %s data and kind %s." % (data_name, data_kind))
if push_mean:
model_runs = self.eval_model_runs_push_mean(data_name, data_extxd,
ext_input_extxi)
else:
model_runs = self.eval_model_runs_avg_epoch(data_name, data_extxd,
ext_input_extxi)
full_fname = os.path.join(hps.lfads_save_dir, fname)
write_data(full_fname, model_runs, compression='gzip')
print("Done.")
def write_model_samples(self, dataset_name, output_fname=None):
"""Use the prior distribution to generate batch_size number of samples
from the model.
LFADS generates a number of outputs for each sample, and these are all
saved. They are:
The mean and variance of the prior of g0.
The control inputs (if enabled)
The initial conditions, g0, for all examples.
The generator states for all time.
The factors for all time.
The output distribution parameters (e.g. rates) for all time.
Args:
dataset_name: The name of the dataset to grab the factors -> rates
alignment matrices from.
output_fname: The name of the file in which to save the generated
samples.
"""
hps = self.hps
batch_size = hps.batch_size
print("Generating %d samples" % (batch_size))
tf_vals = [self.factors, self.gen_states, self.gen_ics,
self.cost, self.output_dist_params]
if hps.ic_dim > 0:
tf_vals += [self.prior_zs_g0.mean, self.prior_zs_g0.logvar]
if hps.co_dim > 0:
tf_vals += [self.prior_zs_ar_con.samples_t]
tf_vals_flat, fidxs = flatten(tf_vals)
session = tf.get_default_session()
feed_dict = {}
feed_dict[self.dataName] = dataset_name
feed_dict[self.keep_prob] = 1.0
np_vals_flat = session.run(tf_vals_flat, feed_dict=feed_dict)
ff = 0
factors = [np_vals_flat[f] for f in fidxs[ff]]; ff += 1
gen_states = [np_vals_flat[f] for f in fidxs[ff]]; ff += 1
gen_ics = [np_vals_flat[f] for f in fidxs[ff]]; ff += 1
costs = [np_vals_flat[f] for f in fidxs[ff]]; ff += 1
output_dist_params = [np_vals_flat[f] for f in fidxs[ff]]; ff += 1
if hps.ic_dim > 0:
prior_g0_mean = [np_vals_flat[f] for f in fidxs[ff]]; ff += 1
prior_g0_logvar = [np_vals_flat[f] for f in fidxs[ff]]; ff += 1
if hps.co_dim > 0:
prior_zs_ar_con = [np_vals_flat[f] for f in fidxs[ff]]; ff += 1
# [0] are to take out the non-temporal items from lists
gen_ics = gen_ics[0]
costs = costs[0]
# Convert to full tensors, not lists of tensors in time dim.
gen_states = list_t_bxn_to_tensor_bxtxn(gen_states)
factors = list_t_bxn_to_tensor_bxtxn(factors)
output_dist_params = list_t_bxn_to_tensor_bxtxn(output_dist_params)
if hps.ic_dim > 0:
prior_g0_mean = prior_g0_mean[0]
prior_g0_logvar = prior_g0_logvar[0]
if hps.co_dim > 0:
prior_zs_ar_con = list_t_bxn_to_tensor_bxtxn(prior_zs_ar_con)
model_vals = {}
model_vals['gen_ics'] = gen_ics
model_vals['gen_states'] = gen_states
model_vals['factors'] = factors
model_vals['output_dist_params'] = output_dist_params
model_vals['costs'] = costs.reshape(1)
if hps.ic_dim > 0:
model_vals['prior_g0_mean'] = prior_g0_mean
model_vals['prior_g0_logvar'] = prior_g0_logvar
if hps.co_dim > 0:
model_vals['prior_zs_ar_con'] = prior_zs_ar_con
full_fname = os.path.join(hps.lfads_save_dir, output_fname)
write_data(full_fname, model_vals, compression='gzip')
print("Done.")
@staticmethod
def eval_model_parameters(use_nested=True, include_strs=None):
"""Evaluate and return all of the TF variables in the model.
Args:
use_nested (optional): For returning values, use a nested dictoinary, based
on variable scoping, or return all variables in a flat dictionary.
include_strs (optional): A list of strings to use as a filter, to reduce the
number of variables returned. A variable name must contain at least one
string in include_strs as a sub-string in order to be returned.
Returns:
The parameters of the model. This can be in a flat
dictionary, or a nested dictionary, where the nesting is by variable
scope.
"""
all_tf_vars = tf.global_variables()
session = tf.get_default_session()
all_tf_vars_eval = session.run(all_tf_vars)
vars_dict = {}
strs = ["LFADS"]
if include_strs:
strs += include_strs
for i, (var, var_eval) in enumerate(zip(all_tf_vars, all_tf_vars_eval)):
if any(s in include_strs for s in var.name):
if not isinstance(var_eval, np.ndarray): # for H5PY
print(var.name, """ is not numpy array, saving as numpy array
with value: """, var_eval, type(var_eval))
e = np.array(var_eval)
print(e, type(e))
else:
e = var_eval
vars_dict[var.name] = e
if not use_nested:
return vars_dict
var_names = vars_dict.keys()
nested_vars_dict = {}
current_dict = nested_vars_dict
for v, var_name in enumerate(var_names):
var_split_name_list = var_name.split('/')
split_name_list_len = len(var_split_name_list)
current_dict = nested_vars_dict
for p, part in enumerate(var_split_name_list):
if p < split_name_list_len - 1:
if part in current_dict:
current_dict = current_dict[part]
else:
current_dict[part] = {}
current_dict = current_dict[part]
else:
current_dict[part] = vars_dict[var_name]
return nested_vars_dict
@staticmethod
def spikify_rates(rates_bxtxd):
"""Randomly spikify underlying rates according a Poisson distribution
Args:
rates_bxtxd: a numpy tensor with shape:
Returns:
A numpy array with the same shape as rates_bxtxd, but with the event
counts.
"""
B,T,N = rates_bxtxd.shape
assert all([B > 0, N > 0]), "problems"
# Because the rates are changing, there is nesting
spikes_bxtxd = np.zeros([B,T,N], dtype=np.int32)
for b in range(B):
for t in range(T):
for n in range(N):
rate = rates_bxtxd[b,t,n]
count = np.random.poisson(rate)
spikes_bxtxd[b,t,n] = count
return spikes_bxtxd