NCTCMumbai's picture
Upload 2571 files
0b8359d
raw
history blame
24 kB
# Copyright 2017 Google, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""A base class definition for trainable optimizers."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import itertools
import tensorflow as tf
from tensorflow.python.framework import tensor_shape
OPTIMIZER_SCOPE = "LOL"
_LOCAL_VARIABLE_PREFIX = "local_state_"
_LOCAL_STATE_VARIABLE_COLLECTION = "local_state_collection"
EPSILON = 1e-6
class TrainableOptimizer(tf.train.Optimizer):
"""Base class for trainable optimizers.
A trainable optimizer is an optimizer that has parameters that can themselves
be learned (meta-optimized).
Subclasses must implement:
_compute_update(self, param, grad, state)
"""
def __init__(self, name, state_keys, use_attention=False,
use_log_objective=False, obj_train_max_multiplier=-1,
use_second_derivatives=True, use_numerator_epsilon=False,
**kwargs):
"""Initializes the optimizer with the given name and settings.
Args:
name: The name string for this optimizer.
state_keys: The names of any required state variables (list)
use_attention: Whether this optimizer uses attention (Default: True)
use_log_objective: Whether this optimizer uses the logarithm of the
objective when computing the loss (Default: False)
obj_train_max_multiplier: The maximum multiplier for the increase in the
objective before meta-training is stopped. If <= 0, meta-training is
not stopped early. (Default: -1)
use_second_derivatives: Whether this optimizer uses second derivatives in
meta-training. This should be set to False if some second derivatives
in the meta-training problem set are not defined in Tensorflow.
(Default: True)
use_numerator_epsilon: Whether to use epsilon in the numerator when
scaling the problem objective during meta-training. (Default: False)
**kwargs: Any additional keyword arguments.
"""
self.use_second_derivatives = use_second_derivatives
self.state_keys = sorted(state_keys)
self.use_attention = use_attention
self.use_log_objective = use_log_objective
self.obj_train_max_multiplier = obj_train_max_multiplier
self.use_numerator_epsilon = use_numerator_epsilon
use_locking = False
super(TrainableOptimizer, self).__init__(use_locking, name)
def _create_slots(self, var_list):
"""Creates all slots needed by the variables.
Args:
var_list: A list of `Variable` objects.
"""
for var in var_list:
init_states = self._initialize_state(var)
for slot_name in sorted(init_states):
slot_var_name = "{}_{}".format(self.get_name(), slot_name)
value = init_states[slot_name]
self._get_or_make_slot(var, value, slot_name, slot_var_name)
def _initialize_state(self, var):
"""Initializes any state required for this variable.
Args:
var: a tensor containing parameters to be optimized
Returns:
state: a dictionary mapping state keys to initial state values (tensors)
"""
return {}
def _initialize_global_state(self):
"""Initializes any global state values."""
return []
def _apply_common(self, grad, var):
"""Applies the optimizer updates to the variables.
Note: this should only get called via _apply_dense or _apply_sparse when
using the optimizer via optimizer.minimize or optimizer.apply_gradients.
During meta-training, the optimizer.train function should be used to
construct an optimization path that is differentiable.
Args:
grad: A tensor representing the gradient.
var: A tf.Variable with the same shape as grad.
Returns:
update_op: A tensorflow op that assigns new values to the variable, and
also defines dependencies that update the state variables for the
optimizer.
"""
state = {key: self.get_slot(var, key) for key in self.get_slot_names()}
new_var, new_state = self._compute_update(var, grad, state)
state_assign_ops = [tf.assign(state_var, new_state[key])
for key, state_var in state.items()]
with tf.control_dependencies(state_assign_ops):
update_op = var.assign(new_var)
return update_op
def _apply_dense(self, grad, var):
"""Adds ops to apply dense gradients to 'var'."""
return self._apply_common(grad, var)
def _apply_sparse(self, grad, var):
"""Adds ops to apply sparse gradients to 'var'."""
return self._apply_common(grad, var)
def _compute_update(self, param, grad, state):
"""Computes the update step for optimization.
Args:
param: A tensor of parameters to optimize.
grad: The gradient tensor of the objective with respect to the parameters.
(It has the same shape as param.)
state: A dictionary containing any extra state required by the optimizer.
Returns:
updated_params: The updated parameters.
updated_state: The dictionary of updated state variable(s).
"""
raise NotImplementedError
def _compute_updates(self, params, grads, states, global_state):
"""Maps the compute update functions for each parameter.
This function can be overriden by a subclass if the subclass wants to
combine information across the different parameters in the list.
Args:
params: A list of parameter tensors.
grads: A list of gradients corresponding to each parameter.
states: A list of state variables corresponding to each parameter.
global_state: A list of global state variables for the problem.
Returns:
new_params: The updated parameters.
new_states: The updated states.
new_global_state: The updated global state.
attention_params: A list of attention parameters. This is the same as
new_params if the optimizer does not use attention.
"""
# Zip up the arguments to _compute_update.
args = zip(params, grads, states)
# Call compute_update on each set of parameter/gradient/state args.
new_params, new_states = zip(*list(
itertools.starmap(self._compute_update, args)))
# Global state is unused in the basic case, just pass it through.
return list(new_params), list(new_states), global_state, list(new_params)
def train(self, problem, dataset):
"""Creates graph operations to train the optimizer.
Args:
problem: A problem_generator.Problem instance to train on.
dataset: A datasets.Dataset tuple to use when training.
Returns:
meta_objective: A tensorflow operation for computing the meta-objective
obj_weights: A tensor placeholder for feeding in the objective weights
obj_values: The subproblem objective values during optimization
batches: The batch indexes tensor for overriding with feed_dict
first_unroll: A placeholder signifying if this is a first unroll
(this will propagate the gradients slightly differently).
reset_state: A placeholder signifying that the rnn state should be reset.
output_state: The final state of the optimizer
init_loop_vars_to_override: Local variables that can be assigned to
propagate the optimizer and problem state for unrolling
final_loop_vals: Final values of the loop variables that can be
assigned to init_loop_vars_to_override.
"""
# Placeholder for the objective weights
obj_weights = tf.placeholder(tf.float32)
num_iter = tf.shape(obj_weights)[0]
# Unpack the dataset and generate the minibatches for training
data, labels = dataset
# Convert the ndarrays to tensors so we can pass them back in via feed_dict
data = tf.constant(data)
labels = tf.constant(labels)
batches = tf.placeholder(tf.int32)
first_unroll = tf.placeholder_with_default(False, [])
reset_state = tf.placeholder_with_default(False, [])
training_output = collections.namedtuple("TrainingOutput",
["metaobj",
"obj_weights",
"problem_objectives",
"initial_obj",
"batches",
"first_unroll",
"reset_state",
"output_state",
"init_loop_vars",
"output_loop_vars"])
def loop_body(itr, obj_accum, params, attend_params, flattened_states,
global_state, all_obj, unused_init_obj, data,
labels, batches):
"""Body of the meta-training while loop for optimizing a sub-problem.
Args:
itr: The current meta-training iteration.
obj_accum: The accumulated objective over all training steps so far.
params: The parameters of the sub-problem.
attend_params: The parameters of the sub-problems at the attended
location.
flattened_states: The states of the trainable optimizer, sorted and
flattened into a list (since a while loop can't handle nested lists
or dictionaries).
global_state: The global state of the optimizer.
all_obj: The list of all objective values in the training process.
unused_init_obj: The initial objective (unused here, but needed in the
variable list because it's used in a stopping condition in the
loop_cond.)
data: The data for this problem.
labels: The labels corresponding to the data.
batches: The batch indexes needed for shuffled minibatch creation.
Returns:
itr: The updated meta-training iteration.
obj_accum: The updated accumulated objective.
params: The new parameters of the sub-problem.
attend_params: The new parameters of the sub-problems at the attended
location.
flattened_states: The new states of the trainable optimizer.
global_state: The updated global state.
all_obj: The updates list of all objective values.
unused_init_obj: The initial objective.
data: The data for this problem.
labels: The labels corresponding to the data.
batches: The batch indexes needed for shuffled minibatch creation.
"""
batch_indices = tf.gather(batches, itr)
batch_data = tf.gather(data, batch_indices)
batch_labels = tf.gather(labels, batch_indices)
# Compute the objective over the entire dataset (full batch).
obj = problem.objective(params, data, labels)
# Compute the gradients on just the current batch
if self.use_attention:
current_obj = problem.objective(attend_params, batch_data, batch_labels)
grads = problem.gradients(current_obj, attend_params)
else:
current_obj = problem.objective(params, batch_data, batch_labels)
grads = problem.gradients(current_obj, params)
if not self.use_second_derivatives:
new_grads = []
for grad in grads:
if isinstance(grad, tf.IndexedSlices):
new_grads.append(
tf.IndexedSlices(tf.stop_gradient(grad.values), grad.indices))
else:
new_grads.append(tf.stop_gradient(grad))
grads = new_grads
# store the objective value for the entire problem at each iteration
all_obj = tf.concat([all_obj, tf.reshape(obj, (1,))], 0)
# accumulate the weighted objective for the entire dataset
acc = tf.gather(obj_weights, itr) * obj
obj_accum = tf.add(obj_accum, acc)
# Set the shape to keep the shape invariant for obj_accum. Without this,
# the graph builder thinks the tensor shape is unknown on the 2nd iter.
obj_accum.set_shape([])
# convert flattened_states to dictionaries
dict_states = [dict(zip(self.state_keys, flat_state))
for flat_state in flattened_states]
# compute the new parameters and states
args = (params, grads, dict_states, global_state)
updates = self._compute_updates(*args)
new_params, new_states, new_global_state, new_attend_params = updates
# flatten the states
new_flattened_states = map(flatten_and_sort, new_states)
return [itr + 1, obj_accum, new_params, new_attend_params,
new_flattened_states, new_global_state, all_obj, unused_init_obj,
data, labels, batches]
def loop_cond(itr, obj_accum, unused_params, unused_attend_params,
unused_flattened_states, unused_global_state, all_obj,
init_obj, *args):
"""Termination conditions of the sub-problem optimization loop."""
del args # unused
cond1 = tf.less(itr, num_iter) # We've run < num_iter times
cond2 = tf.is_finite(obj_accum) # The objective is still finite
if self.obj_train_max_multiplier > 0:
current_obj = tf.gather(all_obj, itr)
# Account for negative init_obj too
max_diff = (self.obj_train_max_multiplier - 1) * tf.abs(init_obj)
max_obj = init_obj + max_diff
# The objective is a reasonable multiplier of the original objective
cond3 = tf.less(current_obj, max_obj)
return tf.logical_and(tf.logical_and(cond1, cond2), cond3,
name="training_loop_cond")
else:
return tf.logical_and(cond1, cond2, name="training_loop_cond")
init = self._initialize_training_loop_parameters(
problem, data, labels, batches, first_unroll, reset_state)
loop_vars, invariants, initial_obj, init_loop_vars_to_override = init
loop_output = tf.while_loop(loop_cond, loop_body, loop_vars,
swap_memory=True, shape_invariants=invariants)
meta_obj, problem_objectives = loop_output[1], loop_output[6]
# The meta objective is normalized by the initial objective at the start of
# the series of partial unrolls.
scaled_meta_objective = self.scale_objective(
meta_obj, problem_objectives, initial_obj)
final_loop_vals = (
[initial_obj] + loop_output[2] + loop_output[3] + loop_output[5])
final_loop_vals.extend(itertools.chain(*loop_output[4]))
return training_output(scaled_meta_objective,
obj_weights,
problem_objectives,
initial_obj,
batches,
first_unroll,
reset_state,
loop_output[4],
init_loop_vars_to_override,
final_loop_vals)
def _initialize_training_loop_parameters(
self, problem, data, labels, batches, first_unroll, reset_state):
"""Initializes the vars and params needed for the training process.
Args:
problem: The problem being optimized.
data: The data for the problem.
labels: The corresponding labels for the data.
batches: The indexes needed to create shuffled batches of the data.
first_unroll: Whether this is the first unroll in a partial unrolling.
reset_state: Whether RNN state variables should be reset.
Returns:
loop_vars: The while loop variables for training.
invariants: The corresponding variable shapes (required by while loop).
initial_obj: The initial objective (used later for scaling).
init_loop_vars_to_override: The loop vars that can be overridden when
performing training via partial unrolls.
"""
# Extract these separately so we don't have to make inter-variable
# dependencies.
initial_tensors = problem.init_tensors()
return_initial_tensor_values = first_unroll
initial_params_vars, initial_params = local_state_variables(
initial_tensors, return_initial_tensor_values)
initial_attend_params_vars, initial_attend_params = local_state_variables(
initial_tensors, return_initial_tensor_values)
# Recalculate the initial objective for the list on each partial unroll with
# the new initial_params. initial_obj holds the value from the very first
# unroll.
initial_obj_init = problem.objective(initial_params, data, labels)
return_initial_obj_init = first_unroll
[initial_obj_var], [initial_obj] = local_state_variables(
[initial_obj_init], return_initial_obj_init)
# Initialize the loop variables.
initial_itr = tf.constant(0, dtype=tf.int32)
initial_meta_obj = tf.constant(0, dtype=tf.float32)
# N.B. the use of initial_obj_init here rather than initial_obj
initial_problem_objectives = tf.reshape(initial_obj_init, (1,))
# Initialize the extra state.
initial_state_vars = []
initial_state = []
state_shapes = []
return_initial_state_values = reset_state
for param in initial_tensors:
param_state_vars, param_state = local_state_variables(
flatten_and_sort(self._initialize_state(param)),
return_initial_state_values)
initial_state_vars.append(param_state_vars)
initial_state.append(param_state)
state_shapes.append([f.get_shape() for f in param_state])
# Initialize any global (problem-level) state.
initial_global_state_vars, initial_global_state = local_state_variables(
self._initialize_global_state(), return_initial_state_values)
global_shapes = []
for item in initial_global_state:
global_shapes.append(item.get_shape())
# build the list of loop variables:
loop_vars = [
initial_itr,
initial_meta_obj,
initial_params, # Local variables.
initial_attend_params, # Local variables.
initial_state, # Local variables.
initial_global_state, # Local variables.
initial_problem_objectives,
initial_obj, # Local variable.
data,
labels,
batches,
]
invariants = [
initial_itr.get_shape(),
initial_meta_obj.get_shape(),
[t.get_shape() for t in initial_params],
[t.get_shape() for t in initial_attend_params],
state_shapes,
global_shapes,
tensor_shape.TensorShape([None]), # The problem objectives list grows
initial_obj.get_shape(),
tensor_shape.unknown_shape(), # Placeholder shapes are unknown
tensor_shape.unknown_shape(),
tensor_shape.unknown_shape(),
]
# Initialize local variables that we will override with final tensors at the
# next iter.
init_loop_vars_to_override = (
[initial_obj_var] + initial_params_vars + initial_attend_params_vars +
initial_global_state_vars)
init_loop_vars_to_override.extend(itertools.chain(*initial_state_vars))
return loop_vars, invariants, initial_obj, init_loop_vars_to_override
def scale_objective(self, total_obj, all_objs, initial_obj,
obj_scale_eps=1e-6):
"""Normalizes the objective based on the initial objective value.
Args:
total_obj: The total accumulated objective over the training run.
all_objs: A list of all the individual objectives over the training run.
initial_obj: The initial objective value.
obj_scale_eps: The epsilon value to use in computations for stability.
Returns:
The scaled objective as a single value.
"""
if self.use_log_objective:
if self.use_numerator_epsilon:
scaled_problem_obj = ((all_objs + obj_scale_eps) /
(initial_obj + obj_scale_eps))
log_scaled_problem_obj = tf.log(scaled_problem_obj)
else:
scaled_problem_obj = all_objs / (initial_obj + obj_scale_eps)
log_scaled_problem_obj = tf.log(scaled_problem_obj + obj_scale_eps)
return tf.reduce_mean(log_scaled_problem_obj)
else:
return total_obj / (initial_obj + obj_scale_eps)
def local_state_variables(init_values, return_init_values):
"""Create local variables initialized from init_values.
This will create local variables from a list of init_values. Each variable
will be named based on the value's shape and dtype.
As a convenience, a boolean tensor allows you to return value from
the created local variable or from the original init value.
Args:
init_values: iterable of tensors
return_init_values: boolean tensor
Returns:
local_vars: list of the created local variables.
vals: if return_init_values is true, then this returns the values of
init_values. Otherwise it returns the values of the local_vars.
"""
if not init_values:
return [], []
# This generates a harmless warning when saving the metagraph.
variable_use_count = tf.get_collection_ref(_LOCAL_STATE_VARIABLE_COLLECTION)
if not variable_use_count:
variable_use_count.append(collections.defaultdict(int))
variable_use_count = variable_use_count[0]
local_vars = []
with tf.variable_scope(OPTIMIZER_SCOPE):
# We can't use the init_value as an initializer as init_value may
# itself depend on some problem variables. This would produce
# inter-variable initialization order dependence which TensorFlow
# sucks at making easy.
for init_value in init_values:
name = create_local_state_variable_name(init_value)
unique_name = name + "_" + str(variable_use_count[name])
variable_use_count[name] += 1
# The overarching idea here is to be able to reuse variables between
# different sessions on the same TensorFlow master without errors. By
# uniquifying based on the type and name we mirror the checks made inside
# TensorFlow, while still allowing some memory reuse. Ultimately this is a
# hack due to the broken Session.reset().
local_vars.append(
tf.get_local_variable(
unique_name,
initializer=tf.zeros(
init_value.get_shape(), dtype=init_value.dtype)))
# It makes things a lot simpler if we use the init_value the first
# iteration, instead of the variable itself. It allows us to propagate
# gradients through it as well as simplifying initialization. The variable
# ends up assigned to after the first iteration.
vals = tf.cond(return_init_values, lambda: init_values, lambda: local_vars)
if len(init_values) == 1:
# tf.cond extracts elements from singleton lists.
vals = [vals]
return local_vars, vals
def create_local_state_variable_name(tensor):
"""Create a name of the variable based on its type and shape."""
if not tensor.get_shape().is_fully_defined():
raise ValueError("Need a fully specified shape to create a local variable.")
return (_LOCAL_VARIABLE_PREFIX + "_".join(
map(str, tensor.get_shape().as_list())) + "_" + tensor.dtype.name)
def is_local_state_variable(op):
"""Returns if this op is a local state variable created for training."""
return op.node_def.op in ["Variable", "VariableV2"] and op.name.startswith(
OPTIMIZER_SCOPE + "/" + _LOCAL_VARIABLE_PREFIX)
def flatten_and_sort(dictionary):
"""Flattens a dictionary into a list of values sorted by the keys."""
return [dictionary[k] for k in sorted(dictionary.keys())]