# Copyright 2017 Google, Inc. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """A trainable ADAM optimizer that learns its internal variables.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function import numpy as np import tensorflow as tf from learned_optimizer.optimizer import trainable_optimizer as opt from learned_optimizer.optimizer import utils class TrainableAdam(opt.TrainableOptimizer): """Adam optimizer with learnable scalar parameters. See Kingma et. al., 2014 for algorithm (http://arxiv.org/abs/1412.6980). """ def __init__(self, learning_rate=1e-3, beta1=0.9, beta2=0.999, epsilon=1e-8, **kwargs): """Initializes the TrainableAdam optimizer with the given initial values. Args: learning_rate: The learning rate (default: 1e-3). beta1: The exponential decay rate for the 1st moment estimates. beta2: The exponential decay rate for the 2nd moment estimates. epsilon: A small constant for numerical stability. **kwargs: Any additional keyword arguments for TrainableOptimizer. Raises: ValueError: if the learning rate or epsilon is not positive ValueError: if beta1 or beta2 is not in (0, 1). """ if learning_rate <= 0: raise ValueError("Learning rate must be positive.") if epsilon <= 0: raise ValueError("Epsilon must be positive.") if not 0 < beta1 < 1 or not 0 < beta2 < 1: raise ValueError("Beta values must be between 0 and 1, exclusive.") self._reuse_vars = False with tf.variable_scope(opt.OPTIMIZER_SCOPE): def inv_sigmoid(x): return np.log(x / (1.0 - x)) self.log_learning_rate = tf.get_variable( "log_learning_rate", shape=[], initializer=tf.constant_initializer(np.log(learning_rate))) self.beta1_logit = tf.get_variable( "beta1_logit", shape=[], initializer=tf.constant_initializer(inv_sigmoid(beta1))) self.beta2_logit = tf.get_variable( "beta2_logit", shape=[], initializer=tf.constant_initializer(inv_sigmoid(beta2))) self.log_epsilon = tf.get_variable( "log_epsilon", shape=[], initializer=tf.constant_initializer(np.log(epsilon))) # Key names are derived from Algorithm 1 described in # https://arxiv.org/pdf/1412.6980.pdf state_keys = ["m", "v", "t"] super(TrainableAdam, self).__init__("Adam", state_keys, **kwargs) def _initialize_state(self, var): """Returns a dictionary mapping names of state variables to their values.""" vectorized_shape = var.get_shape().num_elements(), 1 return {key: tf.zeros(vectorized_shape) for key in self.state_keys} def _compute_update(self, param, grad, state): """Calculates the new internal state and parameters. If the gradient is sparse, updates the appropriate slices in the internal state and stacks the update tensor. Args: param: A tensor of parameters. grad: A tensor of gradients with the same shape as param. state: A dictionary containing any state for the optimizer. Returns: updated_param: The updated parameters. updated_state: The updated state variables in a dictionary. """ with tf.variable_scope(opt.OPTIMIZER_SCOPE) as scope: if self._reuse_vars: scope.reuse_variables() else: self._reuse_vars = True (grad_values, first_moment, second_moment, timestep, grad_indices ) = self._extract_gradients_and_internal_state( grad, state, tf.shape(param)) beta1 = tf.nn.sigmoid(self.beta1_logit) beta2 = tf.nn.sigmoid(self.beta2_logit) epsilon = tf.exp(self.log_epsilon) + 1e-10 learning_rate = tf.exp(self.log_learning_rate) old_grad_shape = tf.shape(grad_values) grad_values = tf.reshape(grad_values, [-1, 1]) new_timestep = timestep + 1 new_first_moment = self._update_adam_estimate( first_moment, grad_values, beta1) new_second_moment = self._debias_adam_estimate( second_moment, tf.square(grad_values), beta2) debiased_first_moment = self._debias_adam_estimate( new_first_moment, beta1, new_timestep) debiased_second_moment = self._debias_adam_estimate( new_second_moment, beta2, new_timestep) # Propagating through the square root of 0 is very bad for stability. update = (learning_rate * debiased_first_moment / (tf.sqrt(debiased_second_moment + 1e-10) + epsilon)) update = tf.reshape(update, old_grad_shape) if grad_indices is not None: param_shape = tf.shape(param) update = utils.stack_tensor( update, grad_indices, param, param_shape[:1]) new_first_moment = utils.update_slices( new_first_moment, grad_indices, state["m"], param_shape) new_second_moment = utils.update_slices( new_second_moment, grad_indices, state["v"], param_shape) new_timestep = utils.update_slices( new_timestep, grad_indices, state["t"], param_shape) new_param = param - update # collect the update and new state new_state = { "m": new_first_moment, "v": new_second_moment, "t": new_timestep } return new_param, new_state def _update_adam_estimate(self, estimate, value, beta): """Returns a beta-weighted average of estimate and value.""" return (beta * estimate) + ((1 - beta) * value) def _debias_adam_estimate(self, estimate, beta, t_step): """Returns a debiased estimate based on beta and the timestep.""" return estimate / (1 - tf.pow(beta, t_step)) def _extract_gradients_and_internal_state(self, grad, state, param_shape): """Extracts the gradients and relevant internal state. If the gradient is sparse, extracts the appropriate slices from the state. Args: grad: The current gradient. state: The current state. param_shape: The shape of the parameter (used if gradient is sparse). Returns: grad_values: The gradient value tensor. first_moment: The first moment tensor (internal state). second_moment: The second moment tensor (internal state). timestep: The current timestep (internal state). grad_indices: The indices for the gradient tensor, if sparse. None otherwise. """ grad_values = grad grad_indices = None first_moment = state["m"] second_moment = state["v"] timestep = state["t"] if isinstance(grad, tf.IndexedSlices): grad_indices, grad_values = utils.accumulate_sparse_gradients(grad) first_moment = utils.slice_tensor( first_moment, grad_indices, param_shape) second_moment = utils.slice_tensor( second_moment, grad_indices, param_shape) timestep = utils.slice_tensor(timestep, grad_indices, param_shape) return grad_values, first_moment, second_moment, timestep, grad_indices