NCTC / models /research /lfads /synth_data /synthetic_data_utils.py
NCTCMumbai's picture
Upload 2571 files
0b8359d
raw
history blame
11.4 kB
# Copyright 2017 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# ==============================================================================
from __future__ import print_function
import h5py
import numpy as np
import os
from utils import write_datasets
import matplotlib
import matplotlib.pyplot as plt
import scipy.signal
def generate_rnn(rng, N, g, tau, dt, max_firing_rate):
"""Create a (vanilla) RNN with a bunch of hyper parameters for generating
chaotic data.
Args:
rng: numpy random number generator
N: number of hidden units
g: scaling of recurrent weight matrix in g W, with W ~ N(0,1/N)
tau: time scale of individual unit dynamics
dt: time step for equation updates
max_firing_rate: how to resecale the -1,1 firing rates
Returns:
the dictionary of these parameters, plus some others.
"""
rnn = {}
rnn['N'] = N
rnn['W'] = rng.randn(N,N)/np.sqrt(N)
rnn['Bin'] = rng.randn(N)/np.sqrt(1.0)
rnn['Bin2'] = rng.randn(N)/np.sqrt(1.0)
rnn['b'] = np.zeros(N)
rnn['g'] = g
rnn['tau'] = tau
rnn['dt'] = dt
rnn['max_firing_rate'] = max_firing_rate
mfr = rnn['max_firing_rate'] # spikes / sec
nbins_per_sec = 1.0/rnn['dt'] # bins / sec
# Used for plotting in LFADS
rnn['conversion_factor'] = mfr / nbins_per_sec # spikes / bin
return rnn
def generate_data(rnn, T, E, x0s=None, P_sxn=None, input_magnitude=0.0,
input_times=None):
""" Generates data from an randomly initialized RNN.
Args:
rnn: the rnn
T: Time in seconds to run (divided by rnn['dt'] to get steps, rounded down.
E: total number of examples
S: number of samples (subsampling N)
Returns:
A list of length E of NxT tensors of the network being run.
"""
N = rnn['N']
def run_rnn(rnn, x0, ntime_steps, input_time=None):
rs = np.zeros([N,ntime_steps])
x_tm1 = x0
r_tm1 = np.tanh(x0)
tau = rnn['tau']
dt = rnn['dt']
alpha = (1.0-dt/tau)
W = dt/tau*rnn['W']*rnn['g']
Bin = dt/tau*rnn['Bin']
Bin2 = dt/tau*rnn['Bin2']
b = dt/tau*rnn['b']
us = np.zeros([1, ntime_steps])
for t in range(ntime_steps):
x_t = alpha*x_tm1 + np.dot(W,r_tm1) + b
if input_time is not None and t == input_time:
us[0,t] = input_magnitude
x_t += Bin * us[0,t] # DCS is this what was used?
r_t = np.tanh(x_t)
x_tm1 = x_t
r_tm1 = r_t
rs[:,t] = r_t
return rs, us
if P_sxn is None:
P_sxn = np.eye(N)
ntime_steps = int(T / rnn['dt'])
data_e = []
inputs_e = []
for e in range(E):
input_time = input_times[e] if input_times is not None else None
r_nxt, u_uxt = run_rnn(rnn, x0s[:,e], ntime_steps, input_time)
r_sxt = np.dot(P_sxn, r_nxt)
inputs_e.append(u_uxt)
data_e.append(r_sxt)
S = P_sxn.shape[0]
data_e = normalize_rates(data_e, E, S)
return data_e, x0s, inputs_e
def normalize_rates(data_e, E, S):
# Normalization, made more complex because of the P matrices.
# Normalize by min and max in each channel. This normalization will
# cause offset differences between identical rnn runs, but different
# t hits.
for e in range(E):
r_sxt = data_e[e]
for i in range(S):
rmin = np.min(r_sxt[i,:])
rmax = np.max(r_sxt[i,:])
assert rmax - rmin != 0, 'Something wrong'
r_sxt[i,:] = (r_sxt[i,:] - rmin)/(rmax-rmin)
data_e[e] = r_sxt
return data_e
def spikify_data(data_e, rng, dt=1.0, max_firing_rate=100):
""" Apply spikes to a continuous dataset whose values are between 0.0 and 1.0
Args:
data_e: nexamples length list of NxT trials
dt: how often the data are sampled
max_firing_rate: the firing rate that is associated with a value of 1.0
Returns:
spikified_e: a list of length b of the data represented as spikes,
sampled from the underlying poisson process.
"""
E = len(data_e)
spikes_e = []
for e in range(E):
data = data_e[e]
N,T = data.shape
data_s = np.zeros([N,T]).astype(np.int)
for n in range(N):
f = data[n,:]
s = rng.poisson(f*max_firing_rate*dt, size=T)
data_s[n,:] = s
spikes_e.append(data_s)
return spikes_e
def gaussify_data(data_e, rng, dt=1.0, max_firing_rate=100):
""" Apply gaussian noise to a continuous dataset whose values are between
0.0 and 1.0
Args:
data_e: nexamples length list of NxT trials
dt: how often the data are sampled
max_firing_rate: the firing rate that is associated with a value of 1.0
Returns:
gauss_e: a list of length b of the data with noise.
"""
E = len(data_e)
mfr = max_firing_rate
gauss_e = []
for e in range(E):
data = data_e[e]
N,T = data.shape
noisy_data = data * mfr + np.random.randn(N,T) * (5.0*mfr) * np.sqrt(dt)
gauss_e.append(noisy_data)
return gauss_e
def get_train_n_valid_inds(num_trials, train_fraction, nreplications):
"""Split the numbers between 0 and num_trials-1 into two portions for
training and validation, based on the train fraction.
Args:
num_trials: the number of trials
train_fraction: (e.g. .80)
nreplications: the number of spiking trials per initial condition
Returns:
a 2-tuple of two lists: the training indices and validation indices
"""
train_inds = []
valid_inds = []
for i in range(num_trials):
# This line divides up the trials so that within one initial condition,
# the randomness of spikifying the condition is shared among both
# training and validation data splits.
if (i % nreplications)+1 > train_fraction * nreplications:
valid_inds.append(i)
else:
train_inds.append(i)
return train_inds, valid_inds
def split_list_by_inds(data, inds1, inds2):
"""Take the data, a list, and split it up based on the indices in inds1 and
inds2.
Args:
data: the list of data to split
inds1, the first list of indices
inds2, the second list of indices
Returns: a 2-tuple of two lists.
"""
if data is None or len(data) == 0:
return [], []
else:
dout1 = [data[i] for i in inds1]
dout2 = [data[i] for i in inds2]
return dout1, dout2
def nparray_and_transpose(data_a_b_c):
"""Convert the list of items in data to a numpy array, and transpose it
Args:
data: data_asbsc: a nested, nested list of length a, with sublist length
b, with sublist length c.
Returns:
a numpy 3-tensor with dimensions a x c x b
"""
data_axbxc = np.array([datum_b_c for datum_b_c in data_a_b_c])
data_axcxb = np.transpose(data_axbxc, axes=[0,2,1])
return data_axcxb
def add_alignment_projections(datasets, npcs, ntime=None, nsamples=None):
"""Create a matrix that aligns the datasets a bit, under
the assumption that each dataset is observing the same underlying dynamical
system.
Args:
datasets: The dictionary of dataset structures.
npcs: The number of pcs for each, basically like lfads factors.
nsamples (optional): Number of samples to take for each dataset.
ntime (optional): Number of time steps to take in each sample.
Returns:
The dataset structures, with the field alignment_matrix_cxf added.
This is # channels x npcs dimension
"""
nchannels_all = 0
channel_idxs = {}
conditions_all = {}
nconditions_all = 0
for name, dataset in datasets.items():
cidxs = np.where(dataset['P_sxn'])[1] # non-zero entries in columns
channel_idxs[name] = [cidxs[0], cidxs[-1]+1]
nchannels_all += cidxs[-1]+1 - cidxs[0]
conditions_all[name] = np.unique(dataset['condition_labels_train'])
all_conditions_list = \
np.unique(np.ndarray.flatten(np.array(conditions_all.values())))
nconditions_all = all_conditions_list.shape[0]
if ntime is None:
ntime = dataset['train_data'].shape[1]
if nsamples is None:
nsamples = dataset['train_data'].shape[0]
# In the data workup in the paper, Chethan did intra condition
# averaging, so let's do that here.
avg_data_all = {}
for name, conditions in conditions_all.items():
dataset = datasets[name]
avg_data_all[name] = {}
for cname in conditions:
td_idxs = np.argwhere(np.array(dataset['condition_labels_train'])==cname)
data = np.squeeze(dataset['train_data'][td_idxs,:,:], axis=1)
avg_data = np.mean(data, axis=0)
avg_data_all[name][cname] = avg_data
# Visualize this in the morning.
all_data_nxtc = np.zeros([nchannels_all, ntime * nconditions_all])
for name, dataset in datasets.items():
cidx_s = channel_idxs[name][0]
cidx_f = channel_idxs[name][1]
for cname in conditions_all[name]:
cidxs = np.argwhere(all_conditions_list == cname)
if cidxs.shape[0] > 0:
cidx = cidxs[0][0]
all_tidxs = np.arange(0, ntime+1) + cidx*ntime
all_data_nxtc[cidx_s:cidx_f, all_tidxs[0]:all_tidxs[-1]] = \
avg_data_all[name][cname].T
# A bit of filtering. We don't care about spectral properties, or
# filtering artifacts, simply correlate time steps a bit.
filt_len = 6
bc_filt = np.ones([filt_len])/float(filt_len)
for c in range(nchannels_all):
all_data_nxtc[c,:] = scipy.signal.filtfilt(bc_filt, [1.0], all_data_nxtc[c,:])
# Compute the PCs.
all_data_mean_nx1 = np.mean(all_data_nxtc, axis=1, keepdims=True)
all_data_zm_nxtc = all_data_nxtc - all_data_mean_nx1
corr_mat_nxn = np.dot(all_data_zm_nxtc, all_data_zm_nxtc.T)
evals_n, evecs_nxn = np.linalg.eigh(corr_mat_nxn)
sidxs = np.flipud(np.argsort(evals_n)) # sort such that 0th is highest
evals_n = evals_n[sidxs]
evecs_nxn = evecs_nxn[:,sidxs]
# Project all the channels data onto the low-D PCA basis, where
# low-d is the npcs parameter.
all_data_pca_pxtc = np.dot(evecs_nxn[:, 0:npcs].T, all_data_zm_nxtc)
# Now for each dataset, we regress the channel data onto the top
# pcs, and this will be our alignment matrix for that dataset.
# |B - A*W|^2
for name, dataset in datasets.items():
cidx_s = channel_idxs[name][0]
cidx_f = channel_idxs[name][1]
all_data_zm_chxtc = all_data_zm_nxtc[cidx_s:cidx_f,:] # ch for channel
W_chxp, _, _, _ = \
np.linalg.lstsq(all_data_zm_chxtc.T, all_data_pca_pxtc.T)
dataset['alignment_matrix_cxf'] = W_chxp
alignment_bias_cx1 = all_data_mean_nx1[cidx_s:cidx_f]
dataset['alignment_bias_c'] = np.squeeze(alignment_bias_cx1, axis=1)
do_debug_plot = False
if do_debug_plot:
pc_vecs = evecs_nxn[:,0:npcs]
ntoplot = 400
plt.figure()
plt.plot(np.log10(evals_n), '-x')
plt.figure()
plt.subplot(311)
plt.imshow(all_data_pca_pxtc)
plt.colorbar()
plt.subplot(312)
plt.imshow(np.dot(W_chxp.T, all_data_zm_chxtc))
plt.colorbar()
plt.subplot(313)
plt.imshow(np.dot(all_data_zm_chxtc.T, W_chxp).T - all_data_pca_pxtc)
plt.colorbar()
import pdb
pdb.set_trace()
return datasets