NCTCMumbai's picture
Upload 2571 files
0b8359d
raw
history blame
12.8 kB
# Copyright 2017 Google, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Collection of trainable optimizers for meta-optimization."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import math
import numpy as np
import tensorflow as tf
from learned_optimizer.optimizer import utils
from learned_optimizer.optimizer import trainable_optimizer as opt
# Default was 1e-3
tf.app.flags.DEFINE_float("crnn_rnn_readout_scale", 0.5,
"""The initialization scale for the RNN readouts.""")
tf.app.flags.DEFINE_float("crnn_default_decay_var_init", 2.2,
"""The default initializer value for any decay/
momentum style variables and constants.
sigmoid(2.2) ~ 0.9, sigmoid(-2.2) ~ 0.01.""")
FLAGS = tf.flags.FLAGS
class CoordinatewiseRNN(opt.TrainableOptimizer):
"""RNN that operates on each coordinate of the problem independently."""
def __init__(self,
cell_sizes,
cell_cls,
init_lr_range=(1., 1.),
dynamic_output_scale=True,
learnable_decay=True,
zero_init_lr_weights=False,
**kwargs):
"""Initializes the RNN per-parameter optimizer.
Args:
cell_sizes: List of hidden state sizes for each RNN cell in the network
cell_cls: tf.contrib.rnn class for specifying the RNN cell type
init_lr_range: the range in which to initialize the learning rates.
dynamic_output_scale: whether to learn weights that dynamically modulate
the output scale (default: True)
learnable_decay: whether to learn weights that dynamically modulate the
input scale via RMS style decay (default: True)
zero_init_lr_weights: whether to initialize the lr weights to zero
**kwargs: args passed to TrainableOptimizer's constructor
Raises:
ValueError: If the init lr range is not of length 2.
ValueError: If the init lr range is not a valid range (min > max).
"""
if len(init_lr_range) != 2:
raise ValueError(
"Initial LR range must be len 2, was {}".format(len(init_lr_range)))
if init_lr_range[0] > init_lr_range[1]:
raise ValueError("Initial LR range min is greater than max.")
self.init_lr_range = init_lr_range
self.zero_init_lr_weights = zero_init_lr_weights
self.reuse_vars = False
# create the RNN cell
with tf.variable_scope(opt.OPTIMIZER_SCOPE):
self.component_cells = [cell_cls(sz) for sz in cell_sizes]
self.cell = tf.contrib.rnn.MultiRNNCell(self.component_cells)
# random normal initialization scaled by the output size
scale_factor = FLAGS.crnn_rnn_readout_scale / math.sqrt(cell_sizes[-1])
scaled_init = tf.random_normal_initializer(0., scale_factor)
# weights for projecting the hidden state to a parameter update
self.update_weights = tf.get_variable("update_weights",
shape=(cell_sizes[-1], 1),
initializer=scaled_init)
self._initialize_decay(learnable_decay, (cell_sizes[-1], 1), scaled_init)
self._initialize_lr(dynamic_output_scale, (cell_sizes[-1], 1),
scaled_init)
state_size = sum([sum(state_size) for state_size in self.cell.state_size])
self._init_vector = tf.get_variable(
"init_vector", shape=[1, state_size],
initializer=tf.random_uniform_initializer(-1., 1.))
state_keys = ["rms", "rnn", "learning_rate", "decay"]
super(CoordinatewiseRNN, self).__init__("cRNN", state_keys, **kwargs)
def _initialize_decay(
self, learnable_decay, weights_tensor_shape, scaled_init):
"""Initializes the decay weights and bias variables or tensors.
Args:
learnable_decay: Whether to use learnable decay.
weights_tensor_shape: The shape the weight tensor should take.
scaled_init: The scaled initialization for the weights tensor.
"""
if learnable_decay:
# weights for projecting the hidden state to the RMS decay term
self.decay_weights = tf.get_variable("decay_weights",
shape=weights_tensor_shape,
initializer=scaled_init)
self.decay_bias = tf.get_variable(
"decay_bias", shape=(1,),
initializer=tf.constant_initializer(
FLAGS.crnn_default_decay_var_init))
else:
self.decay_weights = tf.zeros_like(self.update_weights)
self.decay_bias = tf.constant(FLAGS.crnn_default_decay_var_init)
def _initialize_lr(
self, dynamic_output_scale, weights_tensor_shape, scaled_init):
"""Initializes the learning rate weights and bias variables or tensors.
Args:
dynamic_output_scale: Whether to use a dynamic output scale.
weights_tensor_shape: The shape the weight tensor should take.
scaled_init: The scaled initialization for the weights tensor.
"""
if dynamic_output_scale:
zero_init = tf.constant_initializer(0.)
wt_init = zero_init if self.zero_init_lr_weights else scaled_init
self.lr_weights = tf.get_variable("learning_rate_weights",
shape=weights_tensor_shape,
initializer=wt_init)
self.lr_bias = tf.get_variable("learning_rate_bias", shape=(1,),
initializer=zero_init)
else:
self.lr_weights = tf.zeros_like(self.update_weights)
self.lr_bias = tf.zeros([1, 1])
def _initialize_state(self, var):
"""Return a dictionary mapping names of state variables to their values."""
vectorized_shape = [var.get_shape().num_elements(), 1]
min_lr = self.init_lr_range[0]
max_lr = self.init_lr_range[1]
if min_lr == max_lr:
init_lr = tf.constant(min_lr, shape=vectorized_shape)
else:
actual_vals = tf.random_uniform(vectorized_shape,
np.log(min_lr),
np.log(max_lr))
init_lr = tf.exp(actual_vals)
ones = tf.ones(vectorized_shape)
rnn_init = ones * self._init_vector
return {
"rms": tf.ones(vectorized_shape),
"learning_rate": init_lr,
"rnn": rnn_init,
"decay": tf.ones(vectorized_shape),
}
def _compute_update(self, param, grad, state):
"""Update parameters given the gradient and state.
Args:
param: tensor of parameters
grad: tensor of gradients with the same shape as param
state: a dictionary containing any state for the optimizer
Returns:
updated_param: updated parameters
updated_state: updated state variables in a dictionary
"""
with tf.variable_scope(opt.OPTIMIZER_SCOPE) as scope:
if self.reuse_vars:
scope.reuse_variables()
else:
self.reuse_vars = True
param_shape = tf.shape(param)
(grad_values, decay_state, rms_state, rnn_state, learning_rate_state,
grad_indices) = self._extract_gradients_and_internal_state(
grad, state, param_shape)
# Vectorize and scale the gradients.
grad_scaled, rms = utils.rms_scaling(grad_values, decay_state, rms_state)
# Apply the RNN update.
rnn_state_tuples = self._unpack_rnn_state_into_tuples(rnn_state)
rnn_output, rnn_state_tuples = self.cell(grad_scaled, rnn_state_tuples)
rnn_state = self._pack_tuples_into_rnn_state(rnn_state_tuples)
# Compute the update direction (a linear projection of the RNN output).
delta = utils.project(rnn_output, self.update_weights)
# The updated decay is an affine projection of the hidden state
decay = utils.project(rnn_output, self.decay_weights,
bias=self.decay_bias, activation=tf.nn.sigmoid)
# Compute the change in learning rate (an affine projection of the RNN
# state, passed through a 2x sigmoid, so the change is bounded).
learning_rate_change = 2. * utils.project(rnn_output, self.lr_weights,
bias=self.lr_bias,
activation=tf.nn.sigmoid)
# Update the learning rate.
new_learning_rate = learning_rate_change * learning_rate_state
# Apply the update to the parameters.
update = tf.reshape(new_learning_rate * delta, tf.shape(grad_values))
if isinstance(grad, tf.IndexedSlices):
update = utils.stack_tensor(update, grad_indices, param,
param_shape[:1])
rms = utils.update_slices(rms, grad_indices, state["rms"], param_shape)
new_learning_rate = utils.update_slices(new_learning_rate, grad_indices,
state["learning_rate"],
param_shape)
rnn_state = utils.update_slices(rnn_state, grad_indices, state["rnn"],
param_shape)
decay = utils.update_slices(decay, grad_indices, state["decay"],
param_shape)
new_param = param - update
# Collect the update and new state.
new_state = {
"rms": rms,
"learning_rate": new_learning_rate,
"rnn": rnn_state,
"decay": decay,
}
return new_param, new_state
def _extract_gradients_and_internal_state(self, grad, state, param_shape):
"""Extracts the gradients and relevant internal state.
If the gradient is sparse, extracts the appropriate slices from the state.
Args:
grad: The current gradient.
state: The current state.
param_shape: The shape of the parameter (used if gradient is sparse).
Returns:
grad_values: The gradient value tensor.
decay_state: The current decay state.
rms_state: The current rms state.
rnn_state: The current state of the internal rnns.
learning_rate_state: The current learning rate state.
grad_indices: The indices for the gradient tensor, if sparse.
None otherwise.
"""
if isinstance(grad, tf.IndexedSlices):
grad_indices, grad_values = utils.accumulate_sparse_gradients(grad)
decay_state = utils.slice_tensor(state["decay"], grad_indices,
param_shape)
rms_state = utils.slice_tensor(state["rms"], grad_indices, param_shape)
rnn_state = utils.slice_tensor(state["rnn"], grad_indices, param_shape)
learning_rate_state = utils.slice_tensor(state["learning_rate"],
grad_indices, param_shape)
decay_state.set_shape([None, 1])
rms_state.set_shape([None, 1])
else:
grad_values = grad
grad_indices = None
decay_state = state["decay"]
rms_state = state["rms"]
rnn_state = state["rnn"]
learning_rate_state = state["learning_rate"]
return (grad_values, decay_state, rms_state, rnn_state, learning_rate_state,
grad_indices)
def _unpack_rnn_state_into_tuples(self, rnn_state):
"""Creates state tuples from the rnn state vector."""
rnn_state_tuples = []
cur_state_pos = 0
for cell in self.component_cells:
total_state_size = sum(cell.state_size)
cur_state = tf.slice(rnn_state, [0, cur_state_pos],
[-1, total_state_size])
cur_state_tuple = tf.split(value=cur_state, num_or_size_splits=2,
axis=1)
rnn_state_tuples.append(cur_state_tuple)
cur_state_pos += total_state_size
return rnn_state_tuples
def _pack_tuples_into_rnn_state(self, rnn_state_tuples):
"""Creates a single state vector concatenated along column axis."""
rnn_state = None
for new_state_tuple in rnn_state_tuples:
new_c, new_h = new_state_tuple
if rnn_state is None:
rnn_state = tf.concat([new_c, new_h], axis=1)
else:
rnn_state = tf.concat([rnn_state, tf.concat([new_c, new_h], 1)], axis=1)
return rnn_state