Spaces:
Running
Running
# Copyright 2017 Google, Inc. All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# ============================================================================== | |
"""Collection of trainable optimizers for meta-optimization.""" | |
from __future__ import absolute_import | |
from __future__ import division | |
from __future__ import print_function | |
import math | |
import numpy as np | |
import tensorflow as tf | |
from tensorflow.python.ops import state_ops | |
from learned_optimizer.optimizer import rnn_cells | |
from learned_optimizer.optimizer import trainable_optimizer as opt | |
from learned_optimizer.optimizer import utils | |
# Default was 0.1 | |
tf.app.flags.DEFINE_float("biasgrucell_scale", 0.5, | |
"""The scale for the internal BiasGRUCell vars.""") | |
# Default was 0 | |
tf.app.flags.DEFINE_float("biasgrucell_gate_bias_init", 2.2, | |
"""The bias for the internal BiasGRUCell reset and | |
update gate variables.""") | |
# Default was 1e-3 | |
tf.app.flags.DEFINE_float("hrnn_rnn_readout_scale", 0.5, | |
"""The initialization scale for the RNN readouts.""") | |
tf.app.flags.DEFINE_float("hrnn_default_decay_var_init", 2.2, | |
"""The default initializer value for any decay/ | |
momentum style variables and constants. | |
sigmoid(2.2) ~ 0.9, sigmoid(-2.2) ~ 0.01.""") | |
# Default was 2.2 | |
tf.app.flags.DEFINE_float("scale_decay_bias_init", 3.2, | |
"""The initialization for the scale decay bias. This | |
is the initial bias for the timescale for the | |
exponential avg of the mean square gradients.""") | |
tf.app.flags.DEFINE_float("learning_rate_momentum_logit_init", 3.2, | |
"""Initialization for the learning rate momentum.""") | |
# Default was 0.1 | |
tf.app.flags.DEFINE_float("hrnn_affine_scale", 0.5, | |
"""The initialization scale for the weight matrix of | |
the bias variables in layer0 and 1 of the hrnn.""") | |
FLAGS = tf.flags.FLAGS | |
class HierarchicalRNN(opt.TrainableOptimizer): | |
"""3 level hierarchical RNN. | |
Optionally uses second order gradient information and has decoupled evaluation | |
and update locations. | |
""" | |
def __init__(self, level_sizes, init_lr_range=(1e-6, 1e-2), | |
learnable_decay=True, dynamic_output_scale=True, | |
use_attention=False, use_log_objective=True, | |
num_gradient_scales=4, zero_init_lr_weights=True, | |
use_log_means_squared=True, use_relative_lr=True, | |
use_extreme_indicator=False, max_log_lr=33, | |
obj_train_max_multiplier=-1, use_problem_lr_mean=False, | |
use_gradient_shortcut=False, use_lr_shortcut=False, | |
use_grad_products=False, use_multiple_scale_decays=False, | |
learnable_inp_decay=True, learnable_rnn_init=True, | |
random_seed=None, **kwargs): | |
"""Initializes the RNN per-parameter optimizer. | |
The hierarchy consists of up to three levels: | |
Level 0: per parameter RNN | |
Level 1: per tensor RNN | |
Level 2: global RNN | |
Args: | |
level_sizes: list or tuple with 1, 2, or 3 integers, the number of units | |
in each RNN in the hierarchy (level0, level1, level2). | |
length 1: only coordinatewise rnn's will be used | |
length 2: coordinatewise and tensor-level rnn's will be used | |
length 3: a single global-level rnn will be used in addition to | |
coordinatewise and tensor-level | |
init_lr_range: the range in which to initialize the learning rates | |
learnable_decay: whether to learn weights that dynamically modulate the | |
input scale via RMS style decay | |
dynamic_output_scale: whether to learn weights that dynamically modulate | |
the output scale | |
use_attention: whether to use attention to train the optimizer | |
use_log_objective: whether to train on the log of the objective | |
num_gradient_scales: the number of scales to use for gradient history | |
zero_init_lr_weights: whether to initialize the lr weights to zero | |
use_log_means_squared: whether to track the log of the means_squared, | |
used as a measure of signal vs. noise in gradient. | |
use_relative_lr: whether to use the relative learning rate as an | |
input during training (requires learnable_decay=True) | |
use_extreme_indicator: whether to use the extreme indicator for learning | |
rates as an input during training (requires learnable_decay=True) | |
max_log_lr: the maximum log learning rate allowed during train or test | |
obj_train_max_multiplier: max objective increase during a training run | |
use_problem_lr_mean: whether to use the mean over all learning rates in | |
the problem when calculating the relative learning rate as opposed to | |
the per-tensor mean | |
use_gradient_shortcut: Whether to add a learned affine projection of the | |
gradient to the update delta in addition to the gradient function | |
computed by the RNN | |
use_lr_shortcut: Whether to add as input the difference between the log lr | |
and the desired log lr (1e-3) | |
use_grad_products: Whether to use gradient products in the rnn input. | |
Only applicable if num_gradient_scales > 1 | |
use_multiple_scale_decays: Whether to use multiple scales for the scale | |
decay, as with input decay | |
learnable_inp_decay: Whether to learn the input decay weights and bias. | |
learnable_rnn_init: Whether to learn the RNN state initialization. | |
random_seed: Random seed for random variable initializers. (Default: None) | |
**kwargs: args passed to TrainableOptimizer's constructor | |
Raises: | |
ValueError: If level_sizes is not a length 1, 2, or 3 list. | |
ValueError: If there are any non-integer sizes in level_sizes. | |
ValueError: If the init lr range is not of length 2. | |
ValueError: If the init lr range is not a valid range (min > max). | |
""" | |
if len(level_sizes) not in [1, 2, 3]: | |
raise ValueError("HierarchicalRNN only supports 1, 2, or 3 levels in the " | |
"hierarchy, but {} were requested.".format( | |
len(level_sizes))) | |
if any(not isinstance(level, int) for level in level_sizes): | |
raise ValueError("Level sizes must be integer values, were {}".format( | |
level_sizes)) | |
if len(init_lr_range) != 2: | |
raise ValueError( | |
"Initial LR range must be len 2, was {}".format(len(init_lr_range))) | |
if init_lr_range[0] > init_lr_range[1]: | |
raise ValueError("Initial LR range min is greater than max.") | |
self.learnable_decay = learnable_decay | |
self.dynamic_output_scale = dynamic_output_scale | |
self.use_attention = use_attention | |
self.use_log_objective = use_log_objective | |
self.num_gradient_scales = num_gradient_scales | |
self.zero_init_lr_weights = zero_init_lr_weights | |
self.use_log_means_squared = use_log_means_squared | |
self.use_relative_lr = use_relative_lr | |
self.use_extreme_indicator = use_extreme_indicator | |
self.max_log_lr = max_log_lr | |
self.use_problem_lr_mean = use_problem_lr_mean | |
self.use_gradient_shortcut = use_gradient_shortcut | |
self.use_lr_shortcut = use_lr_shortcut | |
self.use_grad_products = use_grad_products | |
self.use_multiple_scale_decays = use_multiple_scale_decays | |
self.learnable_inp_decay = learnable_inp_decay | |
self.learnable_rnn_init = learnable_rnn_init | |
self.random_seed = random_seed | |
self.num_layers = len(level_sizes) | |
self.init_lr_range = init_lr_range | |
self.reuse_vars = None | |
self.reuse_global_state = None | |
self.cells = [] | |
self.init_vectors = [] | |
with tf.variable_scope(opt.OPTIMIZER_SCOPE): | |
self._initialize_rnn_cells(level_sizes) | |
# get the cell size for the per-parameter RNN (Level 0) | |
cell_size = level_sizes[0] | |
# Random normal initialization scaled by the output size. This is the | |
# scale for the RNN *readouts*. RNN internal weight scale is set in the | |
# BiasGRUCell call. | |
scale_factor = FLAGS.hrnn_rnn_readout_scale / math.sqrt(cell_size) | |
scaled_init = tf.random_normal_initializer(0., scale_factor, | |
seed=self.random_seed) | |
# weights for projecting the hidden state to a parameter update | |
self.update_weights = tf.get_variable("update_weights", | |
shape=(cell_size, 1), | |
initializer=scaled_init) | |
if self.use_attention: | |
# weights for projecting the hidden state to the location at which the | |
# gradient is attended | |
self.attention_weights = tf.get_variable( | |
"attention_weights", | |
initializer=self.update_weights.initialized_value()) | |
# weights for projecting the hidden state to the RMS decay term | |
self._initialize_scale_decay((cell_size, 1), scaled_init) | |
self._initialize_input_decay((cell_size, 1), scaled_init) | |
self._initialize_lr((cell_size, 1), scaled_init) | |
state_keys = ["parameter", "layer", "scl_decay", "inp_decay", "true_param"] | |
if self.dynamic_output_scale: | |
state_keys.append("log_learning_rate") | |
for i in range(self.num_gradient_scales): | |
state_keys.append("grad_accum{}".format(i + 1)) | |
state_keys.append("ms{}".format(i + 1)) | |
super(HierarchicalRNN, self).__init__( | |
"hRNN", state_keys, use_attention=use_attention, | |
use_log_objective=use_log_objective, | |
obj_train_max_multiplier=obj_train_max_multiplier, **kwargs) | |
def _initialize_rnn_cells(self, level_sizes): | |
"""Initializes the RNN cells to use in the hierarchical RNN.""" | |
# RNN Cell layers (0 -> lowest, 1 -> middle, 2 -> global) | |
for level in range(self.num_layers): | |
scope = "Level{}_RNN".format(level) | |
with tf.variable_scope(scope): | |
hcell = rnn_cells.BiasGRUCell( | |
level_sizes[level], | |
scale=FLAGS.biasgrucell_scale, | |
gate_bias_init=FLAGS.biasgrucell_gate_bias_init, | |
random_seed=self.random_seed) | |
self.cells.append(hcell) | |
if self.learnable_rnn_init: | |
self.init_vectors.append(tf.Variable( | |
tf.random_uniform([1, hcell.state_size], -1., 1., | |
seed=self.random_seed), | |
name="init_vector")) | |
else: | |
self.init_vectors.append( | |
tf.random_uniform([1, hcell.state_size], -1., 1., | |
seed=self.random_seed)) | |
def _initialize_scale_decay(self, weights_tensor_shape, scaled_init): | |
"""Initializes the scale decay weights and bias variables or tensors. | |
Args: | |
weights_tensor_shape: The shape the weight tensor should take. | |
scaled_init: The scaled initialization for the weights tensor. | |
""" | |
if self.learnable_decay: | |
self.scl_decay_weights = tf.get_variable("scl_decay_weights", | |
shape=weights_tensor_shape, | |
initializer=scaled_init) | |
scl_decay_bias_init = tf.constant_initializer( | |
FLAGS.scale_decay_bias_init) | |
self.scl_decay_bias = tf.get_variable("scl_decay_bias", | |
shape=(1,), | |
initializer=scl_decay_bias_init) | |
else: | |
self.scl_decay_weights = tf.zeros_like(self.update_weights) | |
self.scl_decay_bias = tf.log(0.93 / (1. - 0.93)) | |
def _initialize_input_decay(self, weights_tensor_shape, scaled_init): | |
"""Initializes the input scale decay weights and bias variables or tensors. | |
Args: | |
weights_tensor_shape: The shape the weight tensor should take. | |
scaled_init: The scaled initialization for the weights tensor. | |
""" | |
if (self.learnable_decay and self.num_gradient_scales > 1 and | |
self.learnable_inp_decay): | |
self.inp_decay_weights = tf.get_variable("inp_decay_weights", | |
shape=weights_tensor_shape, | |
initializer=scaled_init) | |
inp_decay_bias_init = tf.constant_initializer( | |
FLAGS.hrnn_default_decay_var_init) | |
self.inp_decay_bias = tf.get_variable("inp_decay_bias", | |
shape=(1,), | |
initializer=inp_decay_bias_init) | |
else: | |
self.inp_decay_weights = tf.zeros_like(self.update_weights) | |
self.inp_decay_bias = tf.log(0.89 / (1. - 0.89)) | |
def _initialize_lr(self, weights_tensor_shape, scaled_init): | |
"""Initializes the learning rate weights and bias variables or tensors. | |
Args: | |
weights_tensor_shape: The shape the weight tensor should take. | |
scaled_init: The scaled initialization for the weights tensor. | |
""" | |
if self.dynamic_output_scale: | |
zero_init = tf.constant_initializer(0.) | |
wt_init = zero_init if self.zero_init_lr_weights else scaled_init | |
self.lr_weights = tf.get_variable("learning_rate_weights", | |
shape=weights_tensor_shape, | |
initializer=wt_init) | |
self.lr_bias = tf.get_variable("learning_rate_bias", shape=(1,), | |
initializer=zero_init) | |
else: | |
self.lr_weights = tf.zeros_like(self.update_weights) | |
self.lr_bias = tf.zeros([1, 1]) | |
def _initialize_state(self, var): | |
"""Return a dictionary mapping names of state variables to their values.""" | |
var_vectorized = tf.reshape(var, [-1, 1]) | |
ndim = var_vectorized.get_shape().as_list()[0] | |
state = { | |
# parameter init tensor is [var_ndim x layer0_cell_size] | |
"parameter": tf.ones([ndim, 1]) * self.init_vectors[0], | |
"scl_decay": tf.zeros_like(var_vectorized), | |
"inp_decay": tf.zeros_like(var_vectorized), | |
"true_param": var, | |
} | |
if self.num_layers > 1: | |
# layer init tensor is [1 x layer1_cell_size] | |
state["layer"] = tf.ones([1, 1]) * self.init_vectors[1] | |
if self.dynamic_output_scale: | |
min_lr = self.init_lr_range[0] | |
max_lr = self.init_lr_range[1] | |
if min_lr == max_lr: | |
log_init_lr = tf.log(min_lr * tf.ones_like(var_vectorized)) | |
else: | |
# Use a random offset to increase the likelihood that the average of the | |
# LRs for this variable is different from the LRs for other variables. | |
actual_vals = tf.random_uniform(var_vectorized.get_shape().as_list(), | |
np.log(min_lr) / 2., | |
np.log(max_lr) / 2., | |
seed=self.random_seed) | |
offset = tf.random_uniform((), np.log(min_lr) / 2., np.log(max_lr) / 2., | |
seed=self.random_seed) | |
log_init_lr = actual_vals + offset | |
# Clip the log learning rate to the flag at the top end, and to | |
# (log(min int32) - 1) at the bottom | |
clipped = tf.clip_by_value(log_init_lr, -33, self.max_log_lr) | |
state["log_learning_rate"] = clipped | |
for i in range(self.num_gradient_scales): | |
state["grad_accum{}".format(i + 1)] = tf.zeros_like(var_vectorized) | |
state["ms{}".format(i + 1)] = tf.zeros_like(var_vectorized) | |
return state | |
def _initialize_global_state(self): | |
if self.num_layers < 3: | |
return [] | |
rnn_global_init = tf.ones([1, 1]) * self.init_vectors[2] | |
return [rnn_global_init] | |
def _compute_updates(self, params, grads, states, global_state): | |
# Store the updated parameters and states. | |
updated_params = [] | |
updated_attention = [] | |
updated_states = [] | |
with tf.variable_scope(opt.OPTIMIZER_SCOPE): | |
mean_log_lr = self._compute_mean_log_lr(states) | |
# Iterate over the layers. | |
for param, grad_unflat, state in zip(params, grads, states): | |
with tf.variable_scope("PerTensor", reuse=self.reuse_vars): | |
self.reuse_vars = True | |
grad = tf.reshape(grad_unflat, [-1, 1]) | |
# Create the RNN input. We will optionally extend it with additional | |
# features such as curvature and gradient signal vs. noise. | |
(grads_scaled, mean_squared_gradients, | |
grads_accum) = self._compute_scaled_and_ms_grads(grad, state) | |
rnn_input = [g for g in grads_scaled] | |
self._extend_rnn_input(rnn_input, state, grads_scaled, | |
mean_squared_gradients, mean_log_lr) | |
# Concatenate any features we've collected. | |
rnn_input_tensor = tf.concat(rnn_input, 1) | |
layer_state, new_param_state = self._update_rnn_cells( | |
state, global_state, rnn_input_tensor, | |
len(rnn_input) != len(grads_scaled)) | |
(scl_decay, inp_decay, new_log_lr, update_step, lr_attend, | |
attention_delta) = self._compute_rnn_state_projections( | |
state, new_param_state, grads_scaled) | |
# Apply updates and store state variables. | |
if self.use_attention: | |
truth = state["true_param"] | |
updated_param = truth - update_step | |
attention_step = tf.reshape(lr_attend * attention_delta, | |
truth.get_shape()) | |
updated_attention.append(truth - attention_step) | |
else: | |
updated_param = param - update_step | |
updated_attention.append(updated_param) | |
updated_params.append(updated_param) | |
# Collect the new state. | |
new_state = { | |
"parameter": new_param_state, | |
"scl_decay": scl_decay, | |
"inp_decay": inp_decay, | |
"true_param": updated_param, | |
} | |
if layer_state is not None: | |
new_state["layer"] = layer_state | |
if self.dynamic_output_scale: | |
new_state["log_learning_rate"] = new_log_lr | |
for i in range(self.num_gradient_scales): | |
new_state["grad_accum{}".format(i + 1)] = grads_accum[i] | |
new_state["ms{}".format(i + 1)] = mean_squared_gradients[i] | |
updated_states.append(new_state) | |
updated_global_state = self._compute_updated_global_state([layer_state], | |
global_state) | |
return (updated_params, updated_states, [updated_global_state], | |
updated_attention) | |
def _compute_mean_log_lr(self, states): | |
"""Computes the mean log learning rate across all variables.""" | |
if self.use_problem_lr_mean and self.use_relative_lr: | |
sum_log_lr = 0. | |
count_log_lr = 0. | |
for state in states: | |
sum_log_lr += tf.reduce_sum(state["log_learning_rate"]) | |
# Note: get_shape().num_elements()=num elements in the original tensor. | |
count_log_lr += state["log_learning_rate"].get_shape().num_elements() | |
return sum_log_lr / count_log_lr | |
def _compute_scaled_and_ms_grads(self, grad, state): | |
"""Computes the scaled gradient and the mean squared gradients. | |
Gradients are also accumulated across different timescales if appropriate. | |
Args: | |
grad: The gradient tensor for this layer. | |
state: The optimizer state for this layer. | |
Returns: | |
The scaled gradients, mean squared gradients, and accumulated gradients. | |
""" | |
input_decays = [state["inp_decay"]] | |
scale_decays = [state["scl_decay"]] | |
if self.use_multiple_scale_decays and self.num_gradient_scales > 1: | |
for i in range(self.num_gradient_scales - 1): | |
scale_decays.append(tf.sqrt(scale_decays[i])) | |
for i in range(self.num_gradient_scales - 1): | |
# Each accumulator on twice the timescale of the one before. | |
input_decays.append(tf.sqrt(input_decays[i])) | |
grads_accum = [] | |
grads_scaled = [] | |
mean_squared_gradients = [] | |
# populate the scaled gradients and associated mean_squared values | |
if self.num_gradient_scales > 0: | |
for i, decay in enumerate(input_decays): | |
if self.num_gradient_scales == 1: | |
# We don't accumulate if no scales, just take the current gradient. | |
grad_accum = grad | |
else: | |
# The state vars are 1-indexed. | |
old_accum = state["grad_accum{}".format(i + 1)] | |
grad_accum = grad * (1. - decay) + old_accum * decay | |
grads_accum.append(grad_accum) | |
sd = scale_decays[i if self.use_multiple_scale_decays else 0] | |
grad_scaled, ms = utils.rms_scaling(grad_accum, sd, | |
state["ms{}".format(i + 1)], | |
update_ms=True) | |
grads_scaled.append(grad_scaled) | |
mean_squared_gradients.append(ms) | |
return grads_scaled, mean_squared_gradients, grads_accum | |
def _extend_rnn_input(self, rnn_input, state, grads_scaled, | |
mean_squared_gradients, mean_log_lr): | |
"""Computes additional rnn inputs and adds them to the rnn_input list.""" | |
if self.num_gradient_scales > 1 and self.use_grad_products: | |
# This gives a measure of curvature relative to input averaging | |
# lengthscale and to the learning rate | |
grad_products = [a * b for a, b in | |
zip(grads_scaled[:-1], grads_scaled[1:])] | |
rnn_input.extend([g for g in grad_products]) | |
if self.use_log_means_squared: | |
log_means_squared = [tf.log(ms + 1e-16) | |
for ms in mean_squared_gradients] | |
avg = tf.reduce_mean(log_means_squared, axis=0) | |
# This gives a measure of the signal vs. noise contribution to the | |
# gradient, at the current averaging lengthscale. If all the noise | |
# is averaged out, and if updates are small, these will be 0. | |
mean_log_means_squared = [m - avg for m in log_means_squared] | |
rnn_input.extend([m for m in mean_log_means_squared]) | |
if self.use_relative_lr or self.use_extreme_indicator: | |
if not self.dynamic_output_scale: | |
raise Exception("Relative LR and Extreme Indicator features " | |
"require dynamic_output_scale to be set to True.") | |
log_lr_vec = tf.reshape(state["log_learning_rate"], [-1, 1]) | |
if self.use_relative_lr: | |
if self.use_problem_lr_mean: | |
# Learning rate of this dimension vs. rest of target problem. | |
relative_lr = log_lr_vec - mean_log_lr | |
else: | |
# Learning rate of this dimension vs. rest of tensor. | |
relative_lr = log_lr_vec - tf.reduce_mean(log_lr_vec) | |
rnn_input.append(relative_lr) | |
if self.use_extreme_indicator: | |
# Indicator of extremely large or extremely small learning rate. | |
extreme_indicator = (tf.nn.relu(log_lr_vec - tf.log(1.)) - | |
tf.nn.relu(tf.log(1e-6) - log_lr_vec)) | |
rnn_input.append(extreme_indicator) | |
if self.use_lr_shortcut: | |
log_lr_vec = tf.reshape(state["log_learning_rate"], [-1, 1]) | |
rnn_input.append(log_lr_vec - tf.log(1e-3)) | |
def _update_rnn_cells(self, state, global_state, rnn_input_tensor, | |
use_additional_features): | |
"""Updates the component RNN cells with the given state and tensor. | |
Args: | |
state: The current state of the optimizer. | |
global_state: The current global RNN state. | |
rnn_input_tensor: The input tensor to the RNN. | |
use_additional_features: Whether the rnn input tensor contains additional | |
features beyond the scaled gradients (affects whether the rnn input | |
tensor is used as input to the RNN.) | |
Returns: | |
layer_state: The new state of the per-tensor RNN. | |
new_param_state: The new state of the per-parameter RNN. | |
""" | |
# lowest level (per parameter) | |
# input -> gradient for this parameter | |
# bias -> output from the layer RNN | |
with tf.variable_scope("Layer0_RNN"): | |
total_bias = None | |
if self.num_layers > 1: | |
sz = 3 * self.cells[0].state_size # size of the concatenated bias | |
param_bias = utils.affine([state["layer"]], sz, | |
scope="Param/Affine", | |
scale=FLAGS.hrnn_affine_scale, | |
random_seed=self.random_seed) | |
total_bias = param_bias | |
if self.num_layers == 3: | |
global_bias = utils.affine(global_state, sz, | |
scope="Global/Affine", | |
scale=FLAGS.hrnn_affine_scale, | |
random_seed=self.random_seed) | |
total_bias += global_bias | |
new_param_state, _ = self.cells[0]( | |
rnn_input_tensor, state["parameter"], bias=total_bias) | |
if self.num_layers > 1: | |
# middle level (per layer) | |
# input -> average hidden state from each parameter in this layer | |
# bias -> output from the RNN at the global level | |
with tf.variable_scope("Layer1_RNN"): | |
if not use_additional_features: | |
# Restore old behavior and only add the mean of the new params. | |
layer_input = tf.reduce_mean(new_param_state, 0, keep_dims=True) | |
else: | |
layer_input = tf.reduce_mean( | |
tf.concat((new_param_state, rnn_input_tensor), 1), 0, | |
keep_dims=True) | |
if self.num_layers == 3: | |
sz = 3 * self.cells[1].state_size | |
layer_bias = utils.affine(global_state, sz, | |
scale=FLAGS.hrnn_affine_scale, | |
random_seed=self.random_seed) | |
layer_state, _ = self.cells[1]( | |
layer_input, state["layer"], bias=layer_bias) | |
else: | |
layer_state, _ = self.cells[1](layer_input, state["layer"]) | |
else: | |
layer_state = None | |
return layer_state, new_param_state | |
def _compute_rnn_state_projections(self, state, new_param_state, | |
grads_scaled): | |
"""Computes the RNN state-based updates to parameters and update steps.""" | |
# Compute the update direction (a linear projection of the RNN output). | |
update_weights = self.update_weights | |
update_delta = utils.project(new_param_state, update_weights) | |
if self.use_gradient_shortcut: | |
# Include an affine projection of just the direction of the gradient | |
# so that RNN hidden states are freed up to store more complex | |
# functions of the gradient and other parameters. | |
grads_scaled_tensor = tf.concat([g for g in grads_scaled], 1) | |
update_delta += utils.affine(grads_scaled_tensor, 1, | |
scope="GradsToDelta", | |
include_bias=False, | |
vec_mean=1. / len(grads_scaled), | |
random_seed=self.random_seed) | |
if self.dynamic_output_scale: | |
denom = tf.sqrt(tf.reduce_mean(update_delta ** 2) + 1e-16) | |
update_delta /= denom | |
if self.use_attention: | |
attention_weights = self.attention_weights | |
attention_delta = utils.project(new_param_state, | |
attention_weights) | |
if self.use_gradient_shortcut: | |
attention_delta += utils.affine(grads_scaled_tensor, 1, | |
scope="GradsToAttnDelta", | |
include_bias=False, | |
vec_mean=1. / len(grads_scaled), | |
random_seed=self.random_seed) | |
if self.dynamic_output_scale: | |
attention_delta /= tf.sqrt( | |
tf.reduce_mean(attention_delta ** 2) + 1e-16) | |
else: | |
attention_delta = None | |
# The updated decay is an affine projection of the hidden state. | |
scl_decay = utils.project(new_param_state, self.scl_decay_weights, | |
bias=self.scl_decay_bias, | |
activation=tf.nn.sigmoid) | |
# This is only used if learnable_decay and num_gradient_scales > 1 | |
inp_decay = utils.project(new_param_state, self.inp_decay_weights, | |
bias=self.inp_decay_bias, | |
activation=tf.nn.sigmoid) | |
# Also update the learning rate. | |
lr_param, lr_attend, new_log_lr = self._compute_new_learning_rate( | |
state, new_param_state) | |
update_step = tf.reshape(lr_param * update_delta, | |
state["true_param"].get_shape()) | |
return (scl_decay, inp_decay, new_log_lr, update_step, lr_attend, | |
attention_delta) | |
def _compute_new_learning_rate(self, state, new_param_state): | |
if self.dynamic_output_scale: | |
# Compute the change in learning rate (an affine projection of the | |
# RNN state, passed through a sigmoid or log depending on flags). | |
# Update the learning rate, w/ momentum. | |
lr_change = utils.project(new_param_state, self.lr_weights, | |
bias=self.lr_bias) | |
step_log_lr = state["log_learning_rate"] + lr_change | |
# Clip the log learning rate to the flag at the top end, and to | |
# (log(min int32) - 1) at the bottom | |
# Check out this hack: we want to be able to compute the gradient | |
# of the downstream result w.r.t lr weights and bias, even if the | |
# value of step_log_lr is outside the clip range. So we clip, | |
# subtract off step_log_lr, and wrap all that in a stop_gradient so | |
# TF never tries to take the gradient of the clip... or the | |
# subtraction. Then we add BACK step_log_lr so that downstream still | |
# receives the clipped value. But the GRADIENT of step_log_lr will | |
# be the gradient of the unclipped value, which we added back in | |
# after stop_gradients. | |
step_log_lr += tf.stop_gradient( | |
tf.clip_by_value(step_log_lr, -33, self.max_log_lr) | |
- step_log_lr) | |
lr_momentum_logit = tf.get_variable( | |
"learning_rate_momentum_logit", | |
initializer=FLAGS.learning_rate_momentum_logit_init) | |
lrm = tf.nn.sigmoid(lr_momentum_logit) | |
new_log_lr = (lrm * state["log_learning_rate"] + | |
(1. - lrm) * step_log_lr) | |
param_stepsize_offset = tf.get_variable("param_stepsize_offset", | |
initializer=-1.) | |
lr_param = tf.exp(step_log_lr + param_stepsize_offset) | |
lr_attend = tf.exp(step_log_lr) if self.use_attention else lr_param | |
else: | |
# Dynamic output scale is off, LR param is always 1. | |
lr_param = 2. * utils.project(new_param_state, self.lr_weights, | |
bias=self.lr_bias, | |
activation=tf.nn.sigmoid) | |
new_log_lr = None | |
lr_attend = lr_param | |
return lr_param, lr_attend, new_log_lr | |
def _compute_updated_global_state(self, layer_states, global_state): | |
"""Computes the new global state gives the layers states and old state. | |
Args: | |
layer_states: The current layer states. | |
global_state: The old global state. | |
Returns: | |
The updated global state. | |
""" | |
updated_global_state = [] | |
if self.num_layers == 3: | |
# highest (global) layer | |
# input -> average hidden state from each layer-specific RNN | |
# bias -> None | |
with tf.variable_scope("Layer2_RNN", reuse=self.reuse_global_state): | |
self.reuse_global_state = True | |
global_input = tf.reduce_mean(tf.concat(layer_states, 0), 0, | |
keep_dims=True) | |
updated_global_state, _ = self.cells[2](global_input, global_state[0]) | |
return updated_global_state | |
def apply_gradients(self, grads_and_vars, global_step=None, name=None): | |
"""Overwrites the tf.train.Optimizer interface for applying gradients.""" | |
# Pull out the variables. | |
grads_and_vars = tuple(grads_and_vars) # Make sure repeat iteration works. | |
for g, v in grads_and_vars: | |
if not isinstance(g, (tf.Tensor, tf.IndexedSlices, type(None))): | |
raise TypeError( | |
"Gradient must be a Tensor, IndexedSlices, or None: %s" % g) | |
if not isinstance(v, tf.Variable): | |
raise TypeError( | |
"Variable must be a tf.Variable: %s" % v) | |
if g is not None: | |
self._assert_valid_dtypes([g, v]) | |
var_list = [v for g, v in grads_and_vars if g is not None] | |
if not var_list: | |
raise ValueError("No gradients provided for any variable: %s" % | |
(grads_and_vars,)) | |
# Create slots for the variables. | |
with tf.control_dependencies(None): | |
self._create_slots(var_list) | |
# Store update ops in this list. | |
with tf.op_scope([], name, self._name) as name: | |
# Prepare the global state. | |
with tf.variable_scope(self._name, reuse=self.reuse_global_state): | |
gs = self._initialize_global_state() | |
if gs: | |
global_state = [tf.get_variable("global_state", initializer=gs[0])] | |
else: | |
global_state = [] | |
# Get the states for each variable in the list. | |
states = [{key: self.get_slot(var, key) for key in self.get_slot_names()} | |
for var in var_list] | |
# Compute updated values. | |
grads, params = zip(*grads_and_vars) | |
args = (params, grads, states, global_state) | |
updates = self._compute_updates(*args) | |
new_params, new_states, new_global_state, new_attention = updates | |
# Assign op for new global state. | |
update_ops = [tf.assign(gs, ngs) | |
for gs, ngs in zip(global_state, new_global_state)] | |
# Create the assign ops for the params and state variables. | |
args = (params, states, new_params, new_attention, new_states) | |
for var, state, new_var, new_var_attend, new_state in zip(*args): | |
# Assign updates to the state variables. | |
state_assign_ops = [tf.assign(state_var, new_state[key]) | |
for key, state_var in state.items()] | |
# Update the parameter. | |
with tf.control_dependencies(state_assign_ops): | |
if self.use_attention: | |
# Assign to the attended location, rather than the actual location | |
# so that the gradients are computed where attention is. | |
param_update_op = var.assign(new_var_attend) | |
else: | |
param_update_op = var.assign(new_var) | |
with tf.name_scope("update_" + var.op.name): #, tf.colocate_with(var): | |
update_ops.append(param_update_op) | |
real_params = [self.get_slot(var, "true_param") for var in var_list] | |
if global_step is None: | |
# NOTE: if using the optimizer in a non-test-optimizer setting (e.g. | |
# on Inception), remove the real_params return value. Otherwise | |
# the code will throw an error. | |
return self._finish(update_ops, name), real_params | |
else: | |
with tf.control_dependencies([self._finish(update_ops, "update")]): | |
return state_ops.assign_add(global_step, 1, name=name).op, real_params | |