Spaces:
Running
Running
# Copyright 2017 Google Inc. All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# | |
# ============================================================================== | |
from __future__ import print_function | |
import h5py | |
import numpy as np | |
import os | |
from utils import write_datasets | |
import matplotlib | |
import matplotlib.pyplot as plt | |
import scipy.signal | |
def generate_rnn(rng, N, g, tau, dt, max_firing_rate): | |
"""Create a (vanilla) RNN with a bunch of hyper parameters for generating | |
chaotic data. | |
Args: | |
rng: numpy random number generator | |
N: number of hidden units | |
g: scaling of recurrent weight matrix in g W, with W ~ N(0,1/N) | |
tau: time scale of individual unit dynamics | |
dt: time step for equation updates | |
max_firing_rate: how to resecale the -1,1 firing rates | |
Returns: | |
the dictionary of these parameters, plus some others. | |
""" | |
rnn = {} | |
rnn['N'] = N | |
rnn['W'] = rng.randn(N,N)/np.sqrt(N) | |
rnn['Bin'] = rng.randn(N)/np.sqrt(1.0) | |
rnn['Bin2'] = rng.randn(N)/np.sqrt(1.0) | |
rnn['b'] = np.zeros(N) | |
rnn['g'] = g | |
rnn['tau'] = tau | |
rnn['dt'] = dt | |
rnn['max_firing_rate'] = max_firing_rate | |
mfr = rnn['max_firing_rate'] # spikes / sec | |
nbins_per_sec = 1.0/rnn['dt'] # bins / sec | |
# Used for plotting in LFADS | |
rnn['conversion_factor'] = mfr / nbins_per_sec # spikes / bin | |
return rnn | |
def generate_data(rnn, T, E, x0s=None, P_sxn=None, input_magnitude=0.0, | |
input_times=None): | |
""" Generates data from an randomly initialized RNN. | |
Args: | |
rnn: the rnn | |
T: Time in seconds to run (divided by rnn['dt'] to get steps, rounded down. | |
E: total number of examples | |
S: number of samples (subsampling N) | |
Returns: | |
A list of length E of NxT tensors of the network being run. | |
""" | |
N = rnn['N'] | |
def run_rnn(rnn, x0, ntime_steps, input_time=None): | |
rs = np.zeros([N,ntime_steps]) | |
x_tm1 = x0 | |
r_tm1 = np.tanh(x0) | |
tau = rnn['tau'] | |
dt = rnn['dt'] | |
alpha = (1.0-dt/tau) | |
W = dt/tau*rnn['W']*rnn['g'] | |
Bin = dt/tau*rnn['Bin'] | |
Bin2 = dt/tau*rnn['Bin2'] | |
b = dt/tau*rnn['b'] | |
us = np.zeros([1, ntime_steps]) | |
for t in range(ntime_steps): | |
x_t = alpha*x_tm1 + np.dot(W,r_tm1) + b | |
if input_time is not None and t == input_time: | |
us[0,t] = input_magnitude | |
x_t += Bin * us[0,t] # DCS is this what was used? | |
r_t = np.tanh(x_t) | |
x_tm1 = x_t | |
r_tm1 = r_t | |
rs[:,t] = r_t | |
return rs, us | |
if P_sxn is None: | |
P_sxn = np.eye(N) | |
ntime_steps = int(T / rnn['dt']) | |
data_e = [] | |
inputs_e = [] | |
for e in range(E): | |
input_time = input_times[e] if input_times is not None else None | |
r_nxt, u_uxt = run_rnn(rnn, x0s[:,e], ntime_steps, input_time) | |
r_sxt = np.dot(P_sxn, r_nxt) | |
inputs_e.append(u_uxt) | |
data_e.append(r_sxt) | |
S = P_sxn.shape[0] | |
data_e = normalize_rates(data_e, E, S) | |
return data_e, x0s, inputs_e | |
def normalize_rates(data_e, E, S): | |
# Normalization, made more complex because of the P matrices. | |
# Normalize by min and max in each channel. This normalization will | |
# cause offset differences between identical rnn runs, but different | |
# t hits. | |
for e in range(E): | |
r_sxt = data_e[e] | |
for i in range(S): | |
rmin = np.min(r_sxt[i,:]) | |
rmax = np.max(r_sxt[i,:]) | |
assert rmax - rmin != 0, 'Something wrong' | |
r_sxt[i,:] = (r_sxt[i,:] - rmin)/(rmax-rmin) | |
data_e[e] = r_sxt | |
return data_e | |
def spikify_data(data_e, rng, dt=1.0, max_firing_rate=100): | |
""" Apply spikes to a continuous dataset whose values are between 0.0 and 1.0 | |
Args: | |
data_e: nexamples length list of NxT trials | |
dt: how often the data are sampled | |
max_firing_rate: the firing rate that is associated with a value of 1.0 | |
Returns: | |
spikified_e: a list of length b of the data represented as spikes, | |
sampled from the underlying poisson process. | |
""" | |
E = len(data_e) | |
spikes_e = [] | |
for e in range(E): | |
data = data_e[e] | |
N,T = data.shape | |
data_s = np.zeros([N,T]).astype(np.int) | |
for n in range(N): | |
f = data[n,:] | |
s = rng.poisson(f*max_firing_rate*dt, size=T) | |
data_s[n,:] = s | |
spikes_e.append(data_s) | |
return spikes_e | |
def gaussify_data(data_e, rng, dt=1.0, max_firing_rate=100): | |
""" Apply gaussian noise to a continuous dataset whose values are between | |
0.0 and 1.0 | |
Args: | |
data_e: nexamples length list of NxT trials | |
dt: how often the data are sampled | |
max_firing_rate: the firing rate that is associated with a value of 1.0 | |
Returns: | |
gauss_e: a list of length b of the data with noise. | |
""" | |
E = len(data_e) | |
mfr = max_firing_rate | |
gauss_e = [] | |
for e in range(E): | |
data = data_e[e] | |
N,T = data.shape | |
noisy_data = data * mfr + np.random.randn(N,T) * (5.0*mfr) * np.sqrt(dt) | |
gauss_e.append(noisy_data) | |
return gauss_e | |
def get_train_n_valid_inds(num_trials, train_fraction, nreplications): | |
"""Split the numbers between 0 and num_trials-1 into two portions for | |
training and validation, based on the train fraction. | |
Args: | |
num_trials: the number of trials | |
train_fraction: (e.g. .80) | |
nreplications: the number of spiking trials per initial condition | |
Returns: | |
a 2-tuple of two lists: the training indices and validation indices | |
""" | |
train_inds = [] | |
valid_inds = [] | |
for i in range(num_trials): | |
# This line divides up the trials so that within one initial condition, | |
# the randomness of spikifying the condition is shared among both | |
# training and validation data splits. | |
if (i % nreplications)+1 > train_fraction * nreplications: | |
valid_inds.append(i) | |
else: | |
train_inds.append(i) | |
return train_inds, valid_inds | |
def split_list_by_inds(data, inds1, inds2): | |
"""Take the data, a list, and split it up based on the indices in inds1 and | |
inds2. | |
Args: | |
data: the list of data to split | |
inds1, the first list of indices | |
inds2, the second list of indices | |
Returns: a 2-tuple of two lists. | |
""" | |
if data is None or len(data) == 0: | |
return [], [] | |
else: | |
dout1 = [data[i] for i in inds1] | |
dout2 = [data[i] for i in inds2] | |
return dout1, dout2 | |
def nparray_and_transpose(data_a_b_c): | |
"""Convert the list of items in data to a numpy array, and transpose it | |
Args: | |
data: data_asbsc: a nested, nested list of length a, with sublist length | |
b, with sublist length c. | |
Returns: | |
a numpy 3-tensor with dimensions a x c x b | |
""" | |
data_axbxc = np.array([datum_b_c for datum_b_c in data_a_b_c]) | |
data_axcxb = np.transpose(data_axbxc, axes=[0,2,1]) | |
return data_axcxb | |
def add_alignment_projections(datasets, npcs, ntime=None, nsamples=None): | |
"""Create a matrix that aligns the datasets a bit, under | |
the assumption that each dataset is observing the same underlying dynamical | |
system. | |
Args: | |
datasets: The dictionary of dataset structures. | |
npcs: The number of pcs for each, basically like lfads factors. | |
nsamples (optional): Number of samples to take for each dataset. | |
ntime (optional): Number of time steps to take in each sample. | |
Returns: | |
The dataset structures, with the field alignment_matrix_cxf added. | |
This is # channels x npcs dimension | |
""" | |
nchannels_all = 0 | |
channel_idxs = {} | |
conditions_all = {} | |
nconditions_all = 0 | |
for name, dataset in datasets.items(): | |
cidxs = np.where(dataset['P_sxn'])[1] # non-zero entries in columns | |
channel_idxs[name] = [cidxs[0], cidxs[-1]+1] | |
nchannels_all += cidxs[-1]+1 - cidxs[0] | |
conditions_all[name] = np.unique(dataset['condition_labels_train']) | |
all_conditions_list = \ | |
np.unique(np.ndarray.flatten(np.array(conditions_all.values()))) | |
nconditions_all = all_conditions_list.shape[0] | |
if ntime is None: | |
ntime = dataset['train_data'].shape[1] | |
if nsamples is None: | |
nsamples = dataset['train_data'].shape[0] | |
# In the data workup in the paper, Chethan did intra condition | |
# averaging, so let's do that here. | |
avg_data_all = {} | |
for name, conditions in conditions_all.items(): | |
dataset = datasets[name] | |
avg_data_all[name] = {} | |
for cname in conditions: | |
td_idxs = np.argwhere(np.array(dataset['condition_labels_train'])==cname) | |
data = np.squeeze(dataset['train_data'][td_idxs,:,:], axis=1) | |
avg_data = np.mean(data, axis=0) | |
avg_data_all[name][cname] = avg_data | |
# Visualize this in the morning. | |
all_data_nxtc = np.zeros([nchannels_all, ntime * nconditions_all]) | |
for name, dataset in datasets.items(): | |
cidx_s = channel_idxs[name][0] | |
cidx_f = channel_idxs[name][1] | |
for cname in conditions_all[name]: | |
cidxs = np.argwhere(all_conditions_list == cname) | |
if cidxs.shape[0] > 0: | |
cidx = cidxs[0][0] | |
all_tidxs = np.arange(0, ntime+1) + cidx*ntime | |
all_data_nxtc[cidx_s:cidx_f, all_tidxs[0]:all_tidxs[-1]] = \ | |
avg_data_all[name][cname].T | |
# A bit of filtering. We don't care about spectral properties, or | |
# filtering artifacts, simply correlate time steps a bit. | |
filt_len = 6 | |
bc_filt = np.ones([filt_len])/float(filt_len) | |
for c in range(nchannels_all): | |
all_data_nxtc[c,:] = scipy.signal.filtfilt(bc_filt, [1.0], all_data_nxtc[c,:]) | |
# Compute the PCs. | |
all_data_mean_nx1 = np.mean(all_data_nxtc, axis=1, keepdims=True) | |
all_data_zm_nxtc = all_data_nxtc - all_data_mean_nx1 | |
corr_mat_nxn = np.dot(all_data_zm_nxtc, all_data_zm_nxtc.T) | |
evals_n, evecs_nxn = np.linalg.eigh(corr_mat_nxn) | |
sidxs = np.flipud(np.argsort(evals_n)) # sort such that 0th is highest | |
evals_n = evals_n[sidxs] | |
evecs_nxn = evecs_nxn[:,sidxs] | |
# Project all the channels data onto the low-D PCA basis, where | |
# low-d is the npcs parameter. | |
all_data_pca_pxtc = np.dot(evecs_nxn[:, 0:npcs].T, all_data_zm_nxtc) | |
# Now for each dataset, we regress the channel data onto the top | |
# pcs, and this will be our alignment matrix for that dataset. | |
# |B - A*W|^2 | |
for name, dataset in datasets.items(): | |
cidx_s = channel_idxs[name][0] | |
cidx_f = channel_idxs[name][1] | |
all_data_zm_chxtc = all_data_zm_nxtc[cidx_s:cidx_f,:] # ch for channel | |
W_chxp, _, _, _ = \ | |
np.linalg.lstsq(all_data_zm_chxtc.T, all_data_pca_pxtc.T) | |
dataset['alignment_matrix_cxf'] = W_chxp | |
alignment_bias_cx1 = all_data_mean_nx1[cidx_s:cidx_f] | |
dataset['alignment_bias_c'] = np.squeeze(alignment_bias_cx1, axis=1) | |
do_debug_plot = False | |
if do_debug_plot: | |
pc_vecs = evecs_nxn[:,0:npcs] | |
ntoplot = 400 | |
plt.figure() | |
plt.plot(np.log10(evals_n), '-x') | |
plt.figure() | |
plt.subplot(311) | |
plt.imshow(all_data_pca_pxtc) | |
plt.colorbar() | |
plt.subplot(312) | |
plt.imshow(np.dot(W_chxp.T, all_data_zm_chxtc)) | |
plt.colorbar() | |
plt.subplot(313) | |
plt.imshow(np.dot(all_data_zm_chxtc.T, W_chxp).T - all_data_pca_pxtc) | |
plt.colorbar() | |
import pdb | |
pdb.set_trace() | |
return datasets | |