Spaces:
Running
Running
# Copyright 2017 Google, Inc. All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# ============================================================================== | |
"""Collection of trainable optimizers for meta-optimization.""" | |
from __future__ import absolute_import | |
from __future__ import division | |
from __future__ import print_function | |
import math | |
import numpy as np | |
import tensorflow as tf | |
from learned_optimizer.optimizer import utils | |
from learned_optimizer.optimizer import trainable_optimizer as opt | |
# Default was 1e-3 | |
tf.app.flags.DEFINE_float("crnn_rnn_readout_scale", 0.5, | |
"""The initialization scale for the RNN readouts.""") | |
tf.app.flags.DEFINE_float("crnn_default_decay_var_init", 2.2, | |
"""The default initializer value for any decay/ | |
momentum style variables and constants. | |
sigmoid(2.2) ~ 0.9, sigmoid(-2.2) ~ 0.01.""") | |
FLAGS = tf.flags.FLAGS | |
class CoordinatewiseRNN(opt.TrainableOptimizer): | |
"""RNN that operates on each coordinate of the problem independently.""" | |
def __init__(self, | |
cell_sizes, | |
cell_cls, | |
init_lr_range=(1., 1.), | |
dynamic_output_scale=True, | |
learnable_decay=True, | |
zero_init_lr_weights=False, | |
**kwargs): | |
"""Initializes the RNN per-parameter optimizer. | |
Args: | |
cell_sizes: List of hidden state sizes for each RNN cell in the network | |
cell_cls: tf.contrib.rnn class for specifying the RNN cell type | |
init_lr_range: the range in which to initialize the learning rates. | |
dynamic_output_scale: whether to learn weights that dynamically modulate | |
the output scale (default: True) | |
learnable_decay: whether to learn weights that dynamically modulate the | |
input scale via RMS style decay (default: True) | |
zero_init_lr_weights: whether to initialize the lr weights to zero | |
**kwargs: args passed to TrainableOptimizer's constructor | |
Raises: | |
ValueError: If the init lr range is not of length 2. | |
ValueError: If the init lr range is not a valid range (min > max). | |
""" | |
if len(init_lr_range) != 2: | |
raise ValueError( | |
"Initial LR range must be len 2, was {}".format(len(init_lr_range))) | |
if init_lr_range[0] > init_lr_range[1]: | |
raise ValueError("Initial LR range min is greater than max.") | |
self.init_lr_range = init_lr_range | |
self.zero_init_lr_weights = zero_init_lr_weights | |
self.reuse_vars = False | |
# create the RNN cell | |
with tf.variable_scope(opt.OPTIMIZER_SCOPE): | |
self.component_cells = [cell_cls(sz) for sz in cell_sizes] | |
self.cell = tf.contrib.rnn.MultiRNNCell(self.component_cells) | |
# random normal initialization scaled by the output size | |
scale_factor = FLAGS.crnn_rnn_readout_scale / math.sqrt(cell_sizes[-1]) | |
scaled_init = tf.random_normal_initializer(0., scale_factor) | |
# weights for projecting the hidden state to a parameter update | |
self.update_weights = tf.get_variable("update_weights", | |
shape=(cell_sizes[-1], 1), | |
initializer=scaled_init) | |
self._initialize_decay(learnable_decay, (cell_sizes[-1], 1), scaled_init) | |
self._initialize_lr(dynamic_output_scale, (cell_sizes[-1], 1), | |
scaled_init) | |
state_size = sum([sum(state_size) for state_size in self.cell.state_size]) | |
self._init_vector = tf.get_variable( | |
"init_vector", shape=[1, state_size], | |
initializer=tf.random_uniform_initializer(-1., 1.)) | |
state_keys = ["rms", "rnn", "learning_rate", "decay"] | |
super(CoordinatewiseRNN, self).__init__("cRNN", state_keys, **kwargs) | |
def _initialize_decay( | |
self, learnable_decay, weights_tensor_shape, scaled_init): | |
"""Initializes the decay weights and bias variables or tensors. | |
Args: | |
learnable_decay: Whether to use learnable decay. | |
weights_tensor_shape: The shape the weight tensor should take. | |
scaled_init: The scaled initialization for the weights tensor. | |
""" | |
if learnable_decay: | |
# weights for projecting the hidden state to the RMS decay term | |
self.decay_weights = tf.get_variable("decay_weights", | |
shape=weights_tensor_shape, | |
initializer=scaled_init) | |
self.decay_bias = tf.get_variable( | |
"decay_bias", shape=(1,), | |
initializer=tf.constant_initializer( | |
FLAGS.crnn_default_decay_var_init)) | |
else: | |
self.decay_weights = tf.zeros_like(self.update_weights) | |
self.decay_bias = tf.constant(FLAGS.crnn_default_decay_var_init) | |
def _initialize_lr( | |
self, dynamic_output_scale, weights_tensor_shape, scaled_init): | |
"""Initializes the learning rate weights and bias variables or tensors. | |
Args: | |
dynamic_output_scale: Whether to use a dynamic output scale. | |
weights_tensor_shape: The shape the weight tensor should take. | |
scaled_init: The scaled initialization for the weights tensor. | |
""" | |
if dynamic_output_scale: | |
zero_init = tf.constant_initializer(0.) | |
wt_init = zero_init if self.zero_init_lr_weights else scaled_init | |
self.lr_weights = tf.get_variable("learning_rate_weights", | |
shape=weights_tensor_shape, | |
initializer=wt_init) | |
self.lr_bias = tf.get_variable("learning_rate_bias", shape=(1,), | |
initializer=zero_init) | |
else: | |
self.lr_weights = tf.zeros_like(self.update_weights) | |
self.lr_bias = tf.zeros([1, 1]) | |
def _initialize_state(self, var): | |
"""Return a dictionary mapping names of state variables to their values.""" | |
vectorized_shape = [var.get_shape().num_elements(), 1] | |
min_lr = self.init_lr_range[0] | |
max_lr = self.init_lr_range[1] | |
if min_lr == max_lr: | |
init_lr = tf.constant(min_lr, shape=vectorized_shape) | |
else: | |
actual_vals = tf.random_uniform(vectorized_shape, | |
np.log(min_lr), | |
np.log(max_lr)) | |
init_lr = tf.exp(actual_vals) | |
ones = tf.ones(vectorized_shape) | |
rnn_init = ones * self._init_vector | |
return { | |
"rms": tf.ones(vectorized_shape), | |
"learning_rate": init_lr, | |
"rnn": rnn_init, | |
"decay": tf.ones(vectorized_shape), | |
} | |
def _compute_update(self, param, grad, state): | |
"""Update parameters given the gradient and state. | |
Args: | |
param: tensor of parameters | |
grad: tensor of gradients with the same shape as param | |
state: a dictionary containing any state for the optimizer | |
Returns: | |
updated_param: updated parameters | |
updated_state: updated state variables in a dictionary | |
""" | |
with tf.variable_scope(opt.OPTIMIZER_SCOPE) as scope: | |
if self.reuse_vars: | |
scope.reuse_variables() | |
else: | |
self.reuse_vars = True | |
param_shape = tf.shape(param) | |
(grad_values, decay_state, rms_state, rnn_state, learning_rate_state, | |
grad_indices) = self._extract_gradients_and_internal_state( | |
grad, state, param_shape) | |
# Vectorize and scale the gradients. | |
grad_scaled, rms = utils.rms_scaling(grad_values, decay_state, rms_state) | |
# Apply the RNN update. | |
rnn_state_tuples = self._unpack_rnn_state_into_tuples(rnn_state) | |
rnn_output, rnn_state_tuples = self.cell(grad_scaled, rnn_state_tuples) | |
rnn_state = self._pack_tuples_into_rnn_state(rnn_state_tuples) | |
# Compute the update direction (a linear projection of the RNN output). | |
delta = utils.project(rnn_output, self.update_weights) | |
# The updated decay is an affine projection of the hidden state | |
decay = utils.project(rnn_output, self.decay_weights, | |
bias=self.decay_bias, activation=tf.nn.sigmoid) | |
# Compute the change in learning rate (an affine projection of the RNN | |
# state, passed through a 2x sigmoid, so the change is bounded). | |
learning_rate_change = 2. * utils.project(rnn_output, self.lr_weights, | |
bias=self.lr_bias, | |
activation=tf.nn.sigmoid) | |
# Update the learning rate. | |
new_learning_rate = learning_rate_change * learning_rate_state | |
# Apply the update to the parameters. | |
update = tf.reshape(new_learning_rate * delta, tf.shape(grad_values)) | |
if isinstance(grad, tf.IndexedSlices): | |
update = utils.stack_tensor(update, grad_indices, param, | |
param_shape[:1]) | |
rms = utils.update_slices(rms, grad_indices, state["rms"], param_shape) | |
new_learning_rate = utils.update_slices(new_learning_rate, grad_indices, | |
state["learning_rate"], | |
param_shape) | |
rnn_state = utils.update_slices(rnn_state, grad_indices, state["rnn"], | |
param_shape) | |
decay = utils.update_slices(decay, grad_indices, state["decay"], | |
param_shape) | |
new_param = param - update | |
# Collect the update and new state. | |
new_state = { | |
"rms": rms, | |
"learning_rate": new_learning_rate, | |
"rnn": rnn_state, | |
"decay": decay, | |
} | |
return new_param, new_state | |
def _extract_gradients_and_internal_state(self, grad, state, param_shape): | |
"""Extracts the gradients and relevant internal state. | |
If the gradient is sparse, extracts the appropriate slices from the state. | |
Args: | |
grad: The current gradient. | |
state: The current state. | |
param_shape: The shape of the parameter (used if gradient is sparse). | |
Returns: | |
grad_values: The gradient value tensor. | |
decay_state: The current decay state. | |
rms_state: The current rms state. | |
rnn_state: The current state of the internal rnns. | |
learning_rate_state: The current learning rate state. | |
grad_indices: The indices for the gradient tensor, if sparse. | |
None otherwise. | |
""" | |
if isinstance(grad, tf.IndexedSlices): | |
grad_indices, grad_values = utils.accumulate_sparse_gradients(grad) | |
decay_state = utils.slice_tensor(state["decay"], grad_indices, | |
param_shape) | |
rms_state = utils.slice_tensor(state["rms"], grad_indices, param_shape) | |
rnn_state = utils.slice_tensor(state["rnn"], grad_indices, param_shape) | |
learning_rate_state = utils.slice_tensor(state["learning_rate"], | |
grad_indices, param_shape) | |
decay_state.set_shape([None, 1]) | |
rms_state.set_shape([None, 1]) | |
else: | |
grad_values = grad | |
grad_indices = None | |
decay_state = state["decay"] | |
rms_state = state["rms"] | |
rnn_state = state["rnn"] | |
learning_rate_state = state["learning_rate"] | |
return (grad_values, decay_state, rms_state, rnn_state, learning_rate_state, | |
grad_indices) | |
def _unpack_rnn_state_into_tuples(self, rnn_state): | |
"""Creates state tuples from the rnn state vector.""" | |
rnn_state_tuples = [] | |
cur_state_pos = 0 | |
for cell in self.component_cells: | |
total_state_size = sum(cell.state_size) | |
cur_state = tf.slice(rnn_state, [0, cur_state_pos], | |
[-1, total_state_size]) | |
cur_state_tuple = tf.split(value=cur_state, num_or_size_splits=2, | |
axis=1) | |
rnn_state_tuples.append(cur_state_tuple) | |
cur_state_pos += total_state_size | |
return rnn_state_tuples | |
def _pack_tuples_into_rnn_state(self, rnn_state_tuples): | |
"""Creates a single state vector concatenated along column axis.""" | |
rnn_state = None | |
for new_state_tuple in rnn_state_tuples: | |
new_c, new_h = new_state_tuple | |
if rnn_state is None: | |
rnn_state = tf.concat([new_c, new_h], axis=1) | |
else: | |
rnn_state = tf.concat([rnn_state, tf.concat([new_c, new_h], 1)], axis=1) | |
return rnn_state | |