# Copyright 2017 Google Inc. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # # ============================================================================== """ LFADS - Latent Factor Analysis via Dynamical Systems. LFADS is an unsupervised method to decompose time series data into various factors, such as an initial condition, a generative dynamical system, control inputs to that generator, and a low dimensional description of the observed data, called the factors. Additionally, the observations have a noise model (in this case Poisson), so a denoised version of the observations is also created (e.g. underlying rates of a Poisson distribution given the observed event counts). The main data structure being passed around is a dataset. This is a dictionary of data dictionaries. DATASET: The top level dictionary is simply name (string -> dictionary). The nested dictionary is the DATA DICTIONARY, which has the following keys: 'train_data' and 'valid_data', whose values are the corresponding training and validation data with shape ExTxD, E - # examples, T - # time steps, D - # dimensions in data. The data dictionary also has a few more keys: 'train_ext_input' and 'valid_ext_input', if there are know external inputs to the system being modeled, these take on dimensions: ExTxI, E - # examples, T - # time steps, I = # dimensions in input. 'alignment_matrix_cxf' - If you are using multiple days data, it's possible that one can align the channels (see manuscript). If so each dataset will contain this matrix, which will be used for both the input adapter and the output adapter for each dataset. These matrices, if provided, must be of size [data_dim x factors] where data_dim is the number of neurons recorded on that day, and factors is chosen and set through the '--factors' flag. 'alignment_bias_c' - See alignment_matrix_cxf. This bias will used to the offset for the alignment transformation. It will *subtract* off the bias from the data, so pca style inits can align factors across sessions. If one runs LFADS on data where the true rates are known for some trials, (say simulated, testing data, as in the example shipped with the paper), then one can add three more fields for plotting purposes. These are 'train_truth' and 'valid_truth', and 'conversion_factor'. These have the same dimensions as 'train_data', and 'valid_data' but represent the underlying rates of the observations. Finally, if one needs to convert scale for plotting the true underlying firing rates, there is the 'conversion_factor' key. """ from __future__ import absolute_import from __future__ import division from __future__ import print_function import numpy as np import os import tensorflow as tf from distributions import LearnableDiagonalGaussian, DiagonalGaussianFromInput from distributions import diag_gaussian_log_likelihood from distributions import KLCost_GaussianGaussian, Poisson from distributions import LearnableAutoRegressive1Prior from distributions import KLCost_GaussianGaussianProcessSampled from utils import init_linear, linear, list_t_bxn_to_tensor_bxtxn, write_data from utils import log_sum_exp, flatten from plot_lfads import plot_lfads class GRU(object): """Gated Recurrent Unit cell (cf. http://arxiv.org/abs/1406.1078). """ def __init__(self, num_units, forget_bias=1.0, weight_scale=1.0, clip_value=np.inf, collections=None): """Create a GRU object. Args: num_units: Number of units in the GRU forget_bias (optional): Hack to help learning. weight_scale (optional): weights are scaled by ws/sqrt(#inputs), with ws being the weight scale. clip_value (optional): if the recurrent values grow above this value, clip them. collections (optional): List of additonal collections variables should belong to. """ self._num_units = num_units self._forget_bias = forget_bias self._weight_scale = weight_scale self._clip_value = clip_value self._collections = collections @property def state_size(self): return self._num_units @property def output_size(self): return self._num_units @property def state_multiplier(self): return 1 def output_from_state(self, state): """Return the output portion of the state.""" return state def __call__(self, inputs, state, scope=None): """Gated recurrent unit (GRU) function. Args: inputs: A 2D batch x input_dim tensor of inputs. state: The previous state from the last time step. scope (optional): TF variable scope for defined GRU variables. Returns: A tuple (state, state), where state is the newly computed state at time t. It is returned twice to respect an interface that works for LSTMs. """ x = inputs h = state if inputs is not None: xh = tf.concat(axis=1, values=[x, h]) else: xh = h with tf.variable_scope(scope or type(self).__name__): # "GRU" with tf.variable_scope("Gates"): # Reset gate and update gate. # We start with bias of 1.0 to not reset and not update. r, u = tf.split(axis=1, num_or_size_splits=2, value=linear(xh, 2 * self._num_units, alpha=self._weight_scale, name="xh_2_ru", collections=self._collections)) r, u = tf.sigmoid(r), tf.sigmoid(u + self._forget_bias) with tf.variable_scope("Candidate"): xrh = tf.concat(axis=1, values=[x, r * h]) c = tf.tanh(linear(xrh, self._num_units, name="xrh_2_c", collections=self._collections)) new_h = u * h + (1 - u) * c new_h = tf.clip_by_value(new_h, -self._clip_value, self._clip_value) return new_h, new_h class GenGRU(object): """Gated Recurrent Unit cell (cf. http://arxiv.org/abs/1406.1078). This version is specialized for the generator, but isn't as fast, so we have two. Note this allows for l2 regularization on the recurrent weights, but also implicitly rescales the inputs via the 1/sqrt(input) scaling in the linear helper routine to be large magnitude, if there are fewer inputs than recurrent state. """ def __init__(self, num_units, forget_bias=1.0, input_weight_scale=1.0, rec_weight_scale=1.0, clip_value=np.inf, input_collections=None, recurrent_collections=None): """Create a GRU object. Args: num_units: Number of units in the GRU forget_bias (optional): Hack to help learning. input_weight_scale (optional): weights are scaled ws/sqrt(#inputs), with ws being the weight scale. rec_weight_scale (optional): weights are scaled ws/sqrt(#inputs), with ws being the weight scale. clip_value (optional): if the recurrent values grow above this value, clip them. input_collections (optional): List of additonal collections variables that input->rec weights should belong to. recurrent_collections (optional): List of additonal collections variables that rec->rec weights should belong to. """ self._num_units = num_units self._forget_bias = forget_bias self._input_weight_scale = input_weight_scale self._rec_weight_scale = rec_weight_scale self._clip_value = clip_value self._input_collections = input_collections self._rec_collections = recurrent_collections @property def state_size(self): return self._num_units @property def output_size(self): return self._num_units @property def state_multiplier(self): return 1 def output_from_state(self, state): """Return the output portion of the state.""" return state def __call__(self, inputs, state, scope=None): """Gated recurrent unit (GRU) function. Args: inputs: A 2D batch x input_dim tensor of inputs. state: The previous state from the last time step. scope (optional): TF variable scope for defined GRU variables. Returns: A tuple (state, state), where state is the newly computed state at time t. It is returned twice to respect an interface that works for LSTMs. """ x = inputs h = state with tf.variable_scope(scope or type(self).__name__): # "GRU" with tf.variable_scope("Gates"): # Reset gate and update gate. # We start with bias of 1.0 to not reset and not update. r_x = u_x = 0.0 if x is not None: r_x, u_x = tf.split(axis=1, num_or_size_splits=2, value=linear(x, 2 * self._num_units, alpha=self._input_weight_scale, do_bias=False, name="x_2_ru", normalized=False, collections=self._input_collections)) r_h, u_h = tf.split(axis=1, num_or_size_splits=2, value=linear(h, 2 * self._num_units, do_bias=True, alpha=self._rec_weight_scale, name="h_2_ru", collections=self._rec_collections)) r = r_x + r_h u = u_x + u_h r, u = tf.sigmoid(r), tf.sigmoid(u + self._forget_bias) with tf.variable_scope("Candidate"): c_x = 0.0 if x is not None: c_x = linear(x, self._num_units, name="x_2_c", do_bias=False, alpha=self._input_weight_scale, normalized=False, collections=self._input_collections) c_rh = linear(r*h, self._num_units, name="rh_2_c", do_bias=True, alpha=self._rec_weight_scale, collections=self._rec_collections) c = tf.tanh(c_x + c_rh) new_h = u * h + (1 - u) * c new_h = tf.clip_by_value(new_h, -self._clip_value, self._clip_value) return new_h, new_h class LFADS(object): """LFADS - Latent Factor Analysis via Dynamical Systems. LFADS is an unsupervised method to decompose time series data into various factors, such as an initial condition, a generative dynamical system, inferred inputs to that generator, and a low dimensional description of the observed data, called the factors. Additoinally, the observations have a noise model (in this case Poisson), so a denoised version of the observations is also created (e.g. underlying rates of a Poisson distribution given the observed event counts). """ def __init__(self, hps, kind="train", datasets=None): """Create an LFADS model. train - a model for training, sampling of posteriors is used posterior_sample_and_average - sample from the posterior, this is used for evaluating the expected value of the outputs of LFADS, given a specific input, by averaging over multiple samples from the approx posterior. Also used for the lower bound on the negative log-likelihood using IWAE error (Importance Weighed Auto-encoder). This is the denoising operation. prior_sample - a model for generation - sampling from priors is used Args: hps: The dictionary of hyper parameters. kind: the type of model to build (see above). datasets: a dictionary of named data_dictionaries, see top of lfads.py """ print("Building graph...") all_kinds = ['train', 'posterior_sample_and_average', 'posterior_push_mean', 'prior_sample'] assert kind in all_kinds, 'Wrong kind' if hps.feedback_factors_or_rates == "rates": assert len(hps.dataset_names) == 1, \ "Multiple datasets not supported for rate feedback." num_steps = hps.num_steps ic_dim = hps.ic_dim co_dim = hps.co_dim ext_input_dim = hps.ext_input_dim cell_class = GRU gen_cell_class = GenGRU def makelambda(v): # Used with tf.case return lambda: v # Define the data placeholder, and deal with all parts of the graph # that are dataset dependent. self.dataName = tf.placeholder(tf.string, shape=()) # The batch_size to be inferred from data, as normal. # Additionally, the data_dim will be inferred as well, allowing for a # single placeholder for all datasets, regardless of data dimension. if hps.output_dist == 'poisson': # Enforce correct dtype assert np.issubdtype( datasets[hps.dataset_names[0]]['train_data'].dtype, int), \ "Data dtype must be int for poisson output distribution" data_dtype = tf.int32 elif hps.output_dist == 'gaussian': assert np.issubdtype( datasets[hps.dataset_names[0]]['train_data'].dtype, float), \ "Data dtype must be float for gaussian output dsitribution" data_dtype = tf.float32 else: assert False, "NIY" self.dataset_ph = dataset_ph = tf.placeholder(data_dtype, [None, num_steps, None], name="data") self.train_step = tf.get_variable("global_step", [], tf.int64, tf.zeros_initializer(), trainable=False) self.hps = hps ndatasets = hps.ndatasets factors_dim = hps.factors_dim self.preds = preds = [None] * ndatasets self.fns_in_fac_Ws = fns_in_fac_Ws = [None] * ndatasets self.fns_in_fatcor_bs = fns_in_fac_bs = [None] * ndatasets self.fns_out_fac_Ws = fns_out_fac_Ws = [None] * ndatasets self.fns_out_fac_bs = fns_out_fac_bs = [None] * ndatasets self.datasetNames = dataset_names = hps.dataset_names self.ext_inputs = ext_inputs = None if len(dataset_names) == 1: # single session if 'alignment_matrix_cxf' in datasets[dataset_names[0]].keys(): used_in_factors_dim = factors_dim in_identity_if_poss = False else: used_in_factors_dim = hps.dataset_dims[dataset_names[0]] in_identity_if_poss = True else: # multisession used_in_factors_dim = factors_dim in_identity_if_poss = False for d, name in enumerate(dataset_names): data_dim = hps.dataset_dims[name] in_mat_cxf = None in_bias_1xf = None align_bias_1xc = None if datasets and 'alignment_matrix_cxf' in datasets[name].keys(): dataset = datasets[name] if hps.do_train_readin: print("Initializing trainable readin matrix with alignment matrix" \ " provided for dataset:", name) else: print("Setting non-trainable readin matrix to alignment matrix" \ " provided for dataset:", name) in_mat_cxf = dataset['alignment_matrix_cxf'].astype(np.float32) if in_mat_cxf.shape != (data_dim, factors_dim): raise ValueError("""Alignment matrix must have dimensions %d x %d (data_dim x factors_dim), but currently has %d x %d."""% (data_dim, factors_dim, in_mat_cxf.shape[0], in_mat_cxf.shape[1])) if datasets and 'alignment_bias_c' in datasets[name].keys(): dataset = datasets[name] if hps.do_train_readin: print("Initializing trainable readin bias with alignment bias " \ "provided for dataset:", name) else: print("Setting non-trainable readin bias to alignment bias " \ "provided for dataset:", name) align_bias_c = dataset['alignment_bias_c'].astype(np.float32) align_bias_1xc = np.expand_dims(align_bias_c, axis=0) if align_bias_1xc.shape[1] != data_dim: raise ValueError("""Alignment bias must have dimensions %d (data_dim), but currently has %d."""% (data_dim, in_mat_cxf.shape[0])) if in_mat_cxf is not None and align_bias_1xc is not None: # (data - alignment_bias) * W_in # data * W_in - alignment_bias * W_in # So b = -alignment_bias * W_in to accommodate PCA style offset. in_bias_1xf = -np.dot(align_bias_1xc, in_mat_cxf) if hps.do_train_readin: # only add to IO transformations collection only if we want it to be # learnable, because IO_transformations collection will be trained # when do_train_io_only collections_readin=['IO_transformations'] else: collections_readin=None in_fac_lin = init_linear(data_dim, used_in_factors_dim, do_bias=True, mat_init_value=in_mat_cxf, bias_init_value=in_bias_1xf, identity_if_possible=in_identity_if_poss, normalized=False, name="x_2_infac_"+name, collections=collections_readin, trainable=hps.do_train_readin) in_fac_W, in_fac_b = in_fac_lin fns_in_fac_Ws[d] = makelambda(in_fac_W) fns_in_fac_bs[d] = makelambda(in_fac_b) with tf.variable_scope("glm"): out_identity_if_poss = False if len(dataset_names) == 1 and \ factors_dim == hps.dataset_dims[dataset_names[0]]: out_identity_if_poss = True for d, name in enumerate(dataset_names): data_dim = hps.dataset_dims[name] in_mat_cxf = None if datasets and 'alignment_matrix_cxf' in datasets[name].keys(): dataset = datasets[name] in_mat_cxf = dataset['alignment_matrix_cxf'].astype(np.float32) if datasets and 'alignment_bias_c' in datasets[name].keys(): dataset = datasets[name] align_bias_c = dataset['alignment_bias_c'].astype(np.float32) align_bias_1xc = np.expand_dims(align_bias_c, axis=0) out_mat_fxc = None out_bias_1xc = None if in_mat_cxf is not None: out_mat_fxc = in_mat_cxf.T if align_bias_1xc is not None: out_bias_1xc = align_bias_1xc if hps.output_dist == 'poisson': out_fac_lin = init_linear(factors_dim, data_dim, do_bias=True, mat_init_value=out_mat_fxc, bias_init_value=out_bias_1xc, identity_if_possible=out_identity_if_poss, normalized=False, name="fac_2_logrates_"+name, collections=['IO_transformations']) out_fac_W, out_fac_b = out_fac_lin elif hps.output_dist == 'gaussian': out_fac_lin_mean = \ init_linear(factors_dim, data_dim, do_bias=True, mat_init_value=out_mat_fxc, bias_init_value=out_bias_1xc, normalized=False, name="fac_2_means_"+name, collections=['IO_transformations']) out_fac_W_mean, out_fac_b_mean = out_fac_lin_mean mat_init_value = np.zeros([factors_dim, data_dim]).astype(np.float32) bias_init_value = np.ones([1, data_dim]).astype(np.float32) out_fac_lin_logvar = \ init_linear(factors_dim, data_dim, do_bias=True, mat_init_value=mat_init_value, bias_init_value=bias_init_value, normalized=False, name="fac_2_logvars_"+name, collections=['IO_transformations']) out_fac_W_mean, out_fac_b_mean = out_fac_lin_mean out_fac_W_logvar, out_fac_b_logvar = out_fac_lin_logvar out_fac_W = tf.concat( axis=1, values=[out_fac_W_mean, out_fac_W_logvar]) out_fac_b = tf.concat( axis=1, values=[out_fac_b_mean, out_fac_b_logvar]) else: assert False, "NIY" preds[d] = tf.equal(tf.constant(name), self.dataName) data_dim = hps.dataset_dims[name] fns_out_fac_Ws[d] = makelambda(out_fac_W) fns_out_fac_bs[d] = makelambda(out_fac_b) pf_pairs_in_fac_Ws = zip(preds, fns_in_fac_Ws) pf_pairs_in_fac_bs = zip(preds, fns_in_fac_bs) pf_pairs_out_fac_Ws = zip(preds, fns_out_fac_Ws) pf_pairs_out_fac_bs = zip(preds, fns_out_fac_bs) this_in_fac_W = tf.case(pf_pairs_in_fac_Ws, exclusive=True) this_in_fac_b = tf.case(pf_pairs_in_fac_bs, exclusive=True) this_out_fac_W = tf.case(pf_pairs_out_fac_Ws, exclusive=True) this_out_fac_b = tf.case(pf_pairs_out_fac_bs, exclusive=True) # External inputs (not changing by dataset, by definition). if hps.ext_input_dim > 0: self.ext_input = tf.placeholder(tf.float32, [None, num_steps, ext_input_dim], name="ext_input") else: self.ext_input = None ext_input_bxtxi = self.ext_input self.keep_prob = keep_prob = tf.placeholder(tf.float32, [], "keep_prob") self.batch_size = batch_size = int(hps.batch_size) self.learning_rate = tf.Variable(float(hps.learning_rate_init), trainable=False, name="learning_rate") self.learning_rate_decay_op = self.learning_rate.assign( self.learning_rate * hps.learning_rate_decay_factor) # Dropout the data. dataset_do_bxtxd = tf.nn.dropout(tf.to_float(dataset_ph), keep_prob) if hps.ext_input_dim > 0: ext_input_do_bxtxi = tf.nn.dropout(ext_input_bxtxi, keep_prob) else: ext_input_do_bxtxi = None # ENCODERS def encode_data(dataset_bxtxd, enc_cell, name, forward_or_reverse, num_steps_to_encode): """Encode data for LFADS Args: dataset_bxtxd - the data to encode, as a 3 tensor, with dims time x batch x data dims. enc_cell: encoder cell name: name of encoder forward_or_reverse: string, encode in forward or reverse direction num_steps_to_encode: number of steps to encode, 0:num_steps_to_encode Returns: encoded data as a list with num_steps_to_encode items, in order """ if forward_or_reverse == "forward": dstr = "_fwd" time_fwd_or_rev = range(num_steps_to_encode) else: dstr = "_rev" time_fwd_or_rev = reversed(range(num_steps_to_encode)) with tf.variable_scope(name+"_enc"+dstr, reuse=False): enc_state = tf.tile( tf.Variable(tf.zeros([1, enc_cell.state_size]), name=name+"_enc_t0"+dstr), tf.stack([batch_size, 1])) enc_state.set_shape([None, enc_cell.state_size]) # tile loses shape enc_outs = [None] * num_steps_to_encode for i, t in enumerate(time_fwd_or_rev): with tf.variable_scope(name+"_enc"+dstr, reuse=True if i > 0 else None): dataset_t_bxd = dataset_bxtxd[:,t,:] in_fac_t_bxf = tf.matmul(dataset_t_bxd, this_in_fac_W) + this_in_fac_b in_fac_t_bxf.set_shape([None, used_in_factors_dim]) if ext_input_dim > 0 and not hps.inject_ext_input_to_gen: ext_input_t_bxi = ext_input_do_bxtxi[:,t,:] enc_input_t_bxfpe = tf.concat( axis=1, values=[in_fac_t_bxf, ext_input_t_bxi]) else: enc_input_t_bxfpe = in_fac_t_bxf enc_out, enc_state = enc_cell(enc_input_t_bxfpe, enc_state) enc_outs[t] = enc_out return enc_outs # Encode initial condition means and variances # ([x_T, x_T-1, ... x_0] and [x_0, x_1, ... x_T] -> g0/c0) self.ic_enc_fwd = [None] * num_steps self.ic_enc_rev = [None] * num_steps if ic_dim > 0: enc_ic_cell = cell_class(hps.ic_enc_dim, weight_scale=hps.cell_weight_scale, clip_value=hps.cell_clip_value) ic_enc_fwd = encode_data(dataset_do_bxtxd, enc_ic_cell, "ic", "forward", hps.num_steps_for_gen_ic) ic_enc_rev = encode_data(dataset_do_bxtxd, enc_ic_cell, "ic", "reverse", hps.num_steps_for_gen_ic) self.ic_enc_fwd = ic_enc_fwd self.ic_enc_rev = ic_enc_rev # Encoder control input means and variances, bi-directional encoding so: # ([x_T, x_T-1, ..., x_0] and [x_0, x_1 ... x_T] -> u_t) self.ci_enc_fwd = [None] * num_steps self.ci_enc_rev = [None] * num_steps if co_dim > 0: enc_ci_cell = cell_class(hps.ci_enc_dim, weight_scale=hps.cell_weight_scale, clip_value=hps.cell_clip_value) ci_enc_fwd = encode_data(dataset_do_bxtxd, enc_ci_cell, "ci", "forward", hps.num_steps) if hps.do_causal_controller: ci_enc_rev = None else: ci_enc_rev = encode_data(dataset_do_bxtxd, enc_ci_cell, "ci", "reverse", hps.num_steps) self.ci_enc_fwd = ci_enc_fwd self.ci_enc_rev = ci_enc_rev # STOCHASTIC LATENT VARIABLES, priors and posteriors # (initial conditions g0, and control inputs, u_t) # Note that zs represent all the stochastic latent variables. with tf.variable_scope("z", reuse=False): self.prior_zs_g0 = None self.posterior_zs_g0 = None self.g0s_val = None if ic_dim > 0: self.prior_zs_g0 = \ LearnableDiagonalGaussian(batch_size, ic_dim, name="prior_g0", mean_init=0.0, var_min=hps.ic_prior_var_min, var_init=hps.ic_prior_var_scale, var_max=hps.ic_prior_var_max) ic_enc = tf.concat(axis=1, values=[ic_enc_fwd[-1], ic_enc_rev[0]]) ic_enc = tf.nn.dropout(ic_enc, keep_prob) self.posterior_zs_g0 = \ DiagonalGaussianFromInput(ic_enc, ic_dim, "ic_enc_2_post_g0", var_min=hps.ic_post_var_min) if kind in ["train", "posterior_sample_and_average", "posterior_push_mean"]: zs_g0 = self.posterior_zs_g0 else: zs_g0 = self.prior_zs_g0 if kind in ["train", "posterior_sample_and_average", "prior_sample"]: self.g0s_val = zs_g0.sample else: self.g0s_val = zs_g0.mean # Priors for controller, 'co' for controller output self.prior_zs_co = prior_zs_co = [None] * num_steps self.posterior_zs_co = posterior_zs_co = [None] * num_steps self.zs_co = zs_co = [None] * num_steps self.prior_zs_ar_con = None if co_dim > 0: # Controller outputs autocorrelation_taus = [hps.prior_ar_atau for x in range(hps.co_dim)] noise_variances = [hps.prior_ar_nvar for x in range(hps.co_dim)] self.prior_zs_ar_con = prior_zs_ar_con = \ LearnableAutoRegressive1Prior(batch_size, hps.co_dim, autocorrelation_taus, noise_variances, hps.do_train_prior_ar_atau, hps.do_train_prior_ar_nvar, num_steps, "u_prior_ar1") # CONTROLLER -> GENERATOR -> RATES # (u(t) -> gen(t) -> factors(t) -> rates(t) -> p(x_t|z_t) ) self.controller_outputs = u_t = [None] * num_steps self.con_ics = con_state = None self.con_states = con_states = [None] * num_steps self.con_outs = con_outs = [None] * num_steps self.gen_inputs = gen_inputs = [None] * num_steps if co_dim > 0: # gen_cell_class here for l2 penalty recurrent weights # didn't split the cell_weight scale here, because I doubt it matters con_cell = gen_cell_class(hps.con_dim, input_weight_scale=hps.cell_weight_scale, rec_weight_scale=hps.cell_weight_scale, clip_value=hps.cell_clip_value, recurrent_collections=['l2_con_reg']) with tf.variable_scope("con", reuse=False): self.con_ics = tf.tile( tf.Variable(tf.zeros([1, hps.con_dim*con_cell.state_multiplier]), name="c0"), tf.stack([batch_size, 1])) self.con_ics.set_shape([None, con_cell.state_size]) # tile loses shape con_states[-1] = self.con_ics gen_cell = gen_cell_class(hps.gen_dim, input_weight_scale=hps.gen_cell_input_weight_scale, rec_weight_scale=hps.gen_cell_rec_weight_scale, clip_value=hps.cell_clip_value, recurrent_collections=['l2_gen_reg']) with tf.variable_scope("gen", reuse=False): if ic_dim == 0: self.gen_ics = tf.tile( tf.Variable(tf.zeros([1, gen_cell.state_size]), name="g0"), tf.stack([batch_size, 1])) else: self.gen_ics = linear(self.g0s_val, gen_cell.state_size, identity_if_possible=True, name="g0_2_gen_ic") self.gen_states = gen_states = [None] * num_steps self.gen_outs = gen_outs = [None] * num_steps gen_states[-1] = self.gen_ics gen_outs[-1] = gen_cell.output_from_state(gen_states[-1]) self.factors = factors = [None] * num_steps factors[-1] = linear(gen_outs[-1], factors_dim, do_bias=False, normalized=True, name="gen_2_fac") self.rates = rates = [None] * num_steps # rates[-1] is collected to potentially feed back to controller with tf.variable_scope("glm", reuse=False): if hps.output_dist == 'poisson': log_rates_t0 = tf.matmul(factors[-1], this_out_fac_W) + this_out_fac_b log_rates_t0.set_shape([None, None]) rates[-1] = tf.exp(log_rates_t0) # rate rates[-1].set_shape([None, hps.dataset_dims[hps.dataset_names[0]]]) elif hps.output_dist == 'gaussian': mean_n_logvars = tf.matmul(factors[-1],this_out_fac_W) + this_out_fac_b mean_n_logvars.set_shape([None, None]) means_t_bxd, logvars_t_bxd = tf.split(axis=1, num_or_size_splits=2, value=mean_n_logvars) rates[-1] = means_t_bxd else: assert False, "NIY" # We support multiple output distributions, for example Poisson, and also # Gaussian. In these two cases respectively, there are one and two # parameters (rates vs. mean and variance). So the output_dist_params # tensor will variable sizes via tf.concat and tf.split, along the 1st # dimension. So in the case of gaussian, for example, it'll be # batch x (D+D), where each D dims is the mean, and then variances, # respectively. For a distribution with 3 parameters, it would be # batch x (D+D+D). self.output_dist_params = dist_params = [None] * num_steps self.log_p_xgz_b = log_p_xgz_b = 0.0 # log P(x|z) for t in range(num_steps): # Controller if co_dim > 0: # Build inputs for controller tlag = t - hps.controller_input_lag if tlag < 0: con_in_f_t = tf.zeros_like(ci_enc_fwd[0]) else: con_in_f_t = ci_enc_fwd[tlag] if hps.do_causal_controller: # If controller is causal (wrt to data generation process), then it # cannot see future data. Thus, excluding ci_enc_rev[t] is obvious. # Less obvious is the need to exclude factors[t-1]. This arises # because information flows from g0 through factors to the controller # input. The g0 encoding is backwards, so we must necessarily exclude # the factors in order to keep the controller input purely from a # forward encoding (however unlikely it is that # g0->factors->controller channel might actually be used in this way). con_in_list_t = [con_in_f_t] else: tlag_rev = t + hps.controller_input_lag if tlag_rev >= num_steps: # better than zeros con_in_r_t = tf.zeros_like(ci_enc_rev[0]) else: con_in_r_t = ci_enc_rev[tlag_rev] con_in_list_t = [con_in_f_t, con_in_r_t] if hps.do_feed_factors_to_controller: if hps.feedback_factors_or_rates == "factors": con_in_list_t.append(factors[t-1]) elif hps.feedback_factors_or_rates == "rates": con_in_list_t.append(rates[t-1]) else: assert False, "NIY" con_in_t = tf.concat(axis=1, values=con_in_list_t) con_in_t = tf.nn.dropout(con_in_t, keep_prob) with tf.variable_scope("con", reuse=True if t > 0 else None): con_outs[t], con_states[t] = con_cell(con_in_t, con_states[t-1]) posterior_zs_co[t] = \ DiagonalGaussianFromInput(con_outs[t], co_dim, name="con_to_post_co") if kind == "train": u_t[t] = posterior_zs_co[t].sample elif kind == "posterior_sample_and_average": u_t[t] = posterior_zs_co[t].sample elif kind == "posterior_push_mean": u_t[t] = posterior_zs_co[t].mean else: u_t[t] = prior_zs_ar_con.samples_t[t] # Inputs to the generator (controller output + external input) if ext_input_dim > 0 and hps.inject_ext_input_to_gen: ext_input_t_bxi = ext_input_do_bxtxi[:,t,:] if co_dim > 0: gen_inputs[t] = tf.concat(axis=1, values=[u_t[t], ext_input_t_bxi]) else: gen_inputs[t] = ext_input_t_bxi else: gen_inputs[t] = u_t[t] # Generator data_t_bxd = dataset_ph[:,t,:] with tf.variable_scope("gen", reuse=True if t > 0 else None): gen_outs[t], gen_states[t] = gen_cell(gen_inputs[t], gen_states[t-1]) gen_outs[t] = tf.nn.dropout(gen_outs[t], keep_prob) with tf.variable_scope("gen", reuse=True): # ic defined it above factors[t] = linear(gen_outs[t], factors_dim, do_bias=False, normalized=True, name="gen_2_fac") with tf.variable_scope("glm", reuse=True if t > 0 else None): if hps.output_dist == 'poisson': log_rates_t = tf.matmul(factors[t], this_out_fac_W) + this_out_fac_b log_rates_t.set_shape([None, None]) rates[t] = dist_params[t] = tf.exp(tf.clip_by_value(log_rates_t, -hps._clip_value, hps._clip_value)) # rates feed back rates[t].set_shape([None, hps.dataset_dims[hps.dataset_names[0]]]) loglikelihood_t = Poisson(log_rates_t).logp(data_t_bxd) elif hps.output_dist == 'gaussian': mean_n_logvars = tf.matmul(factors[t],this_out_fac_W) + this_out_fac_b mean_n_logvars.set_shape([None, None]) means_t_bxd, logvars_t_bxd = tf.split(axis=1, num_or_size_splits=2, value=mean_n_logvars) rates[t] = means_t_bxd # rates feed back to controller dist_params[t] = tf.concat( axis=1, values=[means_t_bxd, tf.exp(tf.clip_by_value(logvars_t_bxd, -hps._clip_value, hps._clip_value))]) loglikelihood_t = \ diag_gaussian_log_likelihood(data_t_bxd, means_t_bxd, logvars_t_bxd) else: assert False, "NIY" log_p_xgz_b += tf.reduce_sum(loglikelihood_t, [1]) # Correlation of inferred inputs cost. self.corr_cost = tf.constant(0.0) if hps.co_mean_corr_scale > 0.0: all_sum_corr = [] for i in range(hps.co_dim): for j in range(i+1, hps.co_dim): sum_corr_ij = tf.constant(0.0) for t in range(num_steps): u_mean_t = posterior_zs_co[t].mean sum_corr_ij += u_mean_t[:,i]*u_mean_t[:,j] all_sum_corr.append(0.5 * tf.square(sum_corr_ij)) self.corr_cost = tf.reduce_mean(all_sum_corr) # div by batch and by n*(n-1)/2 pairs # Variational Lower Bound on posterior, p(z|x), plus reconstruction cost. # KL and reconstruction costs are normalized only by batch size, not by # dimension, or by time steps. kl_cost_g0_b = tf.zeros_like(batch_size, dtype=tf.float32) kl_cost_co_b = tf.zeros_like(batch_size, dtype=tf.float32) self.kl_cost = tf.constant(0.0) # VAE KL cost self.recon_cost = tf.constant(0.0) # VAE reconstruction cost self.nll_bound_vae = tf.constant(0.0) self.nll_bound_iwae = tf.constant(0.0) # for eval with IWAE cost. if kind in ["train", "posterior_sample_and_average", "posterior_push_mean"]: kl_cost_g0_b = 0.0 kl_cost_co_b = 0.0 if ic_dim > 0: g0_priors = [self.prior_zs_g0] g0_posts = [self.posterior_zs_g0] kl_cost_g0_b = KLCost_GaussianGaussian(g0_posts, g0_priors).kl_cost_b kl_cost_g0_b = hps.kl_ic_weight * kl_cost_g0_b if co_dim > 0: kl_cost_co_b = \ KLCost_GaussianGaussianProcessSampled( posterior_zs_co, prior_zs_ar_con).kl_cost_b kl_cost_co_b = hps.kl_co_weight * kl_cost_co_b # L = -KL + log p(x|z), to maximize bound on likelihood # -L = KL - log p(x|z), to minimize bound on NLL # so 'reconstruction cost' is negative log likelihood self.recon_cost = - tf.reduce_mean(log_p_xgz_b) self.kl_cost = tf.reduce_mean(kl_cost_g0_b + kl_cost_co_b) lb_on_ll_b = log_p_xgz_b - kl_cost_g0_b - kl_cost_co_b # VAE error averages outside the log self.nll_bound_vae = -tf.reduce_mean(lb_on_ll_b) # IWAE error averages inside the log k = tf.cast(tf.shape(log_p_xgz_b)[0], tf.float32) iwae_lb_on_ll = -tf.log(k) + log_sum_exp(lb_on_ll_b) self.nll_bound_iwae = -iwae_lb_on_ll # L2 regularization on the generator, normalized by number of parameters. self.l2_cost = tf.constant(0.0) if self.hps.l2_gen_scale > 0.0 or self.hps.l2_con_scale > 0.0: l2_costs = [] l2_numels = [] l2_reg_var_lists = [tf.get_collection('l2_gen_reg'), tf.get_collection('l2_con_reg')] l2_reg_scales = [self.hps.l2_gen_scale, self.hps.l2_con_scale] for l2_reg_vars, l2_scale in zip(l2_reg_var_lists, l2_reg_scales): for v in l2_reg_vars: numel = tf.reduce_prod(tf.concat(axis=0, values=tf.shape(v))) numel_f = tf.cast(numel, tf.float32) l2_numels.append(numel_f) v_l2 = tf.reduce_sum(v*v) l2_costs.append(0.5 * l2_scale * v_l2) self.l2_cost = tf.add_n(l2_costs) / tf.add_n(l2_numels) # Compute the cost for training, part of the graph regardless. # The KL cost can be problematic at the beginning of optimization, # so we allow an exponential increase in weighting the KL from 0 # to 1. self.kl_decay_step = tf.maximum(self.train_step - hps.kl_start_step, 0) self.l2_decay_step = tf.maximum(self.train_step - hps.l2_start_step, 0) kl_decay_step_f = tf.cast(self.kl_decay_step, tf.float32) l2_decay_step_f = tf.cast(self.l2_decay_step, tf.float32) kl_increase_steps_f = tf.cast(hps.kl_increase_steps, tf.float32) l2_increase_steps_f = tf.cast(hps.l2_increase_steps, tf.float32) self.kl_weight = kl_weight = \ tf.minimum(kl_decay_step_f / kl_increase_steps_f, 1.0) self.l2_weight = l2_weight = \ tf.minimum(l2_decay_step_f / l2_increase_steps_f, 1.0) self.timed_kl_cost = kl_weight * self.kl_cost self.timed_l2_cost = l2_weight * self.l2_cost self.weight_corr_cost = hps.co_mean_corr_scale * self.corr_cost self.cost = self.recon_cost + self.timed_kl_cost + \ self.timed_l2_cost + self.weight_corr_cost if kind != "train": # save every so often self.seso_saver = tf.train.Saver(tf.global_variables(), max_to_keep=hps.max_ckpt_to_keep) # lowest validation error self.lve_saver = tf.train.Saver(tf.global_variables(), max_to_keep=hps.max_ckpt_to_keep_lve) return # OPTIMIZATION # train the io matrices only if self.hps.do_train_io_only: self.train_vars = tvars = \ tf.get_collection('IO_transformations', scope=tf.get_variable_scope().name) # train the encoder only elif self.hps.do_train_encoder_only: tvars1 = \ tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='LFADS/ic_enc_*') tvars2 = \ tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='LFADS/z/ic_enc_*') self.train_vars = tvars = tvars1 + tvars2 # train all variables else: self.train_vars = tvars = \ tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=tf.get_variable_scope().name) print("done.") print("Model Variables (to be optimized): ") total_params = 0 for i in range(len(tvars)): shape = tvars[i].get_shape().as_list() print(" ", i, tvars[i].name, shape) total_params += np.prod(shape) print("Total model parameters: ", total_params) grads = tf.gradients(self.cost, tvars) grads, grad_global_norm = tf.clip_by_global_norm(grads, hps.max_grad_norm) opt = tf.train.AdamOptimizer(self.learning_rate, beta1=0.9, beta2=0.999, epsilon=1e-01) self.grads = grads self.grad_global_norm = grad_global_norm self.train_op = opt.apply_gradients( zip(grads, tvars), global_step=self.train_step) self.seso_saver = tf.train.Saver(tf.global_variables(), max_to_keep=hps.max_ckpt_to_keep) # lowest validation error self.lve_saver = tf.train.Saver(tf.global_variables(), max_to_keep=hps.max_ckpt_to_keep) # SUMMARIES, used only during training. # example summary self.example_image = tf.placeholder(tf.float32, shape=[1,None,None,3], name='image_tensor') self.example_summ = tf.summary.image("LFADS example", self.example_image, collections=["example_summaries"]) # general training summaries self.lr_summ = tf.summary.scalar("Learning rate", self.learning_rate) self.kl_weight_summ = tf.summary.scalar("KL weight", self.kl_weight) self.l2_weight_summ = tf.summary.scalar("L2 weight", self.l2_weight) self.corr_cost_summ = tf.summary.scalar("Corr cost", self.weight_corr_cost) self.grad_global_norm_summ = tf.summary.scalar("Gradient global norm", self.grad_global_norm) if hps.co_dim > 0: self.atau_summ = [None] * hps.co_dim self.pvar_summ = [None] * hps.co_dim for c in range(hps.co_dim): self.atau_summ[c] = \ tf.summary.scalar("AR Autocorrelation taus " + str(c), tf.exp(self.prior_zs_ar_con.logataus_1xu[0,c])) self.pvar_summ[c] = \ tf.summary.scalar("AR Variances " + str(c), tf.exp(self.prior_zs_ar_con.logpvars_1xu[0,c])) # cost summaries, separated into different collections for # training vs validation. We make placeholders for these, because # even though the graph computes these costs on a per-batch basis, # we want to report the more reliable metric of per-epoch cost. kl_cost_ph = tf.placeholder(tf.float32, shape=[], name='kl_cost_ph') self.kl_t_cost_summ = tf.summary.scalar("KL cost (train)", kl_cost_ph, collections=["train_summaries"]) self.kl_v_cost_summ = tf.summary.scalar("KL cost (valid)", kl_cost_ph, collections=["valid_summaries"]) l2_cost_ph = tf.placeholder(tf.float32, shape=[], name='l2_cost_ph') self.l2_cost_summ = tf.summary.scalar("L2 cost", l2_cost_ph, collections=["train_summaries"]) recon_cost_ph = tf.placeholder(tf.float32, shape=[], name='recon_cost_ph') self.recon_t_cost_summ = tf.summary.scalar("Reconstruction cost (train)", recon_cost_ph, collections=["train_summaries"]) self.recon_v_cost_summ = tf.summary.scalar("Reconstruction cost (valid)", recon_cost_ph, collections=["valid_summaries"]) total_cost_ph = tf.placeholder(tf.float32, shape=[], name='total_cost_ph') self.cost_t_summ = tf.summary.scalar("Total cost (train)", total_cost_ph, collections=["train_summaries"]) self.cost_v_summ = tf.summary.scalar("Total cost (valid)", total_cost_ph, collections=["valid_summaries"]) self.kl_cost_ph = kl_cost_ph self.l2_cost_ph = l2_cost_ph self.recon_cost_ph = recon_cost_ph self.total_cost_ph = total_cost_ph # Merged summaries, for easy coding later. self.merged_examples = tf.summary.merge_all(key="example_summaries") self.merged_generic = tf.summary.merge_all() # default key is 'summaries' self.merged_train = tf.summary.merge_all(key="train_summaries") self.merged_valid = tf.summary.merge_all(key="valid_summaries") session = tf.get_default_session() self.logfile = os.path.join(hps.lfads_save_dir, "lfads_log") self.writer = tf.summary.FileWriter(self.logfile) def build_feed_dict(self, train_name, data_bxtxd, ext_input_bxtxi=None, keep_prob=None): """Build the feed dictionary, handles cases where there is no value defined. Args: train_name: The key into the datasets, to set the tf.case statement for the proper readin / readout matrices. data_bxtxd: The data tensor ext_input_bxtxi (optional): The external input tensor keep_prob: The drop out keep probability. Returns: The feed dictionary with TF tensors as keys and data as values, for use with tf.Session.run() """ feed_dict = {} B, T, _ = data_bxtxd.shape feed_dict[self.dataName] = train_name feed_dict[self.dataset_ph] = data_bxtxd if self.ext_input is not None and ext_input_bxtxi is not None: feed_dict[self.ext_input] = ext_input_bxtxi if keep_prob is None: feed_dict[self.keep_prob] = self.hps.keep_prob else: feed_dict[self.keep_prob] = keep_prob return feed_dict @staticmethod def get_batch(data_extxd, ext_input_extxi=None, batch_size=None, example_idxs=None): """Get a batch of data, either randomly chosen, or specified directly. Args: data_extxd: The data to model, numpy tensors with shape: # examples x # time steps x # dimensions ext_input_extxi (optional): The external inputs, numpy tensor with shape: # examples x # time steps x # external input dimensions batch_size: The size of the batch to return example_idxs (optional): The example indices used to select examples. Returns: A tuple with two parts: 1. Batched data numpy tensor with shape: batch_size x # time steps x # dimensions 2. Batched external input numpy tensor with shape: batch_size x # time steps x # external input dims """ assert batch_size is not None or example_idxs is not None, "Problems" E, T, D = data_extxd.shape if example_idxs is None: example_idxs = np.random.choice(E, batch_size) ext_input_bxtxi = None if ext_input_extxi is not None: ext_input_bxtxi = ext_input_extxi[example_idxs,:,:] return data_extxd[example_idxs,:,:], ext_input_bxtxi @staticmethod def example_idxs_mod_batch_size(nexamples, batch_size): """Given a number of examples, E, and a batch_size, B, generate indices [0, 1, 2, ... B-1; [B, B+1, ... 2*B-1; ... ] returning those indices as a 2-dim tensor shaped like E/B x B. Note that shape is only correct if E % B == 0. If not, then an extra row is generated so that the remainder of examples is included. The extra examples are explicitly to to the zero index (see randomize_example_idxs_mod_batch_size) for randomized behavior. Args: nexamples: The number of examples to batch up. batch_size: The size of the batch. Returns: 2-dim tensor as described above. """ bmrem = batch_size - (nexamples % batch_size) bmrem_examples = [] if bmrem < batch_size: #bmrem_examples = np.zeros(bmrem, dtype=np.int32) ridxs = np.random.permutation(nexamples)[0:bmrem].astype(np.int32) bmrem_examples = np.sort(ridxs) example_idxs = range(nexamples) + list(bmrem_examples) example_idxs_e_x_edivb = np.reshape(example_idxs, [-1, batch_size]) return example_idxs_e_x_edivb, bmrem @staticmethod def randomize_example_idxs_mod_batch_size(nexamples, batch_size): """Indices 1:nexamples, randomized, in 2D form of shape = (nexamples / batch_size) x batch_size. The remainder is managed by drawing randomly from 1:nexamples. Args: nexamples: number of examples to randomize batch_size: number of elements in batch Returns: The randomized, properly shaped indicies. """ assert nexamples > batch_size, "Problems" bmrem = batch_size - nexamples % batch_size bmrem_examples = [] if bmrem < batch_size: bmrem_examples = np.random.choice(range(nexamples), size=bmrem, replace=False) example_idxs = range(nexamples) + list(bmrem_examples) mixed_example_idxs = np.random.permutation(example_idxs) example_idxs_e_x_edivb = np.reshape(mixed_example_idxs, [-1, batch_size]) return example_idxs_e_x_edivb, bmrem def shuffle_spikes_in_time(self, data_bxtxd): """Shuffle the spikes in the temporal dimension. This is useful to help the LFADS system avoid overfitting to individual spikes or fast oscillations found in the data that are irrelevant to behavior. A pure 'tabula rasa' approach would avoid this, but LFADS is sensitive enough to pick up dynamics that you may not want. Args: data_bxtxd: numpy array of spike count data to be shuffled. Returns: S_bxtxd, a numpy array with the same dimensions and contents as data_bxtxd, but shuffled appropriately. """ B, T, N = data_bxtxd.shape w = self.hps.temporal_spike_jitter_width if w == 0: return data_bxtxd max_counts = np.max(data_bxtxd) S_bxtxd = np.zeros([B,T,N]) # Intuitively, shuffle spike occurances, 0 or 1, but since we have counts, # Do it over and over again up to the max count. for mc in range(1,max_counts+1): idxs = np.nonzero(data_bxtxd >= mc) data_ones = np.zeros_like(data_bxtxd) data_ones[data_bxtxd >= mc] = 1 nfound = len(idxs[0]) shuffles_incrs_in_time = np.random.randint(-w, w, size=nfound) shuffle_tidxs = idxs[1].copy() shuffle_tidxs += shuffles_incrs_in_time # Reflect on the boundaries to not lose mass. shuffle_tidxs[shuffle_tidxs < 0] = -shuffle_tidxs[shuffle_tidxs < 0] shuffle_tidxs[shuffle_tidxs > T-1] = \ (T-1)-(shuffle_tidxs[shuffle_tidxs > T-1] -(T-1)) for iii in zip(idxs[0], shuffle_tidxs, idxs[2]): S_bxtxd[iii] += 1 return S_bxtxd def shuffle_and_flatten_datasets(self, datasets, kind='train'): """Since LFADS supports multiple datasets in the same dynamical model, we have to be careful to use all the data in a single training epoch. But since the datasets my have different data dimensionality, we cannot batch examples from data dictionaries together. Instead, we generate random batches within each data dictionary, and then randomize these batches while holding onto the dataname, so that when it's time to feed the graph, the correct in/out matrices can be selected, per batch. Args: datasets: A dict of data dicts. The dataset dict is simply a name(string)-> data dictionary mapping (See top of lfads.py). kind: 'train' or 'valid' Returns: A flat list, in which each element is a pair ('name', indices). """ batch_size = self.hps.batch_size ndatasets = len(datasets) random_example_idxs = {} epoch_idxs = {} all_name_example_idx_pairs = [] kind_data = kind + '_data' for name, data_dict in datasets.items(): nexamples, ntime, data_dim = data_dict[kind_data].shape epoch_idxs[name] = 0 random_example_idxs, _ = \ self.randomize_example_idxs_mod_batch_size(nexamples, batch_size) epoch_size = random_example_idxs.shape[0] names = [name] * epoch_size all_name_example_idx_pairs += zip(names, random_example_idxs) np.random.shuffle(all_name_example_idx_pairs) # shuffle in place return all_name_example_idx_pairs def train_epoch(self, datasets, batch_size=None, do_save_ckpt=True): """Train the model through the entire dataset once. Args: datasets: A dict of data dicts. The dataset dict is simply a name(string)-> data dictionary mapping (See top of lfads.py). batch_size (optional): The batch_size to use do_save_ckpt (optional): Should the routine save a checkpoint on this training epoch? Returns: A tuple with 6 float values: (total cost of the epoch, epoch reconstruction cost, epoch kl cost, KL weight used this training epoch, total l2 cost on generator, and the corresponding weight). """ ops_to_eval = [self.cost, self.recon_cost, self.kl_cost, self.kl_weight, self.l2_cost, self.l2_weight, self.train_op] collected_op_values = self.run_epoch(datasets, ops_to_eval, kind="train") total_cost = total_recon_cost = total_kl_cost = 0.0 # normalizing by batch done in distributions.py epoch_size = len(collected_op_values) for op_values in collected_op_values: total_cost += op_values[0] total_recon_cost += op_values[1] total_kl_cost += op_values[2] kl_weight = collected_op_values[-1][3] l2_cost = collected_op_values[-1][4] l2_weight = collected_op_values[-1][5] epoch_total_cost = total_cost / epoch_size epoch_recon_cost = total_recon_cost / epoch_size epoch_kl_cost = total_kl_cost / epoch_size if do_save_ckpt: session = tf.get_default_session() checkpoint_path = os.path.join(self.hps.lfads_save_dir, self.hps.checkpoint_name + '.ckpt') self.seso_saver.save(session, checkpoint_path, global_step=self.train_step) return epoch_total_cost, epoch_recon_cost, epoch_kl_cost, \ kl_weight, l2_cost, l2_weight def run_epoch(self, datasets, ops_to_eval, kind="train", batch_size=None, do_collect=True, keep_prob=None): """Run the model through the entire dataset once. Args: datasets: A dict of data dicts. The dataset dict is simply a name(string)-> data dictionary mapping (See top of lfads.py). ops_to_eval: A list of tensorflow operations that will be evaluated in the tf.session.run() call. batch_size (optional): The batch_size to use do_collect (optional): Should the routine collect all session.run output as a list, and return it? keep_prob (optional): The dropout keep probability. Returns: A list of lists, the internal list is the return for the ops for each session.run() call. The outer list collects over the epoch. """ hps = self.hps all_name_example_idx_pairs = \ self.shuffle_and_flatten_datasets(datasets, kind) kind_data = kind + '_data' kind_ext_input = kind + '_ext_input' total_cost = total_recon_cost = total_kl_cost = 0.0 session = tf.get_default_session() epoch_size = len(all_name_example_idx_pairs) evaled_ops_list = [] for name, example_idxs in all_name_example_idx_pairs: data_dict = datasets[name] data_extxd = data_dict[kind_data] if hps.output_dist == 'poisson' and hps.temporal_spike_jitter_width > 0: data_extxd = self.shuffle_spikes_in_time(data_extxd) ext_input_extxi = data_dict[kind_ext_input] data_bxtxd, ext_input_bxtxi = self.get_batch(data_extxd, ext_input_extxi, example_idxs=example_idxs) feed_dict = self.build_feed_dict(name, data_bxtxd, ext_input_bxtxi, keep_prob=keep_prob) evaled_ops_np = session.run(ops_to_eval, feed_dict=feed_dict) if do_collect: evaled_ops_list.append(evaled_ops_np) return evaled_ops_list def summarize_all(self, datasets, summary_values): """Plot and summarize stuff in tensorboard. Note that everything done in the current function is otherwise done on a single, randomly selected dataset (except for summary_values, which are passed in.) Args: datasets, the dictionary of datasets used in the study. summary_values: These summary values are created from the training loop, and so summarize the entire set of datasets. """ hps = self.hps tr_kl_cost = summary_values['tr_kl_cost'] tr_recon_cost = summary_values['tr_recon_cost'] tr_total_cost = summary_values['tr_total_cost'] kl_weight = summary_values['kl_weight'] l2_weight = summary_values['l2_weight'] l2_cost = summary_values['l2_cost'] has_any_valid_set = summary_values['has_any_valid_set'] i = summary_values['nepochs'] session = tf.get_default_session() train_summ, train_step = session.run([self.merged_train, self.train_step], feed_dict={self.l2_cost_ph:l2_cost, self.kl_cost_ph:tr_kl_cost, self.recon_cost_ph:tr_recon_cost, self.total_cost_ph:tr_total_cost}) self.writer.add_summary(train_summ, train_step) if has_any_valid_set: ev_kl_cost = summary_values['ev_kl_cost'] ev_recon_cost = summary_values['ev_recon_cost'] ev_total_cost = summary_values['ev_total_cost'] eval_summ = session.run(self.merged_valid, feed_dict={self.kl_cost_ph:ev_kl_cost, self.recon_cost_ph:ev_recon_cost, self.total_cost_ph:ev_total_cost}) self.writer.add_summary(eval_summ, train_step) print("Epoch:%d, step:%d (TRAIN, VALID): total: %.2f, %.2f\ recon: %.2f, %.2f, kl: %.2f, %.2f, l2: %.5f,\ kl weight: %.2f, l2 weight: %.2f" % \ (i, train_step, tr_total_cost, ev_total_cost, tr_recon_cost, ev_recon_cost, tr_kl_cost, ev_kl_cost, l2_cost, kl_weight, l2_weight)) csv_outstr = "epoch,%d, step,%d, total,%.2f,%.2f, \ recon,%.2f,%.2f, kl,%.2f,%.2f, l2,%.5f, \ klweight,%.2f, l2weight,%.2f\n"% \ (i, train_step, tr_total_cost, ev_total_cost, tr_recon_cost, ev_recon_cost, tr_kl_cost, ev_kl_cost, l2_cost, kl_weight, l2_weight) else: print("Epoch:%d, step:%d TRAIN: total: %.2f recon: %.2f, kl: %.2f,\ l2: %.5f, kl weight: %.2f, l2 weight: %.2f" % \ (i, train_step, tr_total_cost, tr_recon_cost, tr_kl_cost, l2_cost, kl_weight, l2_weight)) csv_outstr = "epoch,%d, step,%d, total,%.2f, recon,%.2f, kl,%.2f, \ l2,%.5f, klweight,%.2f, l2weight,%.2f\n"% \ (i, train_step, tr_total_cost, tr_recon_cost, tr_kl_cost, l2_cost, kl_weight, l2_weight) if self.hps.csv_log: csv_file = os.path.join(self.hps.lfads_save_dir, self.hps.csv_log+'.csv') with open(csv_file, "a") as myfile: myfile.write(csv_outstr) def plot_single_example(self, datasets): """Plot an image relating to a randomly chosen, specific example. We use posterior sample and average by taking one example, and filling a whole batch with that example, sample from the posterior, and then average the quantities. """ hps = self.hps all_data_names = datasets.keys() data_name = np.random.permutation(all_data_names)[0] data_dict = datasets[data_name] has_valid_set = True if data_dict['valid_data'] is not None else False cf = 1.0 # plotting concern # posterior sample and average here E, _, _ = data_dict['train_data'].shape eidx = np.random.choice(E) example_idxs = eidx * np.ones(hps.batch_size, dtype=np.int32) train_data_bxtxd, train_ext_input_bxtxi = \ self.get_batch(data_dict['train_data'], data_dict['train_ext_input'], example_idxs=example_idxs) truth_train_data_bxtxd = None if 'train_truth' in data_dict and data_dict['train_truth'] is not None: truth_train_data_bxtxd, _ = self.get_batch(data_dict['train_truth'], example_idxs=example_idxs) cf = data_dict['conversion_factor'] # plotter does averaging train_model_values = self.eval_model_runs_batch(data_name, train_data_bxtxd, train_ext_input_bxtxi, do_average_batch=False) train_step = train_model_values['train_steps'] feed_dict = self.build_feed_dict(data_name, train_data_bxtxd, train_ext_input_bxtxi, keep_prob=1.0) session = tf.get_default_session() generic_summ = session.run(self.merged_generic, feed_dict=feed_dict) self.writer.add_summary(generic_summ, train_step) valid_data_bxtxd = valid_model_values = valid_ext_input_bxtxi = None truth_valid_data_bxtxd = None if has_valid_set: E, _, _ = data_dict['valid_data'].shape eidx = np.random.choice(E) example_idxs = eidx * np.ones(hps.batch_size, dtype=np.int32) valid_data_bxtxd, valid_ext_input_bxtxi = \ self.get_batch(data_dict['valid_data'], data_dict['valid_ext_input'], example_idxs=example_idxs) if 'valid_truth' in data_dict and data_dict['valid_truth'] is not None: truth_valid_data_bxtxd, _ = self.get_batch(data_dict['valid_truth'], example_idxs=example_idxs) else: truth_valid_data_bxtxd = None # plotter does averaging valid_model_values = self.eval_model_runs_batch(data_name, valid_data_bxtxd, valid_ext_input_bxtxi, do_average_batch=False) example_image = plot_lfads(train_bxtxd=train_data_bxtxd, train_model_vals=train_model_values, train_ext_input_bxtxi=train_ext_input_bxtxi, train_truth_bxtxd=truth_train_data_bxtxd, valid_bxtxd=valid_data_bxtxd, valid_model_vals=valid_model_values, valid_ext_input_bxtxi=valid_ext_input_bxtxi, valid_truth_bxtxd=truth_valid_data_bxtxd, bidx=None, cf=cf, output_dist=hps.output_dist) example_image = np.expand_dims(example_image, axis=0) example_summ = session.run(self.merged_examples, feed_dict={self.example_image : example_image}) self.writer.add_summary(example_summ) def train_model(self, datasets): """Train the model, print per-epoch information, and save checkpoints. Loop over training epochs. The function that actually does the training is train_epoch. This function iterates over the training data, one epoch at a time. The learning rate schedule is such that it will stay the same until the cost goes up in comparison to the last few values, then it will drop. Args: datasets: A dict of data dicts. The dataset dict is simply a name(string)-> data dictionary mapping (See top of lfads.py). """ hps = self.hps has_any_valid_set = False for data_dict in datasets.values(): if data_dict['valid_data'] is not None: has_any_valid_set = True break session = tf.get_default_session() lr = session.run(self.learning_rate) lr_stop = hps.learning_rate_stop i = -1 train_costs = [] valid_costs = [] ev_total_cost = ev_recon_cost = ev_kl_cost = 0.0 lowest_ev_cost = np.Inf while True: i += 1 do_save_ckpt = True if i % 10 ==0 else False tr_total_cost, tr_recon_cost, tr_kl_cost, kl_weight, l2_cost, l2_weight = \ self.train_epoch(datasets, do_save_ckpt=do_save_ckpt) # Evaluate the validation cost, and potentially save. Note that this # routine will not save a validation checkpoint until the kl weight and # l2 weights are equal to 1.0. if has_any_valid_set: ev_total_cost, ev_recon_cost, ev_kl_cost = \ self.eval_cost_epoch(datasets, kind='valid') valid_costs.append(ev_total_cost) # > 1 may give more consistent results, but not the actual lowest vae. # == 1 gives the lowest vae seen so far. n_lve = 1 run_avg_lve = np.mean(valid_costs[-n_lve:]) # conditions for saving checkpoints: # KL weight must have finished stepping (>=1.0), AND # L2 weight must have finished stepping OR L2 is not being used, AND # the current run has a lower LVE than previous runs AND # len(valid_costs > n_lve) (not sure what that does) if kl_weight >= 1.0 and \ (l2_weight >= 1.0 or \ (self.hps.l2_gen_scale == 0.0 and self.hps.l2_con_scale == 0.0)) \ and (len(valid_costs) > n_lve and run_avg_lve < lowest_ev_cost): lowest_ev_cost = run_avg_lve checkpoint_path = os.path.join(self.hps.lfads_save_dir, self.hps.checkpoint_name + '_lve.ckpt') self.lve_saver.save(session, checkpoint_path, global_step=self.train_step, latest_filename='checkpoint_lve') # Plot and summarize. values = {'nepochs':i, 'has_any_valid_set': has_any_valid_set, 'tr_total_cost':tr_total_cost, 'ev_total_cost':ev_total_cost, 'tr_recon_cost':tr_recon_cost, 'ev_recon_cost':ev_recon_cost, 'tr_kl_cost':tr_kl_cost, 'ev_kl_cost':ev_kl_cost, 'l2_weight':l2_weight, 'kl_weight':kl_weight, 'l2_cost':l2_cost} self.summarize_all(datasets, values) self.plot_single_example(datasets) # Manage learning rate. train_res = tr_total_cost n_lr = hps.learning_rate_n_to_compare if len(train_costs) > n_lr and train_res > np.max(train_costs[-n_lr:]): _ = session.run(self.learning_rate_decay_op) lr = session.run(self.learning_rate) print(" Decreasing learning rate to %f." % lr) # Force the system to run n_lr times while at this lr. train_costs.append(np.inf) else: train_costs.append(train_res) if lr < lr_stop: print("Stopping optimization based on learning rate criteria.") break def eval_cost_epoch(self, datasets, kind='train', ext_input_extxi=None, batch_size=None): """Evaluate the cost of the epoch. Args: data_dict: The dictionary of data (training and validation) used for training and evaluation of the model, respectively. Returns: a 3 tuple of costs: (epoch total cost, epoch reconstruction cost, epoch KL cost) """ ops_to_eval = [self.cost, self.recon_cost, self.kl_cost] collected_op_values = self.run_epoch(datasets, ops_to_eval, kind=kind, keep_prob=1.0) total_cost = total_recon_cost = total_kl_cost = 0.0 # normalizing by batch done in distributions.py epoch_size = len(collected_op_values) for op_values in collected_op_values: total_cost += op_values[0] total_recon_cost += op_values[1] total_kl_cost += op_values[2] epoch_total_cost = total_cost / epoch_size epoch_recon_cost = total_recon_cost / epoch_size epoch_kl_cost = total_kl_cost / epoch_size return epoch_total_cost, epoch_recon_cost, epoch_kl_cost def eval_model_runs_batch(self, data_name, data_bxtxd, ext_input_bxtxi=None, do_eval_cost=False, do_average_batch=False): """Returns all the goodies for the entire model, per batch. If data_bxtxd and ext_input_bxtxi can have fewer than batch_size along dim 1 in which case this handles the padding and truncating automatically Args: data_name: The name of the data dict, to select which in/out matrices to use. data_bxtxd: Numpy array training data with shape: batch_size x # time steps x # dimensions ext_input_bxtxi: Numpy array training external input with shape: batch_size x # time steps x # external input dims do_eval_cost (optional): If true, the IWAE (Importance Weighted Autoencoder) log likeihood bound, instead of the VAE version. do_average_batch (optional): average over the batch, useful for getting good IWAE costs, and model outputs for a single data point. Returns: A dictionary with the outputs of the model decoder, namely: prior g0 mean, prior g0 variance, approx. posterior mean, approx posterior mean, the generator initial conditions, the control inputs (if enabled), the state of the generator, the factors, and the rates. """ session = tf.get_default_session() # if fewer than batch_size provided, pad to batch_size hps = self.hps batch_size = hps.batch_size E, _, _ = data_bxtxd.shape if E < hps.batch_size: data_bxtxd = np.pad(data_bxtxd, ((0, hps.batch_size-E), (0, 0), (0, 0)), mode='constant', constant_values=0) if ext_input_bxtxi is not None: ext_input_bxtxi = np.pad(ext_input_bxtxi, ((0, hps.batch_size-E), (0, 0), (0, 0)), mode='constant', constant_values=0) feed_dict = self.build_feed_dict(data_name, data_bxtxd, ext_input_bxtxi, keep_prob=1.0) # Non-temporal signals will be batch x dim. # Temporal signals are list length T with elements batch x dim. tf_vals = [self.gen_ics, self.gen_states, self.factors, self.output_dist_params] tf_vals.append(self.cost) tf_vals.append(self.nll_bound_vae) tf_vals.append(self.nll_bound_iwae) tf_vals.append(self.train_step) # not train_op! if self.hps.ic_dim > 0: tf_vals += [self.prior_zs_g0.mean, self.prior_zs_g0.logvar, self.posterior_zs_g0.mean, self.posterior_zs_g0.logvar] if self.hps.co_dim > 0: tf_vals.append(self.controller_outputs) tf_vals_flat, fidxs = flatten(tf_vals) np_vals_flat = session.run(tf_vals_flat, feed_dict=feed_dict) ff = 0 gen_ics = [np_vals_flat[f] for f in fidxs[ff]]; ff += 1 gen_states = [np_vals_flat[f] for f in fidxs[ff]]; ff += 1 factors = [np_vals_flat[f] for f in fidxs[ff]]; ff += 1 out_dist_params = [np_vals_flat[f] for f in fidxs[ff]]; ff += 1 costs = [np_vals_flat[f] for f in fidxs[ff]]; ff += 1 nll_bound_vaes = [np_vals_flat[f] for f in fidxs[ff]]; ff += 1 nll_bound_iwaes = [np_vals_flat[f] for f in fidxs[ff]]; ff +=1 train_steps = [np_vals_flat[f] for f in fidxs[ff]]; ff +=1 if self.hps.ic_dim > 0: prior_g0_mean = [np_vals_flat[f] for f in fidxs[ff]]; ff +=1 prior_g0_logvar = [np_vals_flat[f] for f in fidxs[ff]]; ff += 1 post_g0_mean = [np_vals_flat[f] for f in fidxs[ff]]; ff += 1 post_g0_logvar = [np_vals_flat[f] for f in fidxs[ff]]; ff += 1 if self.hps.co_dim > 0: controller_outputs = [np_vals_flat[f] for f in fidxs[ff]]; ff += 1 # [0] are to take out the non-temporal items from lists gen_ics = gen_ics[0] costs = costs[0] nll_bound_vaes = nll_bound_vaes[0] nll_bound_iwaes = nll_bound_iwaes[0] train_steps = train_steps[0] # Convert to full tensors, not lists of tensors in time dim. gen_states = list_t_bxn_to_tensor_bxtxn(gen_states) factors = list_t_bxn_to_tensor_bxtxn(factors) out_dist_params = list_t_bxn_to_tensor_bxtxn(out_dist_params) if self.hps.ic_dim > 0: # select first time point prior_g0_mean = prior_g0_mean[0] prior_g0_logvar = prior_g0_logvar[0] post_g0_mean = post_g0_mean[0] post_g0_logvar = post_g0_logvar[0] if self.hps.co_dim > 0: controller_outputs = list_t_bxn_to_tensor_bxtxn(controller_outputs) # slice out the trials in case < batch_size provided if E < hps.batch_size: idx = np.arange(E) gen_ics = gen_ics[idx, :] gen_states = gen_states[idx, :] factors = factors[idx, :, :] out_dist_params = out_dist_params[idx, :, :] if self.hps.ic_dim > 0: prior_g0_mean = prior_g0_mean[idx, :] prior_g0_logvar = prior_g0_logvar[idx, :] post_g0_mean = post_g0_mean[idx, :] post_g0_logvar = post_g0_logvar[idx, :] if self.hps.co_dim > 0: controller_outputs = controller_outputs[idx, :, :] if do_average_batch: gen_ics = np.mean(gen_ics, axis=0) gen_states = np.mean(gen_states, axis=0) factors = np.mean(factors, axis=0) out_dist_params = np.mean(out_dist_params, axis=0) if self.hps.ic_dim > 0: prior_g0_mean = np.mean(prior_g0_mean, axis=0) prior_g0_logvar = np.mean(prior_g0_logvar, axis=0) post_g0_mean = np.mean(post_g0_mean, axis=0) post_g0_logvar = np.mean(post_g0_logvar, axis=0) if self.hps.co_dim > 0: controller_outputs = np.mean(controller_outputs, axis=0) model_vals = {} model_vals['gen_ics'] = gen_ics model_vals['gen_states'] = gen_states model_vals['factors'] = factors model_vals['output_dist_params'] = out_dist_params model_vals['costs'] = costs model_vals['nll_bound_vaes'] = nll_bound_vaes model_vals['nll_bound_iwaes'] = nll_bound_iwaes model_vals['train_steps'] = train_steps if self.hps.ic_dim > 0: model_vals['prior_g0_mean'] = prior_g0_mean model_vals['prior_g0_logvar'] = prior_g0_logvar model_vals['post_g0_mean'] = post_g0_mean model_vals['post_g0_logvar'] = post_g0_logvar if self.hps.co_dim > 0: model_vals['controller_outputs'] = controller_outputs return model_vals def eval_model_runs_avg_epoch(self, data_name, data_extxd, ext_input_extxi=None): """Returns all the expected value for goodies for the entire model. The expected value is taken over hidden (z) variables, namely the initial conditions and the control inputs. The expected value is approximate, and accomplished via sampling (batch_size) samples for every examples. Args: data_name: The name of the data dict, to select which in/out matrices to use. data_extxd: Numpy array training data with shape: # examples x # time steps x # dimensions ext_input_extxi (optional): Numpy array training external input with shape: # examples x # time steps x # external input dims Returns: A dictionary with the averaged outputs of the model decoder, namely: prior g0 mean, prior g0 variance, approx. posterior mean, approx posterior mean, the generator initial conditions, the control inputs (if enabled), the state of the generator, the factors, and the output distribution parameters, e.g. (rates or mean and variances). """ hps = self.hps batch_size = hps.batch_size E, T, D = data_extxd.shape E_to_process = hps.ps_nexamples_to_process if E_to_process > E: E_to_process = E if hps.ic_dim > 0: prior_g0_mean = np.zeros([E_to_process, hps.ic_dim]) prior_g0_logvar = np.zeros([E_to_process, hps.ic_dim]) post_g0_mean = np.zeros([E_to_process, hps.ic_dim]) post_g0_logvar = np.zeros([E_to_process, hps.ic_dim]) if hps.co_dim > 0: controller_outputs = np.zeros([E_to_process, T, hps.co_dim]) gen_ics = np.zeros([E_to_process, hps.gen_dim]) gen_states = np.zeros([E_to_process, T, hps.gen_dim]) factors = np.zeros([E_to_process, T, hps.factors_dim]) if hps.output_dist == 'poisson': out_dist_params = np.zeros([E_to_process, T, D]) elif hps.output_dist == 'gaussian': out_dist_params = np.zeros([E_to_process, T, D+D]) else: assert False, "NIY" costs = np.zeros(E_to_process) nll_bound_vaes = np.zeros(E_to_process) nll_bound_iwaes = np.zeros(E_to_process) train_steps = np.zeros(E_to_process) for es_idx in range(E_to_process): print("Running %d of %d." % (es_idx+1, E_to_process)) example_idxs = es_idx * np.ones(batch_size, dtype=np.int32) data_bxtxd, ext_input_bxtxi = self.get_batch(data_extxd, ext_input_extxi, batch_size=batch_size, example_idxs=example_idxs) model_values = self.eval_model_runs_batch(data_name, data_bxtxd, ext_input_bxtxi, do_eval_cost=True, do_average_batch=True) if self.hps.ic_dim > 0: prior_g0_mean[es_idx,:] = model_values['prior_g0_mean'] prior_g0_logvar[es_idx,:] = model_values['prior_g0_logvar'] post_g0_mean[es_idx,:] = model_values['post_g0_mean'] post_g0_logvar[es_idx,:] = model_values['post_g0_logvar'] gen_ics[es_idx,:] = model_values['gen_ics'] if self.hps.co_dim > 0: controller_outputs[es_idx,:,:] = model_values['controller_outputs'] gen_states[es_idx,:,:] = model_values['gen_states'] factors[es_idx,:,:] = model_values['factors'] out_dist_params[es_idx,:,:] = model_values['output_dist_params'] costs[es_idx] = model_values['costs'] nll_bound_vaes[es_idx] = model_values['nll_bound_vaes'] nll_bound_iwaes[es_idx] = model_values['nll_bound_iwaes'] train_steps[es_idx] = model_values['train_steps'] print('bound nll(vae): %.3f, bound nll(iwae): %.3f' \ % (nll_bound_vaes[es_idx], nll_bound_iwaes[es_idx])) model_runs = {} if self.hps.ic_dim > 0: model_runs['prior_g0_mean'] = prior_g0_mean model_runs['prior_g0_logvar'] = prior_g0_logvar model_runs['post_g0_mean'] = post_g0_mean model_runs['post_g0_logvar'] = post_g0_logvar model_runs['gen_ics'] = gen_ics if self.hps.co_dim > 0: model_runs['controller_outputs'] = controller_outputs model_runs['gen_states'] = gen_states model_runs['factors'] = factors model_runs['output_dist_params'] = out_dist_params model_runs['costs'] = costs model_runs['nll_bound_vaes'] = nll_bound_vaes model_runs['nll_bound_iwaes'] = nll_bound_iwaes model_runs['train_steps'] = train_steps return model_runs def eval_model_runs_push_mean(self, data_name, data_extxd, ext_input_extxi=None): """Returns values of interest for the model by pushing the means through The mean values for both initial conditions and the control inputs are pushed through the model instead of sampling (as is done in eval_model_runs_avg_epoch). This is a quick and approximate version of estimating these values instead of sampling from the posterior many times and then averaging those values of interest. Internally, a total of batch_size trials are run through the model at once. Args: data_name: The name of the data dict, to select which in/out matrices to use. data_extxd: Numpy array training data with shape: # examples x # time steps x # dimensions ext_input_extxi (optional): Numpy array training external input with shape: # examples x # time steps x # external input dims Returns: A dictionary with the estimated outputs of the model decoder, namely: prior g0 mean, prior g0 variance, approx. posterior mean, approx posterior mean, the generator initial conditions, the control inputs (if enabled), the state of the generator, the factors, and the output distribution parameters, e.g. (rates or mean and variances). """ hps = self.hps batch_size = hps.batch_size E, T, D = data_extxd.shape E_to_process = hps.ps_nexamples_to_process if E_to_process > E: print("Setting number of posterior samples to process to : ", E) E_to_process = E if hps.ic_dim > 0: prior_g0_mean = np.zeros([E_to_process, hps.ic_dim]) prior_g0_logvar = np.zeros([E_to_process, hps.ic_dim]) post_g0_mean = np.zeros([E_to_process, hps.ic_dim]) post_g0_logvar = np.zeros([E_to_process, hps.ic_dim]) if hps.co_dim > 0: controller_outputs = np.zeros([E_to_process, T, hps.co_dim]) gen_ics = np.zeros([E_to_process, hps.gen_dim]) gen_states = np.zeros([E_to_process, T, hps.gen_dim]) factors = np.zeros([E_to_process, T, hps.factors_dim]) if hps.output_dist == 'poisson': out_dist_params = np.zeros([E_to_process, T, D]) elif hps.output_dist == 'gaussian': out_dist_params = np.zeros([E_to_process, T, D+D]) else: assert False, "NIY" costs = np.zeros(E_to_process) nll_bound_vaes = np.zeros(E_to_process) nll_bound_iwaes = np.zeros(E_to_process) train_steps = np.zeros(E_to_process) # generator that will yield 0:N in groups of per items, e.g. # (0:per-1), (per:2*per-1), ..., with the last group containing <= per items # this will be used to feed per=batch_size trials into the model at a time def trial_batches(N, per): for i in range(0, N, per): yield np.arange(i, min(i+per, N), dtype=np.int32) for batch_idx, es_idx in enumerate(trial_batches(E_to_process, hps.batch_size)): print("Running trial batch %d with %d trials" % (batch_idx+1, len(es_idx))) data_bxtxd, ext_input_bxtxi = self.get_batch(data_extxd, ext_input_extxi, batch_size=batch_size, example_idxs=es_idx) model_values = self.eval_model_runs_batch(data_name, data_bxtxd, ext_input_bxtxi, do_eval_cost=True, do_average_batch=False) if self.hps.ic_dim > 0: prior_g0_mean[es_idx,:] = model_values['prior_g0_mean'] prior_g0_logvar[es_idx,:] = model_values['prior_g0_logvar'] post_g0_mean[es_idx,:] = model_values['post_g0_mean'] post_g0_logvar[es_idx,:] = model_values['post_g0_logvar'] gen_ics[es_idx,:] = model_values['gen_ics'] if self.hps.co_dim > 0: controller_outputs[es_idx,:,:] = model_values['controller_outputs'] gen_states[es_idx,:,:] = model_values['gen_states'] factors[es_idx,:,:] = model_values['factors'] out_dist_params[es_idx,:,:] = model_values['output_dist_params'] # TODO # model_values['costs'] and other costs come out as scalars, summed over # all the trials in the batch. what we want is the per-trial costs costs[es_idx] = model_values['costs'] nll_bound_vaes[es_idx] = model_values['nll_bound_vaes'] nll_bound_iwaes[es_idx] = model_values['nll_bound_iwaes'] train_steps[es_idx] = model_values['train_steps'] model_runs = {} if self.hps.ic_dim > 0: model_runs['prior_g0_mean'] = prior_g0_mean model_runs['prior_g0_logvar'] = prior_g0_logvar model_runs['post_g0_mean'] = post_g0_mean model_runs['post_g0_logvar'] = post_g0_logvar model_runs['gen_ics'] = gen_ics if self.hps.co_dim > 0: model_runs['controller_outputs'] = controller_outputs model_runs['gen_states'] = gen_states model_runs['factors'] = factors model_runs['output_dist_params'] = out_dist_params # You probably do not want the LL associated values when pushing the mean # instead of sampling. model_runs['costs'] = costs model_runs['nll_bound_vaes'] = nll_bound_vaes model_runs['nll_bound_iwaes'] = nll_bound_iwaes model_runs['train_steps'] = train_steps return model_runs def write_model_runs(self, datasets, output_fname=None, push_mean=False): """Run the model on the data in data_dict, and save the computed values. LFADS generates a number of outputs for each examples, and these are all saved. They are: The mean and variance of the prior of g0. The mean and variance of approximate posterior of g0. The control inputs (if enabled) The initial conditions, g0, for all examples. The generator states for all time. The factors for all time. The output distribution parameters (e.g. rates) for all time. Args: datasets: a dictionary of named data_dictionaries, see top of lfads.py output_fname: a file name stem for the output files. push_mean: if False (default), generates batch_size samples for each trial and averages the results. if True, runs each trial once without noise, pushing the posterior mean initial conditions and control inputs through the trained model. False is used for posterior_sample_and_average, True is used for posterior_push_mean. """ hps = self.hps kind = hps.kind for data_name, data_dict in datasets.items(): data_tuple = [('train', data_dict['train_data'], data_dict['train_ext_input']), ('valid', data_dict['valid_data'], data_dict['valid_ext_input'])] for data_kind, data_extxd, ext_input_extxi in data_tuple: if not output_fname: fname = "model_runs_" + data_name + '_' + data_kind + '_' + kind else: fname = output_fname + data_name + '_' + data_kind + '_' + kind print("Writing data for %s data and kind %s." % (data_name, data_kind)) if push_mean: model_runs = self.eval_model_runs_push_mean(data_name, data_extxd, ext_input_extxi) else: model_runs = self.eval_model_runs_avg_epoch(data_name, data_extxd, ext_input_extxi) full_fname = os.path.join(hps.lfads_save_dir, fname) write_data(full_fname, model_runs, compression='gzip') print("Done.") def write_model_samples(self, dataset_name, output_fname=None): """Use the prior distribution to generate batch_size number of samples from the model. LFADS generates a number of outputs for each sample, and these are all saved. They are: The mean and variance of the prior of g0. The control inputs (if enabled) The initial conditions, g0, for all examples. The generator states for all time. The factors for all time. The output distribution parameters (e.g. rates) for all time. Args: dataset_name: The name of the dataset to grab the factors -> rates alignment matrices from. output_fname: The name of the file in which to save the generated samples. """ hps = self.hps batch_size = hps.batch_size print("Generating %d samples" % (batch_size)) tf_vals = [self.factors, self.gen_states, self.gen_ics, self.cost, self.output_dist_params] if hps.ic_dim > 0: tf_vals += [self.prior_zs_g0.mean, self.prior_zs_g0.logvar] if hps.co_dim > 0: tf_vals += [self.prior_zs_ar_con.samples_t] tf_vals_flat, fidxs = flatten(tf_vals) session = tf.get_default_session() feed_dict = {} feed_dict[self.dataName] = dataset_name feed_dict[self.keep_prob] = 1.0 np_vals_flat = session.run(tf_vals_flat, feed_dict=feed_dict) ff = 0 factors = [np_vals_flat[f] for f in fidxs[ff]]; ff += 1 gen_states = [np_vals_flat[f] for f in fidxs[ff]]; ff += 1 gen_ics = [np_vals_flat[f] for f in fidxs[ff]]; ff += 1 costs = [np_vals_flat[f] for f in fidxs[ff]]; ff += 1 output_dist_params = [np_vals_flat[f] for f in fidxs[ff]]; ff += 1 if hps.ic_dim > 0: prior_g0_mean = [np_vals_flat[f] for f in fidxs[ff]]; ff += 1 prior_g0_logvar = [np_vals_flat[f] for f in fidxs[ff]]; ff += 1 if hps.co_dim > 0: prior_zs_ar_con = [np_vals_flat[f] for f in fidxs[ff]]; ff += 1 # [0] are to take out the non-temporal items from lists gen_ics = gen_ics[0] costs = costs[0] # Convert to full tensors, not lists of tensors in time dim. gen_states = list_t_bxn_to_tensor_bxtxn(gen_states) factors = list_t_bxn_to_tensor_bxtxn(factors) output_dist_params = list_t_bxn_to_tensor_bxtxn(output_dist_params) if hps.ic_dim > 0: prior_g0_mean = prior_g0_mean[0] prior_g0_logvar = prior_g0_logvar[0] if hps.co_dim > 0: prior_zs_ar_con = list_t_bxn_to_tensor_bxtxn(prior_zs_ar_con) model_vals = {} model_vals['gen_ics'] = gen_ics model_vals['gen_states'] = gen_states model_vals['factors'] = factors model_vals['output_dist_params'] = output_dist_params model_vals['costs'] = costs.reshape(1) if hps.ic_dim > 0: model_vals['prior_g0_mean'] = prior_g0_mean model_vals['prior_g0_logvar'] = prior_g0_logvar if hps.co_dim > 0: model_vals['prior_zs_ar_con'] = prior_zs_ar_con full_fname = os.path.join(hps.lfads_save_dir, output_fname) write_data(full_fname, model_vals, compression='gzip') print("Done.") @staticmethod def eval_model_parameters(use_nested=True, include_strs=None): """Evaluate and return all of the TF variables in the model. Args: use_nested (optional): For returning values, use a nested dictoinary, based on variable scoping, or return all variables in a flat dictionary. include_strs (optional): A list of strings to use as a filter, to reduce the number of variables returned. A variable name must contain at least one string in include_strs as a sub-string in order to be returned. Returns: The parameters of the model. This can be in a flat dictionary, or a nested dictionary, where the nesting is by variable scope. """ all_tf_vars = tf.global_variables() session = tf.get_default_session() all_tf_vars_eval = session.run(all_tf_vars) vars_dict = {} strs = ["LFADS"] if include_strs: strs += include_strs for i, (var, var_eval) in enumerate(zip(all_tf_vars, all_tf_vars_eval)): if any(s in include_strs for s in var.name): if not isinstance(var_eval, np.ndarray): # for H5PY print(var.name, """ is not numpy array, saving as numpy array with value: """, var_eval, type(var_eval)) e = np.array(var_eval) print(e, type(e)) else: e = var_eval vars_dict[var.name] = e if not use_nested: return vars_dict var_names = vars_dict.keys() nested_vars_dict = {} current_dict = nested_vars_dict for v, var_name in enumerate(var_names): var_split_name_list = var_name.split('/') split_name_list_len = len(var_split_name_list) current_dict = nested_vars_dict for p, part in enumerate(var_split_name_list): if p < split_name_list_len - 1: if part in current_dict: current_dict = current_dict[part] else: current_dict[part] = {} current_dict = current_dict[part] else: current_dict[part] = vars_dict[var_name] return nested_vars_dict @staticmethod def spikify_rates(rates_bxtxd): """Randomly spikify underlying rates according a Poisson distribution Args: rates_bxtxd: a numpy tensor with shape: Returns: A numpy array with the same shape as rates_bxtxd, but with the event counts. """ B,T,N = rates_bxtxd.shape assert all([B > 0, N > 0]), "problems" # Because the rates are changing, there is nesting spikes_bxtxd = np.zeros([B,T,N], dtype=np.int32) for b in range(B): for t in range(T): for n in range(N): rate = rates_bxtxd[b,t,n] count = np.random.poisson(rate) spikes_bxtxd[b,t,n] = count return spikes_bxtxd