Spaces:
Running
Running
# Copyright 2018 Google, Inc. All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# ============================================================================== | |
"""Utilities. | |
""" | |
from __future__ import absolute_import | |
from __future__ import division | |
from __future__ import print_function | |
import contextlib | |
import tensorflow as tf | |
import sonnet as snt | |
import itertools | |
import functools | |
from tensorflow.core.framework import node_def_pb2 | |
from tensorflow.python.framework import device as pydev | |
from tensorflow.python.framework import errors | |
from tensorflow.python.ops import variable_scope as variable_scope_ops | |
from sonnet.python.modules import util as snt_util | |
from tensorflow.python.util import nest | |
def eqzip(*args): | |
"""Zip but raises error if lengths don't match. | |
Args: | |
*args: list of lists or tuples | |
Returns: | |
list: the result of zip | |
Raises: | |
ValueError: when the lengths don't match | |
""" | |
sizes = [len(x) for x in args] | |
if not all([sizes[0] == x for x in sizes]): | |
raise ValueError("Lists are of different sizes. \n %s"%str(sizes)) | |
return zip(*args) | |
def assert_no_new_variables(): | |
"""Ensure that no tf.Variables are constructed inside the context. | |
Yields: | |
None | |
Raises: | |
ValueError: if there is a variable created. | |
""" | |
num_vars = len(tf.global_variables()) | |
old_variables = tf.global_variables() | |
yield | |
if len(tf.global_variables()) != num_vars: | |
new_vars = set(tf.global_variables()) - set(old_variables) | |
tf.logging.error("NEW VARIABLES CREATED") | |
tf.logging.error(10*"=") | |
for v in new_vars: | |
tf.logging.error(v) | |
raise ValueError("Variables created inside an " | |
"assert_no_new_variables context") | |
if old_variables != tf.global_variables(): | |
raise ValueError("Variables somehow changed inside an " | |
"assert_no_new_variables context." | |
"This means something modified the tf.global_variables()") | |
def get_variables_in_modules(module_list): | |
var_list = [] | |
for m in module_list: | |
var_list.extend(snt.get_variables_in_module(m)) | |
return var_list | |
def state_barrier_context(state): | |
"""Return a context manager that prevents interior ops from running | |
unless the whole state has been computed. | |
This is to prevent assign race conditions. | |
""" | |
tensors = [x for x in nest.flatten(state) if type(x) == tf.Tensor] | |
tarray = [x.flow for x in nest.flatten(state) if hasattr(x, "flow")] | |
return tf.control_dependencies(tensors + tarray) | |
def _identity_fn(tf_entity): | |
if hasattr(tf_entity, "identity"): | |
return tf_entity.identity() | |
else: | |
return tf.identity(tf_entity) | |
def state_barrier_result(state): | |
"""Return the same state, but with a control dependency to prevent it from | |
being partially computed | |
""" | |
with state_barrier_context(state): | |
return nest.map_structure(_identity_fn, state) | |
def train_iterator(num_iterations): | |
"""Iterator that returns an index of the current step. | |
This iterator runs forever if num_iterations is None | |
otherwise it runs for some fixed amount of steps. | |
""" | |
if num_iterations is None: | |
return itertools.count() | |
else: | |
return xrange(num_iterations) | |
def print_op(op, msg): | |
"""Print a string and return an op wrapped in a control dependency to make | |
sure it ran.""" | |
print_op = tf.Print(tf.constant(0), [tf.constant(0)], msg) | |
return tf.group(op, print_op) | |
class MultiQueueRunner(tf.train.QueueRunner): | |
"""A QueueRunner with multiple queues """ | |
def __init__(self, queues, enqueue_ops): | |
close_op = tf.group(* [q.close() for q in queues]) | |
cancel_op = tf.group( | |
* [q.close(cancel_pending_enqueues=True) for q in queues]) | |
queue_closed_exception_types = (errors.OutOfRangeError,) | |
enqueue_op = tf.group(*enqueue_ops, name="multi_enqueue") | |
super(MultiQueueRunner, self).__init__( | |
queues[0], | |
enqueue_ops=[enqueue_op], | |
close_op=close_op, | |
cancel_op=cancel_op, | |
queue_closed_exception_types=queue_closed_exception_types) | |
# This function is not elegant, but I tried so many other ways to get this to | |
# work and this is the only one that ended up not incuring significant overhead | |
# or obscure tensorflow bugs. | |
def sample_n_per_class(dataset, samples_per_class): | |
"""Create a new callable / dataset object that returns batches of each with | |
samples_per_class per label. | |
Args: | |
dataset: fn | |
samples_per_class: int | |
Returns: | |
function, [] -> batch where batch is the same type as the return of | |
dataset(). | |
""" | |
with tf.control_dependencies(None), tf.name_scope(None): | |
with tf.name_scope("queue_runner/sample_n_per_class"): | |
batch = dataset() | |
num_classes = batch.label_onehot.shape.as_list()[1] | |
batch_size = num_classes * samples_per_class | |
flatten = nest.flatten(batch) | |
queues = [] | |
enqueue_ops = [] | |
capacity = samples_per_class * 20 | |
for i in xrange(num_classes): | |
queue = tf.FIFOQueue( | |
capacity=capacity, | |
shapes=[f.shape.as_list()[1:] for f in flatten], | |
dtypes=[f.dtype for f in flatten]) | |
queues.append(queue) | |
idx = tf.where(tf.equal(batch.label, i)) | |
sub_batch = [] | |
to_enqueue = [] | |
for elem in batch: | |
new_e = tf.gather(elem, idx) | |
new_e = tf.squeeze(new_e, 1) | |
to_enqueue.append(new_e) | |
remaining = (capacity - queue.size()) | |
to_add = tf.minimum(tf.shape(idx)[0], remaining) | |
def _enqueue(): | |
return queue.enqueue_many([t[:to_add] for t in to_enqueue]) | |
enqueue_op = tf.cond( | |
tf.equal(to_add, 0), tf.no_op, _enqueue) | |
enqueue_ops.append(enqueue_op) | |
# This has caused many deadlocks / issues. This is some logging to at least | |
# shed light to what is going on. | |
print_lam = lambda: tf.Print(tf.constant(0.0), [q.size() for q in queues], "MultiQueueRunner queues status. Has capacity %d"%capacity) | |
some_percent_of_time = tf.less(tf.random_uniform([]), 0.0005) | |
maybe_print = tf.cond(some_percent_of_time, print_lam, lambda: tf.constant(0.0)) | |
with tf.control_dependencies([maybe_print]): | |
enqueue_ops = [tf.group(e) for e in enqueue_ops] | |
qr = MultiQueueRunner(queues=queues, enqueue_ops=enqueue_ops) | |
tf.train.add_queue_runner(qr) | |
def dequeue_batch(): | |
with tf.name_scope("sample_n_per_batch/dequeue/"): | |
entries = [] | |
for q in queues: | |
entries.append(q.dequeue_many(samples_per_class)) | |
flat_batch = [tf.concat(x, 0) for x in zip(*entries)] | |
idx = tf.random_shuffle(tf.range(batch_size)) | |
flat_batch = [tf.gather(f, idx, axis=0) for f in flat_batch] | |
return nest.pack_sequence_as(batch, flat_batch) | |
return dequeue_batch | |
def structure_map_multi(func, values): | |
all_values = [nest.flatten(v) for v in values] | |
rets = [] | |
for pair in zip(*all_values): | |
rets.append(func(pair)) | |
return nest.pack_sequence_as(values[0], rets) | |
def structure_map_split(func, value): | |
vv = nest.flatten(value) | |
rets = [] | |
for v in vv: | |
rets.append(func(v)) | |
return [nest.pack_sequence_as(value, r) for r in zip(*rets)] | |
def assign_variables(targets, values): | |
return tf.group(*[t.assign(v) for t,v in eqzip(targets, values)], | |
name="assign_variables") | |
def create_variables_in_class_scope(method): | |
"""Force the variables constructed in this class to live in the sonnet module. | |
Wraps a method on a sonnet module. | |
For example the following will create two different variables. | |
``` | |
class Mod(snt.AbstractModule): | |
@create_variables_in_class_scope | |
def dynamic_thing(self, input, name): | |
return snt.Linear(name)(input) | |
mod.dynamic_thing(x, name="module_nameA") | |
mod.dynamic_thing(x, name="module_nameB") | |
# reuse | |
mod.dynamic_thing(y, name="module_nameA") | |
``` | |
""" | |
def wrapper(obj, *args, **kwargs): | |
def default_context_manager(reuse=None): | |
variable_scope = obj.variable_scope | |
return tf.variable_scope(variable_scope, reuse=reuse) | |
variable_scope_context_manager = getattr(obj, "_enter_variable_scope", | |
default_context_manager) | |
graph = tf.get_default_graph() | |
# Temporarily enter the variable scope to capture it | |
with variable_scope_context_manager() as tmp_variable_scope: | |
variable_scope = tmp_variable_scope | |
with variable_scope_ops._pure_variable_scope( | |
variable_scope, reuse=tf.AUTO_REUSE) as pure_variable_scope: | |
name_scope = variable_scope.original_name_scope | |
if name_scope[-1] != "/": | |
name_scope += "/" | |
with tf.name_scope(name_scope): | |
sub_scope = snt_util.to_snake_case(method.__name__) | |
with tf.name_scope(sub_scope) as scope: | |
out_ops = method(obj, *args, **kwargs) | |
return out_ops | |
return wrapper | |