Spaces:
Running
Running
# Copyright 2017 Google, Inc. All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# ============================================================================== | |
"""A base class definition for trainable optimizers.""" | |
from __future__ import absolute_import | |
from __future__ import division | |
from __future__ import print_function | |
import collections | |
import itertools | |
import tensorflow as tf | |
from tensorflow.python.framework import tensor_shape | |
OPTIMIZER_SCOPE = "LOL" | |
_LOCAL_VARIABLE_PREFIX = "local_state_" | |
_LOCAL_STATE_VARIABLE_COLLECTION = "local_state_collection" | |
EPSILON = 1e-6 | |
class TrainableOptimizer(tf.train.Optimizer): | |
"""Base class for trainable optimizers. | |
A trainable optimizer is an optimizer that has parameters that can themselves | |
be learned (meta-optimized). | |
Subclasses must implement: | |
_compute_update(self, param, grad, state) | |
""" | |
def __init__(self, name, state_keys, use_attention=False, | |
use_log_objective=False, obj_train_max_multiplier=-1, | |
use_second_derivatives=True, use_numerator_epsilon=False, | |
**kwargs): | |
"""Initializes the optimizer with the given name and settings. | |
Args: | |
name: The name string for this optimizer. | |
state_keys: The names of any required state variables (list) | |
use_attention: Whether this optimizer uses attention (Default: True) | |
use_log_objective: Whether this optimizer uses the logarithm of the | |
objective when computing the loss (Default: False) | |
obj_train_max_multiplier: The maximum multiplier for the increase in the | |
objective before meta-training is stopped. If <= 0, meta-training is | |
not stopped early. (Default: -1) | |
use_second_derivatives: Whether this optimizer uses second derivatives in | |
meta-training. This should be set to False if some second derivatives | |
in the meta-training problem set are not defined in Tensorflow. | |
(Default: True) | |
use_numerator_epsilon: Whether to use epsilon in the numerator when | |
scaling the problem objective during meta-training. (Default: False) | |
**kwargs: Any additional keyword arguments. | |
""" | |
self.use_second_derivatives = use_second_derivatives | |
self.state_keys = sorted(state_keys) | |
self.use_attention = use_attention | |
self.use_log_objective = use_log_objective | |
self.obj_train_max_multiplier = obj_train_max_multiplier | |
self.use_numerator_epsilon = use_numerator_epsilon | |
use_locking = False | |
super(TrainableOptimizer, self).__init__(use_locking, name) | |
def _create_slots(self, var_list): | |
"""Creates all slots needed by the variables. | |
Args: | |
var_list: A list of `Variable` objects. | |
""" | |
for var in var_list: | |
init_states = self._initialize_state(var) | |
for slot_name in sorted(init_states): | |
slot_var_name = "{}_{}".format(self.get_name(), slot_name) | |
value = init_states[slot_name] | |
self._get_or_make_slot(var, value, slot_name, slot_var_name) | |
def _initialize_state(self, var): | |
"""Initializes any state required for this variable. | |
Args: | |
var: a tensor containing parameters to be optimized | |
Returns: | |
state: a dictionary mapping state keys to initial state values (tensors) | |
""" | |
return {} | |
def _initialize_global_state(self): | |
"""Initializes any global state values.""" | |
return [] | |
def _apply_common(self, grad, var): | |
"""Applies the optimizer updates to the variables. | |
Note: this should only get called via _apply_dense or _apply_sparse when | |
using the optimizer via optimizer.minimize or optimizer.apply_gradients. | |
During meta-training, the optimizer.train function should be used to | |
construct an optimization path that is differentiable. | |
Args: | |
grad: A tensor representing the gradient. | |
var: A tf.Variable with the same shape as grad. | |
Returns: | |
update_op: A tensorflow op that assigns new values to the variable, and | |
also defines dependencies that update the state variables for the | |
optimizer. | |
""" | |
state = {key: self.get_slot(var, key) for key in self.get_slot_names()} | |
new_var, new_state = self._compute_update(var, grad, state) | |
state_assign_ops = [tf.assign(state_var, new_state[key]) | |
for key, state_var in state.items()] | |
with tf.control_dependencies(state_assign_ops): | |
update_op = var.assign(new_var) | |
return update_op | |
def _apply_dense(self, grad, var): | |
"""Adds ops to apply dense gradients to 'var'.""" | |
return self._apply_common(grad, var) | |
def _apply_sparse(self, grad, var): | |
"""Adds ops to apply sparse gradients to 'var'.""" | |
return self._apply_common(grad, var) | |
def _compute_update(self, param, grad, state): | |
"""Computes the update step for optimization. | |
Args: | |
param: A tensor of parameters to optimize. | |
grad: The gradient tensor of the objective with respect to the parameters. | |
(It has the same shape as param.) | |
state: A dictionary containing any extra state required by the optimizer. | |
Returns: | |
updated_params: The updated parameters. | |
updated_state: The dictionary of updated state variable(s). | |
""" | |
raise NotImplementedError | |
def _compute_updates(self, params, grads, states, global_state): | |
"""Maps the compute update functions for each parameter. | |
This function can be overriden by a subclass if the subclass wants to | |
combine information across the different parameters in the list. | |
Args: | |
params: A list of parameter tensors. | |
grads: A list of gradients corresponding to each parameter. | |
states: A list of state variables corresponding to each parameter. | |
global_state: A list of global state variables for the problem. | |
Returns: | |
new_params: The updated parameters. | |
new_states: The updated states. | |
new_global_state: The updated global state. | |
attention_params: A list of attention parameters. This is the same as | |
new_params if the optimizer does not use attention. | |
""" | |
# Zip up the arguments to _compute_update. | |
args = zip(params, grads, states) | |
# Call compute_update on each set of parameter/gradient/state args. | |
new_params, new_states = zip(*list( | |
itertools.starmap(self._compute_update, args))) | |
# Global state is unused in the basic case, just pass it through. | |
return list(new_params), list(new_states), global_state, list(new_params) | |
def train(self, problem, dataset): | |
"""Creates graph operations to train the optimizer. | |
Args: | |
problem: A problem_generator.Problem instance to train on. | |
dataset: A datasets.Dataset tuple to use when training. | |
Returns: | |
meta_objective: A tensorflow operation for computing the meta-objective | |
obj_weights: A tensor placeholder for feeding in the objective weights | |
obj_values: The subproblem objective values during optimization | |
batches: The batch indexes tensor for overriding with feed_dict | |
first_unroll: A placeholder signifying if this is a first unroll | |
(this will propagate the gradients slightly differently). | |
reset_state: A placeholder signifying that the rnn state should be reset. | |
output_state: The final state of the optimizer | |
init_loop_vars_to_override: Local variables that can be assigned to | |
propagate the optimizer and problem state for unrolling | |
final_loop_vals: Final values of the loop variables that can be | |
assigned to init_loop_vars_to_override. | |
""" | |
# Placeholder for the objective weights | |
obj_weights = tf.placeholder(tf.float32) | |
num_iter = tf.shape(obj_weights)[0] | |
# Unpack the dataset and generate the minibatches for training | |
data, labels = dataset | |
# Convert the ndarrays to tensors so we can pass them back in via feed_dict | |
data = tf.constant(data) | |
labels = tf.constant(labels) | |
batches = tf.placeholder(tf.int32) | |
first_unroll = tf.placeholder_with_default(False, []) | |
reset_state = tf.placeholder_with_default(False, []) | |
training_output = collections.namedtuple("TrainingOutput", | |
["metaobj", | |
"obj_weights", | |
"problem_objectives", | |
"initial_obj", | |
"batches", | |
"first_unroll", | |
"reset_state", | |
"output_state", | |
"init_loop_vars", | |
"output_loop_vars"]) | |
def loop_body(itr, obj_accum, params, attend_params, flattened_states, | |
global_state, all_obj, unused_init_obj, data, | |
labels, batches): | |
"""Body of the meta-training while loop for optimizing a sub-problem. | |
Args: | |
itr: The current meta-training iteration. | |
obj_accum: The accumulated objective over all training steps so far. | |
params: The parameters of the sub-problem. | |
attend_params: The parameters of the sub-problems at the attended | |
location. | |
flattened_states: The states of the trainable optimizer, sorted and | |
flattened into a list (since a while loop can't handle nested lists | |
or dictionaries). | |
global_state: The global state of the optimizer. | |
all_obj: The list of all objective values in the training process. | |
unused_init_obj: The initial objective (unused here, but needed in the | |
variable list because it's used in a stopping condition in the | |
loop_cond.) | |
data: The data for this problem. | |
labels: The labels corresponding to the data. | |
batches: The batch indexes needed for shuffled minibatch creation. | |
Returns: | |
itr: The updated meta-training iteration. | |
obj_accum: The updated accumulated objective. | |
params: The new parameters of the sub-problem. | |
attend_params: The new parameters of the sub-problems at the attended | |
location. | |
flattened_states: The new states of the trainable optimizer. | |
global_state: The updated global state. | |
all_obj: The updates list of all objective values. | |
unused_init_obj: The initial objective. | |
data: The data for this problem. | |
labels: The labels corresponding to the data. | |
batches: The batch indexes needed for shuffled minibatch creation. | |
""" | |
batch_indices = tf.gather(batches, itr) | |
batch_data = tf.gather(data, batch_indices) | |
batch_labels = tf.gather(labels, batch_indices) | |
# Compute the objective over the entire dataset (full batch). | |
obj = problem.objective(params, data, labels) | |
# Compute the gradients on just the current batch | |
if self.use_attention: | |
current_obj = problem.objective(attend_params, batch_data, batch_labels) | |
grads = problem.gradients(current_obj, attend_params) | |
else: | |
current_obj = problem.objective(params, batch_data, batch_labels) | |
grads = problem.gradients(current_obj, params) | |
if not self.use_second_derivatives: | |
new_grads = [] | |
for grad in grads: | |
if isinstance(grad, tf.IndexedSlices): | |
new_grads.append( | |
tf.IndexedSlices(tf.stop_gradient(grad.values), grad.indices)) | |
else: | |
new_grads.append(tf.stop_gradient(grad)) | |
grads = new_grads | |
# store the objective value for the entire problem at each iteration | |
all_obj = tf.concat([all_obj, tf.reshape(obj, (1,))], 0) | |
# accumulate the weighted objective for the entire dataset | |
acc = tf.gather(obj_weights, itr) * obj | |
obj_accum = tf.add(obj_accum, acc) | |
# Set the shape to keep the shape invariant for obj_accum. Without this, | |
# the graph builder thinks the tensor shape is unknown on the 2nd iter. | |
obj_accum.set_shape([]) | |
# convert flattened_states to dictionaries | |
dict_states = [dict(zip(self.state_keys, flat_state)) | |
for flat_state in flattened_states] | |
# compute the new parameters and states | |
args = (params, grads, dict_states, global_state) | |
updates = self._compute_updates(*args) | |
new_params, new_states, new_global_state, new_attend_params = updates | |
# flatten the states | |
new_flattened_states = map(flatten_and_sort, new_states) | |
return [itr + 1, obj_accum, new_params, new_attend_params, | |
new_flattened_states, new_global_state, all_obj, unused_init_obj, | |
data, labels, batches] | |
def loop_cond(itr, obj_accum, unused_params, unused_attend_params, | |
unused_flattened_states, unused_global_state, all_obj, | |
init_obj, *args): | |
"""Termination conditions of the sub-problem optimization loop.""" | |
del args # unused | |
cond1 = tf.less(itr, num_iter) # We've run < num_iter times | |
cond2 = tf.is_finite(obj_accum) # The objective is still finite | |
if self.obj_train_max_multiplier > 0: | |
current_obj = tf.gather(all_obj, itr) | |
# Account for negative init_obj too | |
max_diff = (self.obj_train_max_multiplier - 1) * tf.abs(init_obj) | |
max_obj = init_obj + max_diff | |
# The objective is a reasonable multiplier of the original objective | |
cond3 = tf.less(current_obj, max_obj) | |
return tf.logical_and(tf.logical_and(cond1, cond2), cond3, | |
name="training_loop_cond") | |
else: | |
return tf.logical_and(cond1, cond2, name="training_loop_cond") | |
init = self._initialize_training_loop_parameters( | |
problem, data, labels, batches, first_unroll, reset_state) | |
loop_vars, invariants, initial_obj, init_loop_vars_to_override = init | |
loop_output = tf.while_loop(loop_cond, loop_body, loop_vars, | |
swap_memory=True, shape_invariants=invariants) | |
meta_obj, problem_objectives = loop_output[1], loop_output[6] | |
# The meta objective is normalized by the initial objective at the start of | |
# the series of partial unrolls. | |
scaled_meta_objective = self.scale_objective( | |
meta_obj, problem_objectives, initial_obj) | |
final_loop_vals = ( | |
[initial_obj] + loop_output[2] + loop_output[3] + loop_output[5]) | |
final_loop_vals.extend(itertools.chain(*loop_output[4])) | |
return training_output(scaled_meta_objective, | |
obj_weights, | |
problem_objectives, | |
initial_obj, | |
batches, | |
first_unroll, | |
reset_state, | |
loop_output[4], | |
init_loop_vars_to_override, | |
final_loop_vals) | |
def _initialize_training_loop_parameters( | |
self, problem, data, labels, batches, first_unroll, reset_state): | |
"""Initializes the vars and params needed for the training process. | |
Args: | |
problem: The problem being optimized. | |
data: The data for the problem. | |
labels: The corresponding labels for the data. | |
batches: The indexes needed to create shuffled batches of the data. | |
first_unroll: Whether this is the first unroll in a partial unrolling. | |
reset_state: Whether RNN state variables should be reset. | |
Returns: | |
loop_vars: The while loop variables for training. | |
invariants: The corresponding variable shapes (required by while loop). | |
initial_obj: The initial objective (used later for scaling). | |
init_loop_vars_to_override: The loop vars that can be overridden when | |
performing training via partial unrolls. | |
""" | |
# Extract these separately so we don't have to make inter-variable | |
# dependencies. | |
initial_tensors = problem.init_tensors() | |
return_initial_tensor_values = first_unroll | |
initial_params_vars, initial_params = local_state_variables( | |
initial_tensors, return_initial_tensor_values) | |
initial_attend_params_vars, initial_attend_params = local_state_variables( | |
initial_tensors, return_initial_tensor_values) | |
# Recalculate the initial objective for the list on each partial unroll with | |
# the new initial_params. initial_obj holds the value from the very first | |
# unroll. | |
initial_obj_init = problem.objective(initial_params, data, labels) | |
return_initial_obj_init = first_unroll | |
[initial_obj_var], [initial_obj] = local_state_variables( | |
[initial_obj_init], return_initial_obj_init) | |
# Initialize the loop variables. | |
initial_itr = tf.constant(0, dtype=tf.int32) | |
initial_meta_obj = tf.constant(0, dtype=tf.float32) | |
# N.B. the use of initial_obj_init here rather than initial_obj | |
initial_problem_objectives = tf.reshape(initial_obj_init, (1,)) | |
# Initialize the extra state. | |
initial_state_vars = [] | |
initial_state = [] | |
state_shapes = [] | |
return_initial_state_values = reset_state | |
for param in initial_tensors: | |
param_state_vars, param_state = local_state_variables( | |
flatten_and_sort(self._initialize_state(param)), | |
return_initial_state_values) | |
initial_state_vars.append(param_state_vars) | |
initial_state.append(param_state) | |
state_shapes.append([f.get_shape() for f in param_state]) | |
# Initialize any global (problem-level) state. | |
initial_global_state_vars, initial_global_state = local_state_variables( | |
self._initialize_global_state(), return_initial_state_values) | |
global_shapes = [] | |
for item in initial_global_state: | |
global_shapes.append(item.get_shape()) | |
# build the list of loop variables: | |
loop_vars = [ | |
initial_itr, | |
initial_meta_obj, | |
initial_params, # Local variables. | |
initial_attend_params, # Local variables. | |
initial_state, # Local variables. | |
initial_global_state, # Local variables. | |
initial_problem_objectives, | |
initial_obj, # Local variable. | |
data, | |
labels, | |
batches, | |
] | |
invariants = [ | |
initial_itr.get_shape(), | |
initial_meta_obj.get_shape(), | |
[t.get_shape() for t in initial_params], | |
[t.get_shape() for t in initial_attend_params], | |
state_shapes, | |
global_shapes, | |
tensor_shape.TensorShape([None]), # The problem objectives list grows | |
initial_obj.get_shape(), | |
tensor_shape.unknown_shape(), # Placeholder shapes are unknown | |
tensor_shape.unknown_shape(), | |
tensor_shape.unknown_shape(), | |
] | |
# Initialize local variables that we will override with final tensors at the | |
# next iter. | |
init_loop_vars_to_override = ( | |
[initial_obj_var] + initial_params_vars + initial_attend_params_vars + | |
initial_global_state_vars) | |
init_loop_vars_to_override.extend(itertools.chain(*initial_state_vars)) | |
return loop_vars, invariants, initial_obj, init_loop_vars_to_override | |
def scale_objective(self, total_obj, all_objs, initial_obj, | |
obj_scale_eps=1e-6): | |
"""Normalizes the objective based on the initial objective value. | |
Args: | |
total_obj: The total accumulated objective over the training run. | |
all_objs: A list of all the individual objectives over the training run. | |
initial_obj: The initial objective value. | |
obj_scale_eps: The epsilon value to use in computations for stability. | |
Returns: | |
The scaled objective as a single value. | |
""" | |
if self.use_log_objective: | |
if self.use_numerator_epsilon: | |
scaled_problem_obj = ((all_objs + obj_scale_eps) / | |
(initial_obj + obj_scale_eps)) | |
log_scaled_problem_obj = tf.log(scaled_problem_obj) | |
else: | |
scaled_problem_obj = all_objs / (initial_obj + obj_scale_eps) | |
log_scaled_problem_obj = tf.log(scaled_problem_obj + obj_scale_eps) | |
return tf.reduce_mean(log_scaled_problem_obj) | |
else: | |
return total_obj / (initial_obj + obj_scale_eps) | |
def local_state_variables(init_values, return_init_values): | |
"""Create local variables initialized from init_values. | |
This will create local variables from a list of init_values. Each variable | |
will be named based on the value's shape and dtype. | |
As a convenience, a boolean tensor allows you to return value from | |
the created local variable or from the original init value. | |
Args: | |
init_values: iterable of tensors | |
return_init_values: boolean tensor | |
Returns: | |
local_vars: list of the created local variables. | |
vals: if return_init_values is true, then this returns the values of | |
init_values. Otherwise it returns the values of the local_vars. | |
""" | |
if not init_values: | |
return [], [] | |
# This generates a harmless warning when saving the metagraph. | |
variable_use_count = tf.get_collection_ref(_LOCAL_STATE_VARIABLE_COLLECTION) | |
if not variable_use_count: | |
variable_use_count.append(collections.defaultdict(int)) | |
variable_use_count = variable_use_count[0] | |
local_vars = [] | |
with tf.variable_scope(OPTIMIZER_SCOPE): | |
# We can't use the init_value as an initializer as init_value may | |
# itself depend on some problem variables. This would produce | |
# inter-variable initialization order dependence which TensorFlow | |
# sucks at making easy. | |
for init_value in init_values: | |
name = create_local_state_variable_name(init_value) | |
unique_name = name + "_" + str(variable_use_count[name]) | |
variable_use_count[name] += 1 | |
# The overarching idea here is to be able to reuse variables between | |
# different sessions on the same TensorFlow master without errors. By | |
# uniquifying based on the type and name we mirror the checks made inside | |
# TensorFlow, while still allowing some memory reuse. Ultimately this is a | |
# hack due to the broken Session.reset(). | |
local_vars.append( | |
tf.get_local_variable( | |
unique_name, | |
initializer=tf.zeros( | |
init_value.get_shape(), dtype=init_value.dtype))) | |
# It makes things a lot simpler if we use the init_value the first | |
# iteration, instead of the variable itself. It allows us to propagate | |
# gradients through it as well as simplifying initialization. The variable | |
# ends up assigned to after the first iteration. | |
vals = tf.cond(return_init_values, lambda: init_values, lambda: local_vars) | |
if len(init_values) == 1: | |
# tf.cond extracts elements from singleton lists. | |
vals = [vals] | |
return local_vars, vals | |
def create_local_state_variable_name(tensor): | |
"""Create a name of the variable based on its type and shape.""" | |
if not tensor.get_shape().is_fully_defined(): | |
raise ValueError("Need a fully specified shape to create a local variable.") | |
return (_LOCAL_VARIABLE_PREFIX + "_".join( | |
map(str, tensor.get_shape().as_list())) + "_" + tensor.dtype.name) | |
def is_local_state_variable(op): | |
"""Returns if this op is a local state variable created for training.""" | |
return op.node_def.op in ["Variable", "VariableV2"] and op.name.startswith( | |
OPTIMIZER_SCOPE + "/" + _LOCAL_VARIABLE_PREFIX) | |
def flatten_and_sort(dictionary): | |
"""Flattens a dictionary into a list of values sorted by the keys.""" | |
return [dictionary[k] for k in sorted(dictionary.keys())] | |