Spaces:
Running
Running
# Copyright 2018 Google, Inc. All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# ============================================================================== | |
"""Optimizers for use in unrolled optimization. | |
These optimizers contain a compute_updates function and its own ability to keep | |
track of internal state. | |
These functions can be used with a tf.while_loop to perform multiple training | |
steps per sess.run. | |
""" | |
from __future__ import absolute_import | |
from __future__ import division | |
from __future__ import print_function | |
import abc | |
import collections | |
import tensorflow as tf | |
import sonnet as snt | |
from learning_unsupervised_learning import utils | |
from tensorflow.python.framework import ops | |
from tensorflow.python.ops import math_ops | |
from tensorflow.python.ops import resource_variable_ops | |
from tensorflow.python.training import optimizer | |
from tensorflow.python.training import training_ops | |
class UnrollableOptimizer(snt.AbstractModule): | |
"""Interface for optimizers that can be used in unrolled computation. | |
apply_gradients is derrived from compute_update and assign_state. | |
""" | |
def __init__(self, *args, **kwargs): | |
super(UnrollableOptimizer, self).__init__(*args, **kwargs) | |
self() | |
def compute_updates(self, xs, gs, state=None): | |
"""Compute next step updates for a given variable list and state. | |
Args: | |
xs: list of tensors | |
The "variables" to perform an update on. | |
Note these must match the same order for which get_state was originally | |
called. | |
gs: list of tensors | |
Gradients of `xs` with respect to some loss. | |
state: Any | |
Optimizer specific state to keep track of accumulators such as momentum | |
terms | |
""" | |
raise NotImplementedError() | |
def _build(self): | |
pass | |
def get_state(self, var_list): | |
"""Get the state value associated with a list of tf.Variables. | |
This state is commonly going to be a NamedTuple that contains some | |
mapping between variables and the state associated with those variables. | |
This state could be a moving momentum variable tracked by the optimizer. | |
Args: | |
var_list: list of tf.Variable | |
Returns: | |
state: Any | |
Optimizer specific state | |
""" | |
raise NotImplementedError() | |
def assign_state(self, state): | |
"""Assigns the state to the optimizers internal variables. | |
Args: | |
state: Any | |
Returns: | |
op: tf.Operation | |
The operation that performs the assignment. | |
""" | |
raise NotImplementedError() | |
def apply_gradients(self, grad_vars): | |
gradients, variables = zip(*grad_vars) | |
state = self.get_state(variables) | |
new_vars, new_state = self.compute_updates(variables, gradients, state) | |
assign_op = self.assign_state(new_state) | |
op = utils.assign_variables(variables, new_vars) | |
return tf.group(assign_op, op, name="apply_gradients") | |
class UnrollableGradientDescentRollingOptimizer(UnrollableOptimizer): | |
def __init__(self, | |
learning_rate, | |
name="UnrollableGradientDescentRollingOptimizer"): | |
self.learning_rate = learning_rate | |
super(UnrollableGradientDescentRollingOptimizer, self).__init__(name=name) | |
def compute_updates(self, xs, gs, learning_rates, state): | |
new_vars = [] | |
for x, g, lr in utils.eqzip(xs, gs, learning_rates): | |
if lr is None: | |
lr = self.learning_rate | |
if g is not None: | |
new_vars.append((x * (1 - lr) - g * lr)) | |
else: | |
new_vars.append(x) | |
return new_vars, state | |
def get_state(self, var_list): | |
return tf.constant(0.0) | |
def assign_state(self, state, var_list=None): | |
return tf.no_op() | |