NCTCMumbai's picture
Upload 2571 files
0b8359d
raw
history blame
9.41 kB
# Copyright 2018 Google, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utilities.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import contextlib
import tensorflow as tf
import sonnet as snt
import itertools
import functools
from tensorflow.core.framework import node_def_pb2
from tensorflow.python.framework import device as pydev
from tensorflow.python.framework import errors
from tensorflow.python.ops import variable_scope as variable_scope_ops
from sonnet.python.modules import util as snt_util
from tensorflow.python.util import nest
def eqzip(*args):
"""Zip but raises error if lengths don't match.
Args:
*args: list of lists or tuples
Returns:
list: the result of zip
Raises:
ValueError: when the lengths don't match
"""
sizes = [len(x) for x in args]
if not all([sizes[0] == x for x in sizes]):
raise ValueError("Lists are of different sizes. \n %s"%str(sizes))
return zip(*args)
@contextlib.contextmanager
def assert_no_new_variables():
"""Ensure that no tf.Variables are constructed inside the context.
Yields:
None
Raises:
ValueError: if there is a variable created.
"""
num_vars = len(tf.global_variables())
old_variables = tf.global_variables()
yield
if len(tf.global_variables()) != num_vars:
new_vars = set(tf.global_variables()) - set(old_variables)
tf.logging.error("NEW VARIABLES CREATED")
tf.logging.error(10*"=")
for v in new_vars:
tf.logging.error(v)
raise ValueError("Variables created inside an "
"assert_no_new_variables context")
if old_variables != tf.global_variables():
raise ValueError("Variables somehow changed inside an "
"assert_no_new_variables context."
"This means something modified the tf.global_variables()")
def get_variables_in_modules(module_list):
var_list = []
for m in module_list:
var_list.extend(snt.get_variables_in_module(m))
return var_list
def state_barrier_context(state):
"""Return a context manager that prevents interior ops from running
unless the whole state has been computed.
This is to prevent assign race conditions.
"""
tensors = [x for x in nest.flatten(state) if type(x) == tf.Tensor]
tarray = [x.flow for x in nest.flatten(state) if hasattr(x, "flow")]
return tf.control_dependencies(tensors + tarray)
def _identity_fn(tf_entity):
if hasattr(tf_entity, "identity"):
return tf_entity.identity()
else:
return tf.identity(tf_entity)
def state_barrier_result(state):
"""Return the same state, but with a control dependency to prevent it from
being partially computed
"""
with state_barrier_context(state):
return nest.map_structure(_identity_fn, state)
def train_iterator(num_iterations):
"""Iterator that returns an index of the current step.
This iterator runs forever if num_iterations is None
otherwise it runs for some fixed amount of steps.
"""
if num_iterations is None:
return itertools.count()
else:
return xrange(num_iterations)
def print_op(op, msg):
"""Print a string and return an op wrapped in a control dependency to make
sure it ran."""
print_op = tf.Print(tf.constant(0), [tf.constant(0)], msg)
return tf.group(op, print_op)
class MultiQueueRunner(tf.train.QueueRunner):
"""A QueueRunner with multiple queues """
def __init__(self, queues, enqueue_ops):
close_op = tf.group(* [q.close() for q in queues])
cancel_op = tf.group(
* [q.close(cancel_pending_enqueues=True) for q in queues])
queue_closed_exception_types = (errors.OutOfRangeError,)
enqueue_op = tf.group(*enqueue_ops, name="multi_enqueue")
super(MultiQueueRunner, self).__init__(
queues[0],
enqueue_ops=[enqueue_op],
close_op=close_op,
cancel_op=cancel_op,
queue_closed_exception_types=queue_closed_exception_types)
# This function is not elegant, but I tried so many other ways to get this to
# work and this is the only one that ended up not incuring significant overhead
# or obscure tensorflow bugs.
def sample_n_per_class(dataset, samples_per_class):
"""Create a new callable / dataset object that returns batches of each with
samples_per_class per label.
Args:
dataset: fn
samples_per_class: int
Returns:
function, [] -> batch where batch is the same type as the return of
dataset().
"""
with tf.control_dependencies(None), tf.name_scope(None):
with tf.name_scope("queue_runner/sample_n_per_class"):
batch = dataset()
num_classes = batch.label_onehot.shape.as_list()[1]
batch_size = num_classes * samples_per_class
flatten = nest.flatten(batch)
queues = []
enqueue_ops = []
capacity = samples_per_class * 20
for i in xrange(num_classes):
queue = tf.FIFOQueue(
capacity=capacity,
shapes=[f.shape.as_list()[1:] for f in flatten],
dtypes=[f.dtype for f in flatten])
queues.append(queue)
idx = tf.where(tf.equal(batch.label, i))
sub_batch = []
to_enqueue = []
for elem in batch:
new_e = tf.gather(elem, idx)
new_e = tf.squeeze(new_e, 1)
to_enqueue.append(new_e)
remaining = (capacity - queue.size())
to_add = tf.minimum(tf.shape(idx)[0], remaining)
def _enqueue():
return queue.enqueue_many([t[:to_add] for t in to_enqueue])
enqueue_op = tf.cond(
tf.equal(to_add, 0), tf.no_op, _enqueue)
enqueue_ops.append(enqueue_op)
# This has caused many deadlocks / issues. This is some logging to at least
# shed light to what is going on.
print_lam = lambda: tf.Print(tf.constant(0.0), [q.size() for q in queues], "MultiQueueRunner queues status. Has capacity %d"%capacity)
some_percent_of_time = tf.less(tf.random_uniform([]), 0.0005)
maybe_print = tf.cond(some_percent_of_time, print_lam, lambda: tf.constant(0.0))
with tf.control_dependencies([maybe_print]):
enqueue_ops = [tf.group(e) for e in enqueue_ops]
qr = MultiQueueRunner(queues=queues, enqueue_ops=enqueue_ops)
tf.train.add_queue_runner(qr)
def dequeue_batch():
with tf.name_scope("sample_n_per_batch/dequeue/"):
entries = []
for q in queues:
entries.append(q.dequeue_many(samples_per_class))
flat_batch = [tf.concat(x, 0) for x in zip(*entries)]
idx = tf.random_shuffle(tf.range(batch_size))
flat_batch = [tf.gather(f, idx, axis=0) for f in flat_batch]
return nest.pack_sequence_as(batch, flat_batch)
return dequeue_batch
def structure_map_multi(func, values):
all_values = [nest.flatten(v) for v in values]
rets = []
for pair in zip(*all_values):
rets.append(func(pair))
return nest.pack_sequence_as(values[0], rets)
def structure_map_split(func, value):
vv = nest.flatten(value)
rets = []
for v in vv:
rets.append(func(v))
return [nest.pack_sequence_as(value, r) for r in zip(*rets)]
def assign_variables(targets, values):
return tf.group(*[t.assign(v) for t,v in eqzip(targets, values)],
name="assign_variables")
def create_variables_in_class_scope(method):
"""Force the variables constructed in this class to live in the sonnet module.
Wraps a method on a sonnet module.
For example the following will create two different variables.
```
class Mod(snt.AbstractModule):
@create_variables_in_class_scope
def dynamic_thing(self, input, name):
return snt.Linear(name)(input)
mod.dynamic_thing(x, name="module_nameA")
mod.dynamic_thing(x, name="module_nameB")
# reuse
mod.dynamic_thing(y, name="module_nameA")
```
"""
@functools.wraps(method)
def wrapper(obj, *args, **kwargs):
def default_context_manager(reuse=None):
variable_scope = obj.variable_scope
return tf.variable_scope(variable_scope, reuse=reuse)
variable_scope_context_manager = getattr(obj, "_enter_variable_scope",
default_context_manager)
graph = tf.get_default_graph()
# Temporarily enter the variable scope to capture it
with variable_scope_context_manager() as tmp_variable_scope:
variable_scope = tmp_variable_scope
with variable_scope_ops._pure_variable_scope(
variable_scope, reuse=tf.AUTO_REUSE) as pure_variable_scope:
name_scope = variable_scope.original_name_scope
if name_scope[-1] != "/":
name_scope += "/"
with tf.name_scope(name_scope):
sub_scope = snt_util.to_snake_case(method.__name__)
with tf.name_scope(sub_scope) as scope:
out_ops = method(obj, *args, **kwargs)
return out_ops
return wrapper