NCTCMumbai's picture
Upload 2571 files
0b8359d
raw
history blame
7.7 kB
# Copyright 2017 Google, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""A trainable ADAM optimizer that learns its internal variables."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import tensorflow as tf
from learned_optimizer.optimizer import trainable_optimizer as opt
from learned_optimizer.optimizer import utils
class TrainableAdam(opt.TrainableOptimizer):
"""Adam optimizer with learnable scalar parameters.
See Kingma et. al., 2014 for algorithm (http://arxiv.org/abs/1412.6980).
"""
def __init__(self,
learning_rate=1e-3,
beta1=0.9,
beta2=0.999,
epsilon=1e-8,
**kwargs):
"""Initializes the TrainableAdam optimizer with the given initial values.
Args:
learning_rate: The learning rate (default: 1e-3).
beta1: The exponential decay rate for the 1st moment estimates.
beta2: The exponential decay rate for the 2nd moment estimates.
epsilon: A small constant for numerical stability.
**kwargs: Any additional keyword arguments for TrainableOptimizer.
Raises:
ValueError: if the learning rate or epsilon is not positive
ValueError: if beta1 or beta2 is not in (0, 1).
"""
if learning_rate <= 0:
raise ValueError("Learning rate must be positive.")
if epsilon <= 0:
raise ValueError("Epsilon must be positive.")
if not 0 < beta1 < 1 or not 0 < beta2 < 1:
raise ValueError("Beta values must be between 0 and 1, exclusive.")
self._reuse_vars = False
with tf.variable_scope(opt.OPTIMIZER_SCOPE):
def inv_sigmoid(x):
return np.log(x / (1.0 - x))
self.log_learning_rate = tf.get_variable(
"log_learning_rate",
shape=[],
initializer=tf.constant_initializer(np.log(learning_rate)))
self.beta1_logit = tf.get_variable(
"beta1_logit",
shape=[],
initializer=tf.constant_initializer(inv_sigmoid(beta1)))
self.beta2_logit = tf.get_variable(
"beta2_logit",
shape=[],
initializer=tf.constant_initializer(inv_sigmoid(beta2)))
self.log_epsilon = tf.get_variable(
"log_epsilon",
shape=[],
initializer=tf.constant_initializer(np.log(epsilon)))
# Key names are derived from Algorithm 1 described in
# https://arxiv.org/pdf/1412.6980.pdf
state_keys = ["m", "v", "t"]
super(TrainableAdam, self).__init__("Adam", state_keys, **kwargs)
def _initialize_state(self, var):
"""Returns a dictionary mapping names of state variables to their values."""
vectorized_shape = var.get_shape().num_elements(), 1
return {key: tf.zeros(vectorized_shape) for key in self.state_keys}
def _compute_update(self, param, grad, state):
"""Calculates the new internal state and parameters.
If the gradient is sparse, updates the appropriate slices in the internal
state and stacks the update tensor.
Args:
param: A tensor of parameters.
grad: A tensor of gradients with the same shape as param.
state: A dictionary containing any state for the optimizer.
Returns:
updated_param: The updated parameters.
updated_state: The updated state variables in a dictionary.
"""
with tf.variable_scope(opt.OPTIMIZER_SCOPE) as scope:
if self._reuse_vars:
scope.reuse_variables()
else:
self._reuse_vars = True
(grad_values, first_moment, second_moment, timestep, grad_indices
) = self._extract_gradients_and_internal_state(
grad, state, tf.shape(param))
beta1 = tf.nn.sigmoid(self.beta1_logit)
beta2 = tf.nn.sigmoid(self.beta2_logit)
epsilon = tf.exp(self.log_epsilon) + 1e-10
learning_rate = tf.exp(self.log_learning_rate)
old_grad_shape = tf.shape(grad_values)
grad_values = tf.reshape(grad_values, [-1, 1])
new_timestep = timestep + 1
new_first_moment = self._update_adam_estimate(
first_moment, grad_values, beta1)
new_second_moment = self._debias_adam_estimate(
second_moment, tf.square(grad_values), beta2)
debiased_first_moment = self._debias_adam_estimate(
new_first_moment, beta1, new_timestep)
debiased_second_moment = self._debias_adam_estimate(
new_second_moment, beta2, new_timestep)
# Propagating through the square root of 0 is very bad for stability.
update = (learning_rate * debiased_first_moment /
(tf.sqrt(debiased_second_moment + 1e-10) + epsilon))
update = tf.reshape(update, old_grad_shape)
if grad_indices is not None:
param_shape = tf.shape(param)
update = utils.stack_tensor(
update, grad_indices, param, param_shape[:1])
new_first_moment = utils.update_slices(
new_first_moment, grad_indices, state["m"], param_shape)
new_second_moment = utils.update_slices(
new_second_moment, grad_indices, state["v"], param_shape)
new_timestep = utils.update_slices(
new_timestep, grad_indices, state["t"], param_shape)
new_param = param - update
# collect the update and new state
new_state = {
"m": new_first_moment,
"v": new_second_moment,
"t": new_timestep
}
return new_param, new_state
def _update_adam_estimate(self, estimate, value, beta):
"""Returns a beta-weighted average of estimate and value."""
return (beta * estimate) + ((1 - beta) * value)
def _debias_adam_estimate(self, estimate, beta, t_step):
"""Returns a debiased estimate based on beta and the timestep."""
return estimate / (1 - tf.pow(beta, t_step))
def _extract_gradients_and_internal_state(self, grad, state, param_shape):
"""Extracts the gradients and relevant internal state.
If the gradient is sparse, extracts the appropriate slices from the state.
Args:
grad: The current gradient.
state: The current state.
param_shape: The shape of the parameter (used if gradient is sparse).
Returns:
grad_values: The gradient value tensor.
first_moment: The first moment tensor (internal state).
second_moment: The second moment tensor (internal state).
timestep: The current timestep (internal state).
grad_indices: The indices for the gradient tensor, if sparse.
None otherwise.
"""
grad_values = grad
grad_indices = None
first_moment = state["m"]
second_moment = state["v"]
timestep = state["t"]
if isinstance(grad, tf.IndexedSlices):
grad_indices, grad_values = utils.accumulate_sparse_gradients(grad)
first_moment = utils.slice_tensor(
first_moment, grad_indices, param_shape)
second_moment = utils.slice_tensor(
second_moment, grad_indices, param_shape)
timestep = utils.slice_tensor(timestep, grad_indices, param_shape)
return grad_values, first_moment, second_moment, timestep, grad_indices