Handwriting_Model_Inf / rnn_ops.py
3morrrrr's picture
Upload 14 files
569596a verified
import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import cond as control_flow_ops_cond
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import tensor_array_ops
from tensorflow.python.ops import variable_scope as vs
from tensorflow.python.ops.rnn_cell_impl import _concat, assert_like_rnncell
from tensorflow.python.ops.rnn import _maybe_tensor_shape_from_tensor
from tensorflow.python.util import nest
from tensorflow.python.framework import tensor_shape
def raw_rnn(cell, loop_fn, parallel_iterations=None, swap_memory=False, scope=None):
"""
raw_rnn adapted from the original tensorflow implementation
(https://github.com/tensorflow/tensorflow/blob/r1.4/tensorflow/python/ops/rnn.py)
to emit arbitrarily nested states for each time step (concatenated along the time axis)
in addition to the outputs at each timestep and the final state
returns (
states for all timesteps,
outputs for all timesteps,
final cell state,
)
"""
assert_like_rnncell("dummy_name", cell)
if not callable(loop_fn):
raise TypeError("loop_fn must be a callable")
parallel_iterations = parallel_iterations or 32
# Create a new scope in which the caching device is either
# determined by the parent scope, or is set to place the cached
# Variable using the same placement as for the rest of the RNN.
with vs.variable_scope(scope or "rnn") as varscope:
if not tf.executing_eagerly():
if varscope.caching_device is None:
varscope.set_caching_device(lambda op: op.device)
time = constant_op.constant(0, dtype=dtypes.int32)
(elements_finished, next_input, initial_state, emit_structure,
init_loop_state) = loop_fn(time, None, None, None)
flat_input = nest.flatten(next_input)
# Need a surrogate loop state for the while_loop if none is available.
loop_state = (init_loop_state if init_loop_state is not None
else constant_op.constant(0, dtype=dtypes.int32))
input_shape = [input_.get_shape() for input_ in flat_input]
static_batch_size = input_shape[0][0]
for input_shape_i in input_shape:
# Static verification that batch sizes all match
static_batch_size.merge_with(input_shape_i[0])
batch_size = static_batch_size.value
const_batch_size = batch_size
if batch_size is None:
batch_size = array_ops.shape(flat_input[0])[0]
nest.assert_same_structure(initial_state, cell.state_size)
state = initial_state
flat_state = nest.flatten(state)
flat_state = [ops.convert_to_tensor(s) for s in flat_state]
state = nest.pack_sequence_as(structure=state,
flat_sequence=flat_state)
if emit_structure is not None:
flat_emit_structure = nest.flatten(emit_structure)
flat_emit_size = [emit.shape if emit.shape.is_fully_defined() else
array_ops.shape(emit) for emit in flat_emit_structure]
flat_emit_dtypes = [emit.dtype for emit in flat_emit_structure]
else:
emit_structure = cell.output_size
flat_emit_size = nest.flatten(emit_structure)
flat_emit_dtypes = [flat_state[0].dtype] * len(flat_emit_size)
flat_state_size = [s.shape if s.shape.is_fully_defined() else
array_ops.shape(s) for s in flat_state]
flat_state_dtypes = [s.dtype for s in flat_state]
flat_emit_ta = [
tensor_array_ops.TensorArray(
dtype=dtype_i,
dynamic_size=True,
element_shape=(tensor_shape.TensorShape([const_batch_size])
.concatenate(_maybe_tensor_shape_from_tensor(size_i))),
size=0,
name="rnn_output_%d" % i
)
for i, (dtype_i, size_i) in enumerate(zip(flat_emit_dtypes, flat_emit_size))
]
emit_ta = nest.pack_sequence_as(structure=emit_structure, flat_sequence=flat_emit_ta)
flat_zero_emit = [
array_ops.zeros(_concat(batch_size, size_i), dtype_i)
for size_i, dtype_i in zip(flat_emit_size, flat_emit_dtypes)]
zero_emit = nest.pack_sequence_as(structure=emit_structure, flat_sequence=flat_zero_emit)
flat_state_ta = [
tensor_array_ops.TensorArray(
dtype=dtype_i,
dynamic_size=True,
element_shape=(tensor_shape.TensorShape([const_batch_size])
.concatenate(_maybe_tensor_shape_from_tensor(size_i))),
size=0,
name="rnn_state_%d" % i
)
for i, (dtype_i, size_i) in enumerate(zip(flat_state_dtypes, flat_state_size))
]
state_ta = nest.pack_sequence_as(structure=state, flat_sequence=flat_state_ta)
def condition(unused_time, elements_finished, *_):
return math_ops.logical_not(math_ops.reduce_all(elements_finished))
def body(time, elements_finished, current_input, state_ta, emit_ta, state, loop_state):
(next_output, cell_state) = cell(current_input, state)
nest.assert_same_structure(state, cell_state)
nest.assert_same_structure(cell.output_size, next_output)
next_time = time + 1
(next_finished, next_input, next_state, emit_output,
next_loop_state) = loop_fn(next_time, next_output, cell_state, loop_state)
nest.assert_same_structure(state, next_state)
nest.assert_same_structure(current_input, next_input)
nest.assert_same_structure(emit_ta, emit_output)
# If loop_fn returns None for next_loop_state, just reuse the previous one.
loop_state = loop_state if next_loop_state is None else next_loop_state
def _copy_some_through(current, candidate):
"""Copy some tensors through via array_ops.where."""
def copy_fn(cur_i, cand_i):
# TensorArray and scalar get passed through.
if isinstance(cur_i, tensor_array_ops.TensorArray):
return cand_i
if cur_i.shape.ndims == 0:
return cand_i
# Otherwise propagate the old or the new value.
with ops.colocate_with(cand_i):
return array_ops.where(elements_finished, cur_i, cand_i)
return nest.map_structure(copy_fn, current, candidate)
emit_output = _copy_some_through(zero_emit, emit_output)
next_state = _copy_some_through(state, next_state)
emit_ta = nest.map_structure(lambda ta, emit: ta.write(time, emit), emit_ta, emit_output)
state_ta = nest.map_structure(lambda ta, state: ta.write(time, state), state_ta, next_state)
elements_finished = math_ops.logical_or(elements_finished, next_finished)
return (next_time, elements_finished, next_input, state_ta,
emit_ta, next_state, loop_state)
returned = tf.while_loop(
condition, body, loop_vars=[
time, elements_finished, next_input, state_ta,
emit_ta, state, loop_state],
parallel_iterations=parallel_iterations,
swap_memory=swap_memory
)
(state_ta, emit_ta, final_state, final_loop_state) = returned[-4:]
flat_states = nest.flatten(state_ta)
flat_states = [array_ops.transpose(ta.stack(), (1, 0, 2)) for ta in flat_states]
states = nest.pack_sequence_as(structure=state_ta, flat_sequence=flat_states)
flat_outputs = nest.flatten(emit_ta)
flat_outputs = [array_ops.transpose(ta.stack(), (1, 0, 2)) for ta in flat_outputs]
outputs = nest.pack_sequence_as(structure=emit_ta, flat_sequence=flat_outputs)
return (states, outputs, final_state)
def rnn_teacher_force(inputs, cell, sequence_length, initial_state, scope='dynamic-rnn-teacher-force'):
"""
Implementation of an rnn with teacher forcing inputs provided.
Used in the same way as tf.dynamic_rnn.
"""
inputs = array_ops.transpose(inputs, (1, 0, 2))
inputs_ta = tensor_array_ops.TensorArray(dtype=dtypes.float32, size=array_ops.shape(inputs)[0])
inputs_ta = inputs_ta.unstack(inputs)
def loop_fn(time, cell_output, cell_state, loop_state):
emit_output = cell_output
next_cell_state = initial_state if cell_output is None else cell_state
elements_finished = time >= sequence_length
finished = math_ops.reduce_all(elements_finished)
next_input = control_flow_ops_cond.cond(
finished,
lambda: array_ops.zeros([array_ops.shape(inputs)[1], inputs.shape.as_list()[2]], dtype=dtypes.float32),
lambda: inputs_ta.read(time)
)
next_loop_state = None
return (elements_finished, next_input, next_cell_state, emit_output, next_loop_state)
states, outputs, final_state = raw_rnn(cell, loop_fn, scope=scope)
return states, outputs, final_state
def rnn_free_run(cell, initial_state, sequence_length, initial_input=None, scope='dynamic-rnn-free-run'):
"""
Implementation of an rnn which feeds its feeds its predictions back to itself at the next timestep.
cell must implement two methods:
cell.output_function(state) which takes in the state at timestep t and returns
the cell input at timestep t+1.
cell.termination_condition(state) which returns a boolean tensor of shape
[batch_size] denoting which sequences no longer need to be sampled.
"""
with vs.variable_scope(scope, reuse=True):
if initial_input is None:
initial_input = cell.output_function(initial_state)
def loop_fn(time, cell_output, cell_state, loop_state):
next_cell_state = initial_state if cell_output is None else cell_state
elements_finished = math_ops.logical_or(
time >= sequence_length,
cell.termination_condition(next_cell_state)
)
finished = math_ops.reduce_all(elements_finished)
next_input = control_flow_ops_cond.cond(
finished,
lambda: array_ops.zeros_like(initial_input),
lambda: initial_input if cell_output is None else cell.output_function(next_cell_state)
)
emit_output = next_input[0] if cell_output is None else next_input
next_loop_state = None
return (elements_finished, next_input, next_cell_state, emit_output, next_loop_state)
states, outputs, final_state = raw_rnn(cell, loop_fn, scope=scope)
return states, outputs, final_state