File size: 24,034 Bytes
0b8359d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
# Copyright 2017 Google, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""A base class definition for trainable optimizers."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import collections
import itertools

import tensorflow as tf

from tensorflow.python.framework import tensor_shape

OPTIMIZER_SCOPE = "LOL"
_LOCAL_VARIABLE_PREFIX = "local_state_"
_LOCAL_STATE_VARIABLE_COLLECTION = "local_state_collection"
EPSILON = 1e-6


class TrainableOptimizer(tf.train.Optimizer):
  """Base class for trainable optimizers.

  A trainable optimizer is an optimizer that has parameters that can themselves
  be learned (meta-optimized).

  Subclasses must implement:
      _compute_update(self, param, grad, state)
  """

  def __init__(self, name, state_keys, use_attention=False,
               use_log_objective=False, obj_train_max_multiplier=-1,
               use_second_derivatives=True, use_numerator_epsilon=False,
               **kwargs):
    """Initializes the optimizer with the given name and settings.

    Args:
      name: The name string for this optimizer.
      state_keys: The names of any required state variables (list)
      use_attention: Whether this optimizer uses attention (Default: True)
      use_log_objective: Whether this optimizer uses the logarithm of the
          objective when computing the loss (Default: False)
      obj_train_max_multiplier: The maximum multiplier for the increase in the
          objective before meta-training is stopped. If <= 0, meta-training is
          not stopped early. (Default: -1)
      use_second_derivatives: Whether this optimizer uses second derivatives in
          meta-training. This should be set to False if some second derivatives
          in the meta-training problem set are not defined in Tensorflow.
          (Default: True)
      use_numerator_epsilon: Whether to use epsilon in the numerator when
          scaling the problem objective during meta-training. (Default: False)
      **kwargs: Any additional keyword arguments.
    """
    self.use_second_derivatives = use_second_derivatives
    self.state_keys = sorted(state_keys)
    self.use_attention = use_attention
    self.use_log_objective = use_log_objective
    self.obj_train_max_multiplier = obj_train_max_multiplier
    self.use_numerator_epsilon = use_numerator_epsilon

    use_locking = False
    super(TrainableOptimizer, self).__init__(use_locking, name)

  def _create_slots(self, var_list):
    """Creates all slots needed by the variables.

    Args:
      var_list: A list of `Variable` objects.
    """
    for var in var_list:
      init_states = self._initialize_state(var)
      for slot_name in sorted(init_states):
        slot_var_name = "{}_{}".format(self.get_name(), slot_name)
        value = init_states[slot_name]
        self._get_or_make_slot(var, value, slot_name, slot_var_name)

  def _initialize_state(self, var):
    """Initializes any state required for this variable.

    Args:
      var: a tensor containing parameters to be optimized

    Returns:
      state: a dictionary mapping state keys to initial state values (tensors)
    """
    return {}

  def _initialize_global_state(self):
    """Initializes any global state values."""
    return []

  def _apply_common(self, grad, var):
    """Applies the optimizer updates to the variables.

    Note: this should only get called via _apply_dense or _apply_sparse when
    using the optimizer via optimizer.minimize or optimizer.apply_gradients.
    During meta-training, the optimizer.train function should be used to
    construct an optimization path that is differentiable.

    Args:
      grad: A tensor representing the gradient.
      var: A tf.Variable with the same shape as grad.

    Returns:
      update_op: A tensorflow op that assigns new values to the variable, and
          also defines dependencies that update the state variables for the
          optimizer.
    """
    state = {key: self.get_slot(var, key) for key in self.get_slot_names()}
    new_var, new_state = self._compute_update(var, grad, state)
    state_assign_ops = [tf.assign(state_var, new_state[key])
                        for key, state_var in state.items()]
    with tf.control_dependencies(state_assign_ops):
      update_op = var.assign(new_var)

    return update_op

  def _apply_dense(self, grad, var):
    """Adds ops to apply dense gradients to 'var'."""
    return self._apply_common(grad, var)

  def _apply_sparse(self, grad, var):
    """Adds ops to apply sparse gradients to 'var'."""
    return self._apply_common(grad, var)

  def _compute_update(self, param, grad, state):
    """Computes the update step for optimization.

    Args:
      param: A tensor of parameters to optimize.
      grad: The gradient tensor of the objective with respect to the parameters.
          (It has the same shape as param.)
      state: A dictionary containing any extra state required by the optimizer.

    Returns:
      updated_params: The updated parameters.
      updated_state: The dictionary of updated state variable(s).
    """
    raise NotImplementedError

  def _compute_updates(self, params, grads, states, global_state):
    """Maps the compute update functions for each parameter.

    This function can be overriden by a subclass if the subclass wants to
    combine information across the different parameters in the list.

    Args:
      params: A list of parameter tensors.
      grads: A list of gradients corresponding to each parameter.
      states: A list of state variables corresponding to each parameter.
      global_state: A list of global state variables for the problem.

    Returns:
      new_params: The updated parameters.
      new_states: The updated states.
      new_global_state: The updated global state.
      attention_params: A list of attention parameters. This is the same as
          new_params if the optimizer does not use attention.
    """
    # Zip up the arguments to _compute_update.
    args = zip(params, grads, states)

    # Call compute_update on each set of parameter/gradient/state args.
    new_params, new_states = zip(*list(
        itertools.starmap(self._compute_update, args)))

    # Global state is unused in the basic case, just pass it through.
    return list(new_params), list(new_states), global_state, list(new_params)

  def train(self, problem, dataset):
    """Creates graph operations to train the optimizer.

    Args:
      problem: A problem_generator.Problem instance to train on.
      dataset: A datasets.Dataset tuple to use when training.

    Returns:
      meta_objective: A tensorflow operation for computing the meta-objective
      obj_weights: A tensor placeholder for feeding in the objective weights
      obj_values: The subproblem objective values during optimization
      batches: The batch indexes tensor for overriding with feed_dict
      first_unroll: A placeholder signifying if this is a first unroll
        (this will propagate the gradients slightly differently).
      reset_state: A placeholder signifying that the rnn state should be reset.
      output_state: The final state of the optimizer
      init_loop_vars_to_override: Local variables that can be assigned to
        propagate the optimizer and problem state for unrolling
      final_loop_vals: Final values of the loop variables that can be
        assigned to init_loop_vars_to_override.
    """

    # Placeholder for the objective weights
    obj_weights = tf.placeholder(tf.float32)
    num_iter = tf.shape(obj_weights)[0]

    # Unpack the dataset and generate the minibatches for training
    data, labels = dataset
    # Convert the ndarrays to tensors so we can pass them back in via feed_dict
    data = tf.constant(data)
    labels = tf.constant(labels)
    batches = tf.placeholder(tf.int32)
    first_unroll = tf.placeholder_with_default(False, [])
    reset_state = tf.placeholder_with_default(False, [])

    training_output = collections.namedtuple("TrainingOutput",
                                             ["metaobj",
                                              "obj_weights",
                                              "problem_objectives",
                                              "initial_obj",
                                              "batches",
                                              "first_unroll",
                                              "reset_state",
                                              "output_state",
                                              "init_loop_vars",
                                              "output_loop_vars"])

    def loop_body(itr, obj_accum, params, attend_params, flattened_states,
                  global_state, all_obj, unused_init_obj, data,
                  labels, batches):
      """Body of the meta-training while loop for optimizing a sub-problem.

      Args:
        itr: The current meta-training iteration.
        obj_accum: The accumulated objective over all training steps so far.
        params: The parameters of the sub-problem.
        attend_params: The parameters of the sub-problems at the attended
            location.
        flattened_states: The states of the trainable optimizer, sorted and
            flattened into a list (since a while loop can't handle nested lists
            or dictionaries).
        global_state: The global state of the optimizer.
        all_obj: The list of all objective values in the training process.
        unused_init_obj: The initial objective (unused here, but needed in the
            variable list because it's used in a stopping condition in the
            loop_cond.)
        data: The data for this problem.
        labels: The labels corresponding to the data.
        batches: The batch indexes needed for shuffled minibatch creation.

      Returns:
        itr: The updated meta-training iteration.
        obj_accum: The updated accumulated objective.
        params: The new parameters of the sub-problem.
        attend_params: The new parameters of the sub-problems at the attended
            location.
        flattened_states: The new states of the trainable optimizer.
        global_state: The updated global state.
        all_obj: The updates list of all objective values.
        unused_init_obj: The initial objective.
        data: The data for this problem.
        labels: The labels corresponding to the data.
        batches: The batch indexes needed for shuffled minibatch creation.
      """
      batch_indices = tf.gather(batches, itr)
      batch_data = tf.gather(data, batch_indices)
      batch_labels = tf.gather(labels, batch_indices)

      # Compute the objective over the entire dataset (full batch).
      obj = problem.objective(params, data, labels)

      # Compute the gradients on just the current batch
      if self.use_attention:
        current_obj = problem.objective(attend_params, batch_data, batch_labels)
        grads = problem.gradients(current_obj, attend_params)
      else:
        current_obj = problem.objective(params, batch_data, batch_labels)
        grads = problem.gradients(current_obj, params)

      if not self.use_second_derivatives:
        new_grads = []
        for grad in grads:
          if isinstance(grad, tf.IndexedSlices):
            new_grads.append(
                tf.IndexedSlices(tf.stop_gradient(grad.values), grad.indices))
          else:
            new_grads.append(tf.stop_gradient(grad))
        grads = new_grads

      # store the objective value for the entire problem at each iteration
      all_obj = tf.concat([all_obj, tf.reshape(obj, (1,))], 0)

      # accumulate the weighted objective for the entire dataset
      acc = tf.gather(obj_weights, itr) * obj

      obj_accum = tf.add(obj_accum, acc)
      # Set the shape to keep the shape invariant for obj_accum. Without this,
      # the graph builder thinks the tensor shape is unknown on the 2nd iter.
      obj_accum.set_shape([])

      # convert flattened_states to dictionaries
      dict_states = [dict(zip(self.state_keys, flat_state))
                     for flat_state in flattened_states]

      # compute the new parameters and states
      args = (params, grads, dict_states, global_state)
      updates = self._compute_updates(*args)
      new_params, new_states, new_global_state, new_attend_params = updates

      # flatten the states
      new_flattened_states = map(flatten_and_sort, new_states)

      return [itr + 1, obj_accum, new_params, new_attend_params,
              new_flattened_states, new_global_state, all_obj, unused_init_obj,
              data, labels, batches]

    def loop_cond(itr, obj_accum, unused_params, unused_attend_params,
                  unused_flattened_states, unused_global_state, all_obj,
                  init_obj, *args):
      """Termination conditions of the sub-problem optimization loop."""
      del args  # unused

      cond1 = tf.less(itr, num_iter)  # We've run < num_iter times
      cond2 = tf.is_finite(obj_accum)  # The objective is still finite

      if self.obj_train_max_multiplier > 0:
        current_obj = tf.gather(all_obj, itr)
        # Account for negative init_obj too
        max_diff = (self.obj_train_max_multiplier - 1) * tf.abs(init_obj)
        max_obj = init_obj + max_diff
        # The objective is a reasonable multiplier of the original objective
        cond3 = tf.less(current_obj, max_obj)

        return tf.logical_and(tf.logical_and(cond1, cond2), cond3,
                              name="training_loop_cond")
      else:
        return tf.logical_and(cond1, cond2, name="training_loop_cond")

    init = self._initialize_training_loop_parameters(
        problem, data, labels, batches, first_unroll, reset_state)
    loop_vars, invariants, initial_obj, init_loop_vars_to_override = init

    loop_output = tf.while_loop(loop_cond, loop_body, loop_vars,
                                swap_memory=True, shape_invariants=invariants)
    meta_obj, problem_objectives = loop_output[1], loop_output[6]

    # The meta objective is normalized by the initial objective at the start of
    # the series of partial unrolls.
    scaled_meta_objective = self.scale_objective(
        meta_obj, problem_objectives, initial_obj)

    final_loop_vals = (
        [initial_obj] + loop_output[2] + loop_output[3] + loop_output[5])
    final_loop_vals.extend(itertools.chain(*loop_output[4]))

    return training_output(scaled_meta_objective,
                           obj_weights,
                           problem_objectives,
                           initial_obj,
                           batches,
                           first_unroll,
                           reset_state,
                           loop_output[4],
                           init_loop_vars_to_override,
                           final_loop_vals)

  def _initialize_training_loop_parameters(
      self, problem, data, labels, batches, first_unroll, reset_state):
    """Initializes the vars and params needed for the training process.

    Args:
      problem: The problem being optimized.
      data: The data for the problem.
      labels: The corresponding labels for the data.
      batches: The indexes needed to create shuffled batches of the data.
      first_unroll: Whether this is the first unroll in a partial unrolling.
      reset_state: Whether RNN state variables should be reset.

    Returns:
      loop_vars: The while loop variables for training.
      invariants: The corresponding variable shapes (required by while loop).
      initial_obj: The initial objective (used later for scaling).
      init_loop_vars_to_override: The loop vars that can be overridden when
          performing training via partial unrolls.
    """
    # Extract these separately so we don't have to make inter-variable
    # dependencies.
    initial_tensors = problem.init_tensors()

    return_initial_tensor_values = first_unroll
    initial_params_vars, initial_params = local_state_variables(
        initial_tensors, return_initial_tensor_values)
    initial_attend_params_vars, initial_attend_params = local_state_variables(
        initial_tensors, return_initial_tensor_values)
    # Recalculate the initial objective for the list on each partial unroll with
    # the new initial_params. initial_obj holds the value from the very first
    # unroll.
    initial_obj_init = problem.objective(initial_params, data, labels)
    return_initial_obj_init = first_unroll
    [initial_obj_var], [initial_obj] = local_state_variables(
        [initial_obj_init], return_initial_obj_init)

    # Initialize the loop variables.
    initial_itr = tf.constant(0, dtype=tf.int32)
    initial_meta_obj = tf.constant(0, dtype=tf.float32)
    # N.B. the use of initial_obj_init here rather than initial_obj
    initial_problem_objectives = tf.reshape(initial_obj_init, (1,))

    # Initialize the extra state.
    initial_state_vars = []
    initial_state = []
    state_shapes = []
    return_initial_state_values = reset_state
    for param in initial_tensors:
      param_state_vars, param_state = local_state_variables(
          flatten_and_sort(self._initialize_state(param)),
          return_initial_state_values)

      initial_state_vars.append(param_state_vars)
      initial_state.append(param_state)
      state_shapes.append([f.get_shape() for f in param_state])

    # Initialize any global (problem-level) state.
    initial_global_state_vars, initial_global_state = local_state_variables(
        self._initialize_global_state(), return_initial_state_values)

    global_shapes = []
    for item in initial_global_state:
      global_shapes.append(item.get_shape())

    # build the list of loop variables:
    loop_vars = [
        initial_itr,
        initial_meta_obj,
        initial_params,         # Local variables.
        initial_attend_params,  # Local variables.
        initial_state,          # Local variables.
        initial_global_state,   # Local variables.
        initial_problem_objectives,
        initial_obj,            # Local variable.
        data,
        labels,
        batches,
    ]

    invariants = [
        initial_itr.get_shape(),
        initial_meta_obj.get_shape(),
        [t.get_shape() for t in initial_params],
        [t.get_shape() for t in initial_attend_params],
        state_shapes,
        global_shapes,
        tensor_shape.TensorShape([None]),   # The problem objectives list grows
        initial_obj.get_shape(),
        tensor_shape.unknown_shape(),  # Placeholder shapes are unknown
        tensor_shape.unknown_shape(),
        tensor_shape.unknown_shape(),
    ]

    # Initialize local variables that we will override with final tensors at the
    # next iter.
    init_loop_vars_to_override = (
        [initial_obj_var] + initial_params_vars + initial_attend_params_vars +
        initial_global_state_vars)
    init_loop_vars_to_override.extend(itertools.chain(*initial_state_vars))

    return loop_vars, invariants, initial_obj, init_loop_vars_to_override

  def scale_objective(self, total_obj, all_objs, initial_obj,
                      obj_scale_eps=1e-6):
    """Normalizes the objective based on the initial objective value.

    Args:
      total_obj: The total accumulated objective over the training run.
      all_objs: A list of all the individual objectives over the training run.
      initial_obj: The initial objective value.
      obj_scale_eps: The epsilon value to use in computations for stability.

    Returns:
      The scaled objective as a single value.
    """
    if self.use_log_objective:
      if self.use_numerator_epsilon:
        scaled_problem_obj = ((all_objs + obj_scale_eps) /
                              (initial_obj + obj_scale_eps))
        log_scaled_problem_obj = tf.log(scaled_problem_obj)
      else:
        scaled_problem_obj = all_objs / (initial_obj + obj_scale_eps)
        log_scaled_problem_obj = tf.log(scaled_problem_obj + obj_scale_eps)
      return tf.reduce_mean(log_scaled_problem_obj)
    else:
      return total_obj / (initial_obj + obj_scale_eps)


def local_state_variables(init_values, return_init_values):
  """Create local variables initialized from init_values.

  This will create local variables from a list of init_values. Each variable
  will be named based on the value's shape and dtype.

  As a convenience, a boolean tensor allows you to return value from
  the created local variable or from the original init value.

  Args:
    init_values: iterable of tensors
    return_init_values: boolean tensor

  Returns:
    local_vars: list of the created local variables.
    vals: if return_init_values is true, then this returns the values of
      init_values. Otherwise it returns the values of the local_vars.
  """
  if not init_values:
    return [], []

  # This generates a harmless warning when saving the metagraph.
  variable_use_count = tf.get_collection_ref(_LOCAL_STATE_VARIABLE_COLLECTION)
  if not variable_use_count:
    variable_use_count.append(collections.defaultdict(int))
  variable_use_count = variable_use_count[0]

  local_vars = []
  with tf.variable_scope(OPTIMIZER_SCOPE):
    # We can't use the init_value as an initializer as init_value may
    # itself depend on some problem variables. This would produce
    # inter-variable initialization order dependence which TensorFlow
    # sucks at making easy.
    for init_value in init_values:
      name = create_local_state_variable_name(init_value)
      unique_name = name + "_" + str(variable_use_count[name])
      variable_use_count[name] += 1
      # The overarching idea here is to be able to reuse variables between
      # different sessions on the same TensorFlow master without errors. By
      # uniquifying based on the type and name we mirror the checks made inside
      # TensorFlow, while still allowing some memory reuse. Ultimately this is a
      # hack due to the broken Session.reset().
      local_vars.append(
          tf.get_local_variable(
              unique_name,
              initializer=tf.zeros(
                  init_value.get_shape(), dtype=init_value.dtype)))

  # It makes things a lot simpler if we use the init_value the first
  # iteration, instead of the variable itself. It allows us to propagate
  # gradients through it as well as simplifying initialization. The variable
  # ends up assigned to after the first iteration.
  vals = tf.cond(return_init_values, lambda: init_values, lambda: local_vars)
  if len(init_values) == 1:
    # tf.cond extracts elements from singleton lists.
    vals = [vals]
  return local_vars, vals


def create_local_state_variable_name(tensor):
  """Create a name of the variable based on its type and shape."""
  if not tensor.get_shape().is_fully_defined():
    raise ValueError("Need a fully specified shape to create a local variable.")

  return (_LOCAL_VARIABLE_PREFIX + "_".join(
      map(str, tensor.get_shape().as_list())) + "_" + tensor.dtype.name)


def is_local_state_variable(op):
  """Returns if this op is a local state variable created for training."""
  return op.node_def.op in ["Variable", "VariableV2"] and op.name.startswith(
      OPTIMIZER_SCOPE + "/" + _LOCAL_VARIABLE_PREFIX)


def flatten_and_sort(dictionary):
  """Flattens a dictionary into a list of values sorted by the keys."""
  return [dictionary[k] for k in sorted(dictionary.keys())]