Spaces:
Running
Running
File size: 24,034 Bytes
0b8359d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 |
# Copyright 2017 Google, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""A base class definition for trainable optimizers."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import itertools
import tensorflow as tf
from tensorflow.python.framework import tensor_shape
OPTIMIZER_SCOPE = "LOL"
_LOCAL_VARIABLE_PREFIX = "local_state_"
_LOCAL_STATE_VARIABLE_COLLECTION = "local_state_collection"
EPSILON = 1e-6
class TrainableOptimizer(tf.train.Optimizer):
"""Base class for trainable optimizers.
A trainable optimizer is an optimizer that has parameters that can themselves
be learned (meta-optimized).
Subclasses must implement:
_compute_update(self, param, grad, state)
"""
def __init__(self, name, state_keys, use_attention=False,
use_log_objective=False, obj_train_max_multiplier=-1,
use_second_derivatives=True, use_numerator_epsilon=False,
**kwargs):
"""Initializes the optimizer with the given name and settings.
Args:
name: The name string for this optimizer.
state_keys: The names of any required state variables (list)
use_attention: Whether this optimizer uses attention (Default: True)
use_log_objective: Whether this optimizer uses the logarithm of the
objective when computing the loss (Default: False)
obj_train_max_multiplier: The maximum multiplier for the increase in the
objective before meta-training is stopped. If <= 0, meta-training is
not stopped early. (Default: -1)
use_second_derivatives: Whether this optimizer uses second derivatives in
meta-training. This should be set to False if some second derivatives
in the meta-training problem set are not defined in Tensorflow.
(Default: True)
use_numerator_epsilon: Whether to use epsilon in the numerator when
scaling the problem objective during meta-training. (Default: False)
**kwargs: Any additional keyword arguments.
"""
self.use_second_derivatives = use_second_derivatives
self.state_keys = sorted(state_keys)
self.use_attention = use_attention
self.use_log_objective = use_log_objective
self.obj_train_max_multiplier = obj_train_max_multiplier
self.use_numerator_epsilon = use_numerator_epsilon
use_locking = False
super(TrainableOptimizer, self).__init__(use_locking, name)
def _create_slots(self, var_list):
"""Creates all slots needed by the variables.
Args:
var_list: A list of `Variable` objects.
"""
for var in var_list:
init_states = self._initialize_state(var)
for slot_name in sorted(init_states):
slot_var_name = "{}_{}".format(self.get_name(), slot_name)
value = init_states[slot_name]
self._get_or_make_slot(var, value, slot_name, slot_var_name)
def _initialize_state(self, var):
"""Initializes any state required for this variable.
Args:
var: a tensor containing parameters to be optimized
Returns:
state: a dictionary mapping state keys to initial state values (tensors)
"""
return {}
def _initialize_global_state(self):
"""Initializes any global state values."""
return []
def _apply_common(self, grad, var):
"""Applies the optimizer updates to the variables.
Note: this should only get called via _apply_dense or _apply_sparse when
using the optimizer via optimizer.minimize or optimizer.apply_gradients.
During meta-training, the optimizer.train function should be used to
construct an optimization path that is differentiable.
Args:
grad: A tensor representing the gradient.
var: A tf.Variable with the same shape as grad.
Returns:
update_op: A tensorflow op that assigns new values to the variable, and
also defines dependencies that update the state variables for the
optimizer.
"""
state = {key: self.get_slot(var, key) for key in self.get_slot_names()}
new_var, new_state = self._compute_update(var, grad, state)
state_assign_ops = [tf.assign(state_var, new_state[key])
for key, state_var in state.items()]
with tf.control_dependencies(state_assign_ops):
update_op = var.assign(new_var)
return update_op
def _apply_dense(self, grad, var):
"""Adds ops to apply dense gradients to 'var'."""
return self._apply_common(grad, var)
def _apply_sparse(self, grad, var):
"""Adds ops to apply sparse gradients to 'var'."""
return self._apply_common(grad, var)
def _compute_update(self, param, grad, state):
"""Computes the update step for optimization.
Args:
param: A tensor of parameters to optimize.
grad: The gradient tensor of the objective with respect to the parameters.
(It has the same shape as param.)
state: A dictionary containing any extra state required by the optimizer.
Returns:
updated_params: The updated parameters.
updated_state: The dictionary of updated state variable(s).
"""
raise NotImplementedError
def _compute_updates(self, params, grads, states, global_state):
"""Maps the compute update functions for each parameter.
This function can be overriden by a subclass if the subclass wants to
combine information across the different parameters in the list.
Args:
params: A list of parameter tensors.
grads: A list of gradients corresponding to each parameter.
states: A list of state variables corresponding to each parameter.
global_state: A list of global state variables for the problem.
Returns:
new_params: The updated parameters.
new_states: The updated states.
new_global_state: The updated global state.
attention_params: A list of attention parameters. This is the same as
new_params if the optimizer does not use attention.
"""
# Zip up the arguments to _compute_update.
args = zip(params, grads, states)
# Call compute_update on each set of parameter/gradient/state args.
new_params, new_states = zip(*list(
itertools.starmap(self._compute_update, args)))
# Global state is unused in the basic case, just pass it through.
return list(new_params), list(new_states), global_state, list(new_params)
def train(self, problem, dataset):
"""Creates graph operations to train the optimizer.
Args:
problem: A problem_generator.Problem instance to train on.
dataset: A datasets.Dataset tuple to use when training.
Returns:
meta_objective: A tensorflow operation for computing the meta-objective
obj_weights: A tensor placeholder for feeding in the objective weights
obj_values: The subproblem objective values during optimization
batches: The batch indexes tensor for overriding with feed_dict
first_unroll: A placeholder signifying if this is a first unroll
(this will propagate the gradients slightly differently).
reset_state: A placeholder signifying that the rnn state should be reset.
output_state: The final state of the optimizer
init_loop_vars_to_override: Local variables that can be assigned to
propagate the optimizer and problem state for unrolling
final_loop_vals: Final values of the loop variables that can be
assigned to init_loop_vars_to_override.
"""
# Placeholder for the objective weights
obj_weights = tf.placeholder(tf.float32)
num_iter = tf.shape(obj_weights)[0]
# Unpack the dataset and generate the minibatches for training
data, labels = dataset
# Convert the ndarrays to tensors so we can pass them back in via feed_dict
data = tf.constant(data)
labels = tf.constant(labels)
batches = tf.placeholder(tf.int32)
first_unroll = tf.placeholder_with_default(False, [])
reset_state = tf.placeholder_with_default(False, [])
training_output = collections.namedtuple("TrainingOutput",
["metaobj",
"obj_weights",
"problem_objectives",
"initial_obj",
"batches",
"first_unroll",
"reset_state",
"output_state",
"init_loop_vars",
"output_loop_vars"])
def loop_body(itr, obj_accum, params, attend_params, flattened_states,
global_state, all_obj, unused_init_obj, data,
labels, batches):
"""Body of the meta-training while loop for optimizing a sub-problem.
Args:
itr: The current meta-training iteration.
obj_accum: The accumulated objective over all training steps so far.
params: The parameters of the sub-problem.
attend_params: The parameters of the sub-problems at the attended
location.
flattened_states: The states of the trainable optimizer, sorted and
flattened into a list (since a while loop can't handle nested lists
or dictionaries).
global_state: The global state of the optimizer.
all_obj: The list of all objective values in the training process.
unused_init_obj: The initial objective (unused here, but needed in the
variable list because it's used in a stopping condition in the
loop_cond.)
data: The data for this problem.
labels: The labels corresponding to the data.
batches: The batch indexes needed for shuffled minibatch creation.
Returns:
itr: The updated meta-training iteration.
obj_accum: The updated accumulated objective.
params: The new parameters of the sub-problem.
attend_params: The new parameters of the sub-problems at the attended
location.
flattened_states: The new states of the trainable optimizer.
global_state: The updated global state.
all_obj: The updates list of all objective values.
unused_init_obj: The initial objective.
data: The data for this problem.
labels: The labels corresponding to the data.
batches: The batch indexes needed for shuffled minibatch creation.
"""
batch_indices = tf.gather(batches, itr)
batch_data = tf.gather(data, batch_indices)
batch_labels = tf.gather(labels, batch_indices)
# Compute the objective over the entire dataset (full batch).
obj = problem.objective(params, data, labels)
# Compute the gradients on just the current batch
if self.use_attention:
current_obj = problem.objective(attend_params, batch_data, batch_labels)
grads = problem.gradients(current_obj, attend_params)
else:
current_obj = problem.objective(params, batch_data, batch_labels)
grads = problem.gradients(current_obj, params)
if not self.use_second_derivatives:
new_grads = []
for grad in grads:
if isinstance(grad, tf.IndexedSlices):
new_grads.append(
tf.IndexedSlices(tf.stop_gradient(grad.values), grad.indices))
else:
new_grads.append(tf.stop_gradient(grad))
grads = new_grads
# store the objective value for the entire problem at each iteration
all_obj = tf.concat([all_obj, tf.reshape(obj, (1,))], 0)
# accumulate the weighted objective for the entire dataset
acc = tf.gather(obj_weights, itr) * obj
obj_accum = tf.add(obj_accum, acc)
# Set the shape to keep the shape invariant for obj_accum. Without this,
# the graph builder thinks the tensor shape is unknown on the 2nd iter.
obj_accum.set_shape([])
# convert flattened_states to dictionaries
dict_states = [dict(zip(self.state_keys, flat_state))
for flat_state in flattened_states]
# compute the new parameters and states
args = (params, grads, dict_states, global_state)
updates = self._compute_updates(*args)
new_params, new_states, new_global_state, new_attend_params = updates
# flatten the states
new_flattened_states = map(flatten_and_sort, new_states)
return [itr + 1, obj_accum, new_params, new_attend_params,
new_flattened_states, new_global_state, all_obj, unused_init_obj,
data, labels, batches]
def loop_cond(itr, obj_accum, unused_params, unused_attend_params,
unused_flattened_states, unused_global_state, all_obj,
init_obj, *args):
"""Termination conditions of the sub-problem optimization loop."""
del args # unused
cond1 = tf.less(itr, num_iter) # We've run < num_iter times
cond2 = tf.is_finite(obj_accum) # The objective is still finite
if self.obj_train_max_multiplier > 0:
current_obj = tf.gather(all_obj, itr)
# Account for negative init_obj too
max_diff = (self.obj_train_max_multiplier - 1) * tf.abs(init_obj)
max_obj = init_obj + max_diff
# The objective is a reasonable multiplier of the original objective
cond3 = tf.less(current_obj, max_obj)
return tf.logical_and(tf.logical_and(cond1, cond2), cond3,
name="training_loop_cond")
else:
return tf.logical_and(cond1, cond2, name="training_loop_cond")
init = self._initialize_training_loop_parameters(
problem, data, labels, batches, first_unroll, reset_state)
loop_vars, invariants, initial_obj, init_loop_vars_to_override = init
loop_output = tf.while_loop(loop_cond, loop_body, loop_vars,
swap_memory=True, shape_invariants=invariants)
meta_obj, problem_objectives = loop_output[1], loop_output[6]
# The meta objective is normalized by the initial objective at the start of
# the series of partial unrolls.
scaled_meta_objective = self.scale_objective(
meta_obj, problem_objectives, initial_obj)
final_loop_vals = (
[initial_obj] + loop_output[2] + loop_output[3] + loop_output[5])
final_loop_vals.extend(itertools.chain(*loop_output[4]))
return training_output(scaled_meta_objective,
obj_weights,
problem_objectives,
initial_obj,
batches,
first_unroll,
reset_state,
loop_output[4],
init_loop_vars_to_override,
final_loop_vals)
def _initialize_training_loop_parameters(
self, problem, data, labels, batches, first_unroll, reset_state):
"""Initializes the vars and params needed for the training process.
Args:
problem: The problem being optimized.
data: The data for the problem.
labels: The corresponding labels for the data.
batches: The indexes needed to create shuffled batches of the data.
first_unroll: Whether this is the first unroll in a partial unrolling.
reset_state: Whether RNN state variables should be reset.
Returns:
loop_vars: The while loop variables for training.
invariants: The corresponding variable shapes (required by while loop).
initial_obj: The initial objective (used later for scaling).
init_loop_vars_to_override: The loop vars that can be overridden when
performing training via partial unrolls.
"""
# Extract these separately so we don't have to make inter-variable
# dependencies.
initial_tensors = problem.init_tensors()
return_initial_tensor_values = first_unroll
initial_params_vars, initial_params = local_state_variables(
initial_tensors, return_initial_tensor_values)
initial_attend_params_vars, initial_attend_params = local_state_variables(
initial_tensors, return_initial_tensor_values)
# Recalculate the initial objective for the list on each partial unroll with
# the new initial_params. initial_obj holds the value from the very first
# unroll.
initial_obj_init = problem.objective(initial_params, data, labels)
return_initial_obj_init = first_unroll
[initial_obj_var], [initial_obj] = local_state_variables(
[initial_obj_init], return_initial_obj_init)
# Initialize the loop variables.
initial_itr = tf.constant(0, dtype=tf.int32)
initial_meta_obj = tf.constant(0, dtype=tf.float32)
# N.B. the use of initial_obj_init here rather than initial_obj
initial_problem_objectives = tf.reshape(initial_obj_init, (1,))
# Initialize the extra state.
initial_state_vars = []
initial_state = []
state_shapes = []
return_initial_state_values = reset_state
for param in initial_tensors:
param_state_vars, param_state = local_state_variables(
flatten_and_sort(self._initialize_state(param)),
return_initial_state_values)
initial_state_vars.append(param_state_vars)
initial_state.append(param_state)
state_shapes.append([f.get_shape() for f in param_state])
# Initialize any global (problem-level) state.
initial_global_state_vars, initial_global_state = local_state_variables(
self._initialize_global_state(), return_initial_state_values)
global_shapes = []
for item in initial_global_state:
global_shapes.append(item.get_shape())
# build the list of loop variables:
loop_vars = [
initial_itr,
initial_meta_obj,
initial_params, # Local variables.
initial_attend_params, # Local variables.
initial_state, # Local variables.
initial_global_state, # Local variables.
initial_problem_objectives,
initial_obj, # Local variable.
data,
labels,
batches,
]
invariants = [
initial_itr.get_shape(),
initial_meta_obj.get_shape(),
[t.get_shape() for t in initial_params],
[t.get_shape() for t in initial_attend_params],
state_shapes,
global_shapes,
tensor_shape.TensorShape([None]), # The problem objectives list grows
initial_obj.get_shape(),
tensor_shape.unknown_shape(), # Placeholder shapes are unknown
tensor_shape.unknown_shape(),
tensor_shape.unknown_shape(),
]
# Initialize local variables that we will override with final tensors at the
# next iter.
init_loop_vars_to_override = (
[initial_obj_var] + initial_params_vars + initial_attend_params_vars +
initial_global_state_vars)
init_loop_vars_to_override.extend(itertools.chain(*initial_state_vars))
return loop_vars, invariants, initial_obj, init_loop_vars_to_override
def scale_objective(self, total_obj, all_objs, initial_obj,
obj_scale_eps=1e-6):
"""Normalizes the objective based on the initial objective value.
Args:
total_obj: The total accumulated objective over the training run.
all_objs: A list of all the individual objectives over the training run.
initial_obj: The initial objective value.
obj_scale_eps: The epsilon value to use in computations for stability.
Returns:
The scaled objective as a single value.
"""
if self.use_log_objective:
if self.use_numerator_epsilon:
scaled_problem_obj = ((all_objs + obj_scale_eps) /
(initial_obj + obj_scale_eps))
log_scaled_problem_obj = tf.log(scaled_problem_obj)
else:
scaled_problem_obj = all_objs / (initial_obj + obj_scale_eps)
log_scaled_problem_obj = tf.log(scaled_problem_obj + obj_scale_eps)
return tf.reduce_mean(log_scaled_problem_obj)
else:
return total_obj / (initial_obj + obj_scale_eps)
def local_state_variables(init_values, return_init_values):
"""Create local variables initialized from init_values.
This will create local variables from a list of init_values. Each variable
will be named based on the value's shape and dtype.
As a convenience, a boolean tensor allows you to return value from
the created local variable or from the original init value.
Args:
init_values: iterable of tensors
return_init_values: boolean tensor
Returns:
local_vars: list of the created local variables.
vals: if return_init_values is true, then this returns the values of
init_values. Otherwise it returns the values of the local_vars.
"""
if not init_values:
return [], []
# This generates a harmless warning when saving the metagraph.
variable_use_count = tf.get_collection_ref(_LOCAL_STATE_VARIABLE_COLLECTION)
if not variable_use_count:
variable_use_count.append(collections.defaultdict(int))
variable_use_count = variable_use_count[0]
local_vars = []
with tf.variable_scope(OPTIMIZER_SCOPE):
# We can't use the init_value as an initializer as init_value may
# itself depend on some problem variables. This would produce
# inter-variable initialization order dependence which TensorFlow
# sucks at making easy.
for init_value in init_values:
name = create_local_state_variable_name(init_value)
unique_name = name + "_" + str(variable_use_count[name])
variable_use_count[name] += 1
# The overarching idea here is to be able to reuse variables between
# different sessions on the same TensorFlow master without errors. By
# uniquifying based on the type and name we mirror the checks made inside
# TensorFlow, while still allowing some memory reuse. Ultimately this is a
# hack due to the broken Session.reset().
local_vars.append(
tf.get_local_variable(
unique_name,
initializer=tf.zeros(
init_value.get_shape(), dtype=init_value.dtype)))
# It makes things a lot simpler if we use the init_value the first
# iteration, instead of the variable itself. It allows us to propagate
# gradients through it as well as simplifying initialization. The variable
# ends up assigned to after the first iteration.
vals = tf.cond(return_init_values, lambda: init_values, lambda: local_vars)
if len(init_values) == 1:
# tf.cond extracts elements from singleton lists.
vals = [vals]
return local_vars, vals
def create_local_state_variable_name(tensor):
"""Create a name of the variable based on its type and shape."""
if not tensor.get_shape().is_fully_defined():
raise ValueError("Need a fully specified shape to create a local variable.")
return (_LOCAL_VARIABLE_PREFIX + "_".join(
map(str, tensor.get_shape().as_list())) + "_" + tensor.dtype.name)
def is_local_state_variable(op):
"""Returns if this op is a local state variable created for training."""
return op.node_def.op in ["Variable", "VariableV2"] and op.name.startswith(
OPTIMIZER_SCOPE + "/" + _LOCAL_VARIABLE_PREFIX)
def flatten_and_sort(dictionary):
"""Flattens a dictionary into a list of values sorted by the keys."""
return [dictionary[k] for k in sorted(dictionary.keys())]
|