NCTCMumbai's picture
Upload 2571 files
0b8359d
raw
history blame
35.5 kB
# Copyright 2017 Google, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Collection of trainable optimizers for meta-optimization."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import math
import numpy as np
import tensorflow as tf
from tensorflow.python.ops import state_ops
from learned_optimizer.optimizer import rnn_cells
from learned_optimizer.optimizer import trainable_optimizer as opt
from learned_optimizer.optimizer import utils
# Default was 0.1
tf.app.flags.DEFINE_float("biasgrucell_scale", 0.5,
"""The scale for the internal BiasGRUCell vars.""")
# Default was 0
tf.app.flags.DEFINE_float("biasgrucell_gate_bias_init", 2.2,
"""The bias for the internal BiasGRUCell reset and
update gate variables.""")
# Default was 1e-3
tf.app.flags.DEFINE_float("hrnn_rnn_readout_scale", 0.5,
"""The initialization scale for the RNN readouts.""")
tf.app.flags.DEFINE_float("hrnn_default_decay_var_init", 2.2,
"""The default initializer value for any decay/
momentum style variables and constants.
sigmoid(2.2) ~ 0.9, sigmoid(-2.2) ~ 0.01.""")
# Default was 2.2
tf.app.flags.DEFINE_float("scale_decay_bias_init", 3.2,
"""The initialization for the scale decay bias. This
is the initial bias for the timescale for the
exponential avg of the mean square gradients.""")
tf.app.flags.DEFINE_float("learning_rate_momentum_logit_init", 3.2,
"""Initialization for the learning rate momentum.""")
# Default was 0.1
tf.app.flags.DEFINE_float("hrnn_affine_scale", 0.5,
"""The initialization scale for the weight matrix of
the bias variables in layer0 and 1 of the hrnn.""")
FLAGS = tf.flags.FLAGS
class HierarchicalRNN(opt.TrainableOptimizer):
"""3 level hierarchical RNN.
Optionally uses second order gradient information and has decoupled evaluation
and update locations.
"""
def __init__(self, level_sizes, init_lr_range=(1e-6, 1e-2),
learnable_decay=True, dynamic_output_scale=True,
use_attention=False, use_log_objective=True,
num_gradient_scales=4, zero_init_lr_weights=True,
use_log_means_squared=True, use_relative_lr=True,
use_extreme_indicator=False, max_log_lr=33,
obj_train_max_multiplier=-1, use_problem_lr_mean=False,
use_gradient_shortcut=False, use_lr_shortcut=False,
use_grad_products=False, use_multiple_scale_decays=False,
learnable_inp_decay=True, learnable_rnn_init=True,
random_seed=None, **kwargs):
"""Initializes the RNN per-parameter optimizer.
The hierarchy consists of up to three levels:
Level 0: per parameter RNN
Level 1: per tensor RNN
Level 2: global RNN
Args:
level_sizes: list or tuple with 1, 2, or 3 integers, the number of units
in each RNN in the hierarchy (level0, level1, level2).
length 1: only coordinatewise rnn's will be used
length 2: coordinatewise and tensor-level rnn's will be used
length 3: a single global-level rnn will be used in addition to
coordinatewise and tensor-level
init_lr_range: the range in which to initialize the learning rates
learnable_decay: whether to learn weights that dynamically modulate the
input scale via RMS style decay
dynamic_output_scale: whether to learn weights that dynamically modulate
the output scale
use_attention: whether to use attention to train the optimizer
use_log_objective: whether to train on the log of the objective
num_gradient_scales: the number of scales to use for gradient history
zero_init_lr_weights: whether to initialize the lr weights to zero
use_log_means_squared: whether to track the log of the means_squared,
used as a measure of signal vs. noise in gradient.
use_relative_lr: whether to use the relative learning rate as an
input during training (requires learnable_decay=True)
use_extreme_indicator: whether to use the extreme indicator for learning
rates as an input during training (requires learnable_decay=True)
max_log_lr: the maximum log learning rate allowed during train or test
obj_train_max_multiplier: max objective increase during a training run
use_problem_lr_mean: whether to use the mean over all learning rates in
the problem when calculating the relative learning rate as opposed to
the per-tensor mean
use_gradient_shortcut: Whether to add a learned affine projection of the
gradient to the update delta in addition to the gradient function
computed by the RNN
use_lr_shortcut: Whether to add as input the difference between the log lr
and the desired log lr (1e-3)
use_grad_products: Whether to use gradient products in the rnn input.
Only applicable if num_gradient_scales > 1
use_multiple_scale_decays: Whether to use multiple scales for the scale
decay, as with input decay
learnable_inp_decay: Whether to learn the input decay weights and bias.
learnable_rnn_init: Whether to learn the RNN state initialization.
random_seed: Random seed for random variable initializers. (Default: None)
**kwargs: args passed to TrainableOptimizer's constructor
Raises:
ValueError: If level_sizes is not a length 1, 2, or 3 list.
ValueError: If there are any non-integer sizes in level_sizes.
ValueError: If the init lr range is not of length 2.
ValueError: If the init lr range is not a valid range (min > max).
"""
if len(level_sizes) not in [1, 2, 3]:
raise ValueError("HierarchicalRNN only supports 1, 2, or 3 levels in the "
"hierarchy, but {} were requested.".format(
len(level_sizes)))
if any(not isinstance(level, int) for level in level_sizes):
raise ValueError("Level sizes must be integer values, were {}".format(
level_sizes))
if len(init_lr_range) != 2:
raise ValueError(
"Initial LR range must be len 2, was {}".format(len(init_lr_range)))
if init_lr_range[0] > init_lr_range[1]:
raise ValueError("Initial LR range min is greater than max.")
self.learnable_decay = learnable_decay
self.dynamic_output_scale = dynamic_output_scale
self.use_attention = use_attention
self.use_log_objective = use_log_objective
self.num_gradient_scales = num_gradient_scales
self.zero_init_lr_weights = zero_init_lr_weights
self.use_log_means_squared = use_log_means_squared
self.use_relative_lr = use_relative_lr
self.use_extreme_indicator = use_extreme_indicator
self.max_log_lr = max_log_lr
self.use_problem_lr_mean = use_problem_lr_mean
self.use_gradient_shortcut = use_gradient_shortcut
self.use_lr_shortcut = use_lr_shortcut
self.use_grad_products = use_grad_products
self.use_multiple_scale_decays = use_multiple_scale_decays
self.learnable_inp_decay = learnable_inp_decay
self.learnable_rnn_init = learnable_rnn_init
self.random_seed = random_seed
self.num_layers = len(level_sizes)
self.init_lr_range = init_lr_range
self.reuse_vars = None
self.reuse_global_state = None
self.cells = []
self.init_vectors = []
with tf.variable_scope(opt.OPTIMIZER_SCOPE):
self._initialize_rnn_cells(level_sizes)
# get the cell size for the per-parameter RNN (Level 0)
cell_size = level_sizes[0]
# Random normal initialization scaled by the output size. This is the
# scale for the RNN *readouts*. RNN internal weight scale is set in the
# BiasGRUCell call.
scale_factor = FLAGS.hrnn_rnn_readout_scale / math.sqrt(cell_size)
scaled_init = tf.random_normal_initializer(0., scale_factor,
seed=self.random_seed)
# weights for projecting the hidden state to a parameter update
self.update_weights = tf.get_variable("update_weights",
shape=(cell_size, 1),
initializer=scaled_init)
if self.use_attention:
# weights for projecting the hidden state to the location at which the
# gradient is attended
self.attention_weights = tf.get_variable(
"attention_weights",
initializer=self.update_weights.initialized_value())
# weights for projecting the hidden state to the RMS decay term
self._initialize_scale_decay((cell_size, 1), scaled_init)
self._initialize_input_decay((cell_size, 1), scaled_init)
self._initialize_lr((cell_size, 1), scaled_init)
state_keys = ["parameter", "layer", "scl_decay", "inp_decay", "true_param"]
if self.dynamic_output_scale:
state_keys.append("log_learning_rate")
for i in range(self.num_gradient_scales):
state_keys.append("grad_accum{}".format(i + 1))
state_keys.append("ms{}".format(i + 1))
super(HierarchicalRNN, self).__init__(
"hRNN", state_keys, use_attention=use_attention,
use_log_objective=use_log_objective,
obj_train_max_multiplier=obj_train_max_multiplier, **kwargs)
def _initialize_rnn_cells(self, level_sizes):
"""Initializes the RNN cells to use in the hierarchical RNN."""
# RNN Cell layers (0 -> lowest, 1 -> middle, 2 -> global)
for level in range(self.num_layers):
scope = "Level{}_RNN".format(level)
with tf.variable_scope(scope):
hcell = rnn_cells.BiasGRUCell(
level_sizes[level],
scale=FLAGS.biasgrucell_scale,
gate_bias_init=FLAGS.biasgrucell_gate_bias_init,
random_seed=self.random_seed)
self.cells.append(hcell)
if self.learnable_rnn_init:
self.init_vectors.append(tf.Variable(
tf.random_uniform([1, hcell.state_size], -1., 1.,
seed=self.random_seed),
name="init_vector"))
else:
self.init_vectors.append(
tf.random_uniform([1, hcell.state_size], -1., 1.,
seed=self.random_seed))
def _initialize_scale_decay(self, weights_tensor_shape, scaled_init):
"""Initializes the scale decay weights and bias variables or tensors.
Args:
weights_tensor_shape: The shape the weight tensor should take.
scaled_init: The scaled initialization for the weights tensor.
"""
if self.learnable_decay:
self.scl_decay_weights = tf.get_variable("scl_decay_weights",
shape=weights_tensor_shape,
initializer=scaled_init)
scl_decay_bias_init = tf.constant_initializer(
FLAGS.scale_decay_bias_init)
self.scl_decay_bias = tf.get_variable("scl_decay_bias",
shape=(1,),
initializer=scl_decay_bias_init)
else:
self.scl_decay_weights = tf.zeros_like(self.update_weights)
self.scl_decay_bias = tf.log(0.93 / (1. - 0.93))
def _initialize_input_decay(self, weights_tensor_shape, scaled_init):
"""Initializes the input scale decay weights and bias variables or tensors.
Args:
weights_tensor_shape: The shape the weight tensor should take.
scaled_init: The scaled initialization for the weights tensor.
"""
if (self.learnable_decay and self.num_gradient_scales > 1 and
self.learnable_inp_decay):
self.inp_decay_weights = tf.get_variable("inp_decay_weights",
shape=weights_tensor_shape,
initializer=scaled_init)
inp_decay_bias_init = tf.constant_initializer(
FLAGS.hrnn_default_decay_var_init)
self.inp_decay_bias = tf.get_variable("inp_decay_bias",
shape=(1,),
initializer=inp_decay_bias_init)
else:
self.inp_decay_weights = tf.zeros_like(self.update_weights)
self.inp_decay_bias = tf.log(0.89 / (1. - 0.89))
def _initialize_lr(self, weights_tensor_shape, scaled_init):
"""Initializes the learning rate weights and bias variables or tensors.
Args:
weights_tensor_shape: The shape the weight tensor should take.
scaled_init: The scaled initialization for the weights tensor.
"""
if self.dynamic_output_scale:
zero_init = tf.constant_initializer(0.)
wt_init = zero_init if self.zero_init_lr_weights else scaled_init
self.lr_weights = tf.get_variable("learning_rate_weights",
shape=weights_tensor_shape,
initializer=wt_init)
self.lr_bias = tf.get_variable("learning_rate_bias", shape=(1,),
initializer=zero_init)
else:
self.lr_weights = tf.zeros_like(self.update_weights)
self.lr_bias = tf.zeros([1, 1])
def _initialize_state(self, var):
"""Return a dictionary mapping names of state variables to their values."""
var_vectorized = tf.reshape(var, [-1, 1])
ndim = var_vectorized.get_shape().as_list()[0]
state = {
# parameter init tensor is [var_ndim x layer0_cell_size]
"parameter": tf.ones([ndim, 1]) * self.init_vectors[0],
"scl_decay": tf.zeros_like(var_vectorized),
"inp_decay": tf.zeros_like(var_vectorized),
"true_param": var,
}
if self.num_layers > 1:
# layer init tensor is [1 x layer1_cell_size]
state["layer"] = tf.ones([1, 1]) * self.init_vectors[1]
if self.dynamic_output_scale:
min_lr = self.init_lr_range[0]
max_lr = self.init_lr_range[1]
if min_lr == max_lr:
log_init_lr = tf.log(min_lr * tf.ones_like(var_vectorized))
else:
# Use a random offset to increase the likelihood that the average of the
# LRs for this variable is different from the LRs for other variables.
actual_vals = tf.random_uniform(var_vectorized.get_shape().as_list(),
np.log(min_lr) / 2.,
np.log(max_lr) / 2.,
seed=self.random_seed)
offset = tf.random_uniform((), np.log(min_lr) / 2., np.log(max_lr) / 2.,
seed=self.random_seed)
log_init_lr = actual_vals + offset
# Clip the log learning rate to the flag at the top end, and to
# (log(min int32) - 1) at the bottom
clipped = tf.clip_by_value(log_init_lr, -33, self.max_log_lr)
state["log_learning_rate"] = clipped
for i in range(self.num_gradient_scales):
state["grad_accum{}".format(i + 1)] = tf.zeros_like(var_vectorized)
state["ms{}".format(i + 1)] = tf.zeros_like(var_vectorized)
return state
def _initialize_global_state(self):
if self.num_layers < 3:
return []
rnn_global_init = tf.ones([1, 1]) * self.init_vectors[2]
return [rnn_global_init]
def _compute_updates(self, params, grads, states, global_state):
# Store the updated parameters and states.
updated_params = []
updated_attention = []
updated_states = []
with tf.variable_scope(opt.OPTIMIZER_SCOPE):
mean_log_lr = self._compute_mean_log_lr(states)
# Iterate over the layers.
for param, grad_unflat, state in zip(params, grads, states):
with tf.variable_scope("PerTensor", reuse=self.reuse_vars):
self.reuse_vars = True
grad = tf.reshape(grad_unflat, [-1, 1])
# Create the RNN input. We will optionally extend it with additional
# features such as curvature and gradient signal vs. noise.
(grads_scaled, mean_squared_gradients,
grads_accum) = self._compute_scaled_and_ms_grads(grad, state)
rnn_input = [g for g in grads_scaled]
self._extend_rnn_input(rnn_input, state, grads_scaled,
mean_squared_gradients, mean_log_lr)
# Concatenate any features we've collected.
rnn_input_tensor = tf.concat(rnn_input, 1)
layer_state, new_param_state = self._update_rnn_cells(
state, global_state, rnn_input_tensor,
len(rnn_input) != len(grads_scaled))
(scl_decay, inp_decay, new_log_lr, update_step, lr_attend,
attention_delta) = self._compute_rnn_state_projections(
state, new_param_state, grads_scaled)
# Apply updates and store state variables.
if self.use_attention:
truth = state["true_param"]
updated_param = truth - update_step
attention_step = tf.reshape(lr_attend * attention_delta,
truth.get_shape())
updated_attention.append(truth - attention_step)
else:
updated_param = param - update_step
updated_attention.append(updated_param)
updated_params.append(updated_param)
# Collect the new state.
new_state = {
"parameter": new_param_state,
"scl_decay": scl_decay,
"inp_decay": inp_decay,
"true_param": updated_param,
}
if layer_state is not None:
new_state["layer"] = layer_state
if self.dynamic_output_scale:
new_state["log_learning_rate"] = new_log_lr
for i in range(self.num_gradient_scales):
new_state["grad_accum{}".format(i + 1)] = grads_accum[i]
new_state["ms{}".format(i + 1)] = mean_squared_gradients[i]
updated_states.append(new_state)
updated_global_state = self._compute_updated_global_state([layer_state],
global_state)
return (updated_params, updated_states, [updated_global_state],
updated_attention)
def _compute_mean_log_lr(self, states):
"""Computes the mean log learning rate across all variables."""
if self.use_problem_lr_mean and self.use_relative_lr:
sum_log_lr = 0.
count_log_lr = 0.
for state in states:
sum_log_lr += tf.reduce_sum(state["log_learning_rate"])
# Note: get_shape().num_elements()=num elements in the original tensor.
count_log_lr += state["log_learning_rate"].get_shape().num_elements()
return sum_log_lr / count_log_lr
def _compute_scaled_and_ms_grads(self, grad, state):
"""Computes the scaled gradient and the mean squared gradients.
Gradients are also accumulated across different timescales if appropriate.
Args:
grad: The gradient tensor for this layer.
state: The optimizer state for this layer.
Returns:
The scaled gradients, mean squared gradients, and accumulated gradients.
"""
input_decays = [state["inp_decay"]]
scale_decays = [state["scl_decay"]]
if self.use_multiple_scale_decays and self.num_gradient_scales > 1:
for i in range(self.num_gradient_scales - 1):
scale_decays.append(tf.sqrt(scale_decays[i]))
for i in range(self.num_gradient_scales - 1):
# Each accumulator on twice the timescale of the one before.
input_decays.append(tf.sqrt(input_decays[i]))
grads_accum = []
grads_scaled = []
mean_squared_gradients = []
# populate the scaled gradients and associated mean_squared values
if self.num_gradient_scales > 0:
for i, decay in enumerate(input_decays):
if self.num_gradient_scales == 1:
# We don't accumulate if no scales, just take the current gradient.
grad_accum = grad
else:
# The state vars are 1-indexed.
old_accum = state["grad_accum{}".format(i + 1)]
grad_accum = grad * (1. - decay) + old_accum * decay
grads_accum.append(grad_accum)
sd = scale_decays[i if self.use_multiple_scale_decays else 0]
grad_scaled, ms = utils.rms_scaling(grad_accum, sd,
state["ms{}".format(i + 1)],
update_ms=True)
grads_scaled.append(grad_scaled)
mean_squared_gradients.append(ms)
return grads_scaled, mean_squared_gradients, grads_accum
def _extend_rnn_input(self, rnn_input, state, grads_scaled,
mean_squared_gradients, mean_log_lr):
"""Computes additional rnn inputs and adds them to the rnn_input list."""
if self.num_gradient_scales > 1 and self.use_grad_products:
# This gives a measure of curvature relative to input averaging
# lengthscale and to the learning rate
grad_products = [a * b for a, b in
zip(grads_scaled[:-1], grads_scaled[1:])]
rnn_input.extend([g for g in grad_products])
if self.use_log_means_squared:
log_means_squared = [tf.log(ms + 1e-16)
for ms in mean_squared_gradients]
avg = tf.reduce_mean(log_means_squared, axis=0)
# This gives a measure of the signal vs. noise contribution to the
# gradient, at the current averaging lengthscale. If all the noise
# is averaged out, and if updates are small, these will be 0.
mean_log_means_squared = [m - avg for m in log_means_squared]
rnn_input.extend([m for m in mean_log_means_squared])
if self.use_relative_lr or self.use_extreme_indicator:
if not self.dynamic_output_scale:
raise Exception("Relative LR and Extreme Indicator features "
"require dynamic_output_scale to be set to True.")
log_lr_vec = tf.reshape(state["log_learning_rate"], [-1, 1])
if self.use_relative_lr:
if self.use_problem_lr_mean:
# Learning rate of this dimension vs. rest of target problem.
relative_lr = log_lr_vec - mean_log_lr
else:
# Learning rate of this dimension vs. rest of tensor.
relative_lr = log_lr_vec - tf.reduce_mean(log_lr_vec)
rnn_input.append(relative_lr)
if self.use_extreme_indicator:
# Indicator of extremely large or extremely small learning rate.
extreme_indicator = (tf.nn.relu(log_lr_vec - tf.log(1.)) -
tf.nn.relu(tf.log(1e-6) - log_lr_vec))
rnn_input.append(extreme_indicator)
if self.use_lr_shortcut:
log_lr_vec = tf.reshape(state["log_learning_rate"], [-1, 1])
rnn_input.append(log_lr_vec - tf.log(1e-3))
def _update_rnn_cells(self, state, global_state, rnn_input_tensor,
use_additional_features):
"""Updates the component RNN cells with the given state and tensor.
Args:
state: The current state of the optimizer.
global_state: The current global RNN state.
rnn_input_tensor: The input tensor to the RNN.
use_additional_features: Whether the rnn input tensor contains additional
features beyond the scaled gradients (affects whether the rnn input
tensor is used as input to the RNN.)
Returns:
layer_state: The new state of the per-tensor RNN.
new_param_state: The new state of the per-parameter RNN.
"""
# lowest level (per parameter)
# input -> gradient for this parameter
# bias -> output from the layer RNN
with tf.variable_scope("Layer0_RNN"):
total_bias = None
if self.num_layers > 1:
sz = 3 * self.cells[0].state_size # size of the concatenated bias
param_bias = utils.affine([state["layer"]], sz,
scope="Param/Affine",
scale=FLAGS.hrnn_affine_scale,
random_seed=self.random_seed)
total_bias = param_bias
if self.num_layers == 3:
global_bias = utils.affine(global_state, sz,
scope="Global/Affine",
scale=FLAGS.hrnn_affine_scale,
random_seed=self.random_seed)
total_bias += global_bias
new_param_state, _ = self.cells[0](
rnn_input_tensor, state["parameter"], bias=total_bias)
if self.num_layers > 1:
# middle level (per layer)
# input -> average hidden state from each parameter in this layer
# bias -> output from the RNN at the global level
with tf.variable_scope("Layer1_RNN"):
if not use_additional_features:
# Restore old behavior and only add the mean of the new params.
layer_input = tf.reduce_mean(new_param_state, 0, keep_dims=True)
else:
layer_input = tf.reduce_mean(
tf.concat((new_param_state, rnn_input_tensor), 1), 0,
keep_dims=True)
if self.num_layers == 3:
sz = 3 * self.cells[1].state_size
layer_bias = utils.affine(global_state, sz,
scale=FLAGS.hrnn_affine_scale,
random_seed=self.random_seed)
layer_state, _ = self.cells[1](
layer_input, state["layer"], bias=layer_bias)
else:
layer_state, _ = self.cells[1](layer_input, state["layer"])
else:
layer_state = None
return layer_state, new_param_state
def _compute_rnn_state_projections(self, state, new_param_state,
grads_scaled):
"""Computes the RNN state-based updates to parameters and update steps."""
# Compute the update direction (a linear projection of the RNN output).
update_weights = self.update_weights
update_delta = utils.project(new_param_state, update_weights)
if self.use_gradient_shortcut:
# Include an affine projection of just the direction of the gradient
# so that RNN hidden states are freed up to store more complex
# functions of the gradient and other parameters.
grads_scaled_tensor = tf.concat([g for g in grads_scaled], 1)
update_delta += utils.affine(grads_scaled_tensor, 1,
scope="GradsToDelta",
include_bias=False,
vec_mean=1. / len(grads_scaled),
random_seed=self.random_seed)
if self.dynamic_output_scale:
denom = tf.sqrt(tf.reduce_mean(update_delta ** 2) + 1e-16)
update_delta /= denom
if self.use_attention:
attention_weights = self.attention_weights
attention_delta = utils.project(new_param_state,
attention_weights)
if self.use_gradient_shortcut:
attention_delta += utils.affine(grads_scaled_tensor, 1,
scope="GradsToAttnDelta",
include_bias=False,
vec_mean=1. / len(grads_scaled),
random_seed=self.random_seed)
if self.dynamic_output_scale:
attention_delta /= tf.sqrt(
tf.reduce_mean(attention_delta ** 2) + 1e-16)
else:
attention_delta = None
# The updated decay is an affine projection of the hidden state.
scl_decay = utils.project(new_param_state, self.scl_decay_weights,
bias=self.scl_decay_bias,
activation=tf.nn.sigmoid)
# This is only used if learnable_decay and num_gradient_scales > 1
inp_decay = utils.project(new_param_state, self.inp_decay_weights,
bias=self.inp_decay_bias,
activation=tf.nn.sigmoid)
# Also update the learning rate.
lr_param, lr_attend, new_log_lr = self._compute_new_learning_rate(
state, new_param_state)
update_step = tf.reshape(lr_param * update_delta,
state["true_param"].get_shape())
return (scl_decay, inp_decay, new_log_lr, update_step, lr_attend,
attention_delta)
def _compute_new_learning_rate(self, state, new_param_state):
if self.dynamic_output_scale:
# Compute the change in learning rate (an affine projection of the
# RNN state, passed through a sigmoid or log depending on flags).
# Update the learning rate, w/ momentum.
lr_change = utils.project(new_param_state, self.lr_weights,
bias=self.lr_bias)
step_log_lr = state["log_learning_rate"] + lr_change
# Clip the log learning rate to the flag at the top end, and to
# (log(min int32) - 1) at the bottom
# Check out this hack: we want to be able to compute the gradient
# of the downstream result w.r.t lr weights and bias, even if the
# value of step_log_lr is outside the clip range. So we clip,
# subtract off step_log_lr, and wrap all that in a stop_gradient so
# TF never tries to take the gradient of the clip... or the
# subtraction. Then we add BACK step_log_lr so that downstream still
# receives the clipped value. But the GRADIENT of step_log_lr will
# be the gradient of the unclipped value, which we added back in
# after stop_gradients.
step_log_lr += tf.stop_gradient(
tf.clip_by_value(step_log_lr, -33, self.max_log_lr)
- step_log_lr)
lr_momentum_logit = tf.get_variable(
"learning_rate_momentum_logit",
initializer=FLAGS.learning_rate_momentum_logit_init)
lrm = tf.nn.sigmoid(lr_momentum_logit)
new_log_lr = (lrm * state["log_learning_rate"] +
(1. - lrm) * step_log_lr)
param_stepsize_offset = tf.get_variable("param_stepsize_offset",
initializer=-1.)
lr_param = tf.exp(step_log_lr + param_stepsize_offset)
lr_attend = tf.exp(step_log_lr) if self.use_attention else lr_param
else:
# Dynamic output scale is off, LR param is always 1.
lr_param = 2. * utils.project(new_param_state, self.lr_weights,
bias=self.lr_bias,
activation=tf.nn.sigmoid)
new_log_lr = None
lr_attend = lr_param
return lr_param, lr_attend, new_log_lr
def _compute_updated_global_state(self, layer_states, global_state):
"""Computes the new global state gives the layers states and old state.
Args:
layer_states: The current layer states.
global_state: The old global state.
Returns:
The updated global state.
"""
updated_global_state = []
if self.num_layers == 3:
# highest (global) layer
# input -> average hidden state from each layer-specific RNN
# bias -> None
with tf.variable_scope("Layer2_RNN", reuse=self.reuse_global_state):
self.reuse_global_state = True
global_input = tf.reduce_mean(tf.concat(layer_states, 0), 0,
keep_dims=True)
updated_global_state, _ = self.cells[2](global_input, global_state[0])
return updated_global_state
def apply_gradients(self, grads_and_vars, global_step=None, name=None):
"""Overwrites the tf.train.Optimizer interface for applying gradients."""
# Pull out the variables.
grads_and_vars = tuple(grads_and_vars) # Make sure repeat iteration works.
for g, v in grads_and_vars:
if not isinstance(g, (tf.Tensor, tf.IndexedSlices, type(None))):
raise TypeError(
"Gradient must be a Tensor, IndexedSlices, or None: %s" % g)
if not isinstance(v, tf.Variable):
raise TypeError(
"Variable must be a tf.Variable: %s" % v)
if g is not None:
self._assert_valid_dtypes([g, v])
var_list = [v for g, v in grads_and_vars if g is not None]
if not var_list:
raise ValueError("No gradients provided for any variable: %s" %
(grads_and_vars,))
# Create slots for the variables.
with tf.control_dependencies(None):
self._create_slots(var_list)
# Store update ops in this list.
with tf.op_scope([], name, self._name) as name:
# Prepare the global state.
with tf.variable_scope(self._name, reuse=self.reuse_global_state):
gs = self._initialize_global_state()
if gs:
global_state = [tf.get_variable("global_state", initializer=gs[0])]
else:
global_state = []
# Get the states for each variable in the list.
states = [{key: self.get_slot(var, key) for key in self.get_slot_names()}
for var in var_list]
# Compute updated values.
grads, params = zip(*grads_and_vars)
args = (params, grads, states, global_state)
updates = self._compute_updates(*args)
new_params, new_states, new_global_state, new_attention = updates
# Assign op for new global state.
update_ops = [tf.assign(gs, ngs)
for gs, ngs in zip(global_state, new_global_state)]
# Create the assign ops for the params and state variables.
args = (params, states, new_params, new_attention, new_states)
for var, state, new_var, new_var_attend, new_state in zip(*args):
# Assign updates to the state variables.
state_assign_ops = [tf.assign(state_var, new_state[key])
for key, state_var in state.items()]
# Update the parameter.
with tf.control_dependencies(state_assign_ops):
if self.use_attention:
# Assign to the attended location, rather than the actual location
# so that the gradients are computed where attention is.
param_update_op = var.assign(new_var_attend)
else:
param_update_op = var.assign(new_var)
with tf.name_scope("update_" + var.op.name): #, tf.colocate_with(var):
update_ops.append(param_update_op)
real_params = [self.get_slot(var, "true_param") for var in var_list]
if global_step is None:
# NOTE: if using the optimizer in a non-test-optimizer setting (e.g.
# on Inception), remove the real_params return value. Otherwise
# the code will throw an error.
return self._finish(update_ops, name), real_params
else:
with tf.control_dependencies([self._finish(update_ops, "update")]):
return state_ops.assign_add(global_step, 1, name=name).op, real_params