Spaces:
Running
Running
File size: 6,481 Bytes
0b8359d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 |
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""A script to run training for sequential latent variable models.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
from fivo import ghmm_runners
from fivo import runners
# Shared flags.
tf.app.flags.DEFINE_enum("mode", "train",
["train", "eval", "sample"],
"The mode of the binary.")
tf.app.flags.DEFINE_enum("model", "vrnn",
["vrnn", "ghmm", "srnn"],
"Model choice.")
tf.app.flags.DEFINE_integer("latent_size", 64,
"The size of the latent state of the model.")
tf.app.flags.DEFINE_enum("dataset_type", "pianoroll",
["pianoroll", "speech", "pose"],
"The type of dataset.")
tf.app.flags.DEFINE_string("dataset_path", "",
"Path to load the dataset from.")
tf.app.flags.DEFINE_integer("data_dimension", None,
"The dimension of each vector in the data sequence. "
"Defaults to 88 for pianoroll datasets and 200 for speech "
"datasets. Should not need to be changed except for "
"testing.")
tf.app.flags.DEFINE_integer("batch_size", 4,
"Batch size.")
tf.app.flags.DEFINE_integer("num_samples", 4,
"The number of samples (or particles) for multisample "
"algorithms.")
tf.app.flags.DEFINE_string("logdir", "/tmp/smc_vi",
"The directory to keep checkpoints and summaries in.")
tf.app.flags.DEFINE_integer("random_seed", None,
"A random seed for seeding the TensorFlow graph.")
tf.app.flags.DEFINE_integer("parallel_iterations", 30,
"The number of parallel iterations to use for the while "
"loop that computes the bounds.")
# Training flags.
tf.app.flags.DEFINE_enum("bound", "fivo",
["elbo", "iwae", "fivo", "fivo-aux"],
"The bound to optimize.")
tf.app.flags.DEFINE_boolean("normalize_by_seq_len", True,
"If true, normalize the loss by the number of timesteps "
"per sequence.")
tf.app.flags.DEFINE_float("learning_rate", 0.0002,
"The learning rate for ADAM.")
tf.app.flags.DEFINE_integer("max_steps", int(1e9),
"The number of gradient update steps to train for.")
tf.app.flags.DEFINE_integer("summarize_every", 50,
"The number of steps between summaries.")
tf.app.flags.DEFINE_enum("resampling_type", "multinomial",
["multinomial", "relaxed"],
"The resampling strategy to use for training.")
tf.app.flags.DEFINE_float("relaxed_resampling_temperature", 0.5,
"The relaxation temperature for relaxed resampling.")
tf.app.flags.DEFINE_enum("proposal_type", "filtering",
["prior", "filtering", "smoothing",
"true-filtering", "true-smoothing"],
"The type of proposal to use. true-filtering and true-smoothing "
"are only available for the GHMM. The specific implementation "
"of each proposal type is left to model-writers.")
# Distributed training flags.
tf.app.flags.DEFINE_string("master", "",
"The BNS name of the TensorFlow master to use.")
tf.app.flags.DEFINE_integer("task", 0,
"Task id of the replica running the training.")
tf.app.flags.DEFINE_integer("ps_tasks", 0,
"Number of tasks in the ps job. If 0 no ps job is used.")
tf.app.flags.DEFINE_boolean("stagger_workers", True,
"If true, bring one worker online every 1000 steps.")
# Evaluation flags.
tf.app.flags.DEFINE_enum("split", "train",
["train", "test", "valid"],
"Split to evaluate the model on.")
# Sampling flags.
tf.app.flags.DEFINE_integer("sample_length", 50,
"The number of timesteps to sample for.")
tf.app.flags.DEFINE_integer("prefix_length", 25,
"The number of timesteps to condition the model on "
"before sampling.")
tf.app.flags.DEFINE_string("sample_out_dir", None,
"The directory to write the samples to. "
"Defaults to logdir.")
# GHMM flags.
tf.app.flags.DEFINE_float("variance", 0.1,
"The variance of the ghmm.")
tf.app.flags.DEFINE_integer("num_timesteps", 5,
"The number of timesteps to run the gmp for.")
FLAGS = tf.app.flags.FLAGS
PIANOROLL_DEFAULT_DATA_DIMENSION = 88
SPEECH_DEFAULT_DATA_DIMENSION = 200
def main(unused_argv):
tf.logging.set_verbosity(tf.logging.INFO)
if FLAGS.model in ["vrnn", "srnn"]:
if FLAGS.data_dimension is None:
if FLAGS.dataset_type == "pianoroll":
FLAGS.data_dimension = PIANOROLL_DEFAULT_DATA_DIMENSION
elif FLAGS.dataset_type == "speech":
FLAGS.data_dimension = SPEECH_DEFAULT_DATA_DIMENSION
if FLAGS.mode == "train":
runners.run_train(FLAGS)
elif FLAGS.mode == "eval":
runners.run_eval(FLAGS)
elif FLAGS.mode == "sample":
runners.run_sample(FLAGS)
elif FLAGS.model == "ghmm":
if FLAGS.mode == "train":
ghmm_runners.run_train(FLAGS)
elif FLAGS.mode == "eval":
ghmm_runners.run_eval(FLAGS)
if __name__ == "__main__":
tf.app.run(main)
|