File size: 12,803 Bytes
0b8359d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
# Copyright 2017 Google, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""Collection of trainable optimizers for meta-optimization."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import math

import numpy as np
import tensorflow as tf

from learned_optimizer.optimizer import utils
from learned_optimizer.optimizer import trainable_optimizer as opt


# Default was 1e-3
tf.app.flags.DEFINE_float("crnn_rnn_readout_scale", 0.5,
                          """The initialization scale for the RNN readouts.""")
tf.app.flags.DEFINE_float("crnn_default_decay_var_init", 2.2,
                          """The default initializer value for any decay/
                             momentum style variables and constants.
                             sigmoid(2.2) ~ 0.9, sigmoid(-2.2) ~ 0.01.""")

FLAGS = tf.flags.FLAGS


class CoordinatewiseRNN(opt.TrainableOptimizer):
  """RNN that operates on each coordinate of the problem independently."""

  def __init__(self,
               cell_sizes,
               cell_cls,
               init_lr_range=(1., 1.),
               dynamic_output_scale=True,
               learnable_decay=True,
               zero_init_lr_weights=False,
               **kwargs):
    """Initializes the RNN per-parameter optimizer.

    Args:
      cell_sizes: List of hidden state sizes for each RNN cell in the network
      cell_cls: tf.contrib.rnn class for specifying the RNN cell type
      init_lr_range: the range in which to initialize the learning rates.
      dynamic_output_scale: whether to learn weights that dynamically modulate
          the output scale (default: True)
      learnable_decay: whether to learn weights that dynamically modulate the
          input scale via RMS style decay (default: True)
      zero_init_lr_weights: whether to initialize the lr weights to zero
      **kwargs: args passed to TrainableOptimizer's constructor

    Raises:
      ValueError: If the init lr range is not of length 2.
      ValueError: If the init lr range is not a valid range (min > max).
    """
    if len(init_lr_range) != 2:
      raise ValueError(
          "Initial LR range must be len 2, was {}".format(len(init_lr_range)))
    if init_lr_range[0] > init_lr_range[1]:
      raise ValueError("Initial LR range min is greater than max.")
    self.init_lr_range = init_lr_range

    self.zero_init_lr_weights = zero_init_lr_weights
    self.reuse_vars = False

    # create the RNN cell
    with tf.variable_scope(opt.OPTIMIZER_SCOPE):
      self.component_cells = [cell_cls(sz) for sz in cell_sizes]
      self.cell = tf.contrib.rnn.MultiRNNCell(self.component_cells)

      # random normal initialization scaled by the output size
      scale_factor = FLAGS.crnn_rnn_readout_scale / math.sqrt(cell_sizes[-1])
      scaled_init = tf.random_normal_initializer(0., scale_factor)

      # weights for projecting the hidden state to a parameter update
      self.update_weights = tf.get_variable("update_weights",
                                            shape=(cell_sizes[-1], 1),
                                            initializer=scaled_init)

      self._initialize_decay(learnable_decay, (cell_sizes[-1], 1), scaled_init)

      self._initialize_lr(dynamic_output_scale, (cell_sizes[-1], 1),
                          scaled_init)

      state_size = sum([sum(state_size) for state_size in self.cell.state_size])
      self._init_vector = tf.get_variable(
          "init_vector", shape=[1, state_size],
          initializer=tf.random_uniform_initializer(-1., 1.))

    state_keys = ["rms", "rnn", "learning_rate", "decay"]
    super(CoordinatewiseRNN, self).__init__("cRNN", state_keys, **kwargs)

  def _initialize_decay(
      self, learnable_decay, weights_tensor_shape, scaled_init):
    """Initializes the decay weights and bias variables or tensors.

    Args:
      learnable_decay: Whether to use learnable decay.
      weights_tensor_shape: The shape the weight tensor should take.
      scaled_init: The scaled initialization for the weights tensor.
    """
    if learnable_decay:

      # weights for projecting the hidden state to the RMS decay term
      self.decay_weights = tf.get_variable("decay_weights",
                                           shape=weights_tensor_shape,
                                           initializer=scaled_init)
      self.decay_bias = tf.get_variable(
          "decay_bias", shape=(1,),
          initializer=tf.constant_initializer(
              FLAGS.crnn_default_decay_var_init))
    else:
      self.decay_weights = tf.zeros_like(self.update_weights)
      self.decay_bias = tf.constant(FLAGS.crnn_default_decay_var_init)

  def _initialize_lr(
      self, dynamic_output_scale, weights_tensor_shape, scaled_init):
    """Initializes the learning rate weights and bias variables or tensors.

    Args:
      dynamic_output_scale: Whether to use a dynamic output scale.
      weights_tensor_shape: The shape the weight tensor should take.
      scaled_init: The scaled initialization for the weights tensor.
    """
    if dynamic_output_scale:
      zero_init = tf.constant_initializer(0.)
      wt_init = zero_init if self.zero_init_lr_weights else scaled_init
      self.lr_weights = tf.get_variable("learning_rate_weights",
                                        shape=weights_tensor_shape,
                                        initializer=wt_init)
      self.lr_bias = tf.get_variable("learning_rate_bias", shape=(1,),
                                     initializer=zero_init)
    else:
      self.lr_weights = tf.zeros_like(self.update_weights)
      self.lr_bias = tf.zeros([1, 1])

  def _initialize_state(self, var):
    """Return a dictionary mapping names of state variables to their values."""
    vectorized_shape = [var.get_shape().num_elements(), 1]

    min_lr = self.init_lr_range[0]
    max_lr = self.init_lr_range[1]
    if min_lr == max_lr:
      init_lr = tf.constant(min_lr, shape=vectorized_shape)
    else:
      actual_vals = tf.random_uniform(vectorized_shape,
                                      np.log(min_lr),
                                      np.log(max_lr))
      init_lr = tf.exp(actual_vals)

    ones = tf.ones(vectorized_shape)
    rnn_init = ones * self._init_vector

    return {
        "rms": tf.ones(vectorized_shape),
        "learning_rate": init_lr,
        "rnn": rnn_init,
        "decay": tf.ones(vectorized_shape),
    }

  def _compute_update(self, param, grad, state):
    """Update parameters given the gradient and state.

    Args:
      param: tensor of parameters
      grad: tensor of gradients with the same shape as param
      state: a dictionary containing any state for the optimizer

    Returns:
      updated_param: updated parameters
      updated_state: updated state variables in a dictionary
    """

    with tf.variable_scope(opt.OPTIMIZER_SCOPE) as scope:

      if self.reuse_vars:
        scope.reuse_variables()
      else:
        self.reuse_vars = True

      param_shape = tf.shape(param)

      (grad_values, decay_state, rms_state, rnn_state, learning_rate_state,
       grad_indices) = self._extract_gradients_and_internal_state(
           grad, state, param_shape)

      # Vectorize and scale the gradients.
      grad_scaled, rms = utils.rms_scaling(grad_values, decay_state, rms_state)

      # Apply the RNN update.
      rnn_state_tuples = self._unpack_rnn_state_into_tuples(rnn_state)
      rnn_output, rnn_state_tuples = self.cell(grad_scaled, rnn_state_tuples)
      rnn_state = self._pack_tuples_into_rnn_state(rnn_state_tuples)

      # Compute the update direction (a linear projection of the RNN output).
      delta = utils.project(rnn_output, self.update_weights)

      # The updated decay is an affine projection of the hidden state
      decay = utils.project(rnn_output, self.decay_weights,
                            bias=self.decay_bias, activation=tf.nn.sigmoid)

      # Compute the change in learning rate (an affine projection of the RNN
      # state, passed through a 2x sigmoid, so the change is bounded).
      learning_rate_change = 2. * utils.project(rnn_output, self.lr_weights,
                                                bias=self.lr_bias,
                                                activation=tf.nn.sigmoid)

      # Update the learning rate.
      new_learning_rate = learning_rate_change * learning_rate_state

      # Apply the update to the parameters.
      update = tf.reshape(new_learning_rate * delta, tf.shape(grad_values))

      if isinstance(grad, tf.IndexedSlices):
        update = utils.stack_tensor(update, grad_indices, param,
                                    param_shape[:1])
        rms = utils.update_slices(rms, grad_indices, state["rms"], param_shape)
        new_learning_rate = utils.update_slices(new_learning_rate, grad_indices,
                                                state["learning_rate"],
                                                param_shape)
        rnn_state = utils.update_slices(rnn_state, grad_indices, state["rnn"],
                                        param_shape)
        decay = utils.update_slices(decay, grad_indices, state["decay"],
                                    param_shape)

      new_param = param - update

      # Collect the update and new state.
      new_state = {
          "rms": rms,
          "learning_rate": new_learning_rate,
          "rnn": rnn_state,
          "decay": decay,
      }

    return new_param, new_state

  def _extract_gradients_and_internal_state(self, grad, state, param_shape):
    """Extracts the gradients and relevant internal state.

    If the gradient is sparse, extracts the appropriate slices from the state.

    Args:
      grad: The current gradient.
      state: The current state.
      param_shape: The shape of the parameter (used if gradient is sparse).

    Returns:
      grad_values: The gradient value tensor.
      decay_state: The current decay state.
      rms_state: The current rms state.
      rnn_state: The current state of the internal rnns.
      learning_rate_state: The current learning rate state.
      grad_indices: The indices for the gradient tensor, if sparse.
          None otherwise.
    """
    if isinstance(grad, tf.IndexedSlices):
      grad_indices, grad_values = utils.accumulate_sparse_gradients(grad)
      decay_state = utils.slice_tensor(state["decay"], grad_indices,
                                       param_shape)
      rms_state = utils.slice_tensor(state["rms"], grad_indices, param_shape)
      rnn_state = utils.slice_tensor(state["rnn"], grad_indices, param_shape)
      learning_rate_state = utils.slice_tensor(state["learning_rate"],
                                               grad_indices, param_shape)
      decay_state.set_shape([None, 1])
      rms_state.set_shape([None, 1])
    else:
      grad_values = grad
      grad_indices = None

      decay_state = state["decay"]
      rms_state = state["rms"]
      rnn_state = state["rnn"]
      learning_rate_state = state["learning_rate"]
    return (grad_values, decay_state, rms_state, rnn_state, learning_rate_state,
            grad_indices)

  def _unpack_rnn_state_into_tuples(self, rnn_state):
    """Creates state tuples from the rnn state vector."""
    rnn_state_tuples = []
    cur_state_pos = 0
    for cell in self.component_cells:
      total_state_size = sum(cell.state_size)
      cur_state = tf.slice(rnn_state, [0, cur_state_pos],
                           [-1, total_state_size])
      cur_state_tuple = tf.split(value=cur_state, num_or_size_splits=2,
                                 axis=1)
      rnn_state_tuples.append(cur_state_tuple)
      cur_state_pos += total_state_size
    return rnn_state_tuples

  def _pack_tuples_into_rnn_state(self, rnn_state_tuples):
    """Creates a single state vector concatenated along column axis."""
    rnn_state = None
    for new_state_tuple in rnn_state_tuples:
      new_c, new_h = new_state_tuple
      if rnn_state is None:
        rnn_state = tf.concat([new_c, new_h], axis=1)
      else:
        rnn_state = tf.concat([rnn_state, tf.concat([new_c, new_h], 1)], axis=1)
    return rnn_state