Spaces:
Running
Running
# Copyright 2017 Google Inc. All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# | |
# ============================================================================== | |
from __future__ import print_function | |
import os | |
import h5py | |
import numpy as np | |
from six.moves import xrange | |
from synthetic_data_utils import generate_data, generate_rnn | |
from synthetic_data_utils import get_train_n_valid_inds | |
from synthetic_data_utils import nparray_and_transpose | |
from synthetic_data_utils import spikify_data, split_list_by_inds | |
import tensorflow as tf | |
from utils import write_datasets | |
DATA_DIR = "rnn_synth_data_v1.0" | |
flags = tf.app.flags | |
flags.DEFINE_string("save_dir", "/tmp/" + DATA_DIR + "/", | |
"Directory for saving data.") | |
flags.DEFINE_string("datafile_name", "conditioned_rnn_data", | |
"Name of data file for input case.") | |
flags.DEFINE_integer("synth_data_seed", 5, "Random seed for RNN generation.") | |
flags.DEFINE_float("T", 1.0, "Time in seconds to generate.") | |
flags.DEFINE_integer("C", 400, "Number of conditions") | |
flags.DEFINE_integer("N", 50, "Number of units for the RNN") | |
flags.DEFINE_float("train_percentage", 4.0/5.0, | |
"Percentage of train vs validation trials") | |
flags.DEFINE_integer("nreplications", 10, | |
"Number of spikifications of the same underlying rates.") | |
flags.DEFINE_float("g", 1.5, "Complexity of dynamics") | |
flags.DEFINE_float("x0_std", 1.0, | |
"Volume from which to pull initial conditions (affects diversity of dynamics.") | |
flags.DEFINE_float("tau", 0.025, "Time constant of RNN") | |
flags.DEFINE_float("dt", 0.010, "Time bin") | |
flags.DEFINE_float("max_firing_rate", 30.0, "Map 1.0 of RNN to a spikes per second") | |
FLAGS = flags.FLAGS | |
rng = np.random.RandomState(seed=FLAGS.synth_data_seed) | |
rnn_rngs = [np.random.RandomState(seed=FLAGS.synth_data_seed+1), | |
np.random.RandomState(seed=FLAGS.synth_data_seed+2)] | |
T = FLAGS.T | |
C = FLAGS.C | |
N = FLAGS.N | |
nreplications = FLAGS.nreplications | |
E = nreplications * C | |
train_percentage = FLAGS.train_percentage | |
ntimesteps = int(T / FLAGS.dt) | |
rnn_a = generate_rnn(rnn_rngs[0], N, FLAGS.g, FLAGS.tau, FLAGS.dt, | |
FLAGS.max_firing_rate) | |
rnn_b = generate_rnn(rnn_rngs[1], N, FLAGS.g, FLAGS.tau, FLAGS.dt, | |
FLAGS.max_firing_rate) | |
rnns = [rnn_a, rnn_b] | |
# pick which RNN is used on each trial | |
rnn_to_use = rng.randint(2, size=E) | |
ext_input = np.repeat(np.expand_dims(rnn_to_use, axis=1), ntimesteps, axis=1) | |
ext_input = np.expand_dims(ext_input, axis=2) # these are "a's" in the paper | |
x0s = [] | |
condition_labels = [] | |
condition_number = 0 | |
for c in range(C): | |
x0 = FLAGS.x0_std * rng.randn(N, 1) | |
x0s.append(np.tile(x0, nreplications)) | |
for ns in range(nreplications): | |
condition_labels.append(condition_number) | |
condition_number += 1 | |
x0s = np.concatenate(x0s, axis=1) | |
P_nxn = rng.randn(N, N) / np.sqrt(N) | |
# generate trials for both RNNs | |
rates_a, x0s_a, _ = generate_data(rnn_a, T=T, E=E, x0s=x0s, P_sxn=P_nxn, | |
input_magnitude=0.0, input_times=None) | |
spikes_a = spikify_data(rates_a, rng, rnn_a['dt'], rnn_a['max_firing_rate']) | |
rates_b, x0s_b, _ = generate_data(rnn_b, T=T, E=E, x0s=x0s, P_sxn=P_nxn, | |
input_magnitude=0.0, input_times=None) | |
spikes_b = spikify_data(rates_b, rng, rnn_b['dt'], rnn_b['max_firing_rate']) | |
# not the best way to do this but E is small enough | |
rates = [] | |
spikes = [] | |
for trial in xrange(E): | |
if rnn_to_use[trial] == 0: | |
rates.append(rates_a[trial]) | |
spikes.append(spikes_a[trial]) | |
else: | |
rates.append(rates_b[trial]) | |
spikes.append(spikes_b[trial]) | |
# split into train and validation sets | |
train_inds, valid_inds = get_train_n_valid_inds(E, train_percentage, | |
nreplications) | |
rates_train, rates_valid = split_list_by_inds(rates, train_inds, valid_inds) | |
spikes_train, spikes_valid = split_list_by_inds(spikes, train_inds, valid_inds) | |
condition_labels_train, condition_labels_valid = split_list_by_inds( | |
condition_labels, train_inds, valid_inds) | |
ext_input_train, ext_input_valid = split_list_by_inds( | |
ext_input, train_inds, valid_inds) | |
rates_train = nparray_and_transpose(rates_train) | |
rates_valid = nparray_and_transpose(rates_valid) | |
spikes_train = nparray_and_transpose(spikes_train) | |
spikes_valid = nparray_and_transpose(spikes_valid) | |
# add train_ext_input and valid_ext input | |
data = {'train_truth': rates_train, | |
'valid_truth': rates_valid, | |
'train_data' : spikes_train, | |
'valid_data' : spikes_valid, | |
'train_ext_input' : np.array(ext_input_train), | |
'valid_ext_input': np.array(ext_input_valid), | |
'train_percentage' : train_percentage, | |
'nreplications' : nreplications, | |
'dt' : FLAGS.dt, | |
'P_sxn' : P_nxn, | |
'condition_labels_train' : condition_labels_train, | |
'condition_labels_valid' : condition_labels_valid, | |
'conversion_factor': 1.0 / rnn_a['conversion_factor']} | |
# just one dataset here | |
datasets = {} | |
dataset_name = 'dataset_N' + str(N) | |
datasets[dataset_name] = data | |
# write out the dataset | |
write_datasets(FLAGS.save_dir, FLAGS.datafile_name, datasets) | |
print ('Saved to ', os.path.join(FLAGS.save_dir, | |
FLAGS.datafile_name + '_' + dataset_name)) | |