File size: 4,190 Bytes
0b8359d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
# Copyright 2018 Google, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================



"""Optimizers for use in unrolled optimization.

These optimizers contain a compute_updates function and its own ability to keep
track of internal state.
These functions can be used with a tf.while_loop to perform multiple training
steps per sess.run.
"""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import abc
import collections
import tensorflow as tf
import sonnet as snt

from learning_unsupervised_learning import utils

from tensorflow.python.framework import ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.training import optimizer
from tensorflow.python.training import training_ops


class UnrollableOptimizer(snt.AbstractModule):
  """Interface for optimizers that can be used in unrolled computation.
  apply_gradients is derrived from compute_update and assign_state.
  """

  def __init__(self, *args, **kwargs):
    super(UnrollableOptimizer, self).__init__(*args, **kwargs)
    self()

  @abc.abstractmethod
  def compute_updates(self, xs, gs, state=None):
    """Compute next step updates for a given variable list and state.

    Args:
      xs: list of tensors
        The "variables" to perform an update on.
        Note these must match the same order for which get_state was originally
        called.
      gs: list of tensors
        Gradients of `xs` with respect to some loss.
      state: Any
        Optimizer specific state to keep track of accumulators such as momentum
        terms
    """
    raise NotImplementedError()

  def _build(self):
    pass

  @abc.abstractmethod
  def get_state(self, var_list):
    """Get the state value associated with a list of tf.Variables.

    This state is commonly going to be a NamedTuple that contains some
    mapping between variables and the state associated with those variables.
    This state could be a moving momentum variable tracked by the optimizer.

    Args:
        var_list: list of tf.Variable
    Returns:
      state: Any
        Optimizer specific state
    """
    raise NotImplementedError()

  def assign_state(self, state):
    """Assigns the state to the optimizers internal variables.

    Args:
      state: Any
    Returns:
      op: tf.Operation
        The operation that performs the assignment.
    """
    raise NotImplementedError()

  def apply_gradients(self, grad_vars):
    gradients, variables = zip(*grad_vars)
    state = self.get_state(variables)
    new_vars, new_state = self.compute_updates(variables, gradients, state)
    assign_op = self.assign_state(new_state)
    op = utils.assign_variables(variables, new_vars)
    return tf.group(assign_op, op, name="apply_gradients")


class UnrollableGradientDescentRollingOptimizer(UnrollableOptimizer):

  def __init__(self,
               learning_rate,
               name="UnrollableGradientDescentRollingOptimizer"):
    self.learning_rate = learning_rate
    super(UnrollableGradientDescentRollingOptimizer, self).__init__(name=name)


  def compute_updates(self, xs, gs, learning_rates, state):
    new_vars = []
    for x, g, lr in utils.eqzip(xs, gs, learning_rates):
      if lr is None:
        lr = self.learning_rate
      if g is not None:
        new_vars.append((x * (1 - lr) - g * lr))
      else:
        new_vars.append(x)
    return new_vars, state

  def get_state(self, var_list):
    return tf.constant(0.0)

  def assign_state(self, state, var_list=None):
    return tf.no_op()