Spaces:
Running
Running
File size: 4,190 Bytes
0b8359d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 |
# Copyright 2018 Google, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Optimizers for use in unrolled optimization.
These optimizers contain a compute_updates function and its own ability to keep
track of internal state.
These functions can be used with a tf.while_loop to perform multiple training
steps per sess.run.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import abc
import collections
import tensorflow as tf
import sonnet as snt
from learning_unsupervised_learning import utils
from tensorflow.python.framework import ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.training import optimizer
from tensorflow.python.training import training_ops
class UnrollableOptimizer(snt.AbstractModule):
"""Interface for optimizers that can be used in unrolled computation.
apply_gradients is derrived from compute_update and assign_state.
"""
def __init__(self, *args, **kwargs):
super(UnrollableOptimizer, self).__init__(*args, **kwargs)
self()
@abc.abstractmethod
def compute_updates(self, xs, gs, state=None):
"""Compute next step updates for a given variable list and state.
Args:
xs: list of tensors
The "variables" to perform an update on.
Note these must match the same order for which get_state was originally
called.
gs: list of tensors
Gradients of `xs` with respect to some loss.
state: Any
Optimizer specific state to keep track of accumulators such as momentum
terms
"""
raise NotImplementedError()
def _build(self):
pass
@abc.abstractmethod
def get_state(self, var_list):
"""Get the state value associated with a list of tf.Variables.
This state is commonly going to be a NamedTuple that contains some
mapping between variables and the state associated with those variables.
This state could be a moving momentum variable tracked by the optimizer.
Args:
var_list: list of tf.Variable
Returns:
state: Any
Optimizer specific state
"""
raise NotImplementedError()
def assign_state(self, state):
"""Assigns the state to the optimizers internal variables.
Args:
state: Any
Returns:
op: tf.Operation
The operation that performs the assignment.
"""
raise NotImplementedError()
def apply_gradients(self, grad_vars):
gradients, variables = zip(*grad_vars)
state = self.get_state(variables)
new_vars, new_state = self.compute_updates(variables, gradients, state)
assign_op = self.assign_state(new_state)
op = utils.assign_variables(variables, new_vars)
return tf.group(assign_op, op, name="apply_gradients")
class UnrollableGradientDescentRollingOptimizer(UnrollableOptimizer):
def __init__(self,
learning_rate,
name="UnrollableGradientDescentRollingOptimizer"):
self.learning_rate = learning_rate
super(UnrollableGradientDescentRollingOptimizer, self).__init__(name=name)
def compute_updates(self, xs, gs, learning_rates, state):
new_vars = []
for x, g, lr in utils.eqzip(xs, gs, learning_rates):
if lr is None:
lr = self.learning_rate
if g is not None:
new_vars.append((x * (1 - lr) - g * lr))
else:
new_vars.append(x)
return new_vars, state
def get_state(self, var_list):
return tf.constant(0.0)
def assign_state(self, state, var_list=None):
return tf.no_op()
|