# Copyright 2017 Google Inc. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== from __future__ import absolute_import from __future__ import division from __future__ import print_function import functools import tensorflow as tf import numpy as np from scipy.misc import logsumexp import tensorflow.contrib.slim as slim from tensorflow.python.ops import init_ops import utils as U try: xrange # Python 2 except NameError: xrange = range # Python 3 FLAGS = tf.flags.FLAGS Q_COLLECTION = "q_collection" P_COLLECTION = "p_collection" class SBN(object): # REINFORCE def __init__(self, hparams, activation_func=tf.nn.sigmoid, mean_xs = None, eval_mode=False): self.eval_mode = eval_mode self.hparams = hparams self.mean_xs = mean_xs self.train_bias= -np.log(1./np.clip(mean_xs, 0.001, 0.999)-1.).astype(np.float32) self.activation_func = activation_func self.n_samples = tf.placeholder('int32') self.x = tf.placeholder('float', [None, self.hparams.n_input]) self._x = tf.tile(self.x, [self.n_samples, 1]) self.batch_size = tf.shape(self._x)[0] self.uniform_samples = dict() self.uniform_samples_v = dict() self.prior = tf.Variable(tf.zeros([self.hparams.n_hidden], dtype=tf.float32), name='p_prior', collections=[tf.GraphKeys.GLOBAL_VARIABLES, P_COLLECTION]) self.run_recognition_network = False self.run_generator_network = False # Initialize temperature self.pre_temperature_variable = tf.Variable( np.log(self.hparams.temperature), trainable=False, dtype=tf.float32) self.temperature_variable = tf.exp(self.pre_temperature_variable) self.global_step = tf.Variable(0, trainable=False) self.baseline_loss = [] self.ema = tf.train.ExponentialMovingAverage(decay=0.999) self.maintain_ema_ops = [] self.optimizer_class = tf.train.AdamOptimizer( learning_rate=1*self.hparams.learning_rate, beta2=self.hparams.beta2) self._generate_randomness() self._create_network() def initialize(self, sess): self.sess = sess def _create_eta(self, shape=[], collection='CV'): return 2 * tf.sigmoid(tf.Variable(tf.zeros(shape), trainable=False, collections=[collection, tf.GraphKeys.GLOBAL_VARIABLES, Q_COLLECTION])) def _create_baseline(self, n_output=1, n_hidden=100, is_zero_init=False, collection='BASELINE'): # center input h = self._x if self.mean_xs is not None: h -= self.mean_xs if is_zero_init: initializer = init_ops.zeros_initializer() else: initializer = slim.variance_scaling_initializer() with slim.arg_scope([slim.fully_connected], variables_collections=[collection, Q_COLLECTION], trainable=False, weights_initializer=initializer): h = slim.fully_connected(h, n_hidden, activation_fn=tf.nn.tanh) baseline = slim.fully_connected(h, n_output, activation_fn=None) if n_output == 1: baseline = tf.reshape(baseline, [-1]) # very important to reshape return baseline def _create_transformation(self, input, n_output, reuse, scope_prefix): """Create the deterministic transformation between stochastic layers. If self.hparam.nonlinear: 2 x tanh layers Else: 1 x linear layer """ if self.hparams.nonlinear: h = slim.fully_connected(input, self.hparams.n_hidden, reuse=reuse, activation_fn=tf.nn.tanh, scope='%s_nonlinear_1' % scope_prefix) h = slim.fully_connected(h, self.hparams.n_hidden, reuse=reuse, activation_fn=tf.nn.tanh, scope='%s_nonlinear_2' % scope_prefix) h = slim.fully_connected(h, n_output, reuse=reuse, activation_fn=None, scope='%s' % scope_prefix) else: h = slim.fully_connected(input, n_output, reuse=reuse, activation_fn=None, scope='%s' % scope_prefix) return h def _recognition_network(self, sampler=None, log_likelihood_func=None): """x values -> samples from Q and return log Q(h|x).""" samples = {} reuse = None if not self.run_recognition_network else True # Set defaults if sampler is None: sampler = self._random_sample if log_likelihood_func is None: log_likelihood_func = lambda sample, log_params: ( U.binary_log_likelihood(sample['activation'], log_params)) logQ = [] if self.hparams.task in ['sbn', 'omni']: # Initialize the edge case samples[-1] = {'activation': self._x} if self.mean_xs is not None: samples[-1]['activation'] -= self.mean_xs # center the input samples[-1]['activation'] = (samples[-1]['activation'] + 1)/2.0 with slim.arg_scope([slim.fully_connected], weights_initializer=slim.variance_scaling_initializer(), variables_collections=[Q_COLLECTION]): for i in xrange(self.hparams.n_layer): # Set up the input to the layer input = 2.0*samples[i-1]['activation'] - 1.0 # Create the conditional distribution (output is the logits) h = self._create_transformation(input, n_output=self.hparams.n_hidden, reuse=reuse, scope_prefix='q_%d' % i) samples[i] = sampler(h, self.uniform_samples[i], i) logQ.append(log_likelihood_func(samples[i], h)) self.run_recognition_network = True return logQ, samples elif self.hparams.task == 'sp': # Initialize the edge case samples[-1] = {'activation': tf.split(self._x, num_or_size_splits=2, axis=1)[0]} # top half of digit if self.mean_xs is not None: samples[-1]['activation'] -= np.split(self.mean_xs, 2, 0)[0] # center the input samples[-1]['activation'] = (samples[-1]['activation'] + 1)/2.0 with slim.arg_scope([slim.fully_connected], weights_initializer=slim.variance_scaling_initializer(), variables_collections=[Q_COLLECTION]): for i in xrange(self.hparams.n_layer): # Set up the input to the layer input = 2.0*samples[i-1]['activation'] - 1.0 # Create the conditional distribution (output is the logits) h = self._create_transformation(input, n_output=self.hparams.n_hidden, reuse=reuse, scope_prefix='q_%d' % i) samples[i] = sampler(h, self.uniform_samples[i], i) logQ.append(log_likelihood_func(samples[i], h)) self.run_recognition_network = True return logQ, samples def _generator_network(self, samples, logQ, log_likelihood_func=None): '''Returns learning signal and function. This is the implementation for SBNs for the ELBO. Args: samples: dictionary of sampled latent variables logQ: list of log q(h_i) terms log_likelihood_func: function used to compute log probs for the latent variables Returns: learning_signal: the "reward" function function_term: part of the function that depends on the parameters and needs to have the gradient taken through ''' reuse=None if not self.run_generator_network else True if self.hparams.task in ['sbn', 'omni']: if log_likelihood_func is None: log_likelihood_func = lambda sample, log_params: ( U.binary_log_likelihood(sample['activation'], log_params)) logPPrior = log_likelihood_func( samples[self.hparams.n_layer-1], tf.expand_dims(self.prior, 0)) with slim.arg_scope([slim.fully_connected], weights_initializer=slim.variance_scaling_initializer(), variables_collections=[P_COLLECTION]): for i in reversed(xrange(self.hparams.n_layer)): if i == 0: n_output = self.hparams.n_input else: n_output = self.hparams.n_hidden input = 2.0*samples[i]['activation']-1.0 h = self._create_transformation(input, n_output, reuse=reuse, scope_prefix='p_%d' % i) if i == 0: # Assume output is binary logP = U.binary_log_likelihood(self._x, h + self.train_bias) else: logPPrior += log_likelihood_func(samples[i-1], h) self.run_generator_network = True return logP + logPPrior - tf.add_n(logQ), logP + logPPrior elif self.hparams.task == 'sp': with slim.arg_scope([slim.fully_connected], weights_initializer=slim.variance_scaling_initializer(), variables_collections=[P_COLLECTION]): n_output = int(self.hparams.n_input/2) i = self.hparams.n_layer - 1 # use the last layer input = 2.0*samples[i]['activation']-1.0 h = self._create_transformation(input, n_output, reuse=reuse, scope_prefix='p_%d' % i) # Predict on the lower half of the image logP = U.binary_log_likelihood(tf.split(self._x, num_or_size_splits=2, axis=1)[1], h + np.split(self.train_bias, 2, 0)[1]) self.run_generator_network = True return logP, logP def _create_loss(self): # Hard loss logQHard, samples = self._recognition_network() reinforce_learning_signal, reinforce_model_grad = self._generator_network(samples, logQHard) logQHard = tf.add_n(logQHard) # REINFORCE learning_signal = tf.stop_gradient(U.center(reinforce_learning_signal)) self.optimizerLoss = -(learning_signal*logQHard + reinforce_model_grad) self.lHat = map(tf.reduce_mean, [ reinforce_learning_signal, U.rms(learning_signal), ]) return reinforce_learning_signal def _reshape(self, t): return tf.transpose(tf.reshape(t, [self.n_samples, -1])) def compute_tensor_variance(self, t): """Compute the mean per component variance. Use a moving average to estimate the required moments. """ t_sq = tf.reduce_mean(tf.square(t)) self.maintain_ema_ops.append(self.ema.apply([t, t_sq])) # mean per component variance variance_estimator = (self.ema.average(t_sq) - tf.reduce_mean( tf.square(self.ema.average(t)))) return variance_estimator def _create_train_op(self, grads_and_vars, extra_grads_and_vars=[]): ''' Args: grads_and_vars: gradients to apply and compute running average variance extra_grads_and_vars: gradients to apply (not used to compute average variance) ''' # Variance summaries first_moment = U.vectorize(grads_and_vars, skip_none=True) second_moment = tf.square(first_moment) self.maintain_ema_ops.append(self.ema.apply([first_moment, second_moment])) # Add baseline losses if len(self.baseline_loss) > 0: mean_baseline_loss = tf.reduce_mean(tf.add_n(self.baseline_loss)) extra_grads_and_vars += self.optimizer_class.compute_gradients( mean_baseline_loss, var_list=tf.get_collection('BASELINE')) # Ensure that all required tensors are computed before updates are executed extra_optimizer = tf.train.AdamOptimizer( learning_rate=10*self.hparams.learning_rate, beta2=self.hparams.beta2) with tf.control_dependencies( [tf.group(*[g for g, _ in (grads_and_vars + extra_grads_and_vars) if g is not None])]): # Filter out the P_COLLECTION variables if we're in eval mode if self.eval_mode: grads_and_vars = [(g, v) for g, v in grads_and_vars if v not in tf.get_collection(P_COLLECTION)] train_op = self.optimizer_class.apply_gradients(grads_and_vars, global_step=self.global_step) if len(extra_grads_and_vars) > 0: extra_train_op = extra_optimizer.apply_gradients(extra_grads_and_vars) else: extra_train_op = tf.no_op() self.optimizer = tf.group(train_op, extra_train_op, *self.maintain_ema_ops) # per parameter variance variance_estimator = (self.ema.average(second_moment) - tf.square(self.ema.average(first_moment))) self.grad_variance = tf.reduce_mean(variance_estimator) def _create_network(self): logF = self._create_loss() self.optimizerLoss = tf.reduce_mean(self.optimizerLoss) # Setup optimizer grads_and_vars = self.optimizer_class.compute_gradients(self.optimizerLoss) self._create_train_op(grads_and_vars) # Create IWAE lower bound for evaluation self.logF = self._reshape(logF) self.iwae = tf.reduce_mean(U.logSumExp(self.logF, axis=1) - tf.log(tf.to_float(self.n_samples))) def partial_fit(self, X, n_samples=1): if hasattr(self, 'grad_variances'): grad_variance_field_to_return = self.grad_variances else: grad_variance_field_to_return = self.grad_variance _, res, grad_variance, step, temperature = self.sess.run( (self.optimizer, self.lHat, grad_variance_field_to_return, self.global_step, self.temperature_variable), feed_dict={self.x: X, self.n_samples: n_samples}) return res, grad_variance, step, temperature def partial_grad(self, X, n_samples=1): control_variate_grads, step = self.sess.run( (self.control_variate_grads, self.global_step), feed_dict={self.x: X, self.n_samples: n_samples}) return control_variate_grads, step def partial_eval(self, X, n_samples=5): if n_samples < 1000: res, iwae = self.sess.run( (self.lHat, self.iwae), feed_dict={self.x: X, self.n_samples: n_samples}) res = [iwae] + res else: # special case to handle OOM assert n_samples % 100 == 0, "When using large # of samples, it must be divisble by 100" res = [] for i in xrange(int(n_samples/100)): logF, = self.sess.run( (self.logF,), feed_dict={self.x: X, self.n_samples: 100}) res.append(logsumexp(logF, axis=1)) res = [np.mean(logsumexp(res, axis=0) - np.log(n_samples))] return res # Random samplers def _mean_sample(self, log_alpha, _, layer): """Returns mean of random variables parameterized by log_alpha.""" mu = tf.nn.sigmoid(log_alpha) return { 'preactivation': mu, 'activation': mu, 'log_param': log_alpha, } def _generate_randomness(self): for i in xrange(self.hparams.n_layer): self.uniform_samples[i] = tf.stop_gradient(tf.random_uniform( [self.batch_size, self.hparams.n_hidden])) def _u_to_v(self, log_alpha, u, eps = 1e-8): """Convert u to tied randomness in v.""" u_prime = tf.nn.sigmoid(-log_alpha) # g(u') = 0 v_1 = (u - u_prime) / tf.clip_by_value(1 - u_prime, eps, 1) v_1 = tf.clip_by_value(v_1, 0, 1) v_1 = tf.stop_gradient(v_1) v_1 = v_1*(1 - u_prime) + u_prime v_0 = u / tf.clip_by_value(u_prime, eps, 1) v_0 = tf.clip_by_value(v_0, 0, 1) v_0 = tf.stop_gradient(v_0) v_0 = v_0 * u_prime v = tf.where(u > u_prime, v_1, v_0) v = tf.check_numerics(v, 'v sampling is not numerically stable.') v = v + tf.stop_gradient(-v + u) # v and u are the same up to numerical errors return v def _random_sample(self, log_alpha, u, layer): """Returns sampled random variables parameterized by log_alpha.""" # Generate tied randomness for later if layer not in self.uniform_samples_v: self.uniform_samples_v[layer] = self._u_to_v(log_alpha, u) # Sample random variable underlying softmax/argmax x = log_alpha + U.safe_log_prob(u) - U.safe_log_prob(1 - u) samples = tf.stop_gradient(tf.to_float(x > 0)) return { 'preactivation': x, 'activation': samples, 'log_param': log_alpha, } def _random_sample_soft(self, log_alpha, u, layer, temperature=None): """Returns sampled random variables parameterized by log_alpha.""" if temperature is None: temperature = self.hparams.temperature # Sample random variable underlying softmax/argmax x = log_alpha + U.safe_log_prob(u) - U.safe_log_prob(1 - u) x /= tf.expand_dims(temperature, -1) if self.hparams.muprop_relaxation: y = tf.nn.sigmoid(x + log_alpha * tf.expand_dims(temperature/(temperature + 1), -1)) else: y = tf.nn.sigmoid(x) return { 'preactivation': x, 'activation': y, 'log_param': log_alpha } def _random_sample_soft_v(self, log_alpha, _, layer, temperature=None): """Returns sampled random variables parameterized by log_alpha.""" v = self.uniform_samples_v[layer] return self._random_sample_soft(log_alpha, v, layer, temperature) def get_gumbel_gradient(self): logQ, softSamples = self._recognition_network(sampler=self._random_sample_soft) logQ = tf.add_n(logQ) logPPrior, logP = self._generator_network(softSamples) softELBO = logPPrior + logP - logQ gumbel_gradient = (self.optimizer_class. compute_gradients(softELBO)) debug = { 'softELBO': softELBO, } return gumbel_gradient, debug # samplers used for quadratic version def _random_sample_switch(self, log_alpha, u, layer, switch_layer, temperature=None): """Run partial discrete, then continuous path. Args: switch_layer: this layer and beyond will be continuous """ if layer < switch_layer: return self._random_sample(log_alpha, u, layer) else: return self._random_sample_soft(log_alpha, u, layer, temperature) def _random_sample_switch_v(self, log_alpha, u, layer, switch_layer, temperature=None): """Run partial discrete, then continuous path. Args: switch_layer: this layer and beyond will be continuous """ if layer < switch_layer: return self._random_sample(log_alpha, u, layer) else: return self._random_sample_soft_v(log_alpha, u, layer, temperature) # ##### # Gradient computation # ##### def get_nvil_gradient(self): """Compute the NVIL gradient.""" # Hard loss logQHard, samples = self._recognition_network() ELBO, reinforce_model_grad = self._generator_network(samples, logQHard) logQHard = tf.add_n(logQHard) # Add baselines (no variance normalization) learning_signal = tf.stop_gradient(ELBO) - self._create_baseline() # Set up losses self.baseline_loss.append(tf.square(learning_signal)) optimizerLoss = -(tf.stop_gradient(learning_signal)*logQHard + reinforce_model_grad) optimizerLoss = tf.reduce_mean(optimizerLoss) nvil_gradient = self.optimizer_class.compute_gradients(optimizerLoss) debug = { 'ELBO': ELBO, 'RMS of centered learning signal': U.rms(learning_signal), } return nvil_gradient, debug def get_simple_muprop_gradient(self): """ Computes the simple muprop gradient. This muprop control variate does not include the linear term. """ # Hard loss logQHard, hardSamples = self._recognition_network() hardELBO, reinforce_model_grad = self._generator_network(hardSamples, logQHard) # Soft loss logQ, muSamples = self._recognition_network(sampler=self._mean_sample) muELBO, _ = self._generator_network(muSamples, logQ) scaling_baseline = self._create_eta(collection='BASELINE') learning_signal = (hardELBO - scaling_baseline * muELBO - self._create_baseline()) self.baseline_loss.append(tf.square(learning_signal)) optimizerLoss = -(tf.stop_gradient(learning_signal) * tf.add_n(logQHard) + reinforce_model_grad) optimizerLoss = tf.reduce_mean(optimizerLoss) simple_muprop_gradient = (self.optimizer_class. compute_gradients(optimizerLoss)) debug = { 'ELBO': hardELBO, 'muELBO': muELBO, 'RMS': U.rms(learning_signal), } return simple_muprop_gradient, debug def get_muprop_gradient(self): """ random sample function that actually returns mean new forward pass that returns logQ as a list can get x_i from samples """ # Hard loss logQHard, hardSamples = self._recognition_network() hardELBO, reinforce_model_grad = self._generator_network(hardSamples, logQHard) # Soft loss logQ, muSamples = self._recognition_network(sampler=self._mean_sample) muELBO, _ = self._generator_network(muSamples, logQ) # Compute gradients muELBOGrads = tf.gradients(tf.reduce_sum(muELBO), [ muSamples[i]['activation'] for i in xrange(self.hparams.n_layer) ]) # Compute MuProp gradient estimates learning_signal = hardELBO optimizerLoss = 0.0 learning_signals = [] for i in xrange(self.hparams.n_layer): dfDiff = tf.reduce_sum( muELBOGrads[i] * (hardSamples[i]['activation'] - muSamples[i]['activation']), axis=1) dfMu = tf.reduce_sum( tf.stop_gradient(muELBOGrads[i]) * tf.nn.sigmoid(hardSamples[i]['log_param']), axis=1) scaling_baseline_0 = self._create_eta(collection='BASELINE') scaling_baseline_1 = self._create_eta(collection='BASELINE') learning_signals.append(learning_signal - scaling_baseline_0 * muELBO - scaling_baseline_1 * dfDiff - self._create_baseline()) self.baseline_loss.append(tf.square(learning_signals[i])) optimizerLoss += ( logQHard[i] * tf.stop_gradient(learning_signals[i]) + tf.stop_gradient(scaling_baseline_1) * dfMu) optimizerLoss += reinforce_model_grad optimizerLoss *= -1 optimizerLoss = tf.reduce_mean(optimizerLoss) muprop_gradient = self.optimizer_class.compute_gradients(optimizerLoss) debug = { 'ELBO': hardELBO, 'muELBO': muELBO, } debug.update(dict([ ('RMS learning signal layer %d' % i, U.rms(learning_signal)) for (i, learning_signal) in enumerate(learning_signals)])) return muprop_gradient, debug # REBAR gradient helper functions def _create_gumbel_control_variate(self, logQHard, temperature=None): '''Calculate gumbel control variate. ''' if temperature is None: temperature = self.hparams.temperature logQ, softSamples = self._recognition_network(sampler=functools.partial( self._random_sample_soft, temperature=temperature)) softELBO, _ = self._generator_network(softSamples, logQ) logQ = tf.add_n(logQ) # Generate the softELBO_v (should be the same value but different grads) logQ_v, softSamples_v = self._recognition_network(sampler=functools.partial( self._random_sample_soft_v, temperature=temperature)) softELBO_v, _ = self._generator_network(softSamples_v, logQ_v) logQ_v = tf.add_n(logQ_v) # Compute losses learning_signal = tf.stop_gradient(softELBO_v) # Control variate h = (tf.stop_gradient(learning_signal) * tf.add_n(logQHard) - softELBO + softELBO_v) extra = (softELBO_v, -softELBO + softELBO_v) return h, extra def _create_gumbel_control_variate_quadratic(self, logQHard, temperature=None): '''Calculate gumbel control variate. ''' if temperature is None: temperature = self.hparams.temperature h = 0 extra = [] for layer in xrange(self.hparams.n_layer): logQ, softSamples = self._recognition_network(sampler=functools.partial( self._random_sample_switch, switch_layer=layer, temperature=temperature)) softELBO, _ = self._generator_network(softSamples, logQ) # Generate the softELBO_v (should be the same value but different grads) logQ_v, softSamples_v = self._recognition_network(sampler=functools.partial( self._random_sample_switch_v, switch_layer=layer, temperature=temperature)) softELBO_v, _ = self._generator_network(softSamples_v, logQ_v) # Compute losses learning_signal = tf.stop_gradient(softELBO_v) # Control variate h += (tf.stop_gradient(learning_signal) * logQHard[layer] - softELBO + softELBO_v) extra.append((softELBO_v, -softELBO + softELBO_v)) return h, extra def _create_hard_elbo(self): logQHard, hardSamples = self._recognition_network() hardELBO, reinforce_model_grad = self._generator_network(hardSamples, logQHard) reinforce_learning_signal = tf.stop_gradient(hardELBO) # Center learning signal baseline = self._create_baseline(collection='CV') reinforce_learning_signal = tf.stop_gradient(reinforce_learning_signal) - baseline nvil_gradient = (tf.stop_gradient(hardELBO) - baseline) * tf.add_n(logQHard) + reinforce_model_grad return hardELBO, nvil_gradient, logQHard def multiply_by_eta(self, h_grads, eta): # Modifies eta res = [] eta_statistics = [] for (g, v) in h_grads: if g is None: res.append((g, v)) else: if 'network' not in eta: eta['network'] = self._create_eta() res.append((g*eta['network'], v)) eta_statistics.append(eta['network']) return res, eta_statistics def multiply_by_eta_per_layer(self, h_grads, eta): # Modifies eta res = [] eta_statistics = [] for (g, v) in h_grads: if g is None: res.append((g, v)) else: if v not in eta: eta[v] = self._create_eta() res.append((g*eta[v], v)) eta_statistics.append(eta[v]) return res, eta_statistics def multiply_by_eta_per_unit(self, h_grads, eta): # Modifies eta res = [] eta_statistics = [] for (g, v) in h_grads: if g is None: res.append((g, v)) else: if v not in eta: g_shape = g.shape_as_list() assert len(g_shape) <= 2, 'Gradient has too many dimensions' if len(g_shape) == 1: eta[v] = self._create_eta(g_shape) else: eta[v] = self._create_eta([1, g_shape[1]]) h_grads.append((g*eta[v], v)) eta_statistics.extend(tf.nn.moments(tf.squeeze(eta[v]), axes=[0])) return res, eta_statistics def get_dynamic_rebar_gradient(self): """Get the dynamic rebar gradient (t, eta optimized).""" tiled_pre_temperature = tf.tile([self.pre_temperature_variable], [self.batch_size]) temperature = tf.exp(tiled_pre_temperature) hardELBO, nvil_gradient, logQHard = self._create_hard_elbo() if self.hparams.quadratic: gumbel_cv, extra = self._create_gumbel_control_variate_quadratic(logQHard, temperature=temperature) else: gumbel_cv, extra = self._create_gumbel_control_variate(logQHard, temperature=temperature) f_grads = self.optimizer_class.compute_gradients(tf.reduce_mean(-nvil_gradient)) eta = {} h_grads, eta_statistics = self.multiply_by_eta_per_layer( self.optimizer_class.compute_gradients(tf.reduce_mean(gumbel_cv)), eta) model_grads = U.add_grads_and_vars(f_grads, h_grads) total_grads = model_grads # Construct the variance objective g = U.vectorize(model_grads, set_none_to_zero=True) self.maintain_ema_ops.append(self.ema.apply([g])) gbar = 0 #tf.stop_gradient(self.ema.average(g)) variance_objective = tf.reduce_mean(tf.square(g - gbar)) reinf_g_t = 0 if self.hparams.quadratic: for layer in xrange(self.hparams.n_layer): gumbel_learning_signal, _ = extra[layer] df_dt = tf.gradients(gumbel_learning_signal, tiled_pre_temperature)[0] reinf_g_t_i, _ = self.multiply_by_eta_per_layer( self.optimizer_class.compute_gradients(tf.reduce_mean(tf.stop_gradient(df_dt) * logQHard[layer])), eta) reinf_g_t += U.vectorize(reinf_g_t_i, set_none_to_zero=True) reparam = tf.add_n([reparam_i for _, reparam_i in extra]) else: gumbel_learning_signal, reparam = extra df_dt = tf.gradients(gumbel_learning_signal, tiled_pre_temperature)[0] reinf_g_t, _ = self.multiply_by_eta_per_layer( self.optimizer_class.compute_gradients(tf.reduce_mean(tf.stop_gradient(df_dt) * tf.add_n(logQHard))), eta) reinf_g_t = U.vectorize(reinf_g_t, set_none_to_zero=True) reparam_g, _ = self.multiply_by_eta_per_layer( self.optimizer_class.compute_gradients(tf.reduce_mean(reparam)), eta) reparam_g = U.vectorize(reparam_g, set_none_to_zero=True) reparam_g_t = tf.gradients(tf.reduce_mean(2*tf.stop_gradient(g - gbar)*reparam_g), self.pre_temperature_variable)[0] variance_objective_grad = tf.reduce_mean(2*(g - gbar)*reinf_g_t) + reparam_g_t debug = { 'ELBO': hardELBO, 'etas': eta_statistics, 'variance_objective': variance_objective, } return total_grads, debug, variance_objective, variance_objective_grad def get_rebar_gradient(self): """Get the rebar gradient.""" hardELBO, nvil_gradient, logQHard = self._create_hard_elbo() if self.hparams.quadratic: gumbel_cv, _ = self._create_gumbel_control_variate_quadratic(logQHard) else: gumbel_cv, _ = self._create_gumbel_control_variate(logQHard) f_grads = self.optimizer_class.compute_gradients(tf.reduce_mean(-nvil_gradient)) eta = {} h_grads, eta_statistics = self.multiply_by_eta_per_layer( self.optimizer_class.compute_gradients(tf.reduce_mean(gumbel_cv)), eta) model_grads = U.add_grads_and_vars(f_grads, h_grads) total_grads = model_grads # Construct the variance objective variance_objective = tf.reduce_mean(tf.square(U.vectorize(model_grads, set_none_to_zero=True))) debug = { 'ELBO': hardELBO, 'etas': eta_statistics, 'variance_objective': variance_objective, } return total_grads, debug, variance_objective ### # Create varaints ### class SBNSimpleMuProp(SBN): def _create_loss(self): simple_muprop_gradient, debug = self.get_simple_muprop_gradient() self.lHat = map(tf.reduce_mean, [ debug['ELBO'], debug['muELBO'], ]) return debug['ELBO'], simple_muprop_gradient def _create_network(self): logF, loss_grads = self._create_loss() self._create_train_op(loss_grads) # Create IWAE lower bound for evaluation self.logF = self._reshape(logF) self.iwae = tf.reduce_mean(U.logSumExp(self.logF, axis=1) - tf.log(tf.to_float(self.n_samples))) class SBNMuProp(SBN): def _create_loss(self): muprop_gradient, debug = self.get_muprop_gradient() self.lHat = map(tf.reduce_mean, [ debug['ELBO'], debug['muELBO'], ]) return debug['ELBO'], muprop_gradient def _create_network(self): logF, loss_grads = self._create_loss() self._create_train_op(loss_grads) # Create IWAE lower bound for evaluation self.logF = self._reshape(logF) self.iwae = tf.reduce_mean(U.logSumExp(self.logF, axis=1) - tf.log(tf.to_float(self.n_samples))) class SBNNVIL(SBN): def _create_loss(self): nvil_gradient, debug = self.get_nvil_gradient() self.lHat = map(tf.reduce_mean, [ debug['ELBO'], ]) return debug['ELBO'], nvil_gradient def _create_network(self): logF, loss_grads = self._create_loss() self._create_train_op(loss_grads) # Create IWAE lower bound for evaluation self.logF = self._reshape(logF) self.iwae = tf.reduce_mean(U.logSumExp(self.logF, axis=1) - tf.log(tf.to_float(self.n_samples))) class SBNRebar(SBN): def _create_loss(self): rebar_gradient, debug, variance_objective = self.get_rebar_gradient() self.lHat = map(tf.reduce_mean, [ debug['ELBO'], ]) self.lHat.extend(map(tf.reduce_mean, debug['etas'])) return debug['ELBO'], rebar_gradient, variance_objective def _create_network(self): logF, loss_grads, variance_objective = self._create_loss() # Create additional updates for control variates and temperature eta_grads = (self.optimizer_class.compute_gradients(variance_objective, var_list=tf.get_collection('CV'))) self._create_train_op(loss_grads, eta_grads) # Create IWAE lower bound for evaluation self.logF = self._reshape(logF) self.iwae = tf.reduce_mean(U.logSumExp(self.logF, axis=1) - tf.log(tf.to_float(self.n_samples))) class SBNDynamicRebar(SBN): def _create_loss(self): rebar_gradient, debug, variance_objective, variance_objective_grad = self.get_dynamic_rebar_gradient() self.lHat = map(tf.reduce_mean, [ debug['ELBO'], self.temperature_variable, ]) self.lHat.extend(debug['etas']) return debug['ELBO'], rebar_gradient, variance_objective, variance_objective_grad def _create_network(self): logF, loss_grads, variance_objective, variance_objective_grad = self._create_loss() # Create additional updates for control variates and temperature eta_grads = (self.optimizer_class.compute_gradients(variance_objective, var_list=tf.get_collection('CV')) + [(variance_objective_grad, self.pre_temperature_variable)]) self._create_train_op(loss_grads, eta_grads) # Create IWAE lower bound for evaluation self.logF = self._reshape(logF) self.iwae = tf.reduce_mean(U.logSumExp(self.logF, axis=1) - tf.log(tf.to_float(self.n_samples))) class SBNTrackGradVariances(SBN): """Follow NVIL, compute gradient variances for NVIL, MuProp and REBAR.""" def compute_gradient_moments(self, grads_and_vars): first_moment = U.vectorize(grads_and_vars, set_none_to_zero=True) second_moment = tf.square(first_moment) self.maintain_ema_ops.append(self.ema.apply([first_moment, second_moment])) return self.ema.average(first_moment), self.ema.average(second_moment) def _create_loss(self): self.losses = [ ('NVIL', self.get_nvil_gradient), ('SimpleMuProp', self.get_simple_muprop_gradient), ('MuProp', self.get_muprop_gradient), ] moments = [] for k, v in self.losses: print(k) gradient, debug = v() if k == 'SimpleMuProp': ELBO = debug['ELBO'] gradient_to_follow = gradient moments.append(self.compute_gradient_moments( gradient)) self.losses.append(('DynamicREBAR', self.get_dynamic_rebar_gradient)) dynamic_rebar_gradient, _, variance_objective, variance_objective_grad = self.get_dynamic_rebar_gradient() moments.append(self.compute_gradient_moments(dynamic_rebar_gradient)) self.losses.append(('REBAR', self.get_rebar_gradient)) rebar_gradient, _, variance_objective2 = self.get_rebar_gradient() moments.append(self.compute_gradient_moments(rebar_gradient)) mu = tf.reduce_mean(tf.stack([f for f, _ in moments]), axis=0) self.grad_variances = [] deviations = [] for f, s in moments: self.grad_variances.append(tf.reduce_mean(s - tf.square(mu))) deviations.append(tf.reduce_mean(tf.square(f - mu))) self.lHat = map(tf.reduce_mean, [ ELBO, self.temperature_variable, variance_objective_grad, variance_objective_grad*variance_objective_grad, ]) self.lHat.extend(deviations) self.lHat.append(tf.log(tf.reduce_mean(mu*mu))) # self.lHat.extend(map(tf.log, grad_variances)) return ELBO, gradient_to_follow, variance_objective + variance_objective2, variance_objective_grad def _create_network(self): logF, loss_grads, variance_objective, variance_objective_grad = self._create_loss() eta_grads = (self.optimizer_class.compute_gradients(variance_objective, var_list=tf.get_collection('CV')) + [(variance_objective_grad, self.pre_temperature_variable)]) self._create_train_op(loss_grads, eta_grads) # Create IWAE lower bound for evaluation self.logF = self._reshape(logF) self.iwae = tf.reduce_mean(U.logSumExp(self.logF, axis=1) - tf.log(tf.to_float(self.n_samples))) class SBNGumbel(SBN): def _random_sample_soft(self, log_alpha, u, layer, temperature=None): """Returns sampled random variables parameterized by log_alpha.""" if temperature is None: temperature = self.hparams.temperature # Sample random variable underlying softmax/argmax x = log_alpha + U.safe_log_prob(u) - U.safe_log_prob(1 - u) x /= temperature if self.hparams.muprop_relaxation: x += temperature/(temperature + 1)*log_alpha y = tf.nn.sigmoid(x) return { 'preactivation': x, 'activation': y, 'log_param': log_alpha } def _create_loss(self): # Hard loss logQHard, hardSamples = self._recognition_network() hardELBO, _ = self._generator_network(hardSamples, logQHard) logQ, softSamples = self._recognition_network(sampler=self._random_sample_soft) softELBO, _ = self._generator_network(softSamples, logQ) self.optimizerLoss = -softELBO self.lHat = map(tf.reduce_mean, [ hardELBO, softELBO, ]) return hardELBO default_hparams = tf.contrib.training.HParams(model='SBNGumbel', n_hidden=200, n_input=784, n_layer=1, nonlinear=False, learning_rate=0.001, temperature=0.5, n_samples=1, batch_size=24, trial=1, muprop_relaxation=True, dynamic_b=False, # dynamic binarization quadratic=True, beta2=0.99999, task='sbn', )