Spaces:
Running
Running
File size: 11,362 Bytes
569596a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 |
import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import cond as control_flow_ops_cond
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import tensor_array_ops
from tensorflow.python.ops import variable_scope as vs
from tensorflow.python.ops.rnn_cell_impl import _concat, assert_like_rnncell
from tensorflow.python.ops.rnn import _maybe_tensor_shape_from_tensor
from tensorflow.python.util import nest
from tensorflow.python.framework import tensor_shape
def raw_rnn(cell, loop_fn, parallel_iterations=None, swap_memory=False, scope=None):
"""
raw_rnn adapted from the original tensorflow implementation
(https://github.com/tensorflow/tensorflow/blob/r1.4/tensorflow/python/ops/rnn.py)
to emit arbitrarily nested states for each time step (concatenated along the time axis)
in addition to the outputs at each timestep and the final state
returns (
states for all timesteps,
outputs for all timesteps,
final cell state,
)
"""
assert_like_rnncell("dummy_name", cell)
if not callable(loop_fn):
raise TypeError("loop_fn must be a callable")
parallel_iterations = parallel_iterations or 32
# Create a new scope in which the caching device is either
# determined by the parent scope, or is set to place the cached
# Variable using the same placement as for the rest of the RNN.
with vs.variable_scope(scope or "rnn") as varscope:
if not tf.executing_eagerly():
if varscope.caching_device is None:
varscope.set_caching_device(lambda op: op.device)
time = constant_op.constant(0, dtype=dtypes.int32)
(elements_finished, next_input, initial_state, emit_structure,
init_loop_state) = loop_fn(time, None, None, None)
flat_input = nest.flatten(next_input)
# Need a surrogate loop state for the while_loop if none is available.
loop_state = (init_loop_state if init_loop_state is not None
else constant_op.constant(0, dtype=dtypes.int32))
input_shape = [input_.get_shape() for input_ in flat_input]
static_batch_size = input_shape[0][0]
for input_shape_i in input_shape:
# Static verification that batch sizes all match
static_batch_size.merge_with(input_shape_i[0])
batch_size = static_batch_size.value
const_batch_size = batch_size
if batch_size is None:
batch_size = array_ops.shape(flat_input[0])[0]
nest.assert_same_structure(initial_state, cell.state_size)
state = initial_state
flat_state = nest.flatten(state)
flat_state = [ops.convert_to_tensor(s) for s in flat_state]
state = nest.pack_sequence_as(structure=state,
flat_sequence=flat_state)
if emit_structure is not None:
flat_emit_structure = nest.flatten(emit_structure)
flat_emit_size = [emit.shape if emit.shape.is_fully_defined() else
array_ops.shape(emit) for emit in flat_emit_structure]
flat_emit_dtypes = [emit.dtype for emit in flat_emit_structure]
else:
emit_structure = cell.output_size
flat_emit_size = nest.flatten(emit_structure)
flat_emit_dtypes = [flat_state[0].dtype] * len(flat_emit_size)
flat_state_size = [s.shape if s.shape.is_fully_defined() else
array_ops.shape(s) for s in flat_state]
flat_state_dtypes = [s.dtype for s in flat_state]
flat_emit_ta = [
tensor_array_ops.TensorArray(
dtype=dtype_i,
dynamic_size=True,
element_shape=(tensor_shape.TensorShape([const_batch_size])
.concatenate(_maybe_tensor_shape_from_tensor(size_i))),
size=0,
name="rnn_output_%d" % i
)
for i, (dtype_i, size_i) in enumerate(zip(flat_emit_dtypes, flat_emit_size))
]
emit_ta = nest.pack_sequence_as(structure=emit_structure, flat_sequence=flat_emit_ta)
flat_zero_emit = [
array_ops.zeros(_concat(batch_size, size_i), dtype_i)
for size_i, dtype_i in zip(flat_emit_size, flat_emit_dtypes)]
zero_emit = nest.pack_sequence_as(structure=emit_structure, flat_sequence=flat_zero_emit)
flat_state_ta = [
tensor_array_ops.TensorArray(
dtype=dtype_i,
dynamic_size=True,
element_shape=(tensor_shape.TensorShape([const_batch_size])
.concatenate(_maybe_tensor_shape_from_tensor(size_i))),
size=0,
name="rnn_state_%d" % i
)
for i, (dtype_i, size_i) in enumerate(zip(flat_state_dtypes, flat_state_size))
]
state_ta = nest.pack_sequence_as(structure=state, flat_sequence=flat_state_ta)
def condition(unused_time, elements_finished, *_):
return math_ops.logical_not(math_ops.reduce_all(elements_finished))
def body(time, elements_finished, current_input, state_ta, emit_ta, state, loop_state):
(next_output, cell_state) = cell(current_input, state)
nest.assert_same_structure(state, cell_state)
nest.assert_same_structure(cell.output_size, next_output)
next_time = time + 1
(next_finished, next_input, next_state, emit_output,
next_loop_state) = loop_fn(next_time, next_output, cell_state, loop_state)
nest.assert_same_structure(state, next_state)
nest.assert_same_structure(current_input, next_input)
nest.assert_same_structure(emit_ta, emit_output)
# If loop_fn returns None for next_loop_state, just reuse the previous one.
loop_state = loop_state if next_loop_state is None else next_loop_state
def _copy_some_through(current, candidate):
"""Copy some tensors through via array_ops.where."""
def copy_fn(cur_i, cand_i):
# TensorArray and scalar get passed through.
if isinstance(cur_i, tensor_array_ops.TensorArray):
return cand_i
if cur_i.shape.ndims == 0:
return cand_i
# Otherwise propagate the old or the new value.
with ops.colocate_with(cand_i):
return array_ops.where(elements_finished, cur_i, cand_i)
return nest.map_structure(copy_fn, current, candidate)
emit_output = _copy_some_through(zero_emit, emit_output)
next_state = _copy_some_through(state, next_state)
emit_ta = nest.map_structure(lambda ta, emit: ta.write(time, emit), emit_ta, emit_output)
state_ta = nest.map_structure(lambda ta, state: ta.write(time, state), state_ta, next_state)
elements_finished = math_ops.logical_or(elements_finished, next_finished)
return (next_time, elements_finished, next_input, state_ta,
emit_ta, next_state, loop_state)
returned = tf.while_loop(
condition, body, loop_vars=[
time, elements_finished, next_input, state_ta,
emit_ta, state, loop_state],
parallel_iterations=parallel_iterations,
swap_memory=swap_memory
)
(state_ta, emit_ta, final_state, final_loop_state) = returned[-4:]
flat_states = nest.flatten(state_ta)
flat_states = [array_ops.transpose(ta.stack(), (1, 0, 2)) for ta in flat_states]
states = nest.pack_sequence_as(structure=state_ta, flat_sequence=flat_states)
flat_outputs = nest.flatten(emit_ta)
flat_outputs = [array_ops.transpose(ta.stack(), (1, 0, 2)) for ta in flat_outputs]
outputs = nest.pack_sequence_as(structure=emit_ta, flat_sequence=flat_outputs)
return (states, outputs, final_state)
def rnn_teacher_force(inputs, cell, sequence_length, initial_state, scope='dynamic-rnn-teacher-force'):
"""
Implementation of an rnn with teacher forcing inputs provided.
Used in the same way as tf.dynamic_rnn.
"""
inputs = array_ops.transpose(inputs, (1, 0, 2))
inputs_ta = tensor_array_ops.TensorArray(dtype=dtypes.float32, size=array_ops.shape(inputs)[0])
inputs_ta = inputs_ta.unstack(inputs)
def loop_fn(time, cell_output, cell_state, loop_state):
emit_output = cell_output
next_cell_state = initial_state if cell_output is None else cell_state
elements_finished = time >= sequence_length
finished = math_ops.reduce_all(elements_finished)
next_input = control_flow_ops_cond.cond(
finished,
lambda: array_ops.zeros([array_ops.shape(inputs)[1], inputs.shape.as_list()[2]], dtype=dtypes.float32),
lambda: inputs_ta.read(time)
)
next_loop_state = None
return (elements_finished, next_input, next_cell_state, emit_output, next_loop_state)
states, outputs, final_state = raw_rnn(cell, loop_fn, scope=scope)
return states, outputs, final_state
def rnn_free_run(cell, initial_state, sequence_length, initial_input=None, scope='dynamic-rnn-free-run'):
"""
Implementation of an rnn which feeds its feeds its predictions back to itself at the next timestep.
cell must implement two methods:
cell.output_function(state) which takes in the state at timestep t and returns
the cell input at timestep t+1.
cell.termination_condition(state) which returns a boolean tensor of shape
[batch_size] denoting which sequences no longer need to be sampled.
"""
with vs.variable_scope(scope, reuse=True):
if initial_input is None:
initial_input = cell.output_function(initial_state)
def loop_fn(time, cell_output, cell_state, loop_state):
next_cell_state = initial_state if cell_output is None else cell_state
elements_finished = math_ops.logical_or(
time >= sequence_length,
cell.termination_condition(next_cell_state)
)
finished = math_ops.reduce_all(elements_finished)
next_input = control_flow_ops_cond.cond(
finished,
lambda: array_ops.zeros_like(initial_input),
lambda: initial_input if cell_output is None else cell.output_function(next_cell_state)
)
emit_output = next_input[0] if cell_output is None else next_input
next_loop_state = None
return (elements_finished, next_input, next_cell_state, emit_output, next_loop_state)
states, outputs, final_state = raw_rnn(cell, loop_fn, scope=scope)
return states, outputs, final_state
|