Spaces:
Running
Running
import tensorflow.compat.v1 as tf | |
tf.disable_v2_behavior() | |
from tensorflow.python.framework import constant_op | |
from tensorflow.python.framework import dtypes | |
from tensorflow.python.framework import ops | |
from tensorflow.python.ops import array_ops | |
from tensorflow.python.ops import control_flow_ops | |
from tensorflow.python.ops import cond as control_flow_ops_cond | |
from tensorflow.python.ops import math_ops | |
from tensorflow.python.ops import tensor_array_ops | |
from tensorflow.python.ops import variable_scope as vs | |
from tensorflow.python.ops.rnn_cell_impl import _concat, assert_like_rnncell | |
from tensorflow.python.ops.rnn import _maybe_tensor_shape_from_tensor | |
from tensorflow.python.util import nest | |
from tensorflow.python.framework import tensor_shape | |
def raw_rnn(cell, loop_fn, parallel_iterations=None, swap_memory=False, scope=None): | |
""" | |
raw_rnn adapted from the original tensorflow implementation | |
(https://github.com/tensorflow/tensorflow/blob/r1.4/tensorflow/python/ops/rnn.py) | |
to emit arbitrarily nested states for each time step (concatenated along the time axis) | |
in addition to the outputs at each timestep and the final state | |
returns ( | |
states for all timesteps, | |
outputs for all timesteps, | |
final cell state, | |
) | |
""" | |
assert_like_rnncell("dummy_name", cell) | |
if not callable(loop_fn): | |
raise TypeError("loop_fn must be a callable") | |
parallel_iterations = parallel_iterations or 32 | |
# Create a new scope in which the caching device is either | |
# determined by the parent scope, or is set to place the cached | |
# Variable using the same placement as for the rest of the RNN. | |
with vs.variable_scope(scope or "rnn") as varscope: | |
if not tf.executing_eagerly(): | |
if varscope.caching_device is None: | |
varscope.set_caching_device(lambda op: op.device) | |
time = constant_op.constant(0, dtype=dtypes.int32) | |
(elements_finished, next_input, initial_state, emit_structure, | |
init_loop_state) = loop_fn(time, None, None, None) | |
flat_input = nest.flatten(next_input) | |
# Need a surrogate loop state for the while_loop if none is available. | |
loop_state = (init_loop_state if init_loop_state is not None | |
else constant_op.constant(0, dtype=dtypes.int32)) | |
input_shape = [input_.get_shape() for input_ in flat_input] | |
static_batch_size = input_shape[0][0] | |
for input_shape_i in input_shape: | |
# Static verification that batch sizes all match | |
static_batch_size.merge_with(input_shape_i[0]) | |
batch_size = static_batch_size.value | |
const_batch_size = batch_size | |
if batch_size is None: | |
batch_size = array_ops.shape(flat_input[0])[0] | |
nest.assert_same_structure(initial_state, cell.state_size) | |
state = initial_state | |
flat_state = nest.flatten(state) | |
flat_state = [ops.convert_to_tensor(s) for s in flat_state] | |
state = nest.pack_sequence_as(structure=state, | |
flat_sequence=flat_state) | |
if emit_structure is not None: | |
flat_emit_structure = nest.flatten(emit_structure) | |
flat_emit_size = [emit.shape if emit.shape.is_fully_defined() else | |
array_ops.shape(emit) for emit in flat_emit_structure] | |
flat_emit_dtypes = [emit.dtype for emit in flat_emit_structure] | |
else: | |
emit_structure = cell.output_size | |
flat_emit_size = nest.flatten(emit_structure) | |
flat_emit_dtypes = [flat_state[0].dtype] * len(flat_emit_size) | |
flat_state_size = [s.shape if s.shape.is_fully_defined() else | |
array_ops.shape(s) for s in flat_state] | |
flat_state_dtypes = [s.dtype for s in flat_state] | |
flat_emit_ta = [ | |
tensor_array_ops.TensorArray( | |
dtype=dtype_i, | |
dynamic_size=True, | |
element_shape=(tensor_shape.TensorShape([const_batch_size]) | |
.concatenate(_maybe_tensor_shape_from_tensor(size_i))), | |
size=0, | |
name="rnn_output_%d" % i | |
) | |
for i, (dtype_i, size_i) in enumerate(zip(flat_emit_dtypes, flat_emit_size)) | |
] | |
emit_ta = nest.pack_sequence_as(structure=emit_structure, flat_sequence=flat_emit_ta) | |
flat_zero_emit = [ | |
array_ops.zeros(_concat(batch_size, size_i), dtype_i) | |
for size_i, dtype_i in zip(flat_emit_size, flat_emit_dtypes)] | |
zero_emit = nest.pack_sequence_as(structure=emit_structure, flat_sequence=flat_zero_emit) | |
flat_state_ta = [ | |
tensor_array_ops.TensorArray( | |
dtype=dtype_i, | |
dynamic_size=True, | |
element_shape=(tensor_shape.TensorShape([const_batch_size]) | |
.concatenate(_maybe_tensor_shape_from_tensor(size_i))), | |
size=0, | |
name="rnn_state_%d" % i | |
) | |
for i, (dtype_i, size_i) in enumerate(zip(flat_state_dtypes, flat_state_size)) | |
] | |
state_ta = nest.pack_sequence_as(structure=state, flat_sequence=flat_state_ta) | |
def condition(unused_time, elements_finished, *_): | |
return math_ops.logical_not(math_ops.reduce_all(elements_finished)) | |
def body(time, elements_finished, current_input, state_ta, emit_ta, state, loop_state): | |
(next_output, cell_state) = cell(current_input, state) | |
nest.assert_same_structure(state, cell_state) | |
nest.assert_same_structure(cell.output_size, next_output) | |
next_time = time + 1 | |
(next_finished, next_input, next_state, emit_output, | |
next_loop_state) = loop_fn(next_time, next_output, cell_state, loop_state) | |
nest.assert_same_structure(state, next_state) | |
nest.assert_same_structure(current_input, next_input) | |
nest.assert_same_structure(emit_ta, emit_output) | |
# If loop_fn returns None for next_loop_state, just reuse the previous one. | |
loop_state = loop_state if next_loop_state is None else next_loop_state | |
def _copy_some_through(current, candidate): | |
"""Copy some tensors through via array_ops.where.""" | |
def copy_fn(cur_i, cand_i): | |
# TensorArray and scalar get passed through. | |
if isinstance(cur_i, tensor_array_ops.TensorArray): | |
return cand_i | |
if cur_i.shape.ndims == 0: | |
return cand_i | |
# Otherwise propagate the old or the new value. | |
with ops.colocate_with(cand_i): | |
return array_ops.where(elements_finished, cur_i, cand_i) | |
return nest.map_structure(copy_fn, current, candidate) | |
emit_output = _copy_some_through(zero_emit, emit_output) | |
next_state = _copy_some_through(state, next_state) | |
emit_ta = nest.map_structure(lambda ta, emit: ta.write(time, emit), emit_ta, emit_output) | |
state_ta = nest.map_structure(lambda ta, state: ta.write(time, state), state_ta, next_state) | |
elements_finished = math_ops.logical_or(elements_finished, next_finished) | |
return (next_time, elements_finished, next_input, state_ta, | |
emit_ta, next_state, loop_state) | |
returned = tf.while_loop( | |
condition, body, loop_vars=[ | |
time, elements_finished, next_input, state_ta, | |
emit_ta, state, loop_state], | |
parallel_iterations=parallel_iterations, | |
swap_memory=swap_memory | |
) | |
(state_ta, emit_ta, final_state, final_loop_state) = returned[-4:] | |
flat_states = nest.flatten(state_ta) | |
flat_states = [array_ops.transpose(ta.stack(), (1, 0, 2)) for ta in flat_states] | |
states = nest.pack_sequence_as(structure=state_ta, flat_sequence=flat_states) | |
flat_outputs = nest.flatten(emit_ta) | |
flat_outputs = [array_ops.transpose(ta.stack(), (1, 0, 2)) for ta in flat_outputs] | |
outputs = nest.pack_sequence_as(structure=emit_ta, flat_sequence=flat_outputs) | |
return (states, outputs, final_state) | |
def rnn_teacher_force(inputs, cell, sequence_length, initial_state, scope='dynamic-rnn-teacher-force'): | |
""" | |
Implementation of an rnn with teacher forcing inputs provided. | |
Used in the same way as tf.dynamic_rnn. | |
""" | |
inputs = array_ops.transpose(inputs, (1, 0, 2)) | |
inputs_ta = tensor_array_ops.TensorArray(dtype=dtypes.float32, size=array_ops.shape(inputs)[0]) | |
inputs_ta = inputs_ta.unstack(inputs) | |
def loop_fn(time, cell_output, cell_state, loop_state): | |
emit_output = cell_output | |
next_cell_state = initial_state if cell_output is None else cell_state | |
elements_finished = time >= sequence_length | |
finished = math_ops.reduce_all(elements_finished) | |
next_input = control_flow_ops_cond.cond( | |
finished, | |
lambda: array_ops.zeros([array_ops.shape(inputs)[1], inputs.shape.as_list()[2]], dtype=dtypes.float32), | |
lambda: inputs_ta.read(time) | |
) | |
next_loop_state = None | |
return (elements_finished, next_input, next_cell_state, emit_output, next_loop_state) | |
states, outputs, final_state = raw_rnn(cell, loop_fn, scope=scope) | |
return states, outputs, final_state | |
def rnn_free_run(cell, initial_state, sequence_length, initial_input=None, scope='dynamic-rnn-free-run'): | |
""" | |
Implementation of an rnn which feeds its feeds its predictions back to itself at the next timestep. | |
cell must implement two methods: | |
cell.output_function(state) which takes in the state at timestep t and returns | |
the cell input at timestep t+1. | |
cell.termination_condition(state) which returns a boolean tensor of shape | |
[batch_size] denoting which sequences no longer need to be sampled. | |
""" | |
with vs.variable_scope(scope, reuse=True): | |
if initial_input is None: | |
initial_input = cell.output_function(initial_state) | |
def loop_fn(time, cell_output, cell_state, loop_state): | |
next_cell_state = initial_state if cell_output is None else cell_state | |
elements_finished = math_ops.logical_or( | |
time >= sequence_length, | |
cell.termination_condition(next_cell_state) | |
) | |
finished = math_ops.reduce_all(elements_finished) | |
next_input = control_flow_ops_cond.cond( | |
finished, | |
lambda: array_ops.zeros_like(initial_input), | |
lambda: initial_input if cell_output is None else cell.output_function(next_cell_state) | |
) | |
emit_output = next_input[0] if cell_output is None else next_input | |
next_loop_state = None | |
return (elements_finished, next_input, next_cell_state, emit_output, next_loop_state) | |
states, outputs, final_state = raw_rnn(cell, loop_fn, scope=scope) | |
return states, outputs, final_state | |