Spaces:
Running
Running
Upload 20 files
Browse files- aglib/meliad/transformer/__init__.py +15 -0
- aglib/meliad/transformer/attention.py +443 -0
- aglib/meliad/transformer/decoder_stack.py +426 -0
- aglib/meliad/transformer/ht_main.py +56 -0
- aglib/meliad/transformer/ht_main_inference.py +76 -0
- aglib/meliad/transformer/inference_utils.py +271 -0
- aglib/meliad/transformer/launcher.py +119 -0
- aglib/meliad/transformer/memory_factory.py +120 -0
- aglib/meliad/transformer/memory_layer.py +431 -0
- aglib/meliad/transformer/metric_utils.py +115 -0
- aglib/meliad/transformer/models.py +317 -0
- aglib/meliad/transformer/nn_components.py +437 -0
- aglib/meliad/transformer/position.py +242 -0
- aglib/meliad/transformer/position_fourier.py +218 -0
- aglib/meliad/transformer/position_t5.py +155 -0
- aglib/meliad/transformer/synthetic_text_data.py +0 -0
- aglib/meliad/transformer/tasks.py +52 -0
- aglib/meliad/transformer/text_dataset.py +775 -0
- aglib/meliad/transformer/transformer_base.py +451 -0
- aglib/meliad/transformer/transformer_layer.py +817 -0
aglib/meliad/transformer/__init__.py
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 Google.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
|
aglib/meliad/transformer/attention.py
ADDED
@@ -0,0 +1,443 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 Google.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
"""Transformer attention functions."""
|
16 |
+
|
17 |
+
import typing
|
18 |
+
from typing import Any, Callable, Mapping, NewType, Optional, Sequence, Tuple, Union
|
19 |
+
|
20 |
+
from absl import logging
|
21 |
+
from flax import linen as nn
|
22 |
+
import jax
|
23 |
+
import jax.numpy as jnp
|
24 |
+
|
25 |
+
from transformer import nn_components
|
26 |
+
from transformer import position
|
27 |
+
|
28 |
+
|
29 |
+
Array = jnp.ndarray
|
30 |
+
ArrayTree = Union[Array, Tuple["ArrayTree", ...]]
|
31 |
+
DecoderState = NewType("DecoderState", Mapping[str, Array])
|
32 |
+
|
33 |
+
# Tuple of keys, values, importance.
|
34 |
+
KVITuple = Tuple[Array, Array, Optional[Array]]
|
35 |
+
|
36 |
+
# Tuple of keys, values, queries, queries2, importance.
|
37 |
+
KVQITuple = Tuple[Array, Array, Array, Optional[Array], Optional[Array]]
|
38 |
+
|
39 |
+
# Tuple of scale factors. See TransformerBase.attention_scale_factors().
|
40 |
+
AttnScaleTuple = Tuple[Optional[Array], Optional[Array]]
|
41 |
+
|
42 |
+
|
43 |
+
def initial_kvi(shape: Sequence[int], use_importance: bool, dtype: Any):
|
44 |
+
"""Returns initial (zero) keys/values/i that can be passed to prev_kvi."""
|
45 |
+
z = jnp.zeros(shape, dtype=dtype)
|
46 |
+
if use_importance:
|
47 |
+
i = jnp.zeros((shape[0], shape[1]), dtype=dtype) # (bsize, window_length)
|
48 |
+
else:
|
49 |
+
i = None
|
50 |
+
return (z, z, i)
|
51 |
+
|
52 |
+
|
53 |
+
def concat_kvqi(kvqi: KVQITuple, prev_kvi: Optional[KVITuple]) -> (
|
54 |
+
Tuple[KVQITuple, Optional[KVITuple]]):
|
55 |
+
"""Concatenate previous keys,values with current keys,values.
|
56 |
+
|
57 |
+
Args:
|
58 |
+
kvqi: Current keys, values, queries, quieres2, importance.
|
59 |
+
prev_kvi: Previous keys, values, importance.
|
60 |
+
|
61 |
+
Returns:
|
62 |
+
(kvqi: Concatenated (keys, values, queries, importance),
|
63 |
+
next_kvi: Next (keys, values, importance)) (from kvqi)
|
64 |
+
"""
|
65 |
+
|
66 |
+
(keys, values, queries, queries2, importance) = kvqi
|
67 |
+
# The current keys,values,importance will be passed to the next window.
|
68 |
+
next_kvi = (keys, values, importance)
|
69 |
+
(batch_size, _, num_heads, head_dim) = keys.shape # (b, _, h, d)
|
70 |
+
|
71 |
+
if prev_kvi is None:
|
72 |
+
return (kvqi, None) # If prev_kvi is None, next_kvi should be None.
|
73 |
+
|
74 |
+
# Unpack prev_kvi and check shapes.
|
75 |
+
(pkeys, pvalues, pimportance) = prev_kvi
|
76 |
+
num_pkeys = pkeys.shape[1]
|
77 |
+
assert pkeys.shape == (batch_size, num_pkeys, num_heads, head_dim)
|
78 |
+
assert pkeys.shape == pvalues.shape
|
79 |
+
if pimportance is not None:
|
80 |
+
assert pimportance.shape == (batch_size, num_pkeys)
|
81 |
+
|
82 |
+
# Concatenate keys and values.
|
83 |
+
keys = jnp.concatenate([pkeys, keys], axis=1) # (b, k, h, d)
|
84 |
+
values = jnp.concatenate([pvalues, values], axis=1) # (b, k, h, d)
|
85 |
+
if importance is not None:
|
86 |
+
assert pimportance is not None
|
87 |
+
importance = jnp.concatenate([pimportance, importance], axis=1) # (b, k)
|
88 |
+
logging.info("attn: importance = %r", importance)
|
89 |
+
|
90 |
+
return ((keys, values, queries, queries2, importance), next_kvi)
|
91 |
+
|
92 |
+
|
93 |
+
def simple_attention(keys: Array,
|
94 |
+
values: Array,
|
95 |
+
queries: Array,
|
96 |
+
importance: Optional[Array],
|
97 |
+
*,
|
98 |
+
relative_position_bias: Optional[Array] = None,
|
99 |
+
scale_factor: Optional[Array] = None,
|
100 |
+
causal_mask: Optional[Array] = None,
|
101 |
+
dropout_multiplier: Optional[Array] = None,
|
102 |
+
dtype: Any = jnp.float32) -> Array:
|
103 |
+
"""Simple attention from a set of queries to a set of keys,values.
|
104 |
+
|
105 |
+
Args:
|
106 |
+
keys: of shape [batch_size, num_keys, num_heads, head_dim].
|
107 |
+
values: of shape [batch_size, num_keys, num_heads, head_dim].
|
108 |
+
queries: of shape [batch_size, num_queries, num_heads, head_dim].
|
109 |
+
importance: of shape [batch_size, num_keys].
|
110 |
+
|
111 |
+
*: ---- the following arguments are passed by keyword only ----
|
112 |
+
relative_position_bias: A positional attention matrix of shape
|
113 |
+
[num_heads, num_queries, num_keys]
|
114 |
+
scale_factor: Learned scale factor for use with normalized keys,queries
|
115 |
+
of shape [num_heads]
|
116 |
+
causal_mask: A boolean array of shape [num_heads, num_queries, num_keys]
|
117 |
+
dropout_multiplier: A random mask of either 0.0 or 1.0/keep_prob,
|
118 |
+
of shape [num_heads, num_queries, num_keys]
|
119 |
+
dtype: data type to perform attention at.
|
120 |
+
|
121 |
+
Returns:
|
122 |
+
Attention outputs of shape [batch_size, num_queries, num_heads, head_size]
|
123 |
+
"""
|
124 |
+
|
125 |
+
# (batch_size, num_keys, num_heads, head_dim)
|
126 |
+
(batch_size, num_keys, num_heads, head_dim) = keys.shape # (b, k, h, d)
|
127 |
+
num_queries = queries.shape[1]
|
128 |
+
assert keys.shape == values.shape
|
129 |
+
assert queries.shape == (batch_size, num_queries, num_heads, head_dim)
|
130 |
+
if importance is not None:
|
131 |
+
assert importance.shape == (batch_size, num_keys)
|
132 |
+
|
133 |
+
logging.info("attn: keys = %r", keys)
|
134 |
+
logging.info("attn: queries = %r", queries)
|
135 |
+
|
136 |
+
# Compute attention matrix.
|
137 |
+
attn = jnp.einsum("...qhd,...khd->...hqk", queries, keys) # (b, h, q, k)
|
138 |
+
|
139 |
+
logging.info("attn: content attn = %r", attn)
|
140 |
+
|
141 |
+
# Apply relative position bias.
|
142 |
+
if relative_position_bias is not None:
|
143 |
+
logging.info("attn: pbias = %r", relative_position_bias)
|
144 |
+
relative_position_bias = jnp.asarray(relative_position_bias, dtype=dtype)
|
145 |
+
pbias = position.broadcast_mask(relative_position_bias, attn)
|
146 |
+
attn = attn + pbias
|
147 |
+
|
148 |
+
# Apply learned attention scale.
|
149 |
+
if scale_factor is not None:
|
150 |
+
logging.info("attn: learned attention scale: %s", scale_factor)
|
151 |
+
# Broadcast scale over batch/keys/queries.
|
152 |
+
scale_factor = jnp.asarray(scale_factor, dtype=dtype)
|
153 |
+
scale_factor = scale_factor.reshape((1, num_heads, 1, 1))
|
154 |
+
attn = attn * scale_factor
|
155 |
+
|
156 |
+
# Apply causal mask.
|
157 |
+
if causal_mask is not None:
|
158 |
+
causal_mask = position.broadcast_mask(causal_mask, attn)
|
159 |
+
attn = jnp.where(causal_mask, attn, jnp.asarray(-1_000_000.0, dtype=dtype))
|
160 |
+
|
161 |
+
logging.info("attn: pre-softmax attn = %r", attn)
|
162 |
+
|
163 |
+
# Normalize attention matrix with softmax.
|
164 |
+
# min_x should be much smaller than minimum expected values in attn, but
|
165 |
+
# much larger than the masked_out values created by the causal mask. That
|
166 |
+
# way, if all tokens are masked out, then softmax will attend to nothing,
|
167 |
+
# rather than attend to everything equally.
|
168 |
+
min_x = jnp.asarray(-1000.0, dtype=dtype)
|
169 |
+
attn = nn_components.safe_softmax(attn, axis=-1, min_x=min_x) # (b, h, q, k)
|
170 |
+
|
171 |
+
# Apply dropout to attention matrix.
|
172 |
+
if dropout_multiplier is not None:
|
173 |
+
logging.debug("attn: drop = %r", dropout_multiplier)
|
174 |
+
dropout_multiplier = jnp.asarray(dropout_multiplier, dtype=dtype)
|
175 |
+
attn = attn * dropout_multiplier
|
176 |
+
|
177 |
+
logging.info("attn: final attn = %r", attn)
|
178 |
+
|
179 |
+
# Compute output -- values weighted by attention matrix.
|
180 |
+
y = jnp.einsum("...hqk,...khd->...qhd", attn, values) # (b, q, h, d)
|
181 |
+
|
182 |
+
logging.info("attn: y = %r", y)
|
183 |
+
return y
|
184 |
+
|
185 |
+
|
186 |
+
def external_attention(external_keys: Array,
|
187 |
+
external_values: Array,
|
188 |
+
queries: Array,
|
189 |
+
*,
|
190 |
+
scale_factor: Optional[Array] = None,
|
191 |
+
dtype: Any = jnp.float32) -> Array:
|
192 |
+
"""Attention over (keys, values) retrieved from external memory.
|
193 |
+
|
194 |
+
Args:
|
195 |
+
external_keys: per-query keys from external memory, of shape
|
196 |
+
[batch_size, num_queries, num_heads, num_neighbors, head_size]
|
197 |
+
external_values: per-query values from external memory, of shape
|
198 |
+
[batch_size, num_queries, num_heads, num_neighbors, head_size]
|
199 |
+
queries: current queries, of shape:
|
200 |
+
[batch_size, num_queries, num_heads, head_size]
|
201 |
+
|
202 |
+
*: ---- the following arguments are passed by keyword only. ---
|
203 |
+
scale_factor: Learned scale factor for use with normalized keys,queries
|
204 |
+
of shape [num_heads]
|
205 |
+
dtype: data type to perform attention at.
|
206 |
+
|
207 |
+
Returns:
|
208 |
+
Attention outputs of shape [batch_size, num_queries, num_heads, head_size]
|
209 |
+
"""
|
210 |
+
|
211 |
+
(batch_size, num_queries, num_heads, _, head_dim) = external_keys.shape
|
212 |
+
assert queries.shape == (batch_size, num_queries, num_heads, head_dim)
|
213 |
+
assert external_values.shape == external_keys.shape
|
214 |
+
|
215 |
+
# Build attention matrix.
|
216 |
+
logging.info("ext_attn: external keys = %r", external_keys)
|
217 |
+
ext_attn = jnp.einsum("...qhd,...qhid->...hqi", queries, external_keys)
|
218 |
+
|
219 |
+
logging.info("ext_attn: external_mem_attn: %s", ext_attn)
|
220 |
+
if scale_factor is not None:
|
221 |
+
scale_factor = jnp.asarray(scale_factor, dtype=dtype)
|
222 |
+
scale_factor = scale_factor.reshape((1, num_heads, 1, 1))
|
223 |
+
logging.info("ext_attn: scaling external_mem_attn by %s", scale_factor)
|
224 |
+
ext_attn = ext_attn * scale_factor
|
225 |
+
|
226 |
+
ext_attn = nn.softmax(ext_attn, axis=-1)
|
227 |
+
|
228 |
+
# Compute weighted sum of values.
|
229 |
+
ext_y = jnp.einsum("...hqi,...qhid->...qhd", ext_attn, external_values)
|
230 |
+
logging.info("ext_attn: ext_y = %r", ext_y)
|
231 |
+
return ext_y
|
232 |
+
|
233 |
+
|
234 |
+
def sliding_attention_window_shape(kvi: KVITuple,
|
235 |
+
prev_kvi: Optional[KVITuple],
|
236 |
+
queries: Array,
|
237 |
+
window_length: int) -> Tuple[int, int]:
|
238 |
+
"""Return (num_queries, num_keys) for the sliding attention window."""
|
239 |
+
|
240 |
+
# Do error checking here.
|
241 |
+
(keys, values, importance) = kvi
|
242 |
+
assert keys.shape == queries.shape
|
243 |
+
assert values.shape == queries.shape
|
244 |
+
|
245 |
+
# Get sizes...
|
246 |
+
(batch_size, sequence_length, _, _) = queries.shape
|
247 |
+
|
248 |
+
if importance is not None:
|
249 |
+
assert importance.ndim == 2
|
250 |
+
assert importance.shape == (batch_size, sequence_length)
|
251 |
+
|
252 |
+
assert window_length > 0
|
253 |
+
if window_length >= sequence_length:
|
254 |
+
# No sliding window.
|
255 |
+
num_queries = sequence_length
|
256 |
+
num_keys = sequence_length
|
257 |
+
if prev_kvi is not None:
|
258 |
+
num_keys += prev_kvi[0].shape[1]
|
259 |
+
else:
|
260 |
+
# Sliding window.
|
261 |
+
if prev_kvi is not None:
|
262 |
+
assert prev_kvi[0].shape[1] == window_length
|
263 |
+
num_queries = window_length
|
264 |
+
num_keys = window_length * 2
|
265 |
+
|
266 |
+
return (num_queries, num_keys)
|
267 |
+
|
268 |
+
|
269 |
+
def split_tree(tree: ArrayTree, sections: int, axis: int = 0) -> (
|
270 |
+
Sequence[ArrayTree]):
|
271 |
+
"""Recursively splits a possibly nested tuple of arrays along the given axis.
|
272 |
+
|
273 |
+
Args:
|
274 |
+
tree: A nested tree of tuples and arrays.
|
275 |
+
sections: The number of sections to split the tree into.
|
276 |
+
axis: The axis to do the split on arrays.
|
277 |
+
|
278 |
+
Returns:
|
279 |
+
A list of trees, of length sections, where each has the same shape as the
|
280 |
+
original, but with arrays of size 1/sections.
|
281 |
+
"""
|
282 |
+
|
283 |
+
if tree is None:
|
284 |
+
return [None] * sections
|
285 |
+
elif isinstance(tree, jnp.ndarray):
|
286 |
+
return jnp.split(tree, sections, axis=axis)
|
287 |
+
elif isinstance(tree, tuple):
|
288 |
+
# Recursively split each element of the tuple into a list.
|
289 |
+
branch_lists = [split_tree(tree_i, sections, axis=axis) for tree_i in tree]
|
290 |
+
# Rearrange the tuple of lists into a list of tuples.
|
291 |
+
return [tuple([brs[i] for brs in branch_lists]) for i in range(sections)]
|
292 |
+
else:
|
293 |
+
raise ValueError("Argument %r must be an ndarray or tuple." % tree)
|
294 |
+
|
295 |
+
|
296 |
+
def concat_trees(tree_list: Sequence[ArrayTree], axis: int = 0) -> ArrayTree:
|
297 |
+
"""Merges a list of trees into a single tree by concatenating their elements.
|
298 |
+
|
299 |
+
Args:
|
300 |
+
tree_list: A list of trees, all of the same shape.
|
301 |
+
axis: The axis to concatenate arrays on.
|
302 |
+
|
303 |
+
Returns:
|
304 |
+
A single tree, with the same shape as the trees in tree_list.
|
305 |
+
"""
|
306 |
+
|
307 |
+
# All trees in the list are required to have the same shape.
|
308 |
+
# We return a tree with the same shape as each of the trees in the list,
|
309 |
+
first_tree = tree_list[0]
|
310 |
+
if first_tree is None:
|
311 |
+
# Merge a list of None into a single None.
|
312 |
+
for tree_i in tree_list:
|
313 |
+
assert tree_i is None
|
314 |
+
return None
|
315 |
+
elif isinstance(first_tree, jnp.ndarray):
|
316 |
+
# Concatenate a list of arrays.
|
317 |
+
for tree_i in tree_list:
|
318 |
+
assert isinstance(tree_i, jnp.ndarray)
|
319 |
+
return jnp.concatenate(tree_list, axis=axis)
|
320 |
+
elif isinstance(first_tree, tuple):
|
321 |
+
# Reshape a list of tuples into a tuple of concatenated lists.
|
322 |
+
for tree_i in tree_list:
|
323 |
+
assert isinstance(tree_i, tuple) and len(tree_i) == len(first_tree)
|
324 |
+
num_branches = len(first_tree)
|
325 |
+
return tuple([concat_trees([tree[b] for tree in tree_list], axis=axis)
|
326 |
+
for b in range(num_branches)])
|
327 |
+
else:
|
328 |
+
raise ValueError("Argument %r must be an ndarray or tuple." % first_tree)
|
329 |
+
|
330 |
+
|
331 |
+
def reshape_transpose_tree(tree: ArrayTree, sections: int, axis: int = 0) -> (
|
332 |
+
ArrayTree):
|
333 |
+
"""Reshape and transpose arrays so that the window is dimension 0."""
|
334 |
+
|
335 |
+
# We could use jax tree utils for this, but we do it the hard way so the
|
336 |
+
# implementaiton can be compared with split_tree.
|
337 |
+
if tree is None:
|
338 |
+
return None
|
339 |
+
elif isinstance(tree, jnp.ndarray):
|
340 |
+
tree = typing.cast(Array, tree) # Tell type-checker about isinstance
|
341 |
+
ndim = tree.ndim
|
342 |
+
wlen = tree.shape[axis] // sections
|
343 |
+
assert sections * wlen == tree.shape[axis] # Must be evenly divisible.
|
344 |
+
|
345 |
+
# Break the axis dimension into sections * window_size
|
346 |
+
arr = tree
|
347 |
+
sh = list(arr.shape)
|
348 |
+
nshape = sh[0:axis] + [sections, wlen] + sh[axis + 1:]
|
349 |
+
arr = jnp.reshape(arr, nshape)
|
350 |
+
|
351 |
+
# Transpose sections to be dimension 0.
|
352 |
+
tdims = [axis] + list(range(0, axis)) + list(range(axis + 1, ndim + 1))
|
353 |
+
arr = jnp.transpose(arr, tdims)
|
354 |
+
return arr
|
355 |
+
elif isinstance(tree, tuple):
|
356 |
+
return tuple([reshape_transpose_tree(b, sections, axis) for b in tree])
|
357 |
+
else:
|
358 |
+
raise ValueError("Argument %r must be an ndarray or tuple." % tree)
|
359 |
+
|
360 |
+
|
361 |
+
def transpose_reshape_tree(tree: ArrayTree, sections: int, axis: int = 0) -> (
|
362 |
+
ArrayTree):
|
363 |
+
"""Reshape and transpose arrays so that the window is dimension 0."""
|
364 |
+
|
365 |
+
# We could use jax tree utils for this, but we do it the hard way so the
|
366 |
+
# implementaiton can be compared with split_tree.
|
367 |
+
if tree is None:
|
368 |
+
return None
|
369 |
+
elif isinstance(tree, jnp.ndarray):
|
370 |
+
tree = typing.cast(Array, tree) # Tell type-checker about isinstance
|
371 |
+
ndim = tree.ndim - 1 # Input tree has 1 extra dimension on front.
|
372 |
+
assert axis < ndim
|
373 |
+
wlen = tree.shape[axis + 1] # Window length.
|
374 |
+
|
375 |
+
# Transpose dimension 0 back to its proper place.
|
376 |
+
arr = tree
|
377 |
+
tdims = list(range(1, axis + 1)) + [0] + list(range(axis + 1, ndim + 1))
|
378 |
+
arr = jnp.transpose(arr, tdims)
|
379 |
+
|
380 |
+
# Combine the sections and window_size dimensions.
|
381 |
+
sh = list(arr.shape)
|
382 |
+
nshape = sh[0:axis] + [sections * wlen] + sh[axis + 2:]
|
383 |
+
arr = jnp.reshape(arr, nshape)
|
384 |
+
return arr
|
385 |
+
elif isinstance(tree, tuple):
|
386 |
+
return tuple([transpose_reshape_tree(b, sections, axis) for b in tree])
|
387 |
+
else:
|
388 |
+
raise ValueError("Argument %r must be an ndarray or tuple." % tree)
|
389 |
+
|
390 |
+
|
391 |
+
def split_and_scan(fn: Callable[[ArrayTree, ArrayTree],
|
392 |
+
Tuple[ArrayTree, ArrayTree]],
|
393 |
+
carry: ArrayTree, input_arrays: ArrayTree,
|
394 |
+
sections: int, axis: int = 0,
|
395 |
+
max_unrolled_windows: int = -1) -> (
|
396 |
+
Tuple[ArrayTree, ArrayTree]):
|
397 |
+
"""Scan over a set of input arrays in chunks.
|
398 |
+
|
399 |
+
Splits each array in 'input_arrays' into the number of chunks given by
|
400 |
+
'sections', and then loops over the chunks using a scan operation.
|
401 |
+
Returns a concatenation of the results.
|
402 |
+
|
403 |
+
Args:
|
404 |
+
fn: A function from (carry, input_i) -> (carry, output_i).
|
405 |
+
carry: The initial state for the scan, that will be passed from one
|
406 |
+
iteration to the next.
|
407 |
+
input_arrays: A nested tree of tuples of arrays.
|
408 |
+
sections: The number of sections or chunks for the split.
|
409 |
+
axis: The axis to split each array along.
|
410 |
+
max_unrolled_windows: If 0 <= max_unrolled_windows < sections,
|
411 |
+
use jax.lax.scan rather than unrolling the windows with a python loop.
|
412 |
+
|
413 |
+
Returns:
|
414 |
+
(carry, output)
|
415 |
+
"""
|
416 |
+
|
417 |
+
if sections == 1:
|
418 |
+
logging.info("Single window, no scan.")
|
419 |
+
return fn(carry, input_arrays)
|
420 |
+
|
421 |
+
if axis < 0:
|
422 |
+
raise ValueError(f"Axis must be positive. Got {axis}")
|
423 |
+
|
424 |
+
logging.info("Scanning over %d windows", sections)
|
425 |
+
|
426 |
+
if 0 <= max_unrolled_windows and max_unrolled_windows < sections:
|
427 |
+
logging.info("Using jax.lax.scan.")
|
428 |
+
in_arrs = reshape_transpose_tree(input_arrays, sections, axis)
|
429 |
+
(carry, out_arrs) = jax.lax.scan(fn, carry, in_arrs)
|
430 |
+
output_arrays = transpose_reshape_tree(out_arrs, sections, axis)
|
431 |
+
return (carry, output_arrays)
|
432 |
+
|
433 |
+
logging.info("Using unrolled for-loop.")
|
434 |
+
in_list = split_tree(input_arrays, sections, axis=axis)
|
435 |
+
out_list = []
|
436 |
+
|
437 |
+
for (k, in_chunk) in enumerate(in_list):
|
438 |
+
logging.info("Processing window %d", k)
|
439 |
+
(carry, out_chunk) = fn(carry, in_chunk)
|
440 |
+
out_list.append(out_chunk)
|
441 |
+
|
442 |
+
output_arrays = concat_trees(out_list, axis=axis)
|
443 |
+
return (carry, output_arrays)
|
aglib/meliad/transformer/decoder_stack.py
ADDED
@@ -0,0 +1,426 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 Google.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
"""Hierarchical transformer."""
|
16 |
+
|
17 |
+
import functools
|
18 |
+
from typing import Any, Callable, Optional, Sequence, Tuple
|
19 |
+
|
20 |
+
from absl import logging
|
21 |
+
|
22 |
+
from flax import linen as nn
|
23 |
+
from flax import struct
|
24 |
+
import gin
|
25 |
+
import jax.numpy as jnp
|
26 |
+
from transformer import attention
|
27 |
+
from transformer import metric_utils
|
28 |
+
from transformer import nn_components
|
29 |
+
from transformer import position
|
30 |
+
from transformer import transformer_layer
|
31 |
+
|
32 |
+
|
33 |
+
Array = Any
|
34 |
+
|
35 |
+
|
36 |
+
# Basic task options are shared among multiple classes.
|
37 |
+
@gin.configurable
|
38 |
+
@struct.dataclass
|
39 |
+
class TransformerTaskConfig:
|
40 |
+
"""Configuration hyperparameters for sequence-to-sequence tasks."""
|
41 |
+
|
42 |
+
dataset_name: str = "synthetic"
|
43 |
+
train_split: str = "train"
|
44 |
+
test_split: str = "test"
|
45 |
+
sequential_chunks: bool = True # Process chunks of text in sequential order.
|
46 |
+
|
47 |
+
sequence_length: int = 4096
|
48 |
+
batch_size: int = 1 # per device batch size
|
49 |
+
vocab_size: int = 256
|
50 |
+
|
51 |
+
|
52 |
+
DStackDecoderState = Tuple[transformer_layer.DecoderState, ...]
|
53 |
+
DStackWindowState = Tuple[transformer_layer.WindowState, ...]
|
54 |
+
|
55 |
+
|
56 |
+
@gin.configurable
|
57 |
+
class DecoderStack(nn.Module):
|
58 |
+
"""Stack of transformer decoder layers."""
|
59 |
+
|
60 |
+
mode: str
|
61 |
+
task_config: TransformerTaskConfig = gin.REQUIRED
|
62 |
+
|
63 |
+
# Configurable hyperparameters.
|
64 |
+
num_layers: int = gin.REQUIRED
|
65 |
+
embedding_size: int = gin.REQUIRED
|
66 |
+
embedding_stddev: float = 1.0
|
67 |
+
|
68 |
+
# The class to use for an individual transformer layer.
|
69 |
+
layer_factory: Any = gin.REQUIRED
|
70 |
+
|
71 |
+
# Window length to use for the decoder stack.
|
72 |
+
# If nonzero, use this instead of TransformerLayer.window_length.
|
73 |
+
dstack_window_length: int = 0
|
74 |
+
use_absolute_positions: bool = False
|
75 |
+
use_final_layernorm: bool = True
|
76 |
+
final_dropout_rate: float = 0.0
|
77 |
+
final_mlp_factory: Optional[Callable[[int], nn.Module]] = None
|
78 |
+
|
79 |
+
# Enable recurrence on particular layers.
|
80 |
+
recurrent_layer_indices: Sequence[int] = ()
|
81 |
+
feedback_recurrence: bool = True
|
82 |
+
|
83 |
+
# The factory function which creates a MemoryManager, or None.
|
84 |
+
memory_factory: Any = None
|
85 |
+
# Layers to equip with external memory.
|
86 |
+
memory_layer_indices: Sequence[int] = ()
|
87 |
+
|
88 |
+
dtype: Any = jnp.float32
|
89 |
+
|
90 |
+
def is_training(self):
|
91 |
+
return self.mode == "train"
|
92 |
+
|
93 |
+
def supports_generate(self) -> bool:
|
94 |
+
return all([lyr.supports_generate() for lyr in self.transformer_layers])
|
95 |
+
|
96 |
+
def setup(self):
|
97 |
+
task_config = self.task_config
|
98 |
+
|
99 |
+
embed_init = nn.initializers.normal(stddev=self.embedding_stddev,
|
100 |
+
dtype=jnp.float32)
|
101 |
+
self.embed = nn.Embed(num_embeddings=task_config.vocab_size,
|
102 |
+
features=self.embedding_size,
|
103 |
+
embedding_init=embed_init)
|
104 |
+
|
105 |
+
# Create a memory_factory.MemoryManager object, which is shared among
|
106 |
+
# all transformer layers. Each layer will use the MemoryManager object
|
107 |
+
# to instantiate a block of memory for that layer.
|
108 |
+
memory = None
|
109 |
+
if self.memory_factory is not None:
|
110 |
+
if self.memory_layer_indices:
|
111 |
+
memory = self.memory_factory(batch_size=task_config.batch_size,
|
112 |
+
mode=self.mode)
|
113 |
+
else:
|
114 |
+
logging.warning(
|
115 |
+
"Memory factory specified, but memory_layer_indices is empty.")
|
116 |
+
|
117 |
+
# Allow negative numbers in memory_layer_indices.
|
118 |
+
# Negative numbers refer to layers at the top of the stack.
|
119 |
+
for k in self.memory_layer_indices:
|
120 |
+
if k < -self.num_layers or k >= self.num_layers:
|
121 |
+
raise ValueError(f"Invalid memory layer index {k}")
|
122 |
+
# The % operator will convert negative k to self.num_layers + k.
|
123 |
+
mem_layer_indices = [
|
124 |
+
idx % self.num_layers for idx in self.memory_layer_indices
|
125 |
+
]
|
126 |
+
|
127 |
+
# Allow negative numbers in recurrent_layer_indices.
|
128 |
+
for k in self.recurrent_layer_indices:
|
129 |
+
if k < -self.num_layers or k >= self.num_layers:
|
130 |
+
raise ValueError(f"Invalid recurrent layer index {k}")
|
131 |
+
recurrent_layer_indices = [
|
132 |
+
idx % self.num_layers for idx in self.recurrent_layer_indices
|
133 |
+
]
|
134 |
+
# Turn on cross attention if there are recurrent layers with feedback.
|
135 |
+
enable_cross_attn = (self.feedback_recurrence and
|
136 |
+
self.recurrent_layer_indices and
|
137 |
+
self.dstack_window_length > 0)
|
138 |
+
|
139 |
+
layers = []
|
140 |
+
for i in range(0, self.num_layers):
|
141 |
+
mem = memory if (i in mem_layer_indices) else None
|
142 |
+
rec_i = i in recurrent_layer_indices
|
143 |
+
layer_fn = functools.partial(
|
144 |
+
self.layer_factory,
|
145 |
+
mode=self.mode,
|
146 |
+
batch_size=self.task_config.batch_size,
|
147 |
+
embedding_size=self.embedding_size,
|
148 |
+
name=f"transformer{i}",
|
149 |
+
recurrent_attention=rec_i,
|
150 |
+
cross_attention=enable_cross_attn and not rec_i)
|
151 |
+
if mem:
|
152 |
+
logging.info("Using external memory with transformer layer %d.", i)
|
153 |
+
layer_fn = functools.partial(
|
154 |
+
layer_fn,
|
155 |
+
memory=mem,
|
156 |
+
# We use partial function applications here only to avoid
|
157 |
+
# overwriting the head size unless memory is involved.
|
158 |
+
head_size=mem.key_size,
|
159 |
+
num_heads=mem.num_heads)
|
160 |
+
layers.append(layer_fn())
|
161 |
+
self.transformer_layers = layers
|
162 |
+
|
163 |
+
if self.use_final_layernorm:
|
164 |
+
self.final_layernorm = nn_components.LayerNorm()
|
165 |
+
|
166 |
+
if self.final_mlp_factory is not None:
|
167 |
+
self.final_mlp = self.final_mlp_factory(self.embedding_size)
|
168 |
+
|
169 |
+
def init_decoder_state(self, sequence_length: int,
|
170 |
+
start_of_sequence: Array) -> DStackDecoderState:
|
171 |
+
"""Return initial state for autoregressive generation."""
|
172 |
+
return tuple([
|
173 |
+
layer.init_decoder_state(sequence_length, start_of_sequence)
|
174 |
+
for layer in self.transformer_layers
|
175 |
+
])
|
176 |
+
|
177 |
+
def load_window_state(self, start_of_sequence: Array) -> DStackWindowState:
|
178 |
+
"""Load cached state that is passed from one window to the next."""
|
179 |
+
return tuple([
|
180 |
+
layer.load_window_state(start_of_sequence)
|
181 |
+
for layer in self.transformer_layers
|
182 |
+
])
|
183 |
+
|
184 |
+
def store_window_state(self, window_state: DStackWindowState):
|
185 |
+
"""Write window state to the cache."""
|
186 |
+
for (layer, wstate) in zip(self.transformer_layers, window_state):
|
187 |
+
layer.store_window_state(wstate)
|
188 |
+
|
189 |
+
def _eval_layer_stack(self, xs: Array, start_of_sequence: Array,
|
190 |
+
window_state: Optional[DStackWindowState],
|
191 |
+
decoder_state: Optional[DStackDecoderState]) -> (
|
192 |
+
Tuple[Array, Optional[DStackWindowState],
|
193 |
+
Optional[DStackDecoderState], Any]):
|
194 |
+
"""Evaluate a stack of transformer layers on an input."""
|
195 |
+
|
196 |
+
ys = xs # (batch_size, seq_len, num_hidden)
|
197 |
+
importance = None # (batch_size, sequence_length)
|
198 |
+
next_window_states = []
|
199 |
+
next_decoder_states = []
|
200 |
+
attn_viz_dicts = []
|
201 |
+
|
202 |
+
# If we have a recurrent layer, grab the keys and values from it.
|
203 |
+
# All other layers can then cross-attend to the recurrent keys and values.
|
204 |
+
recurrent_kv = None
|
205 |
+
enable_cross_attn = (self.feedback_recurrence and
|
206 |
+
self.recurrent_layer_indices and
|
207 |
+
self.dstack_window_length > 0)
|
208 |
+
if enable_cross_attn and window_state is not None:
|
209 |
+
# TODO(delesley): fix this so it works with the autoregressive decoder.
|
210 |
+
assert decoder_state is None
|
211 |
+
logging.info("dstack: using recurrent cross attention on all layers.")
|
212 |
+
for (layer, wstate_i) in zip(self.transformer_layers, window_state):
|
213 |
+
rkv = layer.get_recurrent_kv(wstate_i)
|
214 |
+
if rkv is not None:
|
215 |
+
recurrent_kv = rkv
|
216 |
+
|
217 |
+
# Apply transformer layers.
|
218 |
+
for (i, layer) in enumerate(self.transformer_layers):
|
219 |
+
if layer.recurrent_attention:
|
220 |
+
cross_kv = None # The recurrent layer handles rkv internally.
|
221 |
+
else:
|
222 |
+
cross_kv = recurrent_kv # Other layers cross-attend to recurrent one.
|
223 |
+
|
224 |
+
logging.info("dstack: ---- Layer %d ----", i)
|
225 |
+
wstate_i = None if window_state is None else window_state[i]
|
226 |
+
dstate_i = None if decoder_state is None else decoder_state[i]
|
227 |
+
(ys, importance, n_wstate_i, n_dstate_i, viz_dict) = layer(
|
228 |
+
ys, start_of_sequence,
|
229 |
+
importance=importance,
|
230 |
+
cross_attention_kv=cross_kv, # cross-attend to recurrent_kv.
|
231 |
+
window_state=wstate_i,
|
232 |
+
decoder_state=dstate_i)
|
233 |
+
next_window_states.append(n_wstate_i)
|
234 |
+
next_decoder_states.append(n_dstate_i)
|
235 |
+
attn_viz_dicts.append(viz_dict)
|
236 |
+
|
237 |
+
window_state = tuple(next_window_states)
|
238 |
+
decoder_state = tuple(next_decoder_states)
|
239 |
+
return (ys, window_state, decoder_state, attn_viz_dicts)
|
240 |
+
|
241 |
+
def __call__(self,
|
242 |
+
input_tokens: Array,
|
243 |
+
target_tokens: Array,
|
244 |
+
start_of_sequence: Array,
|
245 |
+
decoder_state: Optional[DStackDecoderState] = None) -> (
|
246 |
+
Tuple[Array, Optional[DStackDecoderState], Any]):
|
247 |
+
"""Call the decoder stack.
|
248 |
+
|
249 |
+
This function will embed tokens, run the embeddings through a stack of
|
250 |
+
decoder layers, and then compute logits for the target tokens using the
|
251 |
+
transpose of the embeddings. It returns un-normalized (pre-softmax)
|
252 |
+
logits.
|
253 |
+
|
254 |
+
Args:
|
255 |
+
input_tokens: Integer array of shape [batch_size, sequence_length]
|
256 |
+
target_tokens: For compatibility. Ignored by this class.
|
257 |
+
start_of_sequence: Boolean array of shape [batch_size],
|
258 |
+
which indicates whether a sequence is at the start of sequence.
|
259 |
+
decoder_state: State object for autoregressive decoding,
|
260 |
+
created from init_decoder_state.
|
261 |
+
|
262 |
+
Returns:
|
263 |
+
(logits, of shape [batch_size, sequence_length, vocab_size],
|
264 |
+
next_decoder_state: for autoregressive decoding,
|
265 |
+
viz_dict: dictionary of visualizations,
|
266 |
+
)
|
267 |
+
"""
|
268 |
+
del target_tokens
|
269 |
+
task_config = self.task_config
|
270 |
+
|
271 |
+
# Embed tokens.
|
272 |
+
embeddings = self.embed(input_tokens) # (batch_size, seq_len, num_hidden)
|
273 |
+
embeddings = embeddings.astype(self.dtype)
|
274 |
+
sequence_length = embeddings.shape[1]
|
275 |
+
logging.info("dstack: embeddings = %r", embeddings)
|
276 |
+
|
277 |
+
# Add absolute position encodings if necessary.
|
278 |
+
if self.use_absolute_positions:
|
279 |
+
# Use a large max_wavelength so that only part of the input vector
|
280 |
+
# is used for positions.
|
281 |
+
positions = position.position_encoding(
|
282 |
+
num_positions=task_config.sequence_length,
|
283 |
+
input_dim=self.embedding_size,
|
284 |
+
max_wavelength=10_000)
|
285 |
+
positions = jnp.asarray(positions, dtype=self.dtype)
|
286 |
+
positions = jnp.expand_dims(positions, 0) # Add batch dimension.
|
287 |
+
logging.info("dstack: absolute positions = %r", positions)
|
288 |
+
embeddings = embeddings + positions
|
289 |
+
|
290 |
+
# Function to run the whole transformer stack on a single window.
|
291 |
+
# ---------------------------------------------------------------
|
292 |
+
def single_window_stack(carry, inputs_w):
|
293 |
+
(window_state_w, start_of_seq_w) = carry
|
294 |
+
(outputs_w, window_state_w, _, _) = self._eval_layer_stack(
|
295 |
+
inputs_w, start_of_seq_w,
|
296 |
+
window_state=window_state_w, decoder_state=None)
|
297 |
+
|
298 |
+
# start_of_sequence is false after the first window.
|
299 |
+
bsize = self.task_config.batch_size
|
300 |
+
next_start_of_seq = jnp.asarray([False] * bsize, dtype=jnp.bool_)
|
301 |
+
return ((window_state_w, next_start_of_seq), outputs_w)
|
302 |
+
|
303 |
+
# Find the number of windows. A sequence may be split into multiple
|
304 |
+
# windows here, or alternatively, it may be split (or further split) within
|
305 |
+
# TransformerLayer, depending on configuration.
|
306 |
+
if (self.dstack_window_length == 0 or
|
307 |
+
self.dstack_window_length >= sequence_length):
|
308 |
+
num_windows = 1
|
309 |
+
else:
|
310 |
+
num_windows = sequence_length // self.dstack_window_length
|
311 |
+
assert (num_windows * self.dstack_window_length) == sequence_length
|
312 |
+
|
313 |
+
# Evaluate the stack of layers, scanning over windows if configured.
|
314 |
+
# ------------------------------------------------------------------
|
315 |
+
if decoder_state is None:
|
316 |
+
logging.info("dstack: scanning over %d windows.", num_windows)
|
317 |
+
# Load cached state from the previous training step, for truncated BPTT.
|
318 |
+
window_state = self.load_window_state(start_of_sequence)
|
319 |
+
|
320 |
+
# Scan single_window_stack over the sequence.
|
321 |
+
cstate = (window_state, start_of_sequence)
|
322 |
+
(cstate, ys) = attention.split_and_scan(single_window_stack,
|
323 |
+
cstate,
|
324 |
+
embeddings,
|
325 |
+
sections=num_windows,
|
326 |
+
axis=1)
|
327 |
+
(window_state, _) = cstate
|
328 |
+
|
329 |
+
# Cache state for the next training step, for truncated BPTT.
|
330 |
+
self.store_window_state(window_state)
|
331 |
+
attn_viz_dicts = {} # Temporarily disabled.
|
332 |
+
else:
|
333 |
+
logging.info("dstack: autoregressive generator.")
|
334 |
+
# Run as an autoregressive decoder: evaluate the whole stack on a token.
|
335 |
+
# Do not load or store window_state; decoder_state is used instead.
|
336 |
+
(ys, _, decoder_state, _) = self._eval_layer_stack(
|
337 |
+
embeddings, start_of_sequence,
|
338 |
+
window_state=None, decoder_state=decoder_state)
|
339 |
+
attn_viz_dicts = {}
|
340 |
+
|
341 |
+
# Apply layernorm to the final output, before calculating logits.
|
342 |
+
# With a pre-layernorm architecture, this has to be done here.
|
343 |
+
if self.use_final_layernorm:
|
344 |
+
logging.info("dstack: Final layernorm.")
|
345 |
+
ys = self.final_layernorm(ys)
|
346 |
+
|
347 |
+
# Final dropout before token prediction.
|
348 |
+
drop_tile_shape = (1, 128, self.embedding_size)
|
349 |
+
get_dropout_rng = lambda: self.make_rng("dropout")
|
350 |
+
ys = nn_components.tiled_dropout(ys, drop_tile_shape,
|
351 |
+
self.final_dropout_rate,
|
352 |
+
rng_function=get_dropout_rng,
|
353 |
+
deterministic=not self.is_training())
|
354 |
+
|
355 |
+
# Apply an MLP at the very end to convert the output of the transformer
|
356 |
+
# into a vector to look up target tokens in the embedding table.
|
357 |
+
# This final layer allows the NN to distinguish between the "input context",
|
358 |
+
# which is returned by the transformer resnet, and the "predicted target".
|
359 |
+
if self.final_mlp_factory is not None:
|
360 |
+
logging.info("dstack: Final MLP layer.")
|
361 |
+
ys = self.final_mlp(ys)
|
362 |
+
|
363 |
+
# Reverse embedding to generate logits which predict the output tokens.
|
364 |
+
logits = self.embed.attend(ys) # (..., seq_len, vocab_size)
|
365 |
+
logging.info("dstack: logits = %r", logits)
|
366 |
+
|
367 |
+
# Normalize so that the range of logits is reasonable.
|
368 |
+
logits = logits / jnp.sqrt(logits.shape[-1]).astype(self.dtype)
|
369 |
+
|
370 |
+
# Produce various visualizations in generate mode.
|
371 |
+
# TODO(delesley): Too many visualizations crashes the summary writer.
|
372 |
+
if self.mode == "generate":
|
373 |
+
img_dict = self._make_images(attn_viz_dicts, [])
|
374 |
+
hist_dict = {} # metric_utils.make_histograms(attn_viz_dicts)
|
375 |
+
info_dict = {**img_dict, **hist_dict}
|
376 |
+
else:
|
377 |
+
info_dict = {} # Don't output any visualizations.
|
378 |
+
|
379 |
+
return (logits, decoder_state, info_dict)
|
380 |
+
|
381 |
+
def _make_importance_image(self, importance_list, scaled=True) -> Array:
|
382 |
+
rows = []
|
383 |
+
for imp in importance_list:
|
384 |
+
rows += [imp] * 8 # Rows are 8 pixels high for better visability.
|
385 |
+
image = jnp.stack(rows)
|
386 |
+
if scaled:
|
387 |
+
image = jnp.exp(image)
|
388 |
+
image = metric_utils.normalize_image(image, True)
|
389 |
+
return metric_utils.reshape_image(image)
|
390 |
+
|
391 |
+
def _make_images(self, viz_dicts, importance_list):
|
392 |
+
image_dict = {}
|
393 |
+
for (i, viz_dict) in enumerate(viz_dicts):
|
394 |
+
if "attn_importance_gate" in viz_dict:
|
395 |
+
imp_gate = viz_dict["attn_importance_gate"][0] # First item in batch.
|
396 |
+
imp_strip = metric_utils.normalize_image(imp_gate[:, 0:8, :], True)
|
397 |
+
else:
|
398 |
+
imp_strip = None
|
399 |
+
|
400 |
+
for (k, attn_images) in viz_dict.items():
|
401 |
+
if k not in {"attn_content",
|
402 |
+
"attn_pre_softmax",
|
403 |
+
"attn_log",
|
404 |
+
"attn",
|
405 |
+
"attn_position_bias",
|
406 |
+
"attn_importance_bias",
|
407 |
+
"attn_importance_gate"}:
|
408 |
+
continue
|
409 |
+
|
410 |
+
attn_img = attn_images[0] # Grab the first item in the batch.
|
411 |
+
attn_img = metric_utils.normalize_image(attn_img,
|
412 |
+
as_group=(k != "attn"))
|
413 |
+
if imp_strip is not None and k in {"attn_log", "attn"}:
|
414 |
+
# Show importance bias in a strip at the bottom of the image.
|
415 |
+
attn_img = metric_utils.overlay_images(attn_img, imp_strip)
|
416 |
+
attn_img = metric_utils.reshape_image(attn_img) # Returns None on fail.
|
417 |
+
if attn_img is not None:
|
418 |
+
image_dict[k + "_" + str(i)] = attn_img
|
419 |
+
|
420 |
+
if importance_list:
|
421 |
+
# Create an image out of the importance for each layer.
|
422 |
+
image_dict["importance_gate"] = self._make_importance_image(
|
423 |
+
importance_list, scaled=True)
|
424 |
+
image_dict["importance_raw"] = self._make_importance_image(
|
425 |
+
importance_list, scaled=False)
|
426 |
+
return image_dict
|
aglib/meliad/transformer/ht_main.py
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 Google.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
r"""Main program to train htransformer models.
|
16 |
+
|
17 |
+
"""
|
18 |
+
|
19 |
+
from typing import Sequence
|
20 |
+
|
21 |
+
from absl import app
|
22 |
+
from absl import flags
|
23 |
+
from clu import platform
|
24 |
+
import jax
|
25 |
+
from transformer import launcher
|
26 |
+
import tensorflow.compat.v2 as tf
|
27 |
+
|
28 |
+
|
29 |
+
FLAGS = flags.FLAGS
|
30 |
+
|
31 |
+
|
32 |
+
def main(argv: Sequence[str]) -> None:
|
33 |
+
if len(argv) > 1:
|
34 |
+
raise app.UsageError("Too many command-line arguments.")
|
35 |
+
|
36 |
+
launcher.parse_gin_configuration()
|
37 |
+
|
38 |
+
# Hide any GPUs from TensorFlow. Otherwise TF might reserve memory and make
|
39 |
+
# it unavailable to JAX.
|
40 |
+
tf.config.experimental.set_visible_devices([], "GPU")
|
41 |
+
|
42 |
+
# Set global seed for datasets.
|
43 |
+
# tf.random.set_seed(1234)
|
44 |
+
|
45 |
+
# Add a note so that we can tell which task is which JAX host.
|
46 |
+
# (Depending on the platform task 0 is not guaranteed to be host 0)
|
47 |
+
platform.work_unit().set_task_status(f"process_index: {jax.process_index()}, "
|
48 |
+
f"process_count: {jax.process_count()}")
|
49 |
+
platform.work_unit().create_artifact(platform.ArtifactType.DIRECTORY,
|
50 |
+
FLAGS.workdir, "workdir")
|
51 |
+
|
52 |
+
launcher.run_training_loop(testing=False)
|
53 |
+
|
54 |
+
|
55 |
+
if __name__ == "__main__":
|
56 |
+
app.run(main)
|
aglib/meliad/transformer/ht_main_inference.py
ADDED
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 Google.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
r"""Program to run a transformer model over a single article.
|
16 |
+
|
17 |
+
"""
|
18 |
+
|
19 |
+
# This program is currently a template, which can be expanded to do more
|
20 |
+
# sophisticated analysis.
|
21 |
+
|
22 |
+
from typing import Sequence
|
23 |
+
|
24 |
+
from absl import app
|
25 |
+
from absl import flags
|
26 |
+
from clu import platform
|
27 |
+
import jax
|
28 |
+
from transformer import inference_utils
|
29 |
+
from transformer import tasks # pylint: disable=unused-import
|
30 |
+
import tensorflow.compat.v2 as tf
|
31 |
+
|
32 |
+
|
33 |
+
flags.DEFINE_string("workdir", "", "Directory to save model checkpoints.")
|
34 |
+
flags.DEFINE_string("load_dir", "", "Directory to load pre-trained model.")
|
35 |
+
flags.DEFINE_integer("num_steps", 110, "Number of steps.")
|
36 |
+
|
37 |
+
flags.DEFINE_list(
|
38 |
+
"gin_search_paths",
|
39 |
+
["transformer/configs"],
|
40 |
+
"List of paths where the Gin config files are located.")
|
41 |
+
flags.DEFINE_multi_string(
|
42 |
+
"gin_file", ["base_htrans.gin"], "List of Gin config files.")
|
43 |
+
flags.DEFINE_multi_string(
|
44 |
+
"gin_param", None, "Newline separated list of Gin parameter bindings.")
|
45 |
+
|
46 |
+
FLAGS = flags.FLAGS
|
47 |
+
|
48 |
+
|
49 |
+
def main(argv: Sequence[str]) -> None:
|
50 |
+
if len(argv) > 1:
|
51 |
+
raise app.UsageError("Too many command-line arguments.")
|
52 |
+
|
53 |
+
# Hide any GPUs from TensorFlow. Otherwise TF might reserve memory and make
|
54 |
+
# it unavailable to JAX.
|
55 |
+
tf.config.experimental.set_visible_devices([], "GPU")
|
56 |
+
|
57 |
+
# Add a note so that we can tell which task is which JAX host.
|
58 |
+
# (Depending on the platform task 0 is not guaranteed to be host 0)
|
59 |
+
platform.work_unit().set_task_status(f"process_index: {jax.process_index()}, "
|
60 |
+
f"process_count: {jax.process_count()}")
|
61 |
+
platform.work_unit().create_artifact(platform.ArtifactType.DIRECTORY,
|
62 |
+
FLAGS.workdir, "workdir")
|
63 |
+
|
64 |
+
inference_utils.parse_gin_configuration(FLAGS.gin_file, FLAGS.gin_param,
|
65 |
+
gin_paths=FLAGS.gin_search_paths)
|
66 |
+
|
67 |
+
article_data = inference_utils.read_article(True)
|
68 |
+
(_, vocab) = article_data
|
69 |
+
(task, task_state, _) = inference_utils.create_model_and_task(
|
70 |
+
vocab, load_dir=FLAGS.load_dir)
|
71 |
+
outs = inference_utils.run_model(task, task_state, article_data,
|
72 |
+
verbose=True)
|
73 |
+
inference_utils.get_token_losses(outs)
|
74 |
+
|
75 |
+
if __name__ == "__main__":
|
76 |
+
app.run(main)
|
aglib/meliad/transformer/inference_utils.py
ADDED
@@ -0,0 +1,271 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 Google.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
r"""Various utility functions for doing inference on data.
|
16 |
+
|
17 |
+
This file provides a simple procedural API for loading a model, loading data,
|
18 |
+
and running the model over data. It is intended for use in, e.g., colabs.
|
19 |
+
"""
|
20 |
+
|
21 |
+
from typing import Any, Dict, Optional, Sequence, Tuple
|
22 |
+
|
23 |
+
from absl import logging
|
24 |
+
import gin
|
25 |
+
import jax
|
26 |
+
import training_loop
|
27 |
+
from transformer import decoder_stack
|
28 |
+
from transformer import models
|
29 |
+
from transformer import text_dataset
|
30 |
+
import numpy as np
|
31 |
+
import seqio
|
32 |
+
|
33 |
+
|
34 |
+
Trainer = training_loop.Trainer
|
35 |
+
TrainState = training_loop.TrainState
|
36 |
+
TrainingTask = training_loop.TrainingTask
|
37 |
+
PRNGKeys = training_loop.PRNGKeys
|
38 |
+
|
39 |
+
ModelInput = Dict[str, Any] # Input to model.
|
40 |
+
MetricsOutput = Dict[str, Any] # Metrics output by model.
|
41 |
+
ArticleData = Tuple[Sequence[ModelInput], seqio.Vocabulary]
|
42 |
+
TaskState = Tuple[TrainState, int]
|
43 |
+
|
44 |
+
|
45 |
+
DEFAULT_GIN_PATHS = [
|
46 |
+
"transformer/configs"
|
47 |
+
]
|
48 |
+
|
49 |
+
|
50 |
+
def parse_gin_configuration(gin_files: Optional[Sequence[str]],
|
51 |
+
gin_params: Optional[Sequence[str]],
|
52 |
+
gin_paths: Optional[Sequence[str]] = None):
|
53 |
+
"""Load gin configuration options.
|
54 |
+
|
55 |
+
Args:
|
56 |
+
gin_files: A list of gin file names with the configuration to load.
|
57 |
+
gin_params: A list of additional parameter overrides.
|
58 |
+
gin_paths: A list of paths to search for gin_files.
|
59 |
+
"""
|
60 |
+
|
61 |
+
# We allow None values to more easily handle command-line flags.
|
62 |
+
if gin_files is None:
|
63 |
+
gin_files = []
|
64 |
+
if gin_params is None:
|
65 |
+
gin_params = []
|
66 |
+
if gin_paths is None:
|
67 |
+
gin_paths = DEFAULT_GIN_PATHS
|
68 |
+
|
69 |
+
logging.info("Parsing gin configuration.")
|
70 |
+
for path in gin_paths:
|
71 |
+
logging.info("Added Gin search path %s", path)
|
72 |
+
gin.add_config_file_search_path(path)
|
73 |
+
for file_name in gin_files:
|
74 |
+
logging.info("Loading Gin config file %s", file_name)
|
75 |
+
for param in gin_params:
|
76 |
+
logging.info("Overriding Gin param %s", param)
|
77 |
+
gin.parse_config_files_and_bindings(gin_files, gin_params)
|
78 |
+
|
79 |
+
|
80 |
+
def read_article(split: Optional[str] = None,
|
81 |
+
verbose: bool = False) -> ArticleData:
|
82 |
+
"""Read a single article from the dataset and save it as a list of blocks.
|
83 |
+
|
84 |
+
This routine will return blocks for a single article; so the tokens will
|
85 |
+
have a batch size of 1. The blocks can be fed to the model directly as input.
|
86 |
+
|
87 |
+
Args:
|
88 |
+
split: The dataset split to load from. Defaults to the test split.
|
89 |
+
verbose: If True, will dump the contents of the article to the log.
|
90 |
+
|
91 |
+
Returns:
|
92 |
+
A pair of (list_of_blocks, vocabulary)
|
93 |
+
"""
|
94 |
+
|
95 |
+
logging.info("Reading article.")
|
96 |
+
|
97 |
+
text_dataset.set_default_data_directory()
|
98 |
+
task_config = decoder_stack.TransformerTaskConfig()
|
99 |
+
batch_size = 1
|
100 |
+
|
101 |
+
if split is None:
|
102 |
+
split = task_config.test_split
|
103 |
+
|
104 |
+
(test_ds, vocab) = text_dataset.load_text_dataset(
|
105 |
+
name=task_config.dataset_name,
|
106 |
+
split=split,
|
107 |
+
sequence_length=task_config.sequence_length,
|
108 |
+
batch_size=batch_size,
|
109 |
+
sequential=task_config.sequential_chunks,
|
110 |
+
shard_dataset=False)
|
111 |
+
|
112 |
+
logging.info("Configured vocab_size = %d", task_config.vocab_size)
|
113 |
+
logging.info("Task vocabulary size = %d", vocab.vocab_size)
|
114 |
+
if task_config.vocab_size < vocab.vocab_size:
|
115 |
+
raise ValueError(
|
116 |
+
"Task vocabulary size does not match configured vocab_size: " +
|
117 |
+
f"{task_config.vocab_size} < {vocab.vocab_size}")
|
118 |
+
|
119 |
+
article_segments = []
|
120 |
+
ds_iter = test_ds.as_numpy_iterator()
|
121 |
+
vocab_map = {"targets": vocab}
|
122 |
+
|
123 |
+
segment_num = 0
|
124 |
+
while True:
|
125 |
+
try:
|
126 |
+
x = next(ds_iter)
|
127 |
+
except StopIteration:
|
128 |
+
logging.info("End of epoch? Something went wrong.")
|
129 |
+
break
|
130 |
+
|
131 |
+
# Make sure we've started reading, otherwise it immediately quits...
|
132 |
+
if article_segments:
|
133 |
+
if x["start_of_sequence"][0]:
|
134 |
+
break
|
135 |
+
|
136 |
+
if verbose:
|
137 |
+
logging.info("Segment %d = %s", segment_num,
|
138 |
+
text_dataset.pretty_print_article(x, vocab_map,
|
139 |
+
max_length=10_000))
|
140 |
+
article_segments.append(x)
|
141 |
+
segment_num += 1
|
142 |
+
|
143 |
+
logging.info("Done reading article: %d segments.", segment_num)
|
144 |
+
logging.info("Num tokens = %d", segment_num * task_config.sequence_length)
|
145 |
+
return (article_segments, vocab)
|
146 |
+
|
147 |
+
|
148 |
+
def create_model_and_task(vocab: seqio.Vocabulary,
|
149 |
+
load_dir: Optional[str] = None) -> (
|
150 |
+
Tuple[TrainingTask, TaskState, Trainer]):
|
151 |
+
"""Initialize the model and get a task for inference.
|
152 |
+
|
153 |
+
The task will be configured to take test (inference) steps with the model.
|
154 |
+
The task will also be configured to run on a single replica, at batch size 1.
|
155 |
+
|
156 |
+
Args:
|
157 |
+
vocab: The vocabulary for the training data, used for logging and decoding.
|
158 |
+
load_dir: A directory which contains a pre-trained model.
|
159 |
+
|
160 |
+
Returns:
|
161 |
+
(task -- has a run_step method to take individual steps with the model,
|
162 |
+
state -- contains trainable parameters and other state,
|
163 |
+
trainer -- a Trainer object (see training_loop.py))
|
164 |
+
"""
|
165 |
+
|
166 |
+
logging.info("JAX process: %d / %d", jax.process_index(), jax.process_count())
|
167 |
+
logging.info("JAX local devices: %r", jax.local_devices())
|
168 |
+
|
169 |
+
# This task won't be pulling from a dataset.
|
170 |
+
def null_iter_fn():
|
171 |
+
return None
|
172 |
+
|
173 |
+
trainer = training_loop.Trainer(
|
174 |
+
get_training_dataset_iterator=null_iter_fn,
|
175 |
+
get_test_dataset_iterator=None,
|
176 |
+
pretty_print_input_function=None,
|
177 |
+
process_summaries_function=models.process_summaries_function(vocab),
|
178 |
+
load_dir=load_dir,
|
179 |
+
workdir="", # Don't log or save checkpoints.
|
180 |
+
replicate_mode=False) # Run on a single device at batch size 1.
|
181 |
+
|
182 |
+
# Create and initialize the model.
|
183 |
+
(tstate, start_step, imodel, prngs) = trainer.initialize_model()
|
184 |
+
|
185 |
+
# Create an inference task.
|
186 |
+
writers = {}
|
187 |
+
task = trainer.create_training_task("test", imodel, prngs, writers)
|
188 |
+
|
189 |
+
# Register any additional actions.
|
190 |
+
# Actions are cleared first for use with colab.
|
191 |
+
training_loop.clear_interstep_callbacks()
|
192 |
+
training_loop.register_interstep_callbacks()
|
193 |
+
|
194 |
+
task_state = (tstate, start_step)
|
195 |
+
return (task, task_state, trainer)
|
196 |
+
|
197 |
+
|
198 |
+
def run_model(task: TrainingTask, task_state: TaskState,
|
199 |
+
article_data: ArticleData, verbose: bool = False) -> (
|
200 |
+
Sequence[MetricsOutput]):
|
201 |
+
"""Run the model on an article, and return the outputs for each segment.
|
202 |
+
|
203 |
+
Args:
|
204 |
+
task: The task to run, from create_model_and_task.
|
205 |
+
task_state: The state of the model, from create_model_and_task.
|
206 |
+
article_data: The article and vocabulary, from read_article.
|
207 |
+
verbose: If True, will send input and output to the log.
|
208 |
+
|
209 |
+
Returns:
|
210 |
+
A sequence of model outputs for each block.
|
211 |
+
"""
|
212 |
+
|
213 |
+
logging.info("Running the model.")
|
214 |
+
|
215 |
+
(article_segments, vocab) = article_data
|
216 |
+
(tstate, start_step) = task_state
|
217 |
+
vocab_map = {"targets": vocab}
|
218 |
+
|
219 |
+
# Ignore the iterator for the test task, and loop over the article.
|
220 |
+
step = start_step
|
221 |
+
segment_num = 0
|
222 |
+
|
223 |
+
# Loop over the article, and run the model on each segment.
|
224 |
+
segment_outputs = []
|
225 |
+
for x in article_segments:
|
226 |
+
if verbose:
|
227 |
+
logging.info("Segment [%d] = %s", segment_num,
|
228 |
+
text_dataset.pretty_print_article(x, vocab_map,
|
229 |
+
max_length=10_000))
|
230 |
+
else:
|
231 |
+
logging.info("Segment %d, step %d.", segment_num, step)
|
232 |
+
|
233 |
+
(tstate, metrics_np) = task.run_step(tstate, x, step)
|
234 |
+
training_loop.run_interstep_callbacks("test", step)
|
235 |
+
segment_outputs.append(metrics_np)
|
236 |
+
|
237 |
+
if verbose:
|
238 |
+
logging.info("Output [%d] = %s", segment_num, metrics_np)
|
239 |
+
|
240 |
+
del x
|
241 |
+
segment_num += 1
|
242 |
+
step += 1
|
243 |
+
|
244 |
+
logging.info("Done running the model: %d segments.", segment_num)
|
245 |
+
return segment_outputs
|
246 |
+
|
247 |
+
|
248 |
+
def get_token_losses(segment_outputs: Sequence[Any]) -> np.ndarray:
|
249 |
+
"""Return the loss for each token in a sequence.
|
250 |
+
|
251 |
+
Given a list of model outputs, extract the token losses from each output
|
252 |
+
and concatenate them together.
|
253 |
+
|
254 |
+
Args:
|
255 |
+
segment_outputs: the outputs from run_model().
|
256 |
+
|
257 |
+
Returns:
|
258 |
+
An array of shape (batch_size, sequence_length), of float.
|
259 |
+
"""
|
260 |
+
|
261 |
+
block_token_losses = []
|
262 |
+
for seg in segment_outputs:
|
263 |
+
if "token_losses" in seg:
|
264 |
+
block_token_losses.append(seg["token_losses"])
|
265 |
+
else:
|
266 |
+
raise ValueError("Token losses were not recorded.")
|
267 |
+
|
268 |
+
logging.info("Got token losses for %d segments", len(block_token_losses))
|
269 |
+
token_losses = np.concatenate(block_token_losses, axis=-1)
|
270 |
+
logging.info("token_losses.shape = %r", token_losses.shape)
|
271 |
+
return token_losses
|
aglib/meliad/transformer/launcher.py
ADDED
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 Google.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
"""Setup the data pipeline and launch the main training loop."""
|
16 |
+
|
17 |
+
from absl import flags
|
18 |
+
from absl import logging
|
19 |
+
|
20 |
+
import gin
|
21 |
+
import jax
|
22 |
+
|
23 |
+
import training_loop
|
24 |
+
from transformer import decoder_stack
|
25 |
+
from transformer import models
|
26 |
+
from transformer import tasks # pylint: disable=unused-import
|
27 |
+
from transformer import text_dataset
|
28 |
+
|
29 |
+
|
30 |
+
flags.DEFINE_string("workdir", "", "Directory to save model checkpoints.")
|
31 |
+
flags.DEFINE_string("load_dir", "", "Directory to load pre-trained model.")
|
32 |
+
flags.DEFINE_integer("num_steps", 110, "Number of steps.")
|
33 |
+
|
34 |
+
flags.DEFINE_list(
|
35 |
+
"gin_search_paths",
|
36 |
+
["transformer/configs"],
|
37 |
+
"List of paths where the Gin config files are located.")
|
38 |
+
flags.DEFINE_multi_string(
|
39 |
+
"gin_file", ["base_htrans.gin"], "List of Gin config files.")
|
40 |
+
flags.DEFINE_multi_string(
|
41 |
+
"gin_param", None, "Newline separated list of Gin parameter bindings.")
|
42 |
+
|
43 |
+
FLAGS = flags.FLAGS
|
44 |
+
|
45 |
+
|
46 |
+
def parse_gin_configuration():
|
47 |
+
"""Load and parse Gin configuration from command-line flags."""
|
48 |
+
for gin_file_path in FLAGS.gin_search_paths:
|
49 |
+
logging.info("Added Gin search path %s", gin_file_path)
|
50 |
+
gin.add_config_file_search_path(gin_file_path)
|
51 |
+
for gin_file in FLAGS.gin_file:
|
52 |
+
logging.info("Loading Gin config file %s", gin_file)
|
53 |
+
if FLAGS.gin_param:
|
54 |
+
for gin_param in FLAGS.gin_param:
|
55 |
+
logging.info("Overriding Gin param %s", gin_param)
|
56 |
+
gin.parse_config_files_and_bindings(FLAGS.gin_file, FLAGS.gin_param)
|
57 |
+
|
58 |
+
|
59 |
+
def run_training_loop(testing: bool = False):
|
60 |
+
"""Setup data pipeline and launch the main training loop."""
|
61 |
+
|
62 |
+
logging.info("JAX process: %d / %d", jax.process_index(), jax.process_count())
|
63 |
+
logging.info("JAX local devices: %r", jax.local_devices())
|
64 |
+
|
65 |
+
text_dataset.set_default_data_directory()
|
66 |
+
task_config = decoder_stack.TransformerTaskConfig()
|
67 |
+
batch_size = task_config.batch_size * jax.local_device_count()
|
68 |
+
|
69 |
+
(train_ds, vocab) = text_dataset.load_text_dataset(
|
70 |
+
name=task_config.dataset_name,
|
71 |
+
split=task_config.train_split, # train
|
72 |
+
sequence_length=task_config.sequence_length,
|
73 |
+
batch_size=batch_size,
|
74 |
+
sequential=task_config.sequential_chunks,
|
75 |
+
shard_dataset=True)
|
76 |
+
|
77 |
+
(test_ds, test_vocab) = text_dataset.load_text_dataset(
|
78 |
+
name=task_config.dataset_name,
|
79 |
+
split=task_config.test_split, # test
|
80 |
+
sequence_length=task_config.sequence_length,
|
81 |
+
batch_size=batch_size,
|
82 |
+
sequential=task_config.sequential_chunks,
|
83 |
+
shard_dataset=False)
|
84 |
+
|
85 |
+
logging.info("Configured vocab_size = %d", task_config.vocab_size)
|
86 |
+
logging.info("Task vocabulary size = %d", vocab.vocab_size)
|
87 |
+
assert vocab.vocab_size == test_vocab.vocab_size # Sanity check.
|
88 |
+
if task_config.vocab_size < vocab.vocab_size:
|
89 |
+
raise ValueError(
|
90 |
+
"Task vocabulary size does not match configured vocab_size: " +
|
91 |
+
f"{task_config.vocab_size} < {vocab.vocab_size}")
|
92 |
+
|
93 |
+
# Pretty printing depends on the vocabulary object.
|
94 |
+
def pretty_print_article_fn(article) -> str:
|
95 |
+
return text_dataset.pretty_print_article(article, {"targets": vocab}, 32768)
|
96 |
+
|
97 |
+
train_ds_iter_fn = text_dataset.get_iterator_function(train_ds)
|
98 |
+
test_ds_iter_fn = text_dataset.get_iterator_function(test_ds)
|
99 |
+
|
100 |
+
if testing:
|
101 |
+
# Build trainer, which is configurable by Gin, and run training loop.
|
102 |
+
trainer = training_loop.Trainer(
|
103 |
+
get_training_dataset_iterator=train_ds_iter_fn,
|
104 |
+
get_test_dataset_iterator=test_ds_iter_fn,
|
105 |
+
pretty_print_input_function=pretty_print_article_fn,
|
106 |
+
process_summaries_function=models.process_summaries_function(vocab),
|
107 |
+
num_steps=FLAGS.num_steps, # Ignore Gin config for these options.
|
108 |
+
load_dir=FLAGS.load_dir,
|
109 |
+
workdir=FLAGS.workdir)
|
110 |
+
else:
|
111 |
+
trainer = training_loop.Trainer(
|
112 |
+
get_training_dataset_iterator=train_ds_iter_fn,
|
113 |
+
get_test_dataset_iterator=test_ds_iter_fn,
|
114 |
+
pretty_print_input_function=pretty_print_article_fn,
|
115 |
+
process_summaries_function=models.process_summaries_function(vocab),
|
116 |
+
load_dir=FLAGS.load_dir,
|
117 |
+
workdir=FLAGS.workdir)
|
118 |
+
|
119 |
+
trainer.train()
|
aglib/meliad/transformer/memory_factory.py
ADDED
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 Google.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
"""Flax modules and functions for using external memory."""
|
16 |
+
|
17 |
+
from typing import Any, Optional, Tuple
|
18 |
+
|
19 |
+
from absl import logging
|
20 |
+
from flax import linen
|
21 |
+
import gin
|
22 |
+
import jax
|
23 |
+
from transformer import memory_layer
|
24 |
+
|
25 |
+
|
26 |
+
|
27 |
+
PRNGKey = Any
|
28 |
+
Shape = Tuple[int]
|
29 |
+
Dtype = Any
|
30 |
+
Array = Any
|
31 |
+
MemoryResource = Any
|
32 |
+
|
33 |
+
|
34 |
+
class MemoryManager:
|
35 |
+
"""Manages any external resources that may be required by external memory.
|
36 |
+
|
37 |
+
MemoryManager also functions as a factory, to create Flax modules that will
|
38 |
+
read and write to whatever external memory has been configured.
|
39 |
+
"""
|
40 |
+
|
41 |
+
def __init__(self,
|
42 |
+
batch_size: int,
|
43 |
+
mode: str,
|
44 |
+
num_heads: int,
|
45 |
+
key_size: int,
|
46 |
+
value_size: int,
|
47 |
+
database_size: Optional[int] = None,
|
48 |
+
dtype: Dtype = "float32",
|
49 |
+
off_device_memory: Optional[MemoryResource] = None):
|
50 |
+
"""Create a MemoryManager object.
|
51 |
+
|
52 |
+
A MemoryManager configures external memory, and is used as a factory to
|
53 |
+
construct flax modules that read or write to the memory.
|
54 |
+
|
55 |
+
Args:
|
56 |
+
batch_size: The number of separate documents in a batch.
|
57 |
+
mode: e.g. ("train", or "test")
|
58 |
+
num_heads: The number of transformer heads.
|
59 |
+
key_size: The length of the key vectors.
|
60 |
+
value_size: The length of the value vectors.
|
61 |
+
database_size: The total number of tokens in the database.
|
62 |
+
dtype: The datatype used for keys and values.
|
63 |
+
off_device_memory: An object which manages underlying SCAM memory.
|
64 |
+
If None, then the model will use on-device memory.
|
65 |
+
"""
|
66 |
+
self.batch_size = batch_size
|
67 |
+
self.mode = mode
|
68 |
+
self.num_heads = num_heads
|
69 |
+
self.key_size = key_size
|
70 |
+
self.value_size = value_size
|
71 |
+
self.database_size = database_size
|
72 |
+
self.dtype = dtype
|
73 |
+
self.off_device_memory = off_device_memory
|
74 |
+
|
75 |
+
def create_memory_layer(self) -> linen.Module:
|
76 |
+
"""Create a flax Module that implements external memory."""
|
77 |
+
|
78 |
+
num_datasets = (
|
79 |
+
self.batch_size * self.num_heads #
|
80 |
+
if self.off_device_memory is None #
|
81 |
+
else self.num_heads)
|
82 |
+
if self.off_device_memory is not None:
|
83 |
+
mem_layer = None
|
84 |
+
if mem_layer is None:
|
85 |
+
raise ValueError("Off-device memory is not supported at this time.")
|
86 |
+
return memory_layer.BatchedMemory(
|
87 |
+
mem_layer,
|
88 |
+
split_dimensions=(-2,),
|
89 |
+
)
|
90 |
+
else:
|
91 |
+
assert self.database_size is not None
|
92 |
+
mem_layer = memory_layer.MemoryOnTpu(num_datasets=num_datasets,
|
93 |
+
key_features=self.key_size,
|
94 |
+
value_features=self.value_size,
|
95 |
+
database_size=self.database_size,
|
96 |
+
dtype=self.dtype)
|
97 |
+
# Handle queries of shape [batch_size, seq_len, num_heads, kv_features]
|
98 |
+
return memory_layer.BatchedMemory(mem_layer,
|
99 |
+
split_dimensions=(0, -2))
|
100 |
+
|
101 |
+
|
102 |
+
@gin.configurable
|
103 |
+
def memory_on_tpu_factory(batch_size: int,
|
104 |
+
mode: str,
|
105 |
+
num_heads: int = gin.REQUIRED,
|
106 |
+
key_size: int = gin.REQUIRED,
|
107 |
+
value_size: int = gin.REQUIRED,
|
108 |
+
database_size: int = gin.REQUIRED,
|
109 |
+
dtype: Dtype = gin.REQUIRED) -> MemoryManager:
|
110 |
+
"""Implement SCAM memory on device."""
|
111 |
+
return MemoryManager(batch_size=batch_size,
|
112 |
+
mode=mode,
|
113 |
+
num_heads=num_heads,
|
114 |
+
key_size=key_size,
|
115 |
+
value_size=value_size,
|
116 |
+
database_size=database_size,
|
117 |
+
dtype=dtype,
|
118 |
+
off_device_memory=None)
|
119 |
+
|
120 |
+
|
aglib/meliad/transformer/memory_layer.py
ADDED
@@ -0,0 +1,431 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 Google.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
"""FLAX layers for on-TPU memory."""
|
16 |
+
|
17 |
+
import abc
|
18 |
+
import functools
|
19 |
+
from typing import Callable, Sequence, Tuple, TypeVar, Union
|
20 |
+
|
21 |
+
from absl import logging
|
22 |
+
from flax import linen
|
23 |
+
import gin
|
24 |
+
import jax
|
25 |
+
from jax import lax
|
26 |
+
import jax.numpy as jnp
|
27 |
+
import numpy as np # use with care!
|
28 |
+
|
29 |
+
Shape = Sequence[int]
|
30 |
+
Dtype = jnp.dtype
|
31 |
+
Array = jnp.ndarray
|
32 |
+
|
33 |
+
Axes = Union[int, Tuple[int, ...]]
|
34 |
+
F = TypeVar('F', bound=Callable)
|
35 |
+
|
36 |
+
|
37 |
+
class MemoryLayer(linen.Module, metaclass=abc.ABCMeta):
|
38 |
+
"""Internal interface for memory layers without batch dim.
|
39 |
+
|
40 |
+
See BatchedMemory for a layer that can be used in Flax models.
|
41 |
+
"""
|
42 |
+
num_datasets: int
|
43 |
+
|
44 |
+
@abc.abstractmethod
|
45 |
+
def update(self, key: Array, value: Array) -> int:
|
46 |
+
"""Adds key/value pairs to memory.
|
47 |
+
|
48 |
+
Args:
|
49 |
+
key: of shape (num_kv, num_datasets, k_features)
|
50 |
+
value: of shape (num_kv, num_datasets, v_features)
|
51 |
+
|
52 |
+
Returns:
|
53 |
+
Dummy value so that TPU operations can wait for the update to finish if
|
54 |
+
desired.
|
55 |
+
"""
|
56 |
+
raise NotImplementedError()
|
57 |
+
|
58 |
+
@abc.abstractmethod
|
59 |
+
def topk_retrieval(self, query: Array,
|
60 |
+
num_neighbors: int) -> Tuple[Array, Array]:
|
61 |
+
"""Retrieves the nearest neighbors for each query.
|
62 |
+
|
63 |
+
Args:
|
64 |
+
query: of shape (num_queries, num_datasets, k_features)
|
65 |
+
num_neighbors: int indicating the number of neighbors to retrieve
|
66 |
+
|
67 |
+
Returns:
|
68 |
+
Tuple of selected keys and selected values of shapes
|
69 |
+
(num_queries, num_datasets, num_neighbors, k_features), and
|
70 |
+
(num_queries, num_datasets, num_neighbors, v_features)
|
71 |
+
"""
|
72 |
+
raise NotImplementedError()
|
73 |
+
|
74 |
+
@abc.abstractmethod
|
75 |
+
def reset(self, datasets: Array) -> int:
|
76 |
+
"""Reset some or all of the datasets in the memory.
|
77 |
+
|
78 |
+
Args:
|
79 |
+
datasets: A vector of shape (num_datasets) of type bool. Each position
|
80 |
+
indicates whether the dataset with the same index should be reset.
|
81 |
+
|
82 |
+
Returns:
|
83 |
+
Dummy value so that TPU operations can wait for the update to finish if
|
84 |
+
desired.
|
85 |
+
"""
|
86 |
+
raise NotImplementedError()
|
87 |
+
|
88 |
+
def __call__(self, query, num_neighbors):
|
89 |
+
self.topk_retrieval(query, num_neighbors)
|
90 |
+
|
91 |
+
|
92 |
+
def _target_dimensions(shape: Shape,
|
93 |
+
source_dimensions: Sequence[int]) -> Sequence[int]:
|
94 |
+
target_dimensions = range(-2, -2 - len(source_dimensions), -1)
|
95 |
+
assert len(source_dimensions) == len(target_dimensions)
|
96 |
+
return sorted(d % len(shape) for d in target_dimensions)
|
97 |
+
|
98 |
+
|
99 |
+
def _rearrange_dimensions_shapes(
|
100 |
+
shape: Shape, split_dimensions: Sequence[int]) -> Tuple[Shape, Shape]:
|
101 |
+
split_shape = tuple(shape[d] for d in split_dimensions)
|
102 |
+
remaining_shape = tuple(
|
103 |
+
shape[d] for d in range(len(shape)) if d not in split_dimensions)
|
104 |
+
batch_shape = remaining_shape[:-1]
|
105 |
+
return split_shape, batch_shape
|
106 |
+
|
107 |
+
|
108 |
+
def _rearrange_dimensions(x: Array, split_dimensions: Sequence[int]) -> Array:
|
109 |
+
"""Rearrange array so that we can split by a single dimension.
|
110 |
+
|
111 |
+
Turns an array of shape [d1, ..., dn, features] and a list of dimensions to
|
112 |
+
split by into [prod(remaining_dimensions), prod(split_dimensions),
|
113 |
+
features]
|
114 |
+
|
115 |
+
Args:
|
116 |
+
x: array of shape [d1, ..., dn, features]
|
117 |
+
split_dimensions: list of dimensions that should end up in dimension -2.
|
118 |
+
|
119 |
+
Returns:
|
120 |
+
Rearranged array as described above.
|
121 |
+
"""
|
122 |
+
split_dimensions = [d % len(x.shape) for d in split_dimensions]
|
123 |
+
split_dimensions = sorted(split_dimensions)
|
124 |
+
split_shape, batch_shape = _rearrange_dimensions_shapes(
|
125 |
+
x.shape, split_dimensions)
|
126 |
+
|
127 |
+
target_dimensions = _target_dimensions(x.shape, split_dimensions)
|
128 |
+
x = jnp.moveaxis(x, split_dimensions, target_dimensions)
|
129 |
+
assert len(x.shape) > len(split_dimensions)
|
130 |
+
assert all(isinstance(d, int) and d >= 0 for d in batch_shape)
|
131 |
+
assert all(isinstance(d, int) and d >= 0 for d in split_shape)
|
132 |
+
new_shape = [
|
133 |
+
# The use of numpy is okay here, since shapes are concrete at jit time.
|
134 |
+
np.prod(batch_shape),
|
135 |
+
np.prod(split_shape),
|
136 |
+
x.shape[-1] # features dimension
|
137 |
+
]
|
138 |
+
res = x.reshape(new_shape)
|
139 |
+
assert res.ndim == 3
|
140 |
+
return res
|
141 |
+
|
142 |
+
|
143 |
+
def _restore_dimensions(x: Array, original_shape: Shape,
|
144 |
+
split_dimensions: Sequence[int]) -> Array:
|
145 |
+
"""Restores arrays encoded with _rearrange_dimensions.
|
146 |
+
|
147 |
+
Args:
|
148 |
+
x: Array of shape [prod(batch_shape), prod(split_shape), feature...]
|
149 |
+
original_shape: Shape of the array to restore to.
|
150 |
+
split_dimensions: Dimensions that were multiplied into dimension 2.
|
151 |
+
|
152 |
+
Returns:
|
153 |
+
Array of the original shape and axis order for all dimensions in batch_shape
|
154 |
+
and split_shape. Feature dimensions may have changed (can include additional
|
155 |
+
dimensions for neighbors, for example).
|
156 |
+
"""
|
157 |
+
split_dimensions = [d % len(original_shape) for d in split_dimensions]
|
158 |
+
split_dimensions = sorted(split_dimensions)
|
159 |
+
split_shape, batch_shape = _rearrange_dimensions_shapes(
|
160 |
+
original_shape, split_dimensions)
|
161 |
+
|
162 |
+
features_shape = x.shape[2:]
|
163 |
+
x = x.reshape((*batch_shape, *split_shape, *features_shape))
|
164 |
+
|
165 |
+
# rearrange
|
166 |
+
target_dimensions = _target_dimensions(original_shape, split_dimensions)
|
167 |
+
x = jnp.moveaxis(x, target_dimensions, split_dimensions)
|
168 |
+
return x
|
169 |
+
|
170 |
+
|
171 |
+
@gin.configurable
|
172 |
+
class BatchedMemory(linen.Module):
|
173 |
+
"""Equips a memory module with a batch dimension."""
|
174 |
+
|
175 |
+
# We wrap this linen.Module:
|
176 |
+
wrapped: MemoryLayer
|
177 |
+
|
178 |
+
# `split_dimensions` indicates the dimensions of the query and update tensors
|
179 |
+
# that will go to separate databases. By default, we use a separate database
|
180 |
+
# for each head.
|
181 |
+
# Note that some implementations of the memory share memory across all hosts
|
182 |
+
# and devices (memory_on_borg, unless configured otherwise) or just across
|
183 |
+
# devices of each host (memory_on_host).
|
184 |
+
# Default is (-2,) to split by head only; use (0, -2) to also slit by batch
|
185 |
+
# dimensions.
|
186 |
+
split_dimensions: Tuple[int, ...] = (-2,)
|
187 |
+
|
188 |
+
query_stride: int = 1
|
189 |
+
update_stride: int = 1
|
190 |
+
|
191 |
+
def update(self, key: Array, value: Array):
|
192 |
+
"""Adds key/value pairs to memory.
|
193 |
+
|
194 |
+
Args:
|
195 |
+
key: typically of shape (batch, kv_len, num_heads, k_features). This
|
196 |
+
tensor is split up into datasets according to `split_dimensions`.
|
197 |
+
value: typically of shape (batch, kv_len, num_heads, v_features). This
|
198 |
+
tensor is split up into datasets according to `split_dimensions`.
|
199 |
+
|
200 |
+
Returns:
|
201 |
+
A dummy value 0, once the operation has completed.
|
202 |
+
"""
|
203 |
+
if key.ndim != 4 or value.ndim != 4:
|
204 |
+
raise ValueError('Expected batched inputs; got shapes: %s and %s.' %
|
205 |
+
(key.shape, value.shape))
|
206 |
+
key = _rearrange_dimensions(key, self.split_dimensions)
|
207 |
+
value = _rearrange_dimensions(value, self.split_dimensions)
|
208 |
+
update_stride = self.update_stride
|
209 |
+
if update_stride == 1:
|
210 |
+
return self.wrapped.update(key, value)
|
211 |
+
return self.wrapped.update(key[update_stride - 1::update_stride, ...],
|
212 |
+
value[update_stride - 1::update_stride, ...])
|
213 |
+
|
214 |
+
def topk_retrieval(self, query: Array, num_neighbors: int):
|
215 |
+
"""Retrieves the nearest neighbors for each query.
|
216 |
+
|
217 |
+
Args:
|
218 |
+
query: typically of shape (batch, q_len, num_heads, k_features). This
|
219 |
+
tensor is split up into datasets according to `split_dimensions`.
|
220 |
+
num_neighbors: number of neighbors to retrieve
|
221 |
+
|
222 |
+
Returns:
|
223 |
+
Tuple of tensors with the retrieved keys and value of the same shape as
|
224 |
+
query, but with an extra dimension of length num_neighbors - typically:
|
225 |
+
(batch, q_len, num_heads, num_neighbors, k_features)
|
226 |
+
"""
|
227 |
+
if query.ndim != 4:
|
228 |
+
raise ValueError('Expected batched inputs; got shape: %s.' % query.shape)
|
229 |
+
query_stride = self.query_stride
|
230 |
+
original_shape = query.shape
|
231 |
+
query = _rearrange_dimensions(query, self.split_dimensions)
|
232 |
+
if query_stride == 1:
|
233 |
+
key, value = self.wrapped.topk_retrieval(query, num_neighbors)
|
234 |
+
else:
|
235 |
+
num_queries, num_heads, k_features = query.shape
|
236 |
+
throttled_query = query[0::query_stride, ...]
|
237 |
+
key = jnp.zeros(
|
238 |
+
shape=(num_queries, num_heads, num_neighbors, k_features),
|
239 |
+
dtype=query.dtype)
|
240 |
+
throttled_key, throttled_value = (
|
241 |
+
self.wrapped.topk_retrieval(throttled_query, num_neighbors))
|
242 |
+
_, _, _, v_features = throttled_value.shape
|
243 |
+
value = jnp.zeros(
|
244 |
+
shape=(num_queries, num_heads, num_neighbors, v_features),
|
245 |
+
dtype=query.dtype)
|
246 |
+
key = key.at[0::query_stride, ...].set(throttled_key)
|
247 |
+
value = value.at[0::query_stride, ...].set(throttled_value)
|
248 |
+
key = _restore_dimensions(key, original_shape, self.split_dimensions)
|
249 |
+
# Note that `original_shape` here may have the wrong feature dimension (if
|
250 |
+
# k_features != v_features. But `_restore_dimensions` does not depend on
|
251 |
+
# that dimension and the tests cover this case.
|
252 |
+
value = _restore_dimensions(value, original_shape, self.split_dimensions)
|
253 |
+
assert key.ndim == len(original_shape) + 1
|
254 |
+
return key, value
|
255 |
+
|
256 |
+
def reset(self, datasets: Array) -> int:
|
257 |
+
"""Resets the memory.
|
258 |
+
|
259 |
+
Args:
|
260 |
+
datasets: of shape (num_datasets,), typically the same as (num_heads,).
|
261 |
+
|
262 |
+
Returns:
|
263 |
+
A dummy value 0, once the operation has completed.
|
264 |
+
"""
|
265 |
+
return self.wrapped.reset(datasets)
|
266 |
+
|
267 |
+
|
268 |
+
@functools.partial(jax.jit, static_argnames=('num_buckets', 'bucket_size'))
|
269 |
+
def _chunking_sparsify(query: Array, key: Array, num_buckets: int,
|
270 |
+
bucket_size: int) -> Tuple[Array, Array, Array]:
|
271 |
+
"""Approximate top k operation for a single head."""
|
272 |
+
# q = q_length, f = qk features, d = database_size
|
273 |
+
scores = jnp.einsum('qf,df->qd', query, key)
|
274 |
+
mask = (key.sum(-1) == 0).astype(jnp.bfloat16) * -1e6
|
275 |
+
scores += mask
|
276 |
+
|
277 |
+
num_queries, _ = scores.shape
|
278 |
+
reshaped_scores = jnp.reshape(scores, (num_queries, bucket_size, num_buckets))
|
279 |
+
|
280 |
+
sparse_scores = linen.softmax(reshaped_scores * 1e6, axis=1)
|
281 |
+
|
282 |
+
# topk_scores and topk_indices will only be computed if we depend on their
|
283 |
+
# results.
|
284 |
+
topk_scores = jnp.max(reshaped_scores, axis=1)
|
285 |
+
local_indices = jnp.argmax(reshaped_scores, axis=1)
|
286 |
+
topk_indices = (
|
287 |
+
local_indices * num_buckets + jnp.arange(num_buckets).reshape(
|
288 |
+
(1, num_buckets)))
|
289 |
+
return sparse_scores, topk_scores, topk_indices
|
290 |
+
|
291 |
+
|
292 |
+
def _retrieve_topk_gatherless(
|
293 |
+
query: Array, key: Array, value: Array,
|
294 |
+
num_neighbors: int) -> Tuple[Array, Array, Array, Array]:
|
295 |
+
"""Retrieves for a single head - used to simplify array accesses."""
|
296 |
+
num_kv, query_features = query.shape
|
297 |
+
database_size, key_features = key.shape
|
298 |
+
_, value_features = value.shape
|
299 |
+
assert query_features == key_features
|
300 |
+
num_buckets = num_neighbors
|
301 |
+
if num_buckets > database_size:
|
302 |
+
raise ValueError('More buckets than items in database. %s > %s' %
|
303 |
+
(num_buckets, database_size))
|
304 |
+
if database_size % num_buckets:
|
305 |
+
raise ValueError('Buckets must divide database: %s %% %s.' %
|
306 |
+
(database_size, num_buckets))
|
307 |
+
bucket_size = database_size // num_buckets
|
308 |
+
|
309 |
+
sparse_scores, topk_scores, topk_indices = _chunking_sparsify(
|
310 |
+
query, key, num_buckets, bucket_size)
|
311 |
+
key = key.reshape(bucket_size, num_buckets, key_features)
|
312 |
+
value = value.reshape(bucket_size, num_buckets, value_features)
|
313 |
+
selected_keys = jnp.einsum('qbn,bnd->qnd', sparse_scores, key)
|
314 |
+
selected_values = jnp.einsum('qbn,bnd->qnd', sparse_scores, value)
|
315 |
+
|
316 |
+
assert selected_keys.shape == (num_kv, num_neighbors, key_features)
|
317 |
+
assert selected_values.shape == (num_kv, num_neighbors, value_features)
|
318 |
+
return selected_keys, selected_values, topk_scores, topk_indices
|
319 |
+
|
320 |
+
|
321 |
+
class MemoryOnTpu(MemoryLayer):
|
322 |
+
"""Approximate top K search on TPU."""
|
323 |
+
# database_size must be integer multiple of prod(batch_dims) * num_neighbors.
|
324 |
+
database_size: int
|
325 |
+
dtype: Dtype = jnp.float32 # pylint: disable=g-bare-generic
|
326 |
+
key_features: int = 64
|
327 |
+
value_features: int = 64
|
328 |
+
report_scores_and_indices: bool = False
|
329 |
+
|
330 |
+
def setup(self):
|
331 |
+
self.db_index = self.variable('database', 'database_index',
|
332 |
+
functools.partial(jnp.zeros, dtype=jnp.int32),
|
333 |
+
(self.num_datasets,))
|
334 |
+
self.key_db = self.variable(
|
335 |
+
'database', 'key_db', functools.partial(jnp.zeros, dtype=self.dtype),
|
336 |
+
(self.num_datasets, self.database_size, self.key_features))
|
337 |
+
self.value_db = self.variable(
|
338 |
+
'database', 'value_db', functools.partial(jnp.zeros, dtype=self.dtype),
|
339 |
+
(self.num_datasets, self.database_size, self.value_features))
|
340 |
+
|
341 |
+
self.retrieved_indices = self.variable(
|
342 |
+
'database', 'retrieved_indices',
|
343 |
+
functools.partial(jnp.zeros, dtype=jnp.int32), (0, 0, 0))
|
344 |
+
self.retrieved_indices_scores = self.variable(
|
345 |
+
'database', 'retrieved_indices_scores',
|
346 |
+
functools.partial(jnp.zeros, dtype=jnp.float32), (0, 0, 0))
|
347 |
+
|
348 |
+
def _update_kv_database(self, database, new_values, start_index):
|
349 |
+
num_datasets, database_size, _ = database.shape
|
350 |
+
assert database_size == self.database_size, f'{database_size} vs {self.database_size}'
|
351 |
+
assert num_datasets == self.num_datasets
|
352 |
+
assert new_values.ndim == 3
|
353 |
+
assert start_index.shape == (self.num_datasets,)
|
354 |
+
|
355 |
+
def _update(database, new_values, start_index):
|
356 |
+
return lax.dynamic_update_slice(
|
357 |
+
database, new_values, start_indices=(start_index, 0))
|
358 |
+
|
359 |
+
return jax.vmap(
|
360 |
+
_update, in_axes=(0, 0, 0), out_axes=0)(database, new_values,
|
361 |
+
start_index)
|
362 |
+
|
363 |
+
def update(self, key: Array, value: Array) -> int:
|
364 |
+
"""Add keys and values to the memory; overwrite oldest if memory is full."""
|
365 |
+
key = lax.stop_gradient(key)
|
366 |
+
value = lax.stop_gradient(value)
|
367 |
+
assert len(key.shape) == len(value.shape)
|
368 |
+
assert key.shape[:-1] == value.shape[:-1]
|
369 |
+
num_kv, num_datasets, key_features = key.shape
|
370 |
+
assert num_datasets == self.num_datasets
|
371 |
+
assert key_features == self.key_features
|
372 |
+
assert value.shape[-1] == self.value_features
|
373 |
+
assert self.database_size % num_kv == 0, (
|
374 |
+
'Database size must be integer multiple of num_kv.')
|
375 |
+
key = jnp.moveaxis(key, source=1, destination=0) # split by dataset
|
376 |
+
value = jnp.moveaxis(value, source=1, destination=0) # split by dataset
|
377 |
+
|
378 |
+
# start_index can be larger than DB - we use that to detect which entries
|
379 |
+
# are not written to yet
|
380 |
+
start_index = self.db_index.value % self.database_size
|
381 |
+
self.key_db.value = self._update_kv_database(self.key_db.value, key,
|
382 |
+
start_index)
|
383 |
+
self.value_db.value = self._update_kv_database(self.value_db.value, value,
|
384 |
+
start_index)
|
385 |
+
self.db_index.value = self.db_index.value + num_kv
|
386 |
+
return 0
|
387 |
+
|
388 |
+
def topk_retrieval(self, query: Array,
|
389 |
+
num_neighbors: int) -> Tuple[Array, Array]:
|
390 |
+
"""Nearest neighbors by full multiplication and approximate top k on TPU."""
|
391 |
+
query = lax.stop_gradient(query)
|
392 |
+
unused_num_kv, num_datasets, query_features = query.shape
|
393 |
+
assert num_datasets == self.num_datasets
|
394 |
+
assert query_features == self.key_features
|
395 |
+
query = jnp.moveaxis(query, source=1, destination=0)
|
396 |
+
|
397 |
+
# Process different heads sequentially
|
398 |
+
selected_keys, selected_values, topk_scores, topk_indices = lax.map(
|
399 |
+
lambda x: _retrieve_topk_gatherless(*x, num_neighbors),
|
400 |
+
(query, self.key_db.value, self.value_db.value))
|
401 |
+
|
402 |
+
if self.report_scores_and_indices:
|
403 |
+
# TODO(mrabe): These variable updates may not work perfectly yet. Find out
|
404 |
+
# why Flax does not like them.
|
405 |
+
self.retrieved_indices.value = topk_indices
|
406 |
+
self.retrieved_indices_scores.value = topk_scores
|
407 |
+
|
408 |
+
assert selected_keys.ndim == selected_values.ndim == 4
|
409 |
+
selected_keys = jnp.moveaxis(selected_keys, source=0, destination=1)
|
410 |
+
selected_values = jnp.moveaxis(selected_values, source=0, destination=1)
|
411 |
+
return selected_keys, selected_values
|
412 |
+
|
413 |
+
def reset(self, datasets: Array) -> int:
|
414 |
+
"""Resets specified datasets."""
|
415 |
+
datasets = lax.stop_gradient(datasets)
|
416 |
+
assert datasets.shape == (self.num_datasets,)
|
417 |
+
assert datasets.dtype == jnp.bool_
|
418 |
+
|
419 |
+
def _reset_single_dataset(input_tuple):
|
420 |
+
"""Resets a single head; reset is a single bool."""
|
421 |
+
database, reset = input_tuple
|
422 |
+
assert reset.shape == tuple(), reset.shape
|
423 |
+
assert reset.dtype == jnp.bool_
|
424 |
+
return database * (1 - reset)
|
425 |
+
|
426 |
+
self.db_index.value = self.db_index.value * (1 - datasets)
|
427 |
+
self.key_db.value = lax.map(
|
428 |
+
_reset_single_dataset, xs=(self.key_db.value, datasets))
|
429 |
+
self.value_db.value = lax.map(
|
430 |
+
_reset_single_dataset, xs=(self.value_db.value, datasets))
|
431 |
+
return 0
|
aglib/meliad/transformer/metric_utils.py
ADDED
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 Google.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
"""Helper routines for recording various training metrics."""
|
16 |
+
|
17 |
+
from typing import Any
|
18 |
+
import jax.numpy as jnp
|
19 |
+
|
20 |
+
|
21 |
+
Array = Any
|
22 |
+
|
23 |
+
|
24 |
+
def compute_accuracy_sum(logits, targets, valid_loss_mask=None):
|
25 |
+
"""Compute accuracy for logits and targets.
|
26 |
+
|
27 |
+
Args:
|
28 |
+
logits: [batch, length, num_classes] float array.
|
29 |
+
targets: categorical targets [batch, length] int array.
|
30 |
+
valid_loss_mask: None or array of shape bool[batch, length]
|
31 |
+
|
32 |
+
Returns:
|
33 |
+
The number of correct tokens in the output.
|
34 |
+
"""
|
35 |
+
if logits.shape[:-1] != targets.shape:
|
36 |
+
raise ValueError("Incorrect shapes. Got shape %s logits and %s targets" %
|
37 |
+
logits.shape, targets.shape)
|
38 |
+
if valid_loss_mask is not None and valid_loss_mask.shape != targets.shape:
|
39 |
+
raise ValueError("Incorrect shapes. Got shape %s targets and %s mask" %
|
40 |
+
targets.shape, valid_loss_mask.shape)
|
41 |
+
|
42 |
+
accuracy = jnp.equal(jnp.argmax(logits, axis=-1), targets)
|
43 |
+
if valid_loss_mask is not None:
|
44 |
+
accuracy = jnp.logical_and(accuracy, valid_loss_mask)
|
45 |
+
return jnp.sum(accuracy) # Sum of the number of True values.
|
46 |
+
|
47 |
+
|
48 |
+
def reshape_image(image):
|
49 |
+
"""Reshape image to something that tensorboard recognizes.
|
50 |
+
|
51 |
+
Args:
|
52 |
+
image: Array of shape [xsize, size] or [num_images, xsize, ysize]
|
53 |
+
|
54 |
+
Returns:
|
55 |
+
Array of shape [num_images, xsize, ysize, 1]
|
56 |
+
"""
|
57 |
+
|
58 |
+
# Reshape to [num_images, xdim, ydim, rgb] for tensorboard.
|
59 |
+
sh = image.shape
|
60 |
+
if image.ndim == 2:
|
61 |
+
return jnp.reshape(image, [1, sh[0], sh[1], 1]).astype(jnp.float32)
|
62 |
+
elif image.ndim == 3:
|
63 |
+
return jnp.reshape(image, [sh[0], sh[1], sh[2], 1]).astype(jnp.float32)
|
64 |
+
else:
|
65 |
+
return None # Not an image.
|
66 |
+
|
67 |
+
|
68 |
+
def normalize_image(images: Array, as_group: bool = False) -> Array:
|
69 |
+
"""Rescale the values in images to between 0.0 and 1.0.
|
70 |
+
|
71 |
+
Args:
|
72 |
+
images: Array of size [batch_size, xsize, ysize]
|
73 |
+
as_group: Scale all images in the batch by the same amount if True.
|
74 |
+
|
75 |
+
Returns:
|
76 |
+
A rescaled image of the same shape.
|
77 |
+
"""
|
78 |
+
|
79 |
+
images = images.astype(jnp.float32) # Return images as float32.
|
80 |
+
if as_group:
|
81 |
+
# Normalize the batch of images as a group.
|
82 |
+
min_img = jnp.min(images)
|
83 |
+
max_img = jnp.max(images)
|
84 |
+
else:
|
85 |
+
# Normalize each image in the batch individually.
|
86 |
+
min_img = jnp.min(images, axis=(-2, -1), keepdims=True)
|
87 |
+
max_img = jnp.max(images, axis=(-2, -1), keepdims=True)
|
88 |
+
norm_image = (images - min_img) / (max_img - min_img + 1e-6)
|
89 |
+
return jnp.where(jnp.isfinite(norm_image), norm_image, 0.0)
|
90 |
+
|
91 |
+
|
92 |
+
def overlay_images(image1: Array, image2: Array) -> Array:
|
93 |
+
"""Place image1 on top of image2, broadcasting image2 if necessary.
|
94 |
+
|
95 |
+
Args:
|
96 |
+
image1: array of shape [num_images, xsize, ysize]
|
97 |
+
image2: array of shape [num_images, xsize, ysize]
|
98 |
+
|
99 |
+
Returns:
|
100 |
+
A combined image.
|
101 |
+
"""
|
102 |
+
|
103 |
+
assert image1.ndim == 3 # (num_images, xsize, ysize)
|
104 |
+
assert image2.ndim == 3
|
105 |
+
image2 = jnp.broadcast_to(image2, image1.shape)
|
106 |
+
return jnp.concatenate([image1, image2], axis=1)
|
107 |
+
|
108 |
+
|
109 |
+
def make_histograms(viz_dicts):
|
110 |
+
"""Generate image histograms."""
|
111 |
+
hist_dict = {}
|
112 |
+
for (i, viz_dict) in enumerate(viz_dicts):
|
113 |
+
for (k, images) in viz_dict.items():
|
114 |
+
hist_dict["h_" + k + "_" + str(i)] = images
|
115 |
+
return hist_dict
|
aglib/meliad/transformer/models.py
ADDED
@@ -0,0 +1,317 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 Google.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
"""Sequence to sequence model."""
|
16 |
+
|
17 |
+
from typing import Any, Callable, Dict, Tuple
|
18 |
+
|
19 |
+
from absl import logging
|
20 |
+
from flax import linen as nn
|
21 |
+
from flax.training import common_utils
|
22 |
+
import gin
|
23 |
+
import jax
|
24 |
+
import jax.numpy as jnp
|
25 |
+
import metrics_summary
|
26 |
+
from transformer import decoder_stack
|
27 |
+
from transformer import metric_utils
|
28 |
+
from transformer import text_dataset
|
29 |
+
import numpy as np
|
30 |
+
import seqio
|
31 |
+
|
32 |
+
|
33 |
+
Array = jnp.ndarray
|
34 |
+
MetricsSummary = metrics_summary.MetricsSummary
|
35 |
+
|
36 |
+
|
37 |
+
# TODO(mrabe): Remove this function and find a better way to turn text metrics
|
38 |
+
# into text on tensorboard.
|
39 |
+
def process_summaries(vocab: seqio.Vocabulary,
|
40 |
+
met_summary: MetricsSummary,
|
41 |
+
mode: str) -> MetricsSummary:
|
42 |
+
"""Compute some additional summaries, and convert tokens to text.
|
43 |
+
|
44 |
+
Args:
|
45 |
+
vocab: The vocabulary to detokenize generated text.
|
46 |
+
met_summary: The summary object to process.
|
47 |
+
mode: The mode of the summary (e.g. "test", "train")
|
48 |
+
|
49 |
+
Returns:
|
50 |
+
The modified summary dictionary.
|
51 |
+
"""
|
52 |
+
|
53 |
+
mdict = met_summary.current_metric_dict()
|
54 |
+
|
55 |
+
# Calculate perplexity from the average nats_per_token over all replicas.
|
56 |
+
# This has to be done here, because the perplexities themselves can't be
|
57 |
+
# averaged in the usual way.
|
58 |
+
if "nats_per_token" in mdict:
|
59 |
+
nats_per_token = mdict["nats_per_token"].to_value()
|
60 |
+
met_summary.add({"perplexity": np.exp(nats_per_token)})
|
61 |
+
|
62 |
+
if mode == "generate" and "gen_tokens" in mdict:
|
63 |
+
# Convert output tokens to example output text.
|
64 |
+
# Write text to both the summary, and pretty-print to the log file.
|
65 |
+
gen_toks = mdict["gen_tokens"].to_value()
|
66 |
+
if np.ndim(gen_toks) != 2:
|
67 |
+
raise ValueError("Unsupported shape for gen_tokens: %s" % gen_toks.shape)
|
68 |
+
|
69 |
+
ntoks = gen_toks.shape[-1]
|
70 |
+
gen_text = text_dataset.decode_tokens(gen_toks, vocab, max_length=ntoks)
|
71 |
+
logging.info("Generated text = %s", gen_text)
|
72 |
+
met_summary.add_text({"gen_text": gen_text})
|
73 |
+
del mdict["gen_tokens"] # Otherwise it will turn into a histogram.
|
74 |
+
|
75 |
+
return met_summary
|
76 |
+
|
77 |
+
|
78 |
+
@gin.configurable
|
79 |
+
def process_summaries_function(vocab: seqio.Vocabulary) -> Callable[
|
80 |
+
[MetricsSummary, str], MetricsSummary]:
|
81 |
+
"""Return a function that processes summaries with the given vocabulary."""
|
82 |
+
# For use with training_loop.process_summaries_function
|
83 |
+
def process_fn(met_summary: MetricsSummary, mode: str):
|
84 |
+
return process_summaries(vocab, met_summary, mode)
|
85 |
+
return process_fn
|
86 |
+
|
87 |
+
|
88 |
+
@gin.configurable
|
89 |
+
class DecoderOnlyLanguageModel(nn.Module):
|
90 |
+
"""Decoder only language modeling."""
|
91 |
+
|
92 |
+
mode: str
|
93 |
+
task_config: decoder_stack.TransformerTaskConfig = gin.REQUIRED
|
94 |
+
decoder_factory: Callable[[], Any] = gin.REQUIRED
|
95 |
+
|
96 |
+
sample_method: str = "sample" # Can be {"sample", "greedy"}
|
97 |
+
output_token_losses: bool = False
|
98 |
+
|
99 |
+
def get_fake_input(self):
|
100 |
+
"""Returns a fake input for initialization of the appropriate shape."""
|
101 |
+
b = self.task_config.batch_size
|
102 |
+
fake_input_dict = {
|
103 |
+
"targets": jnp.ones([b, self.task_config.sequence_length],
|
104 |
+
dtype=jnp.int32),
|
105 |
+
"start_of_sequence": jnp.ones([b], dtype=jnp.bool_),
|
106 |
+
"epoch": jnp.ones([b], dtype=jnp.int32),
|
107 |
+
}
|
108 |
+
if text_dataset.get_loss_mask_tokens(split=self.mode) != (None, None):
|
109 |
+
# We are not adding the loss mask to the dummy input by default as it can
|
110 |
+
# cause a slowdown during evaluation and perhaps inference.
|
111 |
+
fake_input_dict["loss_mask"] = jnp.ones(
|
112 |
+
[b, self.task_config.sequence_length], dtype=jnp.bool_)
|
113 |
+
return fake_input_dict
|
114 |
+
|
115 |
+
def metrics_summary_operations(self, aggregate_over: str) -> Dict[str, str]:
|
116 |
+
"""Summary operation to use for recorded metrics."""
|
117 |
+
metric_ops = {
|
118 |
+
"loss": "mean",
|
119 |
+
"nats_per_token": "mean",
|
120 |
+
"bits_per_token": "mean",
|
121 |
+
"bits_per_char": "mean",
|
122 |
+
"accuracy": "mean",
|
123 |
+
"num_tokens": "mean",
|
124 |
+
"num_chars_per_device": "mean",
|
125 |
+
"num_chars_per_batch": "mean",
|
126 |
+
"nonzero_tokens": "mean",
|
127 |
+
"num_tokens_per_device": "mean",
|
128 |
+
"num_tokens_per_batch": "mean",
|
129 |
+
"epoch": "mean",
|
130 |
+
}
|
131 |
+
if aggregate_over == "steps":
|
132 |
+
return metric_ops
|
133 |
+
elif aggregate_over == "devices":
|
134 |
+
# Ensure that statistics that refer to the total batch size stay constant
|
135 |
+
# as TPU topologies change. For those we have to sum over devices, but
|
136 |
+
# compute the mean over steps.
|
137 |
+
metric_ops.update({
|
138 |
+
"num_tokens_per_batch": "sum",
|
139 |
+
"num_chars_per_batch": "sum",
|
140 |
+
"loss": "sum"})
|
141 |
+
return metric_ops
|
142 |
+
else:
|
143 |
+
raise ValueError("Don't know how to aggregate over: %s" % aggregate_over)
|
144 |
+
|
145 |
+
def setup(self):
|
146 |
+
self.decoder = self.decoder_factory(mode=self.mode,
|
147 |
+
task_config=self.task_config) # pytype: disable=wrong-keyword-args # trace-all-classes
|
148 |
+
|
149 |
+
def __call__(self, inputs: ...):
|
150 |
+
task_config = self.task_config
|
151 |
+
|
152 |
+
input_tokens = inputs["targets"] # [b, seq_len]
|
153 |
+
start_of_sequence = inputs["start_of_sequence"] # [b]
|
154 |
+
epochs = inputs["epoch"] # [b]
|
155 |
+
if "loss_mask" in inputs:
|
156 |
+
loss_mask = inputs["loss_mask"] # [b, seq_len]
|
157 |
+
else:
|
158 |
+
loss_mask = jnp.ones((1, 1), dtype=jnp.bool_)
|
159 |
+
|
160 |
+
input_tokens = jnp.asarray(input_tokens)
|
161 |
+
assert input_tokens.ndim == 2
|
162 |
+
assert input_tokens.shape[0] == task_config.batch_size
|
163 |
+
assert input_tokens.shape[1] == task_config.sequence_length
|
164 |
+
assert start_of_sequence.shape[0] == task_config.batch_size
|
165 |
+
|
166 |
+
# Sanity check to avoid out-of-bounds on token lookup.
|
167 |
+
input_tokens = input_tokens % task_config.vocab_size
|
168 |
+
|
169 |
+
logging.info("langmodel: Compiling model for mode %s", self.mode)
|
170 |
+
logging.info("langmodel: input_tokens = %r", input_tokens)
|
171 |
+
logging.info("langmodel: start_of_sequece = %r", start_of_sequence)
|
172 |
+
logging.info("langmodel: epochs = %r", epochs)
|
173 |
+
|
174 |
+
# The target outputs are the next character in each sequence.
|
175 |
+
# Shift tokens left and pad with a zero at the end.
|
176 |
+
# TODO(delesley): We don't predict the first token of each sequence.
|
177 |
+
target_tokens = jnp.pad(input_tokens[:, 1:], [(0, 0), (0, 1)])
|
178 |
+
logging.info("langmodel: target_tokens = %r", target_tokens)
|
179 |
+
|
180 |
+
# Invoke the decoder stack.
|
181 |
+
# The decoder will return pre-softmax logits for the predicted targets.
|
182 |
+
(logits, _, d_metrics) = self.decoder(input_tokens=input_tokens,
|
183 |
+
target_tokens=target_tokens,
|
184 |
+
start_of_sequence=start_of_sequence)
|
185 |
+
|
186 |
+
# Softmax cross-entropy loss on target tokens.
|
187 |
+
logits = nn.log_softmax(logits, axis=-1) # (b, seq_len, vocab_size)
|
188 |
+
logging.info("langmodel: logits = %r", logits)
|
189 |
+
soft_targets = common_utils.onehot(target_tokens, task_config.vocab_size)
|
190 |
+
logging.info("langmodel: soft_targets = %r", soft_targets)
|
191 |
+
|
192 |
+
losses = -jnp.sum(soft_targets * logits, axis=-1) # (b, seq_len)
|
193 |
+
logging.info("langmodel: losses = %r", losses)
|
194 |
+
|
195 |
+
# Don't predict null tokens which are past the end-of-sequence.
|
196 |
+
# Also don't predict the 0 at the end of the sequence.
|
197 |
+
# TODO(delesley): Predict the final end-of-sequence marker.
|
198 |
+
loss_mask = jnp.logical_and(
|
199 |
+
loss_mask,
|
200 |
+
input_tokens > 0)
|
201 |
+
loss_mask = jnp.logical_and(
|
202 |
+
loss_mask,
|
203 |
+
target_tokens > 0)
|
204 |
+
logging.info("langmodel: loss_mask = %r", loss_mask)
|
205 |
+
|
206 |
+
losses = jnp.where(loss_mask, losses, 0.0) # (batch_size, seq_len)
|
207 |
+
loss = jnp.sum(losses) # total loss on device
|
208 |
+
|
209 |
+
token_count = jnp.sum(loss_mask) # tokens on device
|
210 |
+
token_count_nz = token_count + 1.0e-6
|
211 |
+
loss_per_token = loss / token_count_nz
|
212 |
+
bits_per_token = loss_per_token * 1.442695 # log(e)/log(2)
|
213 |
+
accuracy = metric_utils.compute_accuracy_sum(logits, target_tokens,
|
214 |
+
loss_mask)
|
215 |
+
accuracy = accuracy / token_count_nz # Percent correct.
|
216 |
+
epoch = jnp.mean(epochs)
|
217 |
+
|
218 |
+
if self.mode == "generate" and self.decoder.supports_generate():
|
219 |
+
# Generate example text.
|
220 |
+
logging.info("lang_model: text inference.")
|
221 |
+
gen_tokens = self.generate(inputs, task_config.sequence_length)
|
222 |
+
|
223 |
+
# Return generated text, along with vizualizations and histograms.
|
224 |
+
metrics = {"gen_tokens": gen_tokens, **d_metrics}
|
225 |
+
return (loss, metrics)
|
226 |
+
|
227 |
+
# Just return metrics related to the loss.
|
228 |
+
metrics = {
|
229 |
+
"loss": loss, # will be summed over devices
|
230 |
+
"nats_per_token": (loss_per_token, token_count),
|
231 |
+
"bits_per_token": (bits_per_token, token_count),
|
232 |
+
"accuracy": (accuracy, token_count),
|
233 |
+
"num_tokens_per_device": token_count,
|
234 |
+
"num_tokens_per_batch": token_count, # will be summed over devices
|
235 |
+
"epoch": epoch,
|
236 |
+
}
|
237 |
+
|
238 |
+
# Compute bits per character if we have the number of characters.
|
239 |
+
if "num_chars" in inputs:
|
240 |
+
num_chars = jnp.sum(inputs["num_chars"])
|
241 |
+
bits_per_char = loss / (num_chars + 1e-6) * 1.442695
|
242 |
+
metrics["num_chars_per_device"] = num_chars
|
243 |
+
metrics["num_chars_per_batch"] = num_chars # will be summed over devices
|
244 |
+
metrics["bits_per_char"] = (bits_per_char, num_chars)
|
245 |
+
|
246 |
+
# Provided to make sure that the data pipeline and the the model agree
|
247 |
+
# on the number of tokens with a loss.
|
248 |
+
if "nonzero_tokens" in inputs:
|
249 |
+
nonzero_tokens = jnp.sum(inputs["nonzero_tokens"])
|
250 |
+
metrics["nonzero_tokens"] = nonzero_tokens
|
251 |
+
|
252 |
+
if self.output_token_losses:
|
253 |
+
metrics["token_losses"] = losses
|
254 |
+
|
255 |
+
return (loss, metrics)
|
256 |
+
|
257 |
+
def generate(self, inputs: ..., sequence_length: int) -> Array:
|
258 |
+
"""Generate an output sequence.
|
259 |
+
|
260 |
+
Args:
|
261 |
+
inputs: the same as argument to _call_.
|
262 |
+
sequence_length: the length of sequence to generate.
|
263 |
+
|
264 |
+
Returns:
|
265 |
+
An array of generated tokens of shape (batch_size, sequence_length).
|
266 |
+
"""
|
267 |
+
# TODO(delesley): Add support for passing the prefix as an argument.
|
268 |
+
# TODO(delesley): Add support for temperature, gumbel softmax, beam search.
|
269 |
+
|
270 |
+
batch_size = self.task_config.batch_size
|
271 |
+
input_tokens = inputs["targets"] # [b,seq_len]
|
272 |
+
start_of_sequence = inputs["start_of_sequence"] # [b]
|
273 |
+
|
274 |
+
# Initialize decoder.
|
275 |
+
dstate = self.decoder.init_decoder_state(sequence_length,
|
276 |
+
start_of_sequence)
|
277 |
+
|
278 |
+
# TODO(delesley): Handle start-of-sequence in a better way.
|
279 |
+
# There is no special token for start of sequence, so we grab the first
|
280 |
+
# one from the ground-truth input data.
|
281 |
+
first_token = input_tokens[:, 0:1]
|
282 |
+
no_start_of_seq = jnp.array([False] * batch_size, dtype=jnp.bool_)
|
283 |
+
sample_method = self.sample_method
|
284 |
+
sample_prng = self.make_rng("sample")
|
285 |
+
|
286 |
+
# Greedy autoregressive decoder function.
|
287 |
+
def loop_fn(scan_state: Any, i: Array) -> Tuple[Any, Array]:
|
288 |
+
prng = jax.random.fold_in(sample_prng, i)
|
289 |
+
(dstate, input_token) = scan_state
|
290 |
+
del i
|
291 |
+
(logits, dstate, _) = self.decoder(input_tokens=input_token,
|
292 |
+
target_tokens=None,
|
293 |
+
start_of_sequence=no_start_of_seq,
|
294 |
+
decoder_state=dstate)
|
295 |
+
if sample_method == "sample":
|
296 |
+
logging.info("Using categorical sampling.")
|
297 |
+
output_token = jax.random.categorical(prng, logits, axis=-1)
|
298 |
+
elif sample_method == "greedy":
|
299 |
+
logging.info("Using greedy sampling.")
|
300 |
+
output_token = jnp.argmax(logits, axis=-1)
|
301 |
+
else:
|
302 |
+
raise ValueError(f"Invalid sampling method: {sample_method}")
|
303 |
+
logging.info("generate_loop_fn: output_token = %r", output_token)
|
304 |
+
return ((dstate, output_token), output_token)
|
305 |
+
|
306 |
+
# Scan over the sequence length.
|
307 |
+
iterations = jnp.arange(sequence_length)
|
308 |
+
initial_scan_state = (dstate, first_token)
|
309 |
+
(_, output_tokens) = jax.lax.scan(loop_fn, initial_scan_state, iterations)
|
310 |
+
logging.info("generate: output_tokens = %r", output_tokens)
|
311 |
+
|
312 |
+
# Output_tokens has shape (sequence_length, batch_size, 1)
|
313 |
+
assert output_tokens.shape == (sequence_length, batch_size, 1)
|
314 |
+
output_tokens = jnp.reshape(
|
315 |
+
output_tokens, (sequence_length, self.task_config.batch_size))
|
316 |
+
output_tokens = output_tokens.transpose([1, 0])
|
317 |
+
return output_tokens
|
aglib/meliad/transformer/nn_components.py
ADDED
@@ -0,0 +1,437 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 Google.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
"""Core NN components used in models.
|
16 |
+
"""
|
17 |
+
|
18 |
+
from typing import Any, Callable, Optional, Tuple, Union
|
19 |
+
|
20 |
+
from absl import logging
|
21 |
+
from flax import linen as nn
|
22 |
+
import gin
|
23 |
+
import jax
|
24 |
+
from jax import lax
|
25 |
+
from jax.nn import initializers
|
26 |
+
import jax.numpy as jnp
|
27 |
+
|
28 |
+
|
29 |
+
PRNGKey = Any
|
30 |
+
Array = jnp.ndarray
|
31 |
+
Shape = Tuple[int, ...]
|
32 |
+
Dtype = Union[jnp.dtype, str]
|
33 |
+
|
34 |
+
|
35 |
+
def scalar_initializer(x):
|
36 |
+
"""Like linen.zeros, but initializes a parameter to a scalar value."""
|
37 |
+
def init_fun(key, shape, dtype):
|
38 |
+
del key
|
39 |
+
return jnp.broadcast_to(jnp.array(x, dtype=dtype), shape)
|
40 |
+
return init_fun
|
41 |
+
|
42 |
+
|
43 |
+
def swish(x: Array) -> Array:
|
44 |
+
"""Swish function, which is very similar to gelu."""
|
45 |
+
return x * nn.sigmoid(x)
|
46 |
+
|
47 |
+
|
48 |
+
def soft_abs(x: Array) -> Array:
|
49 |
+
"""Soft version of absolute value, that is smoothly differentiable."""
|
50 |
+
return jnp.sqrt(jnp.square(x) + 1) - 1
|
51 |
+
|
52 |
+
|
53 |
+
def get_activation_function(fname: Optional[str]) -> Callable[[Array], Array]:
|
54 |
+
"""Get activation function from the specified string."""
|
55 |
+
if fname is None:
|
56 |
+
return lambda x: x
|
57 |
+
elif fname == "relu":
|
58 |
+
return nn.relu
|
59 |
+
elif fname == "swish":
|
60 |
+
return swish
|
61 |
+
elif fname == "sigmoid":
|
62 |
+
return nn.sigmoid
|
63 |
+
elif fname == "tanh":
|
64 |
+
return nn.tanh
|
65 |
+
else:
|
66 |
+
raise ValueError("Unknown activation function %s" % fname)
|
67 |
+
|
68 |
+
|
69 |
+
# Adapted from flax.linen.softmax.
|
70 |
+
def safe_softmax(x: Array,
|
71 |
+
axis: Optional[Union[int, Tuple[int, ...]]] = -1,
|
72 |
+
min_x: Optional[Array] = None) -> Array:
|
73 |
+
r"""Softmax function.
|
74 |
+
|
75 |
+
Computes the function which rescales elements to the range :math:`[0, 1]`
|
76 |
+
such that the elements along :code:`axis` sum to :math:`1`.
|
77 |
+
|
78 |
+
This version of softmax is intended for use with causal attention masks, and
|
79 |
+
safely covers the situation where all elements are masked out. If min_x is
|
80 |
+
not None, then probabability will be distributed between the values in x, and
|
81 |
+
min_x. If x >> min_x, then the probability allocated to min_x will be zero,
|
82 |
+
and this function will be the same as the usual softmax. However, if
|
83 |
+
x << min_x, (because all the values in x are masked out) then probability
|
84 |
+
will be allocated to min_x instead, and the probability allocated to x will
|
85 |
+
be 0. I.e., attention will attend to nothing if everything is masked out.
|
86 |
+
|
87 |
+
.. math ::
|
88 |
+
\mathrm{softmax}(x) = \frac{\exp(x_i)}{\sum_j \exp(x_j)}
|
89 |
+
|
90 |
+
Args:
|
91 |
+
x: input array
|
92 |
+
axis: the axis or axes along which the softmax should be computed. The
|
93 |
+
softmax output summed across these dimensions should sum to :math:`1`.
|
94 |
+
Either an integer or a tuple of integers.
|
95 |
+
min_x: the value of a minimum element which will be included in the
|
96 |
+
softmax sum. The value of min_x should be small when compared to the
|
97 |
+
expected values in x. If all of the values in x are smaller than
|
98 |
+
min_x, then probability will be allocated to the minimum element
|
99 |
+
instead, and the result of softmax will sum to less than 1.
|
100 |
+
|
101 |
+
Returns:
|
102 |
+
An array of the same shape as x.
|
103 |
+
"""
|
104 |
+
# Subtract maximum value in x for numerical stability, so that the exponent
|
105 |
+
# never exceeds numerical precision.
|
106 |
+
x_max = lax.stop_gradient(jnp.max(x, axis, initial=min_x, keepdims=True))
|
107 |
+
if min_x is not None:
|
108 |
+
min_x = jnp.asarray(min_x, dtype=x.dtype)
|
109 |
+
x_max = jnp.maximum(x_max, min_x)
|
110 |
+
unnormalized = jnp.exp(x - x_max)
|
111 |
+
x_sum = jnp.sum(unnormalized, axis=axis, keepdims=True)
|
112 |
+
if min_x is not None:
|
113 |
+
x_sum = x_sum + jnp.exp(min_x - x_max)
|
114 |
+
return unnormalized / x_sum
|
115 |
+
|
116 |
+
|
117 |
+
def dropout_multiplier_mask(rng, dropout_rate: float, shape: Shape,
|
118 |
+
dtype: Dtype):
|
119 |
+
"""Returns an array which can be multiplied by an input to perform dropout.
|
120 |
+
|
121 |
+
Args:
|
122 |
+
rng: A random number generator.
|
123 |
+
dropout_rate: The rate at which to drop.
|
124 |
+
shape: The shape of the output array.
|
125 |
+
dtype: The type of the output array.
|
126 |
+
|
127 |
+
Returns:
|
128 |
+
An array of given shape, where values are { 0.0, 1.0/keep_probibility. }.
|
129 |
+
"""
|
130 |
+
if dropout_rate <= 0.0:
|
131 |
+
return jnp.ones(shape, dtype=dtype)
|
132 |
+
|
133 |
+
logging.info("dropout mask: %s", shape)
|
134 |
+
keep_prob = 1.0 - dropout_rate
|
135 |
+
keep = jax.random.bernoulli(rng, keep_prob, shape)
|
136 |
+
dropout_multiplier = (keep.astype(dtype) / jnp.asarray(keep_prob, dtype))
|
137 |
+
return dropout_multiplier
|
138 |
+
|
139 |
+
|
140 |
+
def tiled_dropout(x: Array, shape: Shape, dropout_rate: float,
|
141 |
+
rng_function: Callable[[], jax.random.KeyArray],
|
142 |
+
deterministic: bool) -> Array:
|
143 |
+
"""Tiles a dropout mask over a larger array.
|
144 |
+
|
145 |
+
This will generate a smaller dropout mask of the given shape, and tile it
|
146 |
+
over a larger array, which reduces the computational cost and memory
|
147 |
+
associated with generating a large dropout mask.
|
148 |
+
|
149 |
+
Args:
|
150 |
+
x: The input array.
|
151 |
+
shape: The shape of the dropout mask to tile.
|
152 |
+
dropout_rate: The rate at which to drop.
|
153 |
+
rng_function: A function which returns a random number generator, e.g.
|
154 |
+
lambda. self.make_rng("dropout"). The function will not
|
155 |
+
be called if dropout is not enabled.
|
156 |
+
deterministic: If True, don't do dropout.
|
157 |
+
|
158 |
+
Returns:
|
159 |
+
An array of the same shape as x, with some values dropped out.
|
160 |
+
"""
|
161 |
+
if deterministic or dropout_rate <= 0.0:
|
162 |
+
return x
|
163 |
+
|
164 |
+
if x.ndim != len(shape):
|
165 |
+
raise ValueError("Shapes must have same number of dimensions %r, %r." %
|
166 |
+
(x.shape, shape))
|
167 |
+
for (xd, sd) in zip(x.shape, shape):
|
168 |
+
if (xd % sd) != 0:
|
169 |
+
raise ValueError("Incompatible shapes %r, %r" % (x.shape, shape))
|
170 |
+
|
171 |
+
# Get random number generator for dropout.
|
172 |
+
rng = rng_function()
|
173 |
+
|
174 |
+
repeats = [(1 if sd == 1 else xd // sd) for (xd, sd) in zip(x.shape, shape)]
|
175 |
+
logging.info("tiled dropout %r, tile: %r", x.shape, shape)
|
176 |
+
|
177 |
+
dtype = x.dtype
|
178 |
+
keep_prob = 1.0 - dropout_rate
|
179 |
+
keep = jax.random.bernoulli(rng, keep_prob, shape)
|
180 |
+
keep = jnp.tile(keep, repeats)
|
181 |
+
keep = jnp.broadcast_to(keep, x.shape)
|
182 |
+
x_scaled = x / jnp.asarray(keep_prob, dtype=dtype)
|
183 |
+
return lax.select(keep, x_scaled, jnp.zeros_like(x, dtype=dtype))
|
184 |
+
|
185 |
+
|
186 |
+
@gin.configurable
|
187 |
+
class MLP(nn.Module):
|
188 |
+
"""Implements a multi-layer perceptron, with optional resnet or gate."""
|
189 |
+
|
190 |
+
# Arguments to module.
|
191 |
+
num_output_features: int # Length of output vectors.
|
192 |
+
|
193 |
+
# Gin configurable parameters.
|
194 |
+
num_layers: int = gin.REQUIRED # Number of layers in the MLP.
|
195 |
+
num_hidden_units: int = gin.REQUIRED # Length of hidden unit vectors.
|
196 |
+
hidden_activation: Optional[str] = "relu" # Hidden layer activation fn.
|
197 |
+
final_activation: Optional[str] = None # Final layer activation fn.
|
198 |
+
use_bias: bool = True # Use a bias in each dense layer.
|
199 |
+
gate_type: Optional[str] = None # { "residual", "bias", "full" }
|
200 |
+
initializer_scale: float = 1.0 # Scale of initial values.
|
201 |
+
dtype: Any = jnp.float32
|
202 |
+
|
203 |
+
def setup(self):
|
204 |
+
kernel_init = jax.nn.initializers.variance_scaling(
|
205 |
+
scale=self.initializer_scale, mode="fan_in",
|
206 |
+
distribution="truncated_normal")
|
207 |
+
|
208 |
+
assert self.num_layers > 0
|
209 |
+
hlayers = []
|
210 |
+
for i in range(0, self.num_layers - 1):
|
211 |
+
assert self.num_hidden_units > 0
|
212 |
+
hlayer = nn.Dense(self.num_hidden_units,
|
213 |
+
use_bias=self.use_bias,
|
214 |
+
kernel_init=kernel_init,
|
215 |
+
dtype=self.dtype,
|
216 |
+
name=f"hidden{i}")
|
217 |
+
hlayers.append(hlayer)
|
218 |
+
self.hidden_layers = hlayers
|
219 |
+
self.output_layer = nn.Dense(self.num_output_features,
|
220 |
+
use_bias=self.use_bias,
|
221 |
+
kernel_init=kernel_init,
|
222 |
+
dtype=self.dtype)
|
223 |
+
|
224 |
+
if self.gate_type is None or self.gate_type == "residual":
|
225 |
+
return
|
226 |
+
|
227 |
+
# We use a low but non-zero bias so that adafactor knows how to scale it.
|
228 |
+
gate_bias_init = jax.nn.initializers.normal(stddev=0.1)
|
229 |
+
# Also use a lower than normal kernel.
|
230 |
+
gate_kernel_init = jax.nn.initializers.variance_scaling(
|
231 |
+
scale=0.1, mode="fan_in", distribution="truncated_normal")
|
232 |
+
|
233 |
+
if self.gate_type == "bias":
|
234 |
+
self.gate_bias = self.param("gate_bias", gate_bias_init,
|
235 |
+
(self.num_output_features,), jnp.float32)
|
236 |
+
elif self.gate_type == "full":
|
237 |
+
self.gate_layer = nn.Dense(self.num_output_features,
|
238 |
+
use_bias=True,
|
239 |
+
bias_init=gate_bias_init,
|
240 |
+
kernel_init=gate_kernel_init,
|
241 |
+
dtype=self.dtype)
|
242 |
+
elif self.gate_type == "lstm":
|
243 |
+
self.input_gate = nn.Dense(self.num_output_features,
|
244 |
+
use_bias=True,
|
245 |
+
bias_init=gate_bias_init,
|
246 |
+
kernel_init=gate_kernel_init,
|
247 |
+
dtype=self.dtype)
|
248 |
+
self.forget_gate = nn.Dense(self.num_output_features,
|
249 |
+
use_bias=True,
|
250 |
+
bias_init=gate_bias_init,
|
251 |
+
kernel_init=gate_kernel_init,
|
252 |
+
dtype=self.dtype)
|
253 |
+
else:
|
254 |
+
raise ValueError("Unsupported gate_type: %s" % self.gate_type)
|
255 |
+
|
256 |
+
def _gate(self, y_hidden: Array, state: Array, y_out: Array) -> Array:
|
257 |
+
"""Compute the value to use for the gate."""
|
258 |
+
|
259 |
+
if self.gate_type == "residual":
|
260 |
+
# Residual connection: just add y_out to the state.
|
261 |
+
logging.info("mlp: residual")
|
262 |
+
return state + y_out
|
263 |
+
|
264 |
+
elif self.gate_type == "bias":
|
265 |
+
# Simple gate: use a gru_style gate with a learned bias (no kernel).
|
266 |
+
bias = jnp.asarray(self.gate_bias, dtype=self.dtype)
|
267 |
+
bias = jnp.reshape(bias, (1,) * (y_out.ndim - 1) + (-1,)) # batch dims.
|
268 |
+
g = jax.nn.sigmoid(bias)
|
269 |
+
logging.info("mlp: gate bias = %r", g)
|
270 |
+
return (state * g) + (y_out * (1 - g))
|
271 |
+
|
272 |
+
elif self.gate_type == "full":
|
273 |
+
# Normal GRU style gate -- compute g using both a kernel and bias.
|
274 |
+
g = jax.nn.sigmoid(self.gate_layer(y_hidden) + 1) # biased to remember
|
275 |
+
logging.info("mlp: gate full = %r", g)
|
276 |
+
return (state * g) + (y_out * (1 - g))
|
277 |
+
|
278 |
+
elif self.gate_type == "lstm":
|
279 |
+
# LSTM style gate with input and forget gates.
|
280 |
+
fg = jax.nn.sigmoid(self.forget_gate(y_hidden) + 1) # biased to remember
|
281 |
+
ig = jax.nn.sigmoid(self.input_gate(y_hidden) - 1)
|
282 |
+
logging.info("mlp: gate lstm = %r, %r", ig, fg)
|
283 |
+
return (state * fg) + (y_out * ig)
|
284 |
+
|
285 |
+
else:
|
286 |
+
raise ValueError("Unsupported gate type %s" % self.gate_type)
|
287 |
+
|
288 |
+
def __call__(self, x: Array, state: Optional[Array],
|
289 |
+
apply_dropout: bool = False,
|
290 |
+
dropout_rate: float = 0.0,
|
291 |
+
drop_tile_shape: Optional[Shape] = None,
|
292 |
+
rng_function: Optional[Callable[[], Any]] = None) -> Array:
|
293 |
+
"""Apply the multi-layer perceptron to the input x.
|
294 |
+
|
295 |
+
For simple MLPs, returns f(x), where f is the MLP function.
|
296 |
+
For resnets and gated architectures, it returns
|
297 |
+
state + f(x) -- for resnet.
|
298 |
+
g*state + (1-g)*f(x) -- for gated architecture, where g is the gate.
|
299 |
+
|
300 |
+
Args:
|
301 |
+
x: The input to the MLP.
|
302 |
+
state: The prior value, if this MLP is used as part of a resnet or gated
|
303 |
+
architecture.
|
304 |
+
apply_dropout: If true, applies dropout to the result.
|
305 |
+
dropout_rate: The dropout rate to use.
|
306 |
+
drop_tile_shape: The dropout tile shape.
|
307 |
+
rng_function: Gets a random number seed for dropout.
|
308 |
+
|
309 |
+
Returns:
|
310 |
+
The combination of f(x) and the (optional) prior state.
|
311 |
+
"""
|
312 |
+
|
313 |
+
x = jnp.asarray(x, self.dtype)
|
314 |
+
hidden_act_fun = get_activation_function(self.hidden_activation)
|
315 |
+
final_act_fun = get_activation_function(self.final_activation)
|
316 |
+
if self.hidden_layers:
|
317 |
+
# Apply some number of hidden layers.
|
318 |
+
y = x
|
319 |
+
for layer in self.hidden_layers:
|
320 |
+
logging.info("mlp: hidden %d, %s", self.num_hidden_units,
|
321 |
+
self.hidden_activation)
|
322 |
+
y = hidden_act_fun(layer(y))
|
323 |
+
else:
|
324 |
+
# Apply the hidden activation function to the input.
|
325 |
+
logging.info("mlp: activation = %s", self.hidden_activation)
|
326 |
+
y = hidden_act_fun(x)
|
327 |
+
|
328 |
+
y_hidden = y # The hidden layer right before the output.
|
329 |
+
logging.info("mlp: final activation = %s", self.final_activation)
|
330 |
+
y_out = self.output_layer(y_hidden) # The MLP final output.
|
331 |
+
y_out = final_act_fun(y_out) # Apply final activation function.
|
332 |
+
logging.info("mlp: final = %r", y_out)
|
333 |
+
|
334 |
+
# Optionally apply dropout to the output.
|
335 |
+
if apply_dropout:
|
336 |
+
if drop_tile_shape is None:
|
337 |
+
raise ValueError("drop_tile_shape must be specified for dropout.")
|
338 |
+
if rng_function is None:
|
339 |
+
raise ValueError("rng_function must be specified for dropout.")
|
340 |
+
logging.info("mlp: dropout rate = %s", dropout_rate)
|
341 |
+
y_out = tiled_dropout(
|
342 |
+
y_out, shape=drop_tile_shape, dropout_rate=dropout_rate,
|
343 |
+
rng_function=rng_function, deterministic=False)
|
344 |
+
|
345 |
+
if state is None:
|
346 |
+
# Simple MLP. No gate to combine y_out with the state.
|
347 |
+
assert self.gate_type is None
|
348 |
+
logging.info("mlp: gate type = None.")
|
349 |
+
return y_out
|
350 |
+
|
351 |
+
# When using state, gate_type must be specified.
|
352 |
+
assert self.gate_type is not None
|
353 |
+
return self._gate(y_hidden, state, y_out)
|
354 |
+
|
355 |
+
|
356 |
+
# Modified slightly from the flax implementation.
|
357 |
+
@gin.configurable
|
358 |
+
class LayerNorm(nn.Module):
|
359 |
+
"""Layer normalization (https://arxiv.org/abs/1607.06450).
|
360 |
+
|
361 |
+
Operates on the last axis of the input data.
|
362 |
+
|
363 |
+
It normalizes the activations of the layer for each given example in a
|
364 |
+
batch independently, rather than across a batch like Batch Normalization.
|
365 |
+
i.e. applies a transformation that maintains the mean activation within
|
366 |
+
each example close to 0 and the activation standard deviation close to 1.
|
367 |
+
|
368 |
+
Attributes:
|
369 |
+
epsilon: A small float added to variance to avoid dividing by zero.
|
370 |
+
dtype: the dtype of the computation (default: float32).
|
371 |
+
use_bias: If True, bias (beta) is added.
|
372 |
+
use_scale: If True, multiply by scale (gamma).
|
373 |
+
use_mean: If True, compute and adjust for the mean.
|
374 |
+
Note that that T5X layernorm does not use the mean.
|
375 |
+
Empirically, ignoring the mean can stabilize learning in transformers.
|
376 |
+
use_scalar_scale_bias: If True, using a single scalar for scale & bias.
|
377 |
+
enable_layernorm: If False, does not perform layernorm.
|
378 |
+
bias_init: Initializer for bias, by default, zero.
|
379 |
+
scale_init: Initializer for scale, by default, one.
|
380 |
+
"""
|
381 |
+
epsilon: float = 1e-6
|
382 |
+
dtype: Any = jnp.float32
|
383 |
+
use_scale: bool = True # Apply a learned scale.
|
384 |
+
use_bias: bool = False # Apply a learned bias.
|
385 |
+
use_mean: bool = False # Calculate and adjust for the mean.
|
386 |
+
use_scalar_scale_bias: bool = False # Learn a single scalar scale & bias.
|
387 |
+
enable_layernorm: bool = True # Turn off layernorm if false.
|
388 |
+
bias_init: Callable[[PRNGKey, Shape, Dtype], Array] = initializers.zeros
|
389 |
+
scale_init: Callable[[PRNGKey, Shape, Dtype], Array] = initializers.ones
|
390 |
+
|
391 |
+
@nn.compact
|
392 |
+
def __call__(self, x):
|
393 |
+
"""Applies layer normalization on the input.
|
394 |
+
|
395 |
+
Args:
|
396 |
+
x: the inputs
|
397 |
+
|
398 |
+
Returns:
|
399 |
+
Normalized inputs (the same shape as inputs).
|
400 |
+
"""
|
401 |
+
if not self.enable_layernorm:
|
402 |
+
return x
|
403 |
+
x = jnp.asarray(x)
|
404 |
+
|
405 |
+
# Calculate mean and variance at higher precision.
|
406 |
+
xf = jnp.asarray(x, jnp.float32)
|
407 |
+
if self.use_mean:
|
408 |
+
mean = jnp.mean(xf, axis=-1, keepdims=True)
|
409 |
+
xf = xf - mean
|
410 |
+
var = jnp.mean(lax.square(xf), axis=-1, keepdims=True)
|
411 |
+
mul = lax.rsqrt(var + self.epsilon)
|
412 |
+
|
413 |
+
# Rescale x
|
414 |
+
# if not use_mean, then rescale around zero instead. (A simplification.)
|
415 |
+
if self.use_mean:
|
416 |
+
y = (x - mean) * mul
|
417 |
+
else:
|
418 |
+
y = x * mul
|
419 |
+
|
420 |
+
if self.use_scalar_scale_bias:
|
421 |
+
# Learn a single scalar value for bias and scale.
|
422 |
+
# (Which mirrors the single value for mean and stddev above.)
|
423 |
+
num_scale_bias_features = 1
|
424 |
+
else:
|
425 |
+
# Learn a different value per neuron/feature for bias and scale.
|
426 |
+
num_scale_bias_features = x.shape[-1]
|
427 |
+
|
428 |
+
# Apply learned scale and bias.
|
429 |
+
if self.use_scale:
|
430 |
+
y = y * jnp.asarray(
|
431 |
+
self.param("scale", self.scale_init, (num_scale_bias_features,)),
|
432 |
+
dtype=self.dtype)
|
433 |
+
if self.use_bias:
|
434 |
+
y = y + jnp.asarray(
|
435 |
+
self.param("bias", self.bias_init, (num_scale_bias_features,)),
|
436 |
+
dtype=self.dtype)
|
437 |
+
return y.astype(self.dtype)
|
aglib/meliad/transformer/position.py
ADDED
@@ -0,0 +1,242 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 Google.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
"""Functions for dealing with relative and absolute positions, and masks."""
|
16 |
+
|
17 |
+
from typing import Optional, Tuple, Union
|
18 |
+
|
19 |
+
import jax.numpy as jnp
|
20 |
+
import numpy as np
|
21 |
+
|
22 |
+
|
23 |
+
Array = jnp.ndarray
|
24 |
+
NpArray = np.ndarray
|
25 |
+
Dtype = Union[jnp.dtype, str]
|
26 |
+
|
27 |
+
|
28 |
+
def relative_positions(num_queries: int, num_keys: int,
|
29 |
+
offset: Optional[int] = None):
|
30 |
+
"""Returns an jax array of relative positions between query and key.
|
31 |
+
|
32 |
+
If num_keys >= num_queries, e.g. for transformer XL or sliding window,
|
33 |
+
then offset should be (num_keys - num_queries) to make the last N queries
|
34 |
+
line up with the last N keys. This is the default if offset is None.
|
35 |
+
|
36 |
+
Args:
|
37 |
+
num_queries: Number of queries.
|
38 |
+
num_keys: Number of keys.
|
39 |
+
offset: Offset of the first query wrt. the first key.
|
40 |
+
|
41 |
+
Returns:
|
42 |
+
A /jax/ array of shape [num_queries, num_keys] with the signed distance
|
43 |
+
from each query to each key.
|
44 |
+
"""
|
45 |
+
|
46 |
+
# Get the offset of each query wrt. to each key.
|
47 |
+
# If not specified, assume the last N queries line up with the last N keys.
|
48 |
+
if offset is None:
|
49 |
+
if num_keys < num_queries:
|
50 |
+
raise ValueError("Number of keys %d must be greater than queries %d" %
|
51 |
+
(num_keys, num_queries))
|
52 |
+
offset = num_keys - num_queries
|
53 |
+
qidx = jnp.arange(0, num_queries, dtype=jnp.int32).reshape(num_queries, 1)
|
54 |
+
kidx = jnp.arange(0, num_keys, dtype=jnp.int32).reshape(1, num_keys)
|
55 |
+
return kidx - (qidx + offset)
|
56 |
+
|
57 |
+
|
58 |
+
def relative_positions_np(num_queries: int, num_keys: int,
|
59 |
+
offset: Optional[int] = None):
|
60 |
+
"""Returns a numpy array of relative positions between query and key.
|
61 |
+
|
62 |
+
If num_keys >= num_queries, e.g. for transformer XL or sliding window,
|
63 |
+
then offset should be (num_keys - num_queries) to make the last N queries
|
64 |
+
line up with the last N keys. This is the default if offset is None.
|
65 |
+
|
66 |
+
Args:
|
67 |
+
num_queries: Number of queries.
|
68 |
+
num_keys: Number of keys.
|
69 |
+
offset: Offset of the first query wrt. to the first key.
|
70 |
+
|
71 |
+
Returns:
|
72 |
+
A /numpy/ array of shape [num_queries, num_keys] with the signed distance
|
73 |
+
from each query to each key.
|
74 |
+
"""
|
75 |
+
|
76 |
+
# Get the offset of each query wrt. to each key.
|
77 |
+
# If not specified, assume the last N queries line up with the last N keys.
|
78 |
+
if offset is None:
|
79 |
+
if num_keys < num_queries:
|
80 |
+
raise ValueError("Number of keys %d must be greater than queries %d" %
|
81 |
+
(num_keys, num_queries))
|
82 |
+
offset = num_keys - num_queries
|
83 |
+
qidx = np.arange(0, num_queries, dtype=np.int32).reshape(num_queries, 1)
|
84 |
+
kidx = np.arange(0, num_keys, dtype=np.int32).reshape(1, num_keys)
|
85 |
+
return kidx - (qidx + offset)
|
86 |
+
|
87 |
+
|
88 |
+
def broadcast_mask(mask: Array, attn: Array):
|
89 |
+
"""Broadcast a mask or bias over all the dimensions of attn."""
|
90 |
+
|
91 |
+
# Add leading dimensions for batch_size, num_heads if necessary.
|
92 |
+
if mask.ndim < attn.ndim:
|
93 |
+
mask = jnp.expand_dims(mask, axis=tuple(range(0, attn.ndim - mask.ndim)))
|
94 |
+
return mask
|
95 |
+
|
96 |
+
|
97 |
+
def causal_mask(num_queries: int, num_keys: int, window_length: int = 0):
|
98 |
+
"""Returns a causal mask of the same shape as attn."""
|
99 |
+
|
100 |
+
# The mask ranges over the window_length positions prior to each query.
|
101 |
+
if window_length == 0:
|
102 |
+
window_length = num_queries
|
103 |
+
|
104 |
+
kqpos = relative_positions(num_queries, num_keys) # 2D mask
|
105 |
+
|
106 |
+
# The causal mask includes only those tokens *before* the current token.
|
107 |
+
# This slightly improves perplexity in practice, and simplifies generation.
|
108 |
+
# Each token attends to exactly window_length prior tokens.
|
109 |
+
mask = (kqpos < 0) & (kqpos >= -window_length)
|
110 |
+
return mask
|
111 |
+
|
112 |
+
|
113 |
+
def position_encoding(num_positions: int,
|
114 |
+
input_dim: int,
|
115 |
+
*,
|
116 |
+
offset: int = 0,
|
117 |
+
max_wavelength: float = 0) -> NpArray:
|
118 |
+
"""Returns a position encoding of shape (num_positions, input_dim).
|
119 |
+
|
120 |
+
Positions are encoded as sin/cos pairs at geometrically increasing
|
121 |
+
wavelengths.
|
122 |
+
|
123 |
+
The length of a half-wave (peak to trough) increases geometrically from 1 to
|
124 |
+
max_wavelength. (Technically, it's slightly less; the last sin/cos pair has
|
125 |
+
a wavelength of max_wavelength**((d-1)/d), where d = input_dim/2.)
|
126 |
+
|
127 |
+
NOTE: unlike prior published position encodings, we multiply the position of
|
128 |
+
each token by pi to convert from fractions of a wave (position/wavelength)
|
129 |
+
to radians. Thus, the highest frequency wave alternates between -1 and 1 on
|
130 |
+
every token, whereas in prior published work the highest frequency alternates
|
131 |
+
between -1 and 1 every pi tokens. The max_wavelength is also effectively
|
132 |
+
1/pi times as long, so a prior published factor of 10,000
|
133 |
+
(e.g. https://arxiv.org/abs/1706.03762) would equate to a max_wavelength
|
134 |
+
of 31,416.
|
135 |
+
|
136 |
+
This encoding also does not alternate between sin/cos values, but puts all of
|
137 |
+
the cos values on one side, and the sin values on the other. That makes it
|
138 |
+
easier to split the sin,cos values to construct or apply a rotation matrix.
|
139 |
+
|
140 |
+
The default value for max_wavelength is 2 * num_positions.
|
141 |
+
|
142 |
+
Args:
|
143 |
+
num_positions: The number of positions.
|
144 |
+
input_dim: The dimension of the position vector.
|
145 |
+
*: --- The following are keyword arguments only. ---
|
146 |
+
offset: Positions count from offset to (offset + num_positions).
|
147 |
+
max_wavelength: The maximum length of a half-wave (peak to trough)
|
148 |
+
|
149 |
+
Returns:
|
150 |
+
Numpy matrix of shape (num_positions, input_dim) containing the encodings.
|
151 |
+
Position encodings are packed as concat(cos_values, sin_values, axis=1).
|
152 |
+
"""
|
153 |
+
|
154 |
+
if max_wavelength == 0:
|
155 |
+
max_wavelength = 2 * num_positions
|
156 |
+
assert max_wavelength > 1
|
157 |
+
|
158 |
+
assert (input_dim % 2) == 0
|
159 |
+
idim2 = input_dim // 2
|
160 |
+
|
161 |
+
# t ranges from 0 <= t < 1
|
162 |
+
t = np.arange(0, idim2, dtype=np.float32) / idim2
|
163 |
+
|
164 |
+
# wavelength (columns)
|
165 |
+
# The length of a half-wave (trough to peak) increases geometrically
|
166 |
+
# 1 <= wavelength < max_wavelength
|
167 |
+
wavelength = float(max_wavelength)**t
|
168 |
+
wavelength = np.reshape(wavelength, (1, idim2)) # broadcast over rows
|
169 |
+
|
170 |
+
# k is the position in the sequence (rows)
|
171 |
+
k = np.arange(offset, num_positions + offset, dtype=np.float32)
|
172 |
+
k = np.reshape(k, (num_positions, 1)) # broadcast over columns
|
173 |
+
|
174 |
+
# For each position (row) compute an angle (column) at various wavelengths.
|
175 |
+
# NOTE: unlike prior published work, we multiply by pi to convert to radians.
|
176 |
+
pi_f = np.array(np.pi, dtype=np.float32)
|
177 |
+
angles = pi_f * k / wavelength # shape (num_positions, idim2)
|
178 |
+
posx = np.cos(angles, dtype=np.float32)
|
179 |
+
posy = np.sin(angles, dtype=np.float32)
|
180 |
+
return np.concatenate([posx, posy], axis=1) # shape (num_positions, idim)
|
181 |
+
|
182 |
+
|
183 |
+
def rotate_kq(keys: Array, queries: Array,
|
184 |
+
*, # the following args must be passed by keyword.
|
185 |
+
max_wavelength: float,
|
186 |
+
offset: Optional[int] = None,
|
187 |
+
dtype: Optional[Dtype] = None) -> Tuple[Array, Array]:
|
188 |
+
"""Rotate keys and queries by the relative distance between query and key.
|
189 |
+
|
190 |
+
Implements rotary position embeddings (RoPE) https://arxiv.org/abs/2104.09864.
|
191 |
+
|
192 |
+
Args:
|
193 |
+
keys: array of shape (batch_size, num_keys, num_heads, head_size)
|
194 |
+
queries: aray of shape (batch_size, num_queries, num_heads, head_size)
|
195 |
+
max_wavelength: The maximum length of a half-wave (peak to trough)
|
196 |
+
offset: The relative positional offset from keys[i] to queries[i].
|
197 |
+
Defaults to num_keys - num_queries if not specified.
|
198 |
+
dtype: The precision to perform the rotation at.
|
199 |
+
Defaults to keys.dtype.
|
200 |
+
|
201 |
+
Returns:
|
202 |
+
(keys, queries) after rotation.
|
203 |
+
"""
|
204 |
+
|
205 |
+
(batch_size, num_keys, num_heads, head_size) = keys.shape
|
206 |
+
(_, num_queries, _, _) = queries.shape
|
207 |
+
assert queries.shape == (batch_size, num_queries, num_heads, head_size)
|
208 |
+
|
209 |
+
if offset is None:
|
210 |
+
assert num_keys >= num_queries
|
211 |
+
offset = num_keys - num_queries
|
212 |
+
|
213 |
+
if dtype is None:
|
214 |
+
dtype = keys.dtype
|
215 |
+
|
216 |
+
def rotate_k_or_q(kq: Array, num_kq: int, kq_offset: int) -> Array:
|
217 |
+
nonlocal max_wavelength
|
218 |
+
nonlocal dtype
|
219 |
+
|
220 |
+
# Get position encodings, which can be used to do a rotation.
|
221 |
+
kq_pos = position_encoding(num_kq, head_size, offset=kq_offset,
|
222 |
+
max_wavelength=max_wavelength)
|
223 |
+
# Broadcast over batch_size and num_heads.
|
224 |
+
kq_pos = np.reshape(kq_pos, (1, num_kq, 1, head_size))
|
225 |
+
# Split position encoding into separate sin/cos values in order to
|
226 |
+
# construct a rotation matrix.
|
227 |
+
(cosa, sina) = np.split(kq_pos, 2, axis=-1)
|
228 |
+
cosa = jnp.asarray(cosa, dtype=dtype) # convert from numpy -> jax
|
229 |
+
sina = jnp.asarray(sina, dtype=dtype) # convert from numpy -> jax
|
230 |
+
|
231 |
+
# Split keys/queries into real & imaginary (i.e. x & y) parts.
|
232 |
+
(kqx, kqy) = jnp.split(kq, 2, axis=-1)
|
233 |
+
# Apply rotation matrix.
|
234 |
+
kqx_rot = (kqx * cosa) - (kqy * sina)
|
235 |
+
kqy_rot = (kqx * sina) + (kqy * cosa)
|
236 |
+
# Concatenate back into keys/queries.
|
237 |
+
return jnp.concatenate([kqx_rot, kqy_rot], axis=-1)
|
238 |
+
|
239 |
+
keys = rotate_k_or_q(keys, num_keys, -offset) # pylint: disable=invalid-unary-operand-type
|
240 |
+
queries = rotate_k_or_q(queries, num_queries, 0)
|
241 |
+
return (keys, queries)
|
242 |
+
|
aglib/meliad/transformer/position_fourier.py
ADDED
@@ -0,0 +1,218 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 Google.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
"""Class for Fourier relative position biases.
|
16 |
+
|
17 |
+
This implementation uses the same Fourier position encodings that are used
|
18 |
+
in the absolute position encoding. However, instead of adding the positions
|
19 |
+
to the input, where the position vector and content vectors become entangled,
|
20 |
+
the relative encoding computes a relative position bias matrix, which is then
|
21 |
+
added to the content-based attention matrix before applying softmax.
|
22 |
+
|
23 |
+
The bias matrix is computed as follows. First, a learned transformation is
|
24 |
+
applied to each query position, which transforms it so that it matches a set
|
25 |
+
of key positions. The relative position bias between query 'i' and key 'j' is
|
26 |
+
the dot product between the transformed position 'i', and position 'j'.
|
27 |
+
|
28 |
+
The learned transformation is designed so that the match between query and key
|
29 |
+
is a function of the relative distance between the two. Although absolute
|
30 |
+
positions are fed as inputs, the rest of the network can't "see" the absolute
|
31 |
+
positions; it can only transform them by some relative amount.
|
32 |
+
|
33 |
+
A position vector consists of a sequence of (sin, cos) pairs, which have
|
34 |
+
geometrically increasing wavelengths that span from 2 (for the first pair
|
35 |
+
in each vector) to twice the length of the token sequence (for the last pair).
|
36 |
+
Each sin/cos pair encodes the (x, y) value of a 2D unit vector at a particular
|
37 |
+
angle. For each sin/cos pair in the query position vector, we apply a learned
|
38 |
+
2x2 rotation matrix, which will rotate and scale the pair by some amount.
|
39 |
+
|
40 |
+
The dot product of two (sin, cos) pairs is the cosine of the angle between them.
|
41 |
+
The dot product of the query position and key position vectors is thus the sum
|
42 |
+
of such cosines. By rotating and scaling the query position, it is possible to
|
43 |
+
approximate any function over relative position as a Fourier series: a sum of
|
44 |
+
cosine waves at different wavelengths. The rotation provides phase, and the
|
45 |
+
scale provides magnitude.
|
46 |
+
|
47 |
+
Put another way, rotating the (sin, cos) pairs of a query position will compute
|
48 |
+
a relative offset from the /query/ position to some target /key/ position.
|
49 |
+
"""
|
50 |
+
|
51 |
+
from typing import Any, Optional
|
52 |
+
|
53 |
+
from flax import linen as nn
|
54 |
+
import gin
|
55 |
+
import jax.numpy as jnp
|
56 |
+
from transformer import position
|
57 |
+
import numpy as np
|
58 |
+
|
59 |
+
|
60 |
+
Array = jnp.ndarray
|
61 |
+
|
62 |
+
|
63 |
+
def _initialize_frel_rotation_matrix(rng, num_heads, vec_size):
|
64 |
+
"""Intialize the rotation matrices."""
|
65 |
+
# Initialize each rotation matrix to the identity * scale.
|
66 |
+
#
|
67 |
+
# Initially scale by 1 / number of sine waves = 1/2 the position vector size.
|
68 |
+
# With this initialization, the initial position bias terms should be
|
69 |
+
# between -1.0 and 1.0 after the rotation matrix has been applied.
|
70 |
+
del rng # required for init function but unused
|
71 |
+
scale = float(2.0 / vec_size)
|
72 |
+
tmat_a = jnp.ones([num_heads, vec_size // 2], dtype=jnp.float32) * scale
|
73 |
+
tmat_b = jnp.zeros([num_heads, vec_size // 2], dtype=jnp.float32)
|
74 |
+
return jnp.concatenate([tmat_a, tmat_b], axis=1)
|
75 |
+
|
76 |
+
|
77 |
+
@gin.configurable
|
78 |
+
class RelativeFourierPositions(nn.Module):
|
79 |
+
"""A implementation of Fourier relative positions."""
|
80 |
+
|
81 |
+
# The number of attention heads.
|
82 |
+
num_heads: int = 8
|
83 |
+
|
84 |
+
# The maximum number of keys to attend to.
|
85 |
+
# The sin/cos wavelengths of the position vectors will be tuned to this max.
|
86 |
+
max_number_of_keys: int = 1024
|
87 |
+
|
88 |
+
# Size of the position vector. Needs to be large enough to address the keys.
|
89 |
+
position_vector_size: int = 128
|
90 |
+
|
91 |
+
# Data type to use for the rotation matrices.
|
92 |
+
dtype: Any = jnp.float32
|
93 |
+
|
94 |
+
@nn.compact
|
95 |
+
def __call__(self, num_queries: int, num_keys: int,
|
96 |
+
offset: Optional[int] = None,
|
97 |
+
bidirectional: bool = True) -> Array:
|
98 |
+
"""Returns relative positional attention matrix.
|
99 |
+
|
100 |
+
If num_keys >= num_queries, e.g. for transformer XL or sliding window,
|
101 |
+
then offset should be (num_keys - num_queries) to make the last N queries
|
102 |
+
line up with the last N keys. This is the default if offset is None.
|
103 |
+
|
104 |
+
Args:
|
105 |
+
num_queries: Number of queries.
|
106 |
+
num_keys: Number of keys.
|
107 |
+
offset: Offset of the first query with respect to the first key.
|
108 |
+
(See position.relative_positions() for more info.)
|
109 |
+
bidirectional: Unused, included for compatibility.
|
110 |
+
Relative positions are always bidirectional.
|
111 |
+
Returns:
|
112 |
+
Attention matrix of shape (num_heads, num_queries, num_keys)
|
113 |
+
"""
|
114 |
+
|
115 |
+
# Get the offset of each query with respect to each key.
|
116 |
+
# If not specified, the last N queries line up with the last N keys.
|
117 |
+
if offset is None:
|
118 |
+
assert num_keys >= num_queries
|
119 |
+
offset = num_keys - num_queries
|
120 |
+
max_wavelength = 2 * self.max_number_of_keys
|
121 |
+
|
122 |
+
# Compute absolute position vectors for keys.
|
123 |
+
# Use numpy to compute these arrays statically.
|
124 |
+
# ks : (num_keys, pvec_size)
|
125 |
+
ks = position.position_encoding(num_keys,
|
126 |
+
self.position_vector_size,
|
127 |
+
offset=0, # offset of queries wrt. keys
|
128 |
+
max_wavelength=max_wavelength)
|
129 |
+
|
130 |
+
# Compute absolute position vectors for queries.
|
131 |
+
# qs : (num_queries, pvec_size)
|
132 |
+
if offset >= 0 and offset + num_queries <= num_keys:
|
133 |
+
# Query positions are a subset of the key positions.
|
134 |
+
qs = ks[offset:offset + num_queries]
|
135 |
+
else:
|
136 |
+
# Query positions must be computed separately.
|
137 |
+
qs = position.position_encoding(num_queries,
|
138 |
+
self.position_vector_size,
|
139 |
+
offset=offset,
|
140 |
+
max_wavelength=max_wavelength)
|
141 |
+
|
142 |
+
# Split qs into x and y coordinates for rotation.
|
143 |
+
(qx, qy) = np.split(qs, 2, axis=-1)
|
144 |
+
qs_xs = np.concatenate([qx, qx], axis=-1)
|
145 |
+
qs_ys = np.concatenate([qy, qy], axis=-1)
|
146 |
+
del qs
|
147 |
+
|
148 |
+
# Convert from numpy to jax.
|
149 |
+
ks = jnp.asarray(ks, dtype=self.dtype)
|
150 |
+
qs_xs = jnp.asarray(qs_xs, dtype=self.dtype)
|
151 |
+
qs_ys = jnp.asarray(qs_ys, dtype=self.dtype)
|
152 |
+
|
153 |
+
# Initialize the rotation matrices to the identity.
|
154 |
+
rotation_matrix = self.param("rotation_matrix",
|
155 |
+
_initialize_frel_rotation_matrix,
|
156 |
+
self.num_heads,
|
157 |
+
self.position_vector_size)
|
158 |
+
|
159 |
+
rotation_matrix = jnp.asarray(rotation_matrix, dtype=self.dtype)
|
160 |
+
|
161 |
+
# Unpack rotatation_matrix to a set of 2x2 matrices.
|
162 |
+
rmat1 = rotation_matrix # [rm_a, rm_b]
|
163 |
+
(rm_a, rm_b) = jnp.split(rotation_matrix, 2, axis=-1)
|
164 |
+
rmat2 = jnp.concatenate([-rm_b, rm_a], axis=-1)
|
165 |
+
|
166 |
+
# Vectors in qs consist of a set of (x,y) (e.g. sin,cos) pairs.
|
167 |
+
# We transform each (x,y) pair with a 2D rotation matrix:
|
168 |
+
#
|
169 |
+
# x' = a*x + -b*y
|
170 |
+
# y' = b*x + a*y
|
171 |
+
#
|
172 |
+
# or equivalently, x' + y'i = (a + bi)(x + yi) where i = sqrt(-1).
|
173 |
+
#
|
174 |
+
# For an angle theta, and scale s, a = cos(theta)*s, b = sin(theta)*s,
|
175 |
+
# and a + bi = s*exp(i*theta). We avoid computing sin,cos by training a,b
|
176 |
+
# directly.
|
177 |
+
#
|
178 |
+
# qs_xs = [x0 .. xn; x0 .. xn] -- layout of qs_xs
|
179 |
+
# qs_ys = [y0 .. yn; y0 .. yn]
|
180 |
+
# rmat1 = [a0 .. an; b0 .. bn] -- layout of (a,b) values in rmat1
|
181 |
+
# rmat2 = [-b0 .. -bn; a0 .. an]
|
182 |
+
#
|
183 |
+
# rot_qs: (num_heads, num_queries, pvec_size)
|
184 |
+
|
185 |
+
# Broadcast qs over the number of heads.
|
186 |
+
# Broadcast rmat over the number of queries.
|
187 |
+
qs_xs = qs_xs[jnp.newaxis, ...] # (1, num_queries, pvec_size)
|
188 |
+
qs_ys = qs_ys[jnp.newaxis, ...]
|
189 |
+
rmat1 = rmat1[:, jnp.newaxis, ...] # (num_heads, 1, pvec_size)
|
190 |
+
rmat2 = rmat2[:, jnp.newaxis, ...]
|
191 |
+
rot_qs = ((rmat1 * qs_xs) + (rmat2 * qs_ys))
|
192 |
+
|
193 |
+
# Compute the dot product of each position vector in ks by the rotated qs.
|
194 |
+
#
|
195 |
+
# The dot product of each (x, y) pair in ks, and each (x', y') in rot_qs,
|
196 |
+
# is equal to the cosine of the angle between them, times the length
|
197 |
+
# of (x', y').
|
198 |
+
#
|
199 |
+
# The angle of the cosine for each pair depends on:
|
200 |
+
# - The distance between the key and the query, divided by the wavelength.
|
201 |
+
# (From the initial position encoding for ks and qs).
|
202 |
+
# - The rotation performed by (a,b).
|
203 |
+
#
|
204 |
+
# The length of (x', y') is equal to the scale of (a, b).
|
205 |
+
#
|
206 |
+
# The dot product of two complete position vectors is the sum of the
|
207 |
+
# cosines for all pairs. The cosines form a progression of geometrically
|
208 |
+
# increasing wavelengths, and each wave has a scale and phase provided by
|
209 |
+
# the rotation matrix. The sum of such waves can thus approximate any
|
210 |
+
# function of position.
|
211 |
+
#
|
212 |
+
# pbias: (num_heads, num_queries, num_keys)
|
213 |
+
pbias = jnp.einsum("hqd,kd->hqk", rot_qs, ks)
|
214 |
+
|
215 |
+
# Add batch dimension; --> shape (1, num_heads, num_queries, num_keys)
|
216 |
+
pbias = jnp.expand_dims(pbias, 0)
|
217 |
+
return pbias.astype(self.dtype)
|
218 |
+
|
aglib/meliad/transformer/position_t5.py
ADDED
@@ -0,0 +1,155 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 Google.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
"""Class for T5 relative position biases.
|
16 |
+
|
17 |
+
Adapted from flaxformer.components.relative_position_biases.py
|
18 |
+
"""
|
19 |
+
|
20 |
+
from typing import Any, Callable, Optional
|
21 |
+
|
22 |
+
from flax import linen as nn
|
23 |
+
import gin
|
24 |
+
from jax import lax
|
25 |
+
import jax.numpy as jnp
|
26 |
+
from transformer import position
|
27 |
+
import numpy as np
|
28 |
+
|
29 |
+
|
30 |
+
Array = Any
|
31 |
+
|
32 |
+
|
33 |
+
@gin.configurable
|
34 |
+
class T5RelativePositionBiases(nn.Module):
|
35 |
+
"""Adds T5-style relative positional embeddings to the attention logits.
|
36 |
+
|
37 |
+
Attributes:
|
38 |
+
num_buckets: Number of buckets to bucket distances between key and query
|
39 |
+
positions into.
|
40 |
+
max_distance: Maximum distance before everything is lumped into the last
|
41 |
+
distance bucket.
|
42 |
+
num_heads: Number of heads in the attention layer. Each head will get a
|
43 |
+
different relative position weighting.
|
44 |
+
dtype: Type of arrays through this module.
|
45 |
+
embedding_init: initializer for relative embedding table.
|
46 |
+
"""
|
47 |
+
num_buckets: int
|
48 |
+
max_distance: int
|
49 |
+
num_heads: int
|
50 |
+
dtype: Any
|
51 |
+
embedding_init: Callable[..., Array] = nn.linear.default_embed_init
|
52 |
+
|
53 |
+
@staticmethod
|
54 |
+
def _relative_position_bucket(relative_position,
|
55 |
+
bidirectional=True,
|
56 |
+
num_buckets=32,
|
57 |
+
max_distance=128):
|
58 |
+
"""Translate relative position to a bucket number for relative attention.
|
59 |
+
|
60 |
+
The relative position is defined as memory_position - query_position, i.e.
|
61 |
+
the distance in tokens from the attending position to the attended-to
|
62 |
+
position. If bidirectional=False, then positive relative positions are
|
63 |
+
invalid.
|
64 |
+
We use smaller buckets for small absolute relative_position and larger
|
65 |
+
buckets for larger absolute relative_positions. All relative
|
66 |
+
positions >=max_distance map to the same bucket. All relative
|
67 |
+
positions <=-max_distance map to the same bucket. This should allow for
|
68 |
+
more graceful generalization to longer sequences than the model has been
|
69 |
+
trained on.
|
70 |
+
|
71 |
+
Args:
|
72 |
+
relative_position: an int32 array
|
73 |
+
bidirectional: a boolean - whether the attention is bidirectional
|
74 |
+
num_buckets: an integer
|
75 |
+
max_distance: an integer
|
76 |
+
|
77 |
+
Returns:
|
78 |
+
a Tensor with the same shape as relative_position, containing int32
|
79 |
+
values in the range [0, num_buckets)
|
80 |
+
"""
|
81 |
+
ret = 0
|
82 |
+
n = -relative_position
|
83 |
+
if bidirectional:
|
84 |
+
num_buckets //= 2
|
85 |
+
ret += (n < 0).astype(np.int32) * num_buckets
|
86 |
+
n = np.abs(n)
|
87 |
+
else:
|
88 |
+
n = np.maximum(n, 0)
|
89 |
+
# now n is in the range [0, inf)
|
90 |
+
max_exact = num_buckets // 2
|
91 |
+
is_small = (n < max_exact)
|
92 |
+
val_if_large = max_exact + (
|
93 |
+
np.log(n.astype(np.float32) / max_exact + np.finfo(np.float32).eps) /
|
94 |
+
np.log(max_distance / max_exact) *
|
95 |
+
(num_buckets - max_exact)).astype(np.int32)
|
96 |
+
val_if_large = np.minimum(val_if_large, num_buckets - 1)
|
97 |
+
ret += np.where(is_small, n, val_if_large)
|
98 |
+
return ret
|
99 |
+
|
100 |
+
@nn.compact
|
101 |
+
def __call__(self, num_queries, num_keys, offset: Optional[int]=None,
|
102 |
+
bidirectional=True):
|
103 |
+
"""Produce relative position embedding attention biases.
|
104 |
+
|
105 |
+
Args:
|
106 |
+
num_queries: Number of queries.
|
107 |
+
num_keys: Number of keys.
|
108 |
+
offset: Offset of the first query with respect to the first key.
|
109 |
+
(See position.relative_positions() for more info.)
|
110 |
+
bidirectional: whether to allow positive memory-query relative position
|
111 |
+
embeddings.
|
112 |
+
|
113 |
+
Returns:
|
114 |
+
output: `(1, num_heads, num_queries, num_keys)` attention bias
|
115 |
+
"""
|
116 |
+
|
117 |
+
# Find the distance between each query and each key.
|
118 |
+
# This is where this implementation differs from the T5 implementation;
|
119 |
+
# this version lines the /last/ N queries up with the /last/ N keys,
|
120 |
+
# (which is appropriate for XL/sliding window) while the T5 version lines
|
121 |
+
# up the /first/ N queries with the first N keys, in cases where the
|
122 |
+
# number of keys and queries differ.
|
123 |
+
relative_position = position.relative_positions_np(
|
124 |
+
num_queries=num_queries, num_keys=num_keys, offset=offset)
|
125 |
+
|
126 |
+
rp_bucket = self._relative_position_bucket(
|
127 |
+
relative_position,
|
128 |
+
bidirectional=bidirectional,
|
129 |
+
num_buckets=self.num_buckets,
|
130 |
+
max_distance=self.max_distance)
|
131 |
+
relative_attention_bias = self.param('rel_embedding', self.embedding_init,
|
132 |
+
(self.num_heads, self.num_buckets),
|
133 |
+
jnp.float32)
|
134 |
+
|
135 |
+
relative_attention_bias = jnp.asarray(relative_attention_bias, self.dtype)
|
136 |
+
# Instead of using a slow gather, we create a leading-dimension one-hot
|
137 |
+
# array from rp_bucket and use it to perform the gather-equivalent via a
|
138 |
+
# contraction, i.e.:
|
139 |
+
# (num_head, num_buckets) x (num_buckets one-hot, num_queries, num_keys).
|
140 |
+
# This is equivalent to relative_attention_bias[:, rp_bucket]
|
141 |
+
bcast_iota = lax.broadcasted_iota(jnp.int32, (self.num_buckets, 1, 1), 0)
|
142 |
+
rp_bucket_one_hot = jnp.array(
|
143 |
+
rp_bucket[jnp.newaxis, ...] == bcast_iota, dtype=self.dtype)
|
144 |
+
# --> shape (num_queries, num_keys, num_heads)
|
145 |
+
values = lax.dot_general(
|
146 |
+
relative_attention_bias,
|
147 |
+
rp_bucket_one_hot,
|
148 |
+
(
|
149 |
+
((1,), (0,)), # rhs, lhs contracting dims
|
150 |
+
((), ()))) # no batched dims
|
151 |
+
# Add a singleton batch dimension.
|
152 |
+
# --> shape (1, num_heads, num_queries, num_keys)
|
153 |
+
out = values[jnp.newaxis, ...]
|
154 |
+
|
155 |
+
return out
|
aglib/meliad/transformer/synthetic_text_data.py
ADDED
The diff for this file is too large to render.
See raw diff
|
|
aglib/meliad/transformer/tasks.py
ADDED
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 Google.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
"""Add Tasks to registry."""
|
16 |
+
|
17 |
+
import functools
|
18 |
+
|
19 |
+
from transformer import text_dataset
|
20 |
+
import seqio
|
21 |
+
import t5.data
|
22 |
+
from t5.data import preprocessors
|
23 |
+
import tensorflow as tf
|
24 |
+
|
25 |
+
|
26 |
+
TaskRegistry = seqio.TaskRegistry
|
27 |
+
|
28 |
+
|
29 |
+
def define_pg19_task(name: str, vocab: seqio.Vocabulary):
|
30 |
+
seqio.TaskRegistry.add(
|
31 |
+
name,
|
32 |
+
seqio.TfdsDataSource(
|
33 |
+
tfds_name="pg19:0.1.1"
|
34 |
+
),
|
35 |
+
preprocessors=[
|
36 |
+
functools.partial(text_dataset.rekey_articles,
|
37 |
+
rekey={"book_text": "targets"},
|
38 |
+
keep={"book_title", "book_id", "publication_date"}),
|
39 |
+
seqio.preprocessors.tokenize,
|
40 |
+
],
|
41 |
+
output_features={
|
42 |
+
"targets": seqio.Feature(vocab,
|
43 |
+
add_eos=False, dtype=tf.int32),
|
44 |
+
}
|
45 |
+
)
|
46 |
+
|
47 |
+
|
48 |
+
T5_DEFAULT_VOCABULARY = t5.data.get_default_vocabulary()
|
49 |
+
define_pg19_task("pg19_bytes", seqio.ByteVocabulary())
|
50 |
+
define_pg19_task("pg19_tokens", T5_DEFAULT_VOCABULARY)
|
51 |
+
|
52 |
+
|
aglib/meliad/transformer/text_dataset.py
ADDED
@@ -0,0 +1,775 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 Google.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
"""Load text datasets for long-range transformer models."""
|
16 |
+
|
17 |
+
import os
|
18 |
+
import re
|
19 |
+
from typing import Any, Callable, Dict, Iterable, Mapping, Optional, Sequence, Set, Tuple, Union
|
20 |
+
|
21 |
+
from absl import flags
|
22 |
+
from absl import logging
|
23 |
+
import gin
|
24 |
+
import jax
|
25 |
+
from transformer import synthetic_text_data
|
26 |
+
import numpy as np
|
27 |
+
import seqio
|
28 |
+
import tensorflow.compat.v2 as tf
|
29 |
+
|
30 |
+
|
31 |
+
|
32 |
+
flags.DEFINE_string("default_data_dir", None,
|
33 |
+
"Default directory where data is stored.")
|
34 |
+
FLAGS = flags.FLAGS
|
35 |
+
|
36 |
+
|
37 |
+
_DEFAULT_DATA_DIRECTORY = None
|
38 |
+
|
39 |
+
|
40 |
+
@gin.configurable
|
41 |
+
def set_default_data_directory(directory_name=None):
|
42 |
+
"""Set the default directory where training data is located."""
|
43 |
+
global _DEFAULT_DATA_DIRECTORY
|
44 |
+
# If the data directory has been overridden with a command-line flag, use it.
|
45 |
+
# If not, the see if directory_name has been configured by Gin.
|
46 |
+
# Otherwise, use the default tfds directory.
|
47 |
+
if FLAGS.default_data_dir:
|
48 |
+
directory_name = FLAGS.default_data_dir
|
49 |
+
if directory_name is not None:
|
50 |
+
seqio.set_tfds_data_dir_override(directory_name)
|
51 |
+
_DEFAULT_DATA_DIRECTORY = directory_name
|
52 |
+
|
53 |
+
|
54 |
+
def get_iterator_function(dataset: Optional[tf.data.Dataset]):
|
55 |
+
"""Returns a function which gets an iterator over the given dataset."""
|
56 |
+
if dataset is None:
|
57 |
+
return None
|
58 |
+
else:
|
59 |
+
return dataset.as_numpy_iterator
|
60 |
+
|
61 |
+
|
62 |
+
@gin.configurable
|
63 |
+
def get_loss_mask_tokens(
|
64 |
+
split: str,
|
65 |
+
loss_mask_start_tokens: Sequence[int] = (),
|
66 |
+
loss_mask_end_tokens: Sequence[int] = (),
|
67 |
+
splits: Sequence[str] = ("all",)
|
68 |
+
) -> Tuple[Sequence[int], Sequence[int]]:
|
69 |
+
"""Returns two token sequences to indicate start and end of the loss.
|
70 |
+
|
71 |
+
Please configure loss_mask_start_tokens, loss_mask_end_tokens, and
|
72 |
+
split_filter via gin. Example gin config to only apply loss between tokens 2
|
73 |
+
and 1 for the test set (and everywhere for any other data split):
|
74 |
+
|
75 |
+
```
|
76 |
+
text_dataset.get_loss_mask_tokens:
|
77 |
+
loss_mask_start_tokens=(2,)
|
78 |
+
loss_mask_end_tokens=(1,)
|
79 |
+
restrict_to_splits=("test",)
|
80 |
+
```
|
81 |
+
|
82 |
+
Args:
|
83 |
+
split: The mode ("test", "train", ...)
|
84 |
+
loss_mask_start_tokens: token sequence to starts the loss
|
85 |
+
loss_mask_end_tokens: token sequence to stop the loss
|
86 |
+
splits: Only compute the loss mask for splits in this list.
|
87 |
+
By default it is 'all', which is a reserved split string that applies to
|
88 |
+
all splits.
|
89 |
+
"""
|
90 |
+
if "all" in splits or split in splits:
|
91 |
+
return loss_mask_start_tokens, loss_mask_end_tokens
|
92 |
+
return (), ()
|
93 |
+
|
94 |
+
|
95 |
+
@gin.configurable
|
96 |
+
def load_text_dataset(name: str,
|
97 |
+
split: str,
|
98 |
+
sequence_length: int,
|
99 |
+
batch_size: int,
|
100 |
+
sequential: bool = True,
|
101 |
+
shard_dataset: bool = True,
|
102 |
+
verbose: bool = False,
|
103 |
+
) -> Tuple[tf.data.Dataset, seqio.Vocabulary]:
|
104 |
+
"""Load a text dataset of long articles or books, and split_and_batch them.
|
105 |
+
|
106 |
+
The input dataset must produce complete books or articles, where each article
|
107 |
+
is a dictionary containing a "tokens" field.
|
108 |
+
See split_and_batch for more information on the output dataset.
|
109 |
+
|
110 |
+
Args:
|
111 |
+
name: The name of the seqio task which produces the dataset.
|
112 |
+
split: The name of the split to use, e.g. "train" or "test".
|
113 |
+
sequence_length: Split text into sequences of this length.
|
114 |
+
batch_size: Draw from batch_size articles in each batch.
|
115 |
+
sequential: If True, return the chunks of each article in sequence.
|
116 |
+
shard_dataset: If True, split data set into shards.
|
117 |
+
verbose: Log (an excerpt) of every text example loaded from disk. If False,
|
118 |
+
will only print 1 excerpt every 60 seconds.
|
119 |
+
|
120 |
+
Returns:
|
121 |
+
(dataset, vocabulary)
|
122 |
+
where vocabulary is the seqio.Vocabulary which is used to encode "targets".
|
123 |
+
"""
|
124 |
+
|
125 |
+
logging.info("Loading text data set %s, split=%s, shape=(%d, %d)",
|
126 |
+
name, split, batch_size, sequence_length)
|
127 |
+
|
128 |
+
if name == "synthetic":
|
129 |
+
ds = synthetic_data_long(split, sequence_length, batch_size)
|
130 |
+
return (ds, seqio.PassThroughVocabulary(256, 0))
|
131 |
+
elif name == "synthetic_short":
|
132 |
+
ds = synthetic_data_short(split, sequence_length, batch_size)
|
133 |
+
return (ds, seqio.PassThroughVocabulary(256, 0))
|
134 |
+
elif name == "enwik8":
|
135 |
+
# TODO(delesley): Encapsulate enwik8 into a Task.
|
136 |
+
ds = load_enwik8(split, sequence_length, batch_size,
|
137 |
+
data_dir=_DEFAULT_DATA_DIRECTORY)
|
138 |
+
return (ds, seqio.PassThroughVocabulary(256, 0))
|
139 |
+
|
140 |
+
# Bypass the seqio "feature converter", and get the task directly.
|
141 |
+
task = seqio.get_mixture_or_task(name)
|
142 |
+
vocab = task.output_features["targets"].vocabulary
|
143 |
+
|
144 |
+
# Create the task input pipeline.
|
145 |
+
if shard_dataset:
|
146 |
+
logging.info("Shards: %d of %d", jax.process_index(), jax.process_count())
|
147 |
+
shard_info = seqio.ShardInfo(index=jax.process_index(),
|
148 |
+
num_shards=jax.process_count())
|
149 |
+
else:
|
150 |
+
shard_info = None
|
151 |
+
|
152 |
+
if sequential:
|
153 |
+
task_seqlen = None # We do our own splitting.
|
154 |
+
shuffle_buffer_size = 1000 # Number of full-length books.
|
155 |
+
else:
|
156 |
+
task_seqlen = {"targets": sequence_length} # Ask the task to do splitting.
|
157 |
+
shuffle_buffer_size = 10_000 # Number of chunks.
|
158 |
+
|
159 |
+
ds = task.get_dataset(
|
160 |
+
sequence_length=task_seqlen,
|
161 |
+
split=split,
|
162 |
+
use_cached=False,
|
163 |
+
shuffle=True,
|
164 |
+
shuffle_buffer_size=shuffle_buffer_size,
|
165 |
+
seed=None,
|
166 |
+
shard_info=shard_info,
|
167 |
+
num_epochs=1)
|
168 |
+
|
169 |
+
if sequence_length == 0:
|
170 |
+
return (ds, vocab) # Don't chop into subsequences.
|
171 |
+
|
172 |
+
def extract_fn(article):
|
173 |
+
return article["targets"]
|
174 |
+
|
175 |
+
include_loss_mask = bool(get_loss_mask_tokens(split)[0])
|
176 |
+
ds = split_and_batch(ds,
|
177 |
+
split=split,
|
178 |
+
extract_fn=extract_fn,
|
179 |
+
sequence_length=sequence_length,
|
180 |
+
batch_size=batch_size,
|
181 |
+
auto_rewind=True,
|
182 |
+
vocab=vocab,
|
183 |
+
include_loss_mask=include_loss_mask,
|
184 |
+
verbose=verbose)
|
185 |
+
return (ds, vocab)
|
186 |
+
|
187 |
+
|
188 |
+
def rekey_articles(ds: tf.data.Dataset,
|
189 |
+
rekey: Mapping[str, str],
|
190 |
+
keep: Optional[Set[str]] = None) -> tf.data.Dataset:
|
191 |
+
"""Rekey the articles in ds.
|
192 |
+
|
193 |
+
Fields in rekey will be renamed, field in keep will be kept, others will
|
194 |
+
be discarded. E.g., For PG19:
|
195 |
+
|
196 |
+
rekey_article(ds,
|
197 |
+
rekey={"book_text": "targets"},
|
198 |
+
keep={"book_title", "book_id"})
|
199 |
+
Args:
|
200 |
+
ds: The dataset to rekey.
|
201 |
+
rekey: Dictionary which contains fields to rename.
|
202 |
+
keep: Set of fields to keep.
|
203 |
+
|
204 |
+
Returns:
|
205 |
+
A rekeyed dataset.
|
206 |
+
"""
|
207 |
+
|
208 |
+
def rekey_fn(article):
|
209 |
+
result_dict = {}
|
210 |
+
for (k, v) in article.items():
|
211 |
+
if k in rekey:
|
212 |
+
result_dict[rekey[k]] = v
|
213 |
+
elif k in keep:
|
214 |
+
result_dict[k] = v
|
215 |
+
return result_dict
|
216 |
+
|
217 |
+
return ds.map(rekey_fn)
|
218 |
+
|
219 |
+
|
220 |
+
def pretty_print_article(article,
|
221 |
+
vocab_map: Mapping[str, Optional[seqio.Vocabulary]],
|
222 |
+
max_length: int = 60) -> str:
|
223 |
+
"""Convert the contents of a long article to a short string."""
|
224 |
+
if not hasattr(article, "items"):
|
225 |
+
return pretty_print_value(article, max_length) # Not a dictionary.
|
226 |
+
dstr = "{"
|
227 |
+
for (k, v) in article.items():
|
228 |
+
if vocab_map and k in vocab_map:
|
229 |
+
vstr = decode_tokens(v, vocab_map[k], max_length)
|
230 |
+
else:
|
231 |
+
vstr = pretty_print_value(v, max_length)
|
232 |
+
dstr += "\n " + k + ": " + vstr
|
233 |
+
return dstr + "\n}"
|
234 |
+
|
235 |
+
|
236 |
+
def pretty_print_value(value, max_length: int) -> str:
|
237 |
+
"""Convert a possibly large value to a short string."""
|
238 |
+
if isinstance(value, bytes):
|
239 |
+
if len(value) <= max_length:
|
240 |
+
return str(value)
|
241 |
+
else:
|
242 |
+
return f"bytes[{len(value)}] " + str(value[:max_length]) + "..."
|
243 |
+
elif isinstance(value, str):
|
244 |
+
if len(value) <= max_length:
|
245 |
+
return value
|
246 |
+
else:
|
247 |
+
return f"str[{len(value)}] " + value[:max_length] + "..."
|
248 |
+
elif isinstance(value, np.ndarray):
|
249 |
+
vstr = f"ndarray({value.shape}, {value.dtype.str})"
|
250 |
+
if value.size <= (max_length / 4):
|
251 |
+
vstr += " = " + str(value)
|
252 |
+
return vstr
|
253 |
+
elif np.ndim(value) == 0:
|
254 |
+
return str(value) # Scalar data.
|
255 |
+
else:
|
256 |
+
return str(type(value))
|
257 |
+
|
258 |
+
|
259 |
+
def decode_tokens(tokens: Any, vocab: seqio.Vocabulary, max_length: int) -> str:
|
260 |
+
"""Convert tokens to a human-readable string."""
|
261 |
+
if isinstance(tokens, np.ndarray):
|
262 |
+
tstr = f"ndarray({tokens.shape}, {tokens.dtype.str}) = "
|
263 |
+
else:
|
264 |
+
tstr = f"{str(type(tokens))} = "
|
265 |
+
|
266 |
+
if np.ndim(tokens) == 1:
|
267 |
+
tstr += decode_tokens_1d(tokens, vocab, max_length)
|
268 |
+
elif np.ndim(tokens) == 2:
|
269 |
+
jtstr = ",\n ".join([decode_tokens_1d(s, vocab, max_length)
|
270 |
+
for s in tokens])
|
271 |
+
tstr += f"[\n {jtstr}\n ]"
|
272 |
+
else:
|
273 |
+
tstr = pretty_print_value(tokens, max_length)
|
274 |
+
return tstr
|
275 |
+
|
276 |
+
|
277 |
+
def decode_tokens_1d(tokens: Any, vocab: Any, max_length: int,
|
278 |
+
raw_string: bool = False) -> Union[str, bytes]:
|
279 |
+
"""Convert a 1D array of tokens to a human-readable string.
|
280 |
+
|
281 |
+
Args:
|
282 |
+
tokens: 1-dimensional array of integers.
|
283 |
+
vocab: The vocabulary to detokenize the array.
|
284 |
+
max_length: The maximum number of tokens to detokenize.
|
285 |
+
raw_string: If True, return the string as bytes.
|
286 |
+
If false, pretty print it (e.g. with "\n").
|
287 |
+
|
288 |
+
Returns:
|
289 |
+
The detokenized string.
|
290 |
+
"""
|
291 |
+
|
292 |
+
assert np.ndim(tokens) == 1
|
293 |
+
# The type of tokens is np.ndarray((sequence_length,), "int32")
|
294 |
+
# We have to convert this to an actual list of python integers, NOT numpy
|
295 |
+
# integers, or decode will blow up, and fail to marshall the data to C++.
|
296 |
+
dtoks = [int(i) for i in tokens[:max_length]]
|
297 |
+
tstr = vocab.decode(dtoks)
|
298 |
+
|
299 |
+
# Convert the decoded string to a byte string.
|
300 |
+
# PassThroughVocabulary returns a list, not a string.
|
301 |
+
if isinstance(tstr, str):
|
302 |
+
tstr = bytes(tstr.encode("utf-8"))
|
303 |
+
else:
|
304 |
+
tstr = bytes(tstr)
|
305 |
+
|
306 |
+
# If raw_string, return immediately.
|
307 |
+
if raw_string:
|
308 |
+
return tstr
|
309 |
+
|
310 |
+
# Otherwise format it for pretty-printing.
|
311 |
+
# Converting bytes to str will convert, e.g., newlines as "\n".
|
312 |
+
tstr = str(tstr)
|
313 |
+
if len(tokens) > max_length:
|
314 |
+
tstr += "..."
|
315 |
+
return tstr
|
316 |
+
|
317 |
+
|
318 |
+
def bytes_to_tokens(s: str):
|
319 |
+
"""Convert a byte string to an array of integers."""
|
320 |
+
return np.fromiter((char for char in s), count=len(s), dtype=np.int32)
|
321 |
+
|
322 |
+
|
323 |
+
def pad_chunk(s: Optional[np.ndarray], sequence_length: int):
|
324 |
+
"""Pad an array s out to the given sequence_length."""
|
325 |
+
if s is None:
|
326 |
+
return np.zeros(sequence_length, dtype=np.int32)
|
327 |
+
assert np.ndim(s) == 1
|
328 |
+
chunk_len = len(s)
|
329 |
+
assert chunk_len <= sequence_length
|
330 |
+
if chunk_len == sequence_length:
|
331 |
+
return s
|
332 |
+
else:
|
333 |
+
return np.pad(s, (0, sequence_length - chunk_len),
|
334 |
+
mode="constant", constant_values=0)
|
335 |
+
|
336 |
+
|
337 |
+
def split_article(tokens: np.ndarray, sequence_length: int, split: str,
|
338 |
+
include_loss_mask: bool) -> (
|
339 |
+
Iterable[Tuple[np.ndarray, np.ndarray]]):
|
340 |
+
"""Split an array into segments of length sequence_length."""
|
341 |
+
assert np.ndim(tokens) == 1
|
342 |
+
if include_loss_mask:
|
343 |
+
loss_mask = loss_mask_from_tokens(tokens, split)
|
344 |
+
|
345 |
+
for k in range(0, len(tokens), sequence_length):
|
346 |
+
segment = pad_chunk(tokens[k:k + sequence_length], sequence_length)
|
347 |
+
if include_loss_mask:
|
348 |
+
segment_loss_mask = pad_chunk(
|
349 |
+
loss_mask[k:k + sequence_length], sequence_length).astype(bool)
|
350 |
+
else:
|
351 |
+
segment_loss_mask = np.array(True, dtype=bool) # dummy mask
|
352 |
+
yield (segment, segment_loss_mask)
|
353 |
+
|
354 |
+
|
355 |
+
def nonzero_tokens(tokens: np.ndarray,
|
356 |
+
loss_mask: Optional[np.ndarray]) -> list[int]:
|
357 |
+
"""Removes tokens that are not predicted by the model."""
|
358 |
+
# TODO(delesley): Fix the model so that it predicts the first token.
|
359 |
+
# The language model doesn't predict the first token.
|
360 |
+
toks = [int(tokens[i]) for i in range(1, len(tokens))
|
361 |
+
if (tokens[i] != 0 and (loss_mask is None or loss_mask[i]))]
|
362 |
+
return toks
|
363 |
+
|
364 |
+
|
365 |
+
def _find_subsequence_idxs(sequence: np.ndarray, subsequence: Sequence[int]):
|
366 |
+
"""Returns the indices where `subsequence` occurs in `sequence`."""
|
367 |
+
subsequence = np.asarray(subsequence, dtype=np.int32)
|
368 |
+
# use np.where as an efficient way to iterate over the whole array; but we can
|
369 |
+
# only test for a single token, unfortunately.
|
370 |
+
potential_matches = np.where(sequence == subsequence[0])[0]
|
371 |
+
match_indices = []
|
372 |
+
for start_index in potential_matches:
|
373 |
+
if np.array_equal(sequence[start_index:start_index + len(subsequence)],
|
374 |
+
subsequence):
|
375 |
+
match_indices.append(start_index)
|
376 |
+
return match_indices
|
377 |
+
|
378 |
+
|
379 |
+
def loss_mask_from_tokens(tokens: np.ndarray, split: str) -> np.ndarray:
|
380 |
+
"""Compute a mask for language modelling loss using start and end tokens."""
|
381 |
+
assert np.ndim(tokens) == 1
|
382 |
+
tokens = tokens.astype(np.int32)
|
383 |
+
|
384 |
+
# Position offset of loss mask and target positions. Typically -1, which
|
385 |
+
# indicates that targets are shifted 1 position left compared to inputs.
|
386 |
+
offset = -1
|
387 |
+
|
388 |
+
start_tokens, end_tokens = get_loss_mask_tokens(split=split)
|
389 |
+
if not start_tokens:
|
390 |
+
# default to not masking out any loss
|
391 |
+
return np.ones_like(tokens, dtype=bool)
|
392 |
+
|
393 |
+
start = 0
|
394 |
+
end = len(tokens) # include end_tokens
|
395 |
+
start_indices = _find_subsequence_idxs(tokens, start_tokens)
|
396 |
+
if start_indices:
|
397 |
+
if end_tokens:
|
398 |
+
end_indices = _find_subsequence_idxs(tokens, end_tokens)
|
399 |
+
else:
|
400 |
+
end_indices = []
|
401 |
+
if len(start_indices) > 1 or len(end_indices) > 1:
|
402 |
+
logging.error("Multiple start or end tokens for loss mask: %s, %s",
|
403 |
+
start_indices, end_indices)
|
404 |
+
start = start_indices[0]
|
405 |
+
if end_indices and end_indices[0] >= start:
|
406 |
+
end = end_indices[0]
|
407 |
+
|
408 |
+
# We include the start_tokens and the end_tokens, which represents that the
|
409 |
+
# model must predict the location, the content, and the end of the
|
410 |
+
# subsequence.
|
411 |
+
start += offset
|
412 |
+
start = max(0, start) # to prevent offset creating negative indices
|
413 |
+
end += len(end_tokens) + offset
|
414 |
+
|
415 |
+
# Create the actual mask. Roughly equivalent to
|
416 |
+
# mask = np.array([i >= start && i <= end for i in range(len(tokens))])
|
417 |
+
mask = np.concatenate([
|
418 |
+
np.zeros((start,), dtype=bool),
|
419 |
+
np.ones((end - start,), dtype=bool),
|
420 |
+
np.zeros((len(tokens) - end,), dtype=bool)
|
421 |
+
])
|
422 |
+
return mask
|
423 |
+
|
424 |
+
|
425 |
+
def _batched_interleave_generator(
|
426 |
+
ds: tf.data.Dataset,
|
427 |
+
flat_map_func: Callable[[str], Iterable[Tuple[np.ndarray, np.ndarray]]],
|
428 |
+
post_map_func,
|
429 |
+
batch_size: int,
|
430 |
+
vocab: Optional[seqio.Vocabulary] = None,
|
431 |
+
include_loss_mask: bool = False,
|
432 |
+
auto_rewind: bool = False) -> Iterable[Dict[str, np.ndarray]]:
|
433 |
+
"""Generator which combines the interleave and batch dataset operations.
|
434 |
+
|
435 |
+
Given a set of articles from ds, flat_map_func is mapped over the articles
|
436 |
+
to break each article up into an iterable of chunks and their loss masks.
|
437 |
+
The generator will return the examples from each article in sequential order,
|
438 |
+
for transformer-XL style models that process long articles over multiple
|
439 |
+
training steps.
|
440 |
+
|
441 |
+
Articles are combined into batches of size batch_size, where each example in
|
442 |
+
the batch is pulled from a different article. When one article ends, the
|
443 |
+
generator will start pulling examples from the next article. The overall
|
444 |
+
result is similar to tf.Data.Dataset.interleave, except that interleave does
|
445 |
+
not always maintain the same order of articles. If this generator starts
|
446 |
+
pulling from article "foo" as the 3rd item in the batch, then consecutive
|
447 |
+
examples from "foo" will remain as the 3rd item until the article ends. This
|
448 |
+
guarantee is necessary to pass state from one training step to the next.
|
449 |
+
|
450 |
+
If auto_rewind, then the generator will automatically grab a new iterator
|
451 |
+
from ds at the end of the epoch, and increment the epoch counter. Otherwise,
|
452 |
+
it will yield empty datasets until all articles in the batch have been
|
453 |
+
completed.
|
454 |
+
|
455 |
+
Args:
|
456 |
+
ds: A dataset of articles.
|
457 |
+
flat_map_func: A function which returns an iterator over chunks of tokens
|
458 |
+
and the loss masks associated with those tokens.
|
459 |
+
post_map_func: A function which post-processes each item to fixed size.
|
460 |
+
batch_size: The number of articles in a batch.
|
461 |
+
vocab: The vocabulary to detokenize strings and count characters.
|
462 |
+
include_loss_mask: If true, will return a loss mask with the tokens.
|
463 |
+
auto_rewind: Automatically rewind ds at end of epoch.
|
464 |
+
|
465 |
+
Yields:
|
466 |
+
Batches of consecutive examples from articles.
|
467 |
+
Each example has type: {
|
468 |
+
"targets": int32[batch_size, sequence_length],
|
469 |
+
"start_of_sequence": bool[batch_size],
|
470 |
+
"epoch": int32[batch_size],
|
471 |
+
"loss_mask": bool[batch_size, sequence_length],
|
472 |
+
}
|
473 |
+
"""
|
474 |
+
|
475 |
+
ds_iter = ds.as_numpy_iterator()
|
476 |
+
|
477 |
+
document_start = [True] * batch_size # At start of each article.
|
478 |
+
readers = [None] * batch_size # Iterator for each article
|
479 |
+
still_reading = [True] * batch_size # End of current article?
|
480 |
+
item_epochs = [0] * batch_size # Epoch of the given item.
|
481 |
+
epoch = 0
|
482 |
+
|
483 |
+
# Main generator loop
|
484 |
+
while any(still_reading):
|
485 |
+
targets = [None] * batch_size
|
486 |
+
loss_mask = [None] * batch_size
|
487 |
+
for i in range(0, batch_size):
|
488 |
+
targets_i = None
|
489 |
+
loss_mask_i = None
|
490 |
+
while targets_i is None and still_reading[i]:
|
491 |
+
if readers[i] is not None:
|
492 |
+
try:
|
493 |
+
# Grab the next item from the article.
|
494 |
+
targets_i, loss_mask_i = next(readers[i])
|
495 |
+
except StopIteration:
|
496 |
+
# Article has ended; continue the while loop to grab a new one.
|
497 |
+
readers[i] = None
|
498 |
+
else:
|
499 |
+
# Grab the next article from ds if the current one has ended.
|
500 |
+
dsi = None
|
501 |
+
try:
|
502 |
+
dsi = iter(flat_map_func(next(ds_iter)))
|
503 |
+
except StopIteration:
|
504 |
+
logging.info("End of epoch %d.", epoch)
|
505 |
+
if auto_rewind:
|
506 |
+
epoch = epoch + 1
|
507 |
+
logging.info("Starting epoch %d.", epoch)
|
508 |
+
ds_iter = ds.as_numpy_iterator()
|
509 |
+
dsi = iter(flat_map_func(next(ds_iter)))
|
510 |
+
else:
|
511 |
+
still_reading[i] = False # No more articles on i
|
512 |
+
if dsi is not None:
|
513 |
+
# Start reading the new article.
|
514 |
+
# Continue while loop to grab the first chunk.
|
515 |
+
readers[i] = dsi
|
516 |
+
document_start[i] = True
|
517 |
+
item_epochs[i] = epoch
|
518 |
+
|
519 |
+
# post_map_func must handle None values, and return stackable np.arrays.
|
520 |
+
targets[i] = post_map_func(targets_i) # handles None
|
521 |
+
if include_loss_mask:
|
522 |
+
loss_mask[i] = post_map_func(loss_mask_i).astype(bool) # handles None
|
523 |
+
|
524 |
+
# If we've reached the end of all articles, stop immediately.
|
525 |
+
if not any(still_reading):
|
526 |
+
break
|
527 |
+
|
528 |
+
doc_start_orig = document_start.copy() # Return doc_start_orig.
|
529 |
+
for i in range(0, batch_size):
|
530 |
+
# Now that we've read an item, set /start/ to false for each reader.
|
531 |
+
document_start[i] = False
|
532 |
+
|
533 |
+
# Decode the tokenized segement back to characters, to count the number
|
534 |
+
# of characters for the bits-per-character computation.
|
535 |
+
num_chars = [0] * batch_size
|
536 |
+
nz_toks = [0] * batch_size
|
537 |
+
for i in range(0, batch_size):
|
538 |
+
lmask = loss_mask[i] if include_loss_mask else None
|
539 |
+
toks = nonzero_tokens(targets[i], lmask)
|
540 |
+
if vocab is not None:
|
541 |
+
bchars = decode_tokens_1d(toks, vocab, max_length=len(targets[i]),
|
542 |
+
raw_string=True)
|
543 |
+
num_chars[i] = len(bchars)
|
544 |
+
else:
|
545 |
+
num_chars[i] = len(toks)
|
546 |
+
nz_toks[i] = len(toks)
|
547 |
+
|
548 |
+
item = {
|
549 |
+
"targets": np.stack(targets),
|
550 |
+
"start_of_sequence": np.array(doc_start_orig),
|
551 |
+
"epoch": np.array(item_epochs),
|
552 |
+
"num_chars": np.stack(num_chars),
|
553 |
+
"nonzero_tokens": np.stack(nz_toks),
|
554 |
+
}
|
555 |
+
if include_loss_mask:
|
556 |
+
item["loss_mask"] = np.stack(loss_mask)
|
557 |
+
yield item
|
558 |
+
|
559 |
+
|
560 |
+
def split_and_batch(ds: tf.data.Dataset,
|
561 |
+
split: str,
|
562 |
+
extract_fn: Callable[[Any], Any],
|
563 |
+
sequence_length: int,
|
564 |
+
batch_size: int,
|
565 |
+
auto_rewind: bool = False,
|
566 |
+
vocab: Optional[seqio.Vocabulary] = None,
|
567 |
+
include_loss_mask: bool = False,
|
568 |
+
verbose: bool = False) -> tf.data.Dataset:
|
569 |
+
"""Converts articles to tokens and chops and batches them.
|
570 |
+
|
571 |
+
See batched_interleave_generator for more details.
|
572 |
+
|
573 |
+
Args:
|
574 |
+
ds: A dataset of articles.
|
575 |
+
split: Which dataset split is to be computed, e.g. 'train'.
|
576 |
+
extract_fn: Return a sequence of tokens from article.
|
577 |
+
sequence_length: The number of tokens in each sequence.
|
578 |
+
batch_size: The number of examples in each batch.
|
579 |
+
auto_rewind: If True, will automatically rewind at end of epoch.
|
580 |
+
vocab: Vocabulary, used to count characters.
|
581 |
+
include_loss_mask: Return a loss mask for each batch.
|
582 |
+
verbose: Write article info to log as they are read.
|
583 |
+
|
584 |
+
Returns:
|
585 |
+
A dataset which yields examples of shape {
|
586 |
+
"targets": int32[batch_size, sequence_length],
|
587 |
+
"start_of_sequence": bool[batch_size],
|
588 |
+
"epoch": int32[batch_size],
|
589 |
+
"loss_mask": bool[batch_size, sequence_length],
|
590 |
+
"num_chars": A count of the number of detokenized characters.
|
591 |
+
"nonzero_tokens": A count of the number of nonzero predicted tokens.
|
592 |
+
}
|
593 |
+
"""
|
594 |
+
|
595 |
+
# Tokenize article, compute loss mask, split into multiple chunks.
|
596 |
+
# The entire article must fit into memory.
|
597 |
+
def wrap_split_article(article):
|
598 |
+
if verbose:
|
599 |
+
logging.info("Reading article: %s", pretty_print_article(article, {}))
|
600 |
+
else:
|
601 |
+
logging.log_every_n_seconds(logging.INFO, "Reading article: %s", 60,
|
602 |
+
pretty_print_article(article, {}))
|
603 |
+
tokens = extract_fn(article)
|
604 |
+
if isinstance(tokens, str) or isinstance(tokens, bytes):
|
605 |
+
tokens = bytes_to_tokens(tokens)
|
606 |
+
elif isinstance(tokens, np.ndarray):
|
607 |
+
tokens = tokens.astype(np.int32)
|
608 |
+
else:
|
609 |
+
raise TypeError("Unusupported sequence type: %s" % str(type(tokens)))
|
610 |
+
return split_article(tokens, sequence_length, split=split,
|
611 |
+
include_loss_mask=include_loss_mask)
|
612 |
+
|
613 |
+
# Handle None values.
|
614 |
+
def wrap_pad_chunk(s):
|
615 |
+
return pad_chunk(s, sequence_length)
|
616 |
+
|
617 |
+
def wrap_batched_interleave_generator():
|
618 |
+
return _batched_interleave_generator(ds,
|
619 |
+
flat_map_func=wrap_split_article,
|
620 |
+
post_map_func=wrap_pad_chunk,
|
621 |
+
batch_size=batch_size,
|
622 |
+
vocab=vocab,
|
623 |
+
include_loss_mask=include_loss_mask,
|
624 |
+
auto_rewind=auto_rewind)
|
625 |
+
|
626 |
+
out_sig = {
|
627 |
+
"targets": tf.TensorSpec(shape=(batch_size, sequence_length),
|
628 |
+
dtype=tf.int32),
|
629 |
+
"start_of_sequence": tf.TensorSpec(shape=(batch_size,), dtype=tf.bool),
|
630 |
+
"epoch": tf.TensorSpec(shape=(batch_size,), dtype=tf.int32),
|
631 |
+
"num_chars": tf.TensorSpec(shape=(batch_size,), dtype=tf.int32),
|
632 |
+
"nonzero_tokens": tf.TensorSpec(shape=(batch_size,), dtype=tf.int32),
|
633 |
+
}
|
634 |
+
if include_loss_mask:
|
635 |
+
out_sig["loss_mask"] = tf.TensorSpec(shape=(batch_size, sequence_length),
|
636 |
+
dtype=tf.bool)
|
637 |
+
|
638 |
+
cds = tf.data.Dataset.from_generator(wrap_batched_interleave_generator,
|
639 |
+
output_signature=out_sig)
|
640 |
+
return cds
|
641 |
+
|
642 |
+
|
643 |
+
def merge_articles(article_starts_ends, sequence_length):
|
644 |
+
"""Merge consecutive articles if their combined length < sequence_length."""
|
645 |
+
cs = 0
|
646 |
+
ce = 0
|
647 |
+
for (s, e) in article_starts_ends:
|
648 |
+
if ce == 0:
|
649 |
+
ce = s
|
650 |
+
if (e - cs) > sequence_length:
|
651 |
+
if ce > cs:
|
652 |
+
# print("Yield: ", cs, " to ", ce)
|
653 |
+
yield (cs, ce) # Yield prior merged articles
|
654 |
+
cs = s # Reset to start of current article
|
655 |
+
ce = e
|
656 |
+
else:
|
657 |
+
ce = e # Merge article with current set.
|
658 |
+
# print("Article: ", s, " to ", e)
|
659 |
+
if ce > 0:
|
660 |
+
# print("Yield: ", cs, " to ", ce)
|
661 |
+
yield (cs, ce) # Yield final merged set.
|
662 |
+
|
663 |
+
|
664 |
+
def _targets_to_tokens(article):
|
665 |
+
return bytes_to_tokens(article["targets"])
|
666 |
+
|
667 |
+
|
668 |
+
def _wrap_text_in_dict(text):
|
669 |
+
return {"targets": text}
|
670 |
+
|
671 |
+
|
672 |
+
# ---------------------
|
673 |
+
|
674 |
+
def load_enwik8(split: str,
|
675 |
+
sequence_length: int,
|
676 |
+
batch_size: int,
|
677 |
+
data_dir: str) -> tf.data.Dataset:
|
678 |
+
"""Load the enwik8 dataset, partitioning into articles."""
|
679 |
+
|
680 |
+
if data_dir is None:
|
681 |
+
raise ValueError("Must specify a data directory for enwik8")
|
682 |
+
|
683 |
+
filename = os.path.join(data_dir, "enwik8")
|
684 |
+
filename = os.path.join(filename, "enwik8_" + split)
|
685 |
+
|
686 |
+
# Don't attempt to split the data, just shuffle it differently for
|
687 |
+
# each worker.
|
688 |
+
local_seed = 42 + jax.process_index()
|
689 |
+
|
690 |
+
logging.info("Enwik8: reading %s", filename)
|
691 |
+
with gfile.Open(filename, "r") as f:
|
692 |
+
text_data = f.read()
|
693 |
+
|
694 |
+
logging.info("Enwik8: parsing %s", filename)
|
695 |
+
article_starts = [m.start(0) for m in re.finditer("<page>", text_data)]
|
696 |
+
article_ends = article_starts[1:] + [len(text_data)]
|
697 |
+
logging.info("Enwik8: found %d articles.", len(article_starts))
|
698 |
+
|
699 |
+
merged_se = merge_articles(zip(article_starts, article_ends),
|
700 |
+
sequence_length)
|
701 |
+
articles = [text_data[s:e] for (s, e) in merged_se]
|
702 |
+
num_articles = len(articles)
|
703 |
+
logging.info("Enwik8: merged into %d articles.", num_articles)
|
704 |
+
|
705 |
+
logging.info("Building dataset.")
|
706 |
+
ds = tf.data.Dataset.from_tensor_slices(articles)
|
707 |
+
ds = ds.map(_wrap_text_in_dict)
|
708 |
+
ds = ds.shuffle(num_articles, reshuffle_each_iteration=True, seed=local_seed)
|
709 |
+
if sequence_length == 0:
|
710 |
+
return ds # Don't split and batch
|
711 |
+
|
712 |
+
return split_and_batch(ds,
|
713 |
+
split=split,
|
714 |
+
extract_fn=_targets_to_tokens,
|
715 |
+
sequence_length=sequence_length,
|
716 |
+
batch_size=batch_size,
|
717 |
+
auto_rewind=True,
|
718 |
+
verbose=False)
|
719 |
+
|
720 |
+
# ---------------------
|
721 |
+
|
722 |
+
|
723 |
+
def synthetic_data_short(split: str,
|
724 |
+
sequence_length: int,
|
725 |
+
batch_size: int,
|
726 |
+
auto_rewind: bool = True) -> tf.data.Dataset:
|
727 |
+
"""Return a synthetic data set of sequences."""
|
728 |
+
|
729 |
+
strings = [
|
730 |
+
b"The quick brown fox jumped over the lazy dog.",
|
731 |
+
b"Humpty dumpty sat on a wall and had a great fall and went splat.",
|
732 |
+
b"She sells sea shells by the sea shore.",
|
733 |
+
b"Peter piper picked a peck of pickled peppercorns."
|
734 |
+
]
|
735 |
+
logging.info("Building synthetic dataset (short).")
|
736 |
+
ds = tf.data.Dataset.from_tensor_slices(strings)
|
737 |
+
ds = ds.map(_wrap_text_in_dict)
|
738 |
+
ds = ds.shuffle(4, reshuffle_each_iteration=True, seed=42)
|
739 |
+
if sequence_length == 0:
|
740 |
+
return ds # Don't split and batch
|
741 |
+
|
742 |
+
return split_and_batch(ds,
|
743 |
+
split=split,
|
744 |
+
extract_fn=_targets_to_tokens,
|
745 |
+
sequence_length=sequence_length,
|
746 |
+
batch_size=batch_size,
|
747 |
+
auto_rewind=auto_rewind,
|
748 |
+
verbose=False)
|
749 |
+
|
750 |
+
|
751 |
+
def synthetic_data_long(split: str,
|
752 |
+
sequence_length: int,
|
753 |
+
batch_size: int,
|
754 |
+
auto_rewind: bool = True) -> tf.data.Dataset:
|
755 |
+
"""Returns a synthetic data set with several long articles."""
|
756 |
+
articles = [
|
757 |
+
synthetic_text_data.text1_illiad_book1,
|
758 |
+
synthetic_text_data.text2_huckleberry_finn,
|
759 |
+
synthetic_text_data.text3_call_of_the_wild,
|
760 |
+
synthetic_text_data.text4_the_prince
|
761 |
+
]
|
762 |
+
logging.info("Building synthetic dataset (long).")
|
763 |
+
ds = tf.data.Dataset.from_tensor_slices(articles)
|
764 |
+
ds = ds.map(_wrap_text_in_dict)
|
765 |
+
ds = ds.shuffle(4, reshuffle_each_iteration=True, seed=42)
|
766 |
+
if sequence_length == 0:
|
767 |
+
return ds # Don't split and batch
|
768 |
+
|
769 |
+
return split_and_batch(ds,
|
770 |
+
split=split,
|
771 |
+
extract_fn=_targets_to_tokens,
|
772 |
+
sequence_length=sequence_length,
|
773 |
+
batch_size=batch_size,
|
774 |
+
auto_rewind=auto_rewind,
|
775 |
+
verbose=False)
|
aglib/meliad/transformer/transformer_base.py
ADDED
@@ -0,0 +1,451 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 Google.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
"""Base class for transformer layers."""
|
16 |
+
|
17 |
+
from typing import Any, Callable, Optional, Tuple
|
18 |
+
|
19 |
+
from absl import logging
|
20 |
+
|
21 |
+
from flax import linen as nn
|
22 |
+
import gin
|
23 |
+
import jax
|
24 |
+
import jax.numpy as jnp
|
25 |
+
|
26 |
+
|
27 |
+
from transformer import nn_components
|
28 |
+
|
29 |
+
|
30 |
+
Array = Any
|
31 |
+
|
32 |
+
# Tuple of scale factors
|
33 |
+
AttnScaleTuple = Tuple[Optional[Array], Optional[Array]]
|
34 |
+
|
35 |
+
# Tuple of keys,values,queries
|
36 |
+
KVQTuple = Tuple[Array, Array, Optional[Array], Optional[Array]]
|
37 |
+
|
38 |
+
|
39 |
+
@gin.configurable
|
40 |
+
class KVQLayer(nn.Module):
|
41 |
+
"""Generate keys, values, and queries for attention."""
|
42 |
+
|
43 |
+
embedding_size: int
|
44 |
+
num_heads: int
|
45 |
+
head_size: int
|
46 |
+
has_queries: bool = True
|
47 |
+
has_queries2: bool = False # For cross-attention, e.g. decoder or recurrence.
|
48 |
+
|
49 |
+
normalize_keys: bool = True # Normalize keys and queries.
|
50 |
+
num_position_embeddings: int = 0 # Learned absolute position embeddings.
|
51 |
+
pre_attn_dropout: bool = True
|
52 |
+
dropout_rate: float = 0.0
|
53 |
+
dtype: Any = jnp.float32
|
54 |
+
|
55 |
+
def setup(self):
|
56 |
+
kernel_init = nn.initializers.variance_scaling(
|
57 |
+
scale=1.0, mode="fan_in", distribution="truncated_normal")
|
58 |
+
|
59 |
+
# Project to keys,values,queries
|
60 |
+
# Disable bias. This prevents a failure mode whereby the attention matrix
|
61 |
+
# can become filled with very large uniform values, due to high bias.
|
62 |
+
self.keys_layer = nn.Dense(
|
63 |
+
features=self.num_heads * self.head_size,
|
64 |
+
use_bias=False, # No bias for keys.
|
65 |
+
kernel_init=kernel_init,
|
66 |
+
dtype=self.dtype)
|
67 |
+
self.values_layer = nn.Dense(
|
68 |
+
features=self.num_heads * self.head_size,
|
69 |
+
use_bias=False, # No bias for values.
|
70 |
+
kernel_init=kernel_init,
|
71 |
+
dtype=self.dtype)
|
72 |
+
if self.has_queries:
|
73 |
+
self.queries_layer = nn.Dense(
|
74 |
+
features=self.num_heads * self.head_size,
|
75 |
+
use_bias=False, # No bias for queries.
|
76 |
+
kernel_init=kernel_init,
|
77 |
+
dtype=self.dtype)
|
78 |
+
if self.has_queries2:
|
79 |
+
self.queries2_layer = nn.Dense(
|
80 |
+
features=self.num_heads * self.head_size,
|
81 |
+
use_bias=False, # No bias for queries.
|
82 |
+
kernel_init=kernel_init,
|
83 |
+
dtype=self.dtype)
|
84 |
+
|
85 |
+
# When normalizing keys and queries, attention must be scaled with
|
86 |
+
# learned parameters.
|
87 |
+
if self.normalize_keys:
|
88 |
+
self.attention_scale = self.param("attention_scale",
|
89 |
+
jax.nn.initializers.ones,
|
90 |
+
(self.num_heads,), jnp.float32)
|
91 |
+
|
92 |
+
# Learned position embeddings for absolute positions.
|
93 |
+
if self.num_position_embeddings > 0:
|
94 |
+
# Embeddings for query elements.
|
95 |
+
self.position_embeddings = self.param(
|
96 |
+
"position_embeddings",
|
97 |
+
jax.nn.initializers.normal(stddev=1.0),
|
98 |
+
(self.num_position_embeddings, self.embedding_size),
|
99 |
+
jnp.float32)
|
100 |
+
|
101 |
+
# Layernorm
|
102 |
+
self.pre_attn_layernorm = nn_components.LayerNorm()
|
103 |
+
|
104 |
+
def attention_scale_factor(self) -> Optional[Array]:
|
105 |
+
"""Returns the attention scale, when keys and queries are normalized."""
|
106 |
+
if self.normalize_keys:
|
107 |
+
return jnp.asarray(self.attention_scale, dtype=self.dtype)
|
108 |
+
else:
|
109 |
+
return None
|
110 |
+
|
111 |
+
def _get_dropout_rng(self):
|
112 |
+
return self.make_rng("dropout")
|
113 |
+
|
114 |
+
def _normalize_kq(self, kq: Array) -> Array:
|
115 |
+
"""Normalize function for keys and queries."""
|
116 |
+
epsilon = jnp.array(1.0e-6, dtype=self.dtype)
|
117 |
+
kq_sum_sqr = jnp.sum(jnp.square(kq), axis=-1, keepdims=True)
|
118 |
+
norm_kq = kq * jax.lax.rsqrt(kq_sum_sqr + epsilon)
|
119 |
+
return jnp.asarray(norm_kq, dtype=self.dtype)
|
120 |
+
|
121 |
+
def __call__(self, xs: Array, deterministic: bool = False) -> KVQTuple:
|
122 |
+
"""Takes a sequence of embeddings as input, and returns keys,values,queries.
|
123 |
+
|
124 |
+
First apply pre_attn layernorm, and pre_attn dropout.
|
125 |
+
Then add learned positional embeddings, if any.
|
126 |
+
Return (keys, values, queries, queries2).
|
127 |
+
|
128 |
+
Args:
|
129 |
+
xs: input sequence of shape (batch_size, sequence_length, embedding_size)
|
130 |
+
deterministic: if False, apply dropout.
|
131 |
+
|
132 |
+
Returns:
|
133 |
+
(keys, values, queries, queries2) of shape
|
134 |
+
(batch_size, sequence_length, num_heads, head_size)
|
135 |
+
"""
|
136 |
+
|
137 |
+
# Project inputs to (keys, values, queries).
|
138 |
+
(batch_size, num_keys, _) = xs.shape
|
139 |
+
drop_tile_shape = (1, 128, self.embedding_size)
|
140 |
+
|
141 |
+
# Apply layernorm to input, rather than the output.
|
142 |
+
# This provides better gradients through the resnet, and also avoids
|
143 |
+
# the need for a prolonged warmup phase (https://arxiv.org/abs/2002.04745)
|
144 |
+
|
145 |
+
# Layernorm for self-attention.
|
146 |
+
logging.info("kvq: pre_attn xs = %r", xs)
|
147 |
+
xs = jnp.asarray(xs, dtype=self.dtype)
|
148 |
+
xs = self.pre_attn_layernorm(xs)
|
149 |
+
|
150 |
+
# Add (optional) learned position embeddings.
|
151 |
+
if self.num_position_embeddings > 0:
|
152 |
+
assert xs.ndim == 3 # (b, sequence_length, embedding_size)
|
153 |
+
assert xs.shape[-2] == self.num_position_embeddings
|
154 |
+
logging.info("kvq: learned positions.")
|
155 |
+
xs_pos = jnp.asarray(self.position_embeddings, dtype=self.dtype)
|
156 |
+
xs_pos = jnp.expand_dims(xs_pos, 0) # Add batch dimension.
|
157 |
+
xs = xs + xs_pos
|
158 |
+
|
159 |
+
# Pre-attention dropout.
|
160 |
+
if self.pre_attn_dropout:
|
161 |
+
logging.info("kvq: pre_attn dropout.")
|
162 |
+
xs = nn_components.tiled_dropout(xs, drop_tile_shape, self.dropout_rate,
|
163 |
+
rng_function=self._get_dropout_rng,
|
164 |
+
deterministic=deterministic)
|
165 |
+
|
166 |
+
# Compute keys and values.
|
167 |
+
keys = self.keys_layer(xs) # (b, num_keys, num_heads * head_size)
|
168 |
+
values = self.values_layer(xs)
|
169 |
+
|
170 |
+
# Compute queries and cross-attention queries if necessary.
|
171 |
+
if self.has_queries:
|
172 |
+
queries = self.queries_layer(xs) # (b, num_keys, n_heads * head_size)
|
173 |
+
logging.info("kvq: queries = %r", queries)
|
174 |
+
else:
|
175 |
+
queries = None
|
176 |
+
if self.has_queries2:
|
177 |
+
queries2 = self.queries2_layer(xs) # (b, num_keys, n_heads * head_size)
|
178 |
+
logging.info("kvq: queries2 = %r", queries2)
|
179 |
+
else:
|
180 |
+
queries2 = None
|
181 |
+
|
182 |
+
# Reshape to split num_heads, head_size into separate dimensions.
|
183 |
+
kv_shape = (batch_size, num_keys, self.num_heads, self.head_size)
|
184 |
+
keys = jnp.reshape(keys, kv_shape)
|
185 |
+
values = jnp.reshape(values, kv_shape)
|
186 |
+
if queries is not None:
|
187 |
+
queries = jnp.reshape(queries, kv_shape)
|
188 |
+
if queries2 is not None:
|
189 |
+
queries2 = jnp.reshape(queries2, kv_shape)
|
190 |
+
|
191 |
+
if self.normalize_keys:
|
192 |
+
# Normalize both keys and queries.
|
193 |
+
# The learned attention_scale_factors() will return non-None.
|
194 |
+
logging.info("kvq: normalize keys, queries.")
|
195 |
+
keys = self._normalize_kq(keys)
|
196 |
+
if queries is not None:
|
197 |
+
queries = self._normalize_kq(queries)
|
198 |
+
if queries2 is not None:
|
199 |
+
queries2 = self._normalize_kq(queries2)
|
200 |
+
else:
|
201 |
+
# Scale queries by 1 / sqrt(d) when using unnormalized keys,queries.
|
202 |
+
d_scale = jax.lax.rsqrt(float(self.head_size)).astype(self.dtype)
|
203 |
+
logging.info("kvq: scale queries by 1/sqrt(d).")
|
204 |
+
if queries is not None:
|
205 |
+
queries = queries * d_scale
|
206 |
+
if queries2 is not None:
|
207 |
+
queries2 = queries2 * d_scale
|
208 |
+
|
209 |
+
# Return keys, values, and queries.
|
210 |
+
return (keys, values, queries, queries2)
|
211 |
+
|
212 |
+
|
213 |
+
@gin.configurable
|
214 |
+
class TransformerBase(nn.Module):
|
215 |
+
"""TransformerBase implements everything except attention.
|
216 |
+
|
217 |
+
It handles:
|
218 |
+
- Projection to (keys, values, queries) before attention.
|
219 |
+
- Projection MLP back to embedding_size after attention.
|
220 |
+
- Final FFN layer.
|
221 |
+
- layernorm, dropout, and normalization of keys and queries.
|
222 |
+
|
223 |
+
This functionality is ecapsulated here so that it can be reused with more
|
224 |
+
complicated attention mechanisms.
|
225 |
+
"""
|
226 |
+
|
227 |
+
# Options set by parent module.
|
228 |
+
mode: str
|
229 |
+
embedding_size: int
|
230 |
+
num_heads: int
|
231 |
+
head_size: int
|
232 |
+
|
233 |
+
cross_attention_q: bool = False # Additional q for cross-attention.
|
234 |
+
cross_attention_kv: bool = False # Additional kv for cross-attention.
|
235 |
+
num_position_embeddings: int = 0 # Learned position embeddings.
|
236 |
+
num_cross_position_embeddings: int = 0 # Learned position embeddings.
|
237 |
+
|
238 |
+
# Configurable hyperparameters.
|
239 |
+
attn_mlp_factory: Callable[[int], nn.Module] = gin.REQUIRED
|
240 |
+
ffn_factory: Callable[[int], nn.Module] = gin.REQUIRED
|
241 |
+
gate_type: str = "residual"
|
242 |
+
single_gate: bool = False
|
243 |
+
skip_ffn: bool = False
|
244 |
+
|
245 |
+
normalize_keys: bool = True
|
246 |
+
dropout_rate: float = 0.0
|
247 |
+
pre_attn_dropout: bool = True
|
248 |
+
post_attn_dropout: bool = False
|
249 |
+
pre_ffn_dropout: bool = False
|
250 |
+
post_ffn_dropout: bool = True
|
251 |
+
|
252 |
+
dtype: Any = jnp.float32
|
253 |
+
|
254 |
+
def is_training(self) -> bool:
|
255 |
+
return self.mode == "train"
|
256 |
+
|
257 |
+
def _get_dropout_rng(self):
|
258 |
+
return self.make_rng("dropout")
|
259 |
+
|
260 |
+
def _normalize_kq(self, kq: Array) -> Array:
|
261 |
+
"""Normalize function for keys and queries."""
|
262 |
+
epsilon = jnp.array(1.0e-6, dtype=self.dtype)
|
263 |
+
kq_sum_sqr = jnp.sum(jnp.square(kq), axis=-1, keepdims=True)
|
264 |
+
norm_kq = kq * jax.lax.rsqrt(kq_sum_sqr + epsilon)
|
265 |
+
return jnp.asarray(norm_kq, dtype=self.dtype)
|
266 |
+
|
267 |
+
def setup(self):
|
268 |
+
# Keys,values,queries for self-attention; queries for cross-attention.
|
269 |
+
self._kvq = KVQLayer(self.embedding_size, self.num_heads, self.head_size,
|
270 |
+
has_queries=True,
|
271 |
+
has_queries2=self.cross_attention_q,
|
272 |
+
num_position_embeddings=self.num_position_embeddings,
|
273 |
+
normalize_keys=self.normalize_keys,
|
274 |
+
pre_attn_dropout=self.pre_attn_dropout,
|
275 |
+
dropout_rate=self.dropout_rate,
|
276 |
+
dtype=self.dtype)
|
277 |
+
|
278 |
+
# Keys,values, attention_scale for cross-attention.
|
279 |
+
if self.cross_attention_kv:
|
280 |
+
# Use a full kvq layer, with layernorm and attention scale.
|
281 |
+
self._cross_kv = KVQLayer(
|
282 |
+
self.embedding_size, self.num_heads, self.head_size,
|
283 |
+
has_queries=False,
|
284 |
+
has_queries2=False,
|
285 |
+
num_position_embeddings=self.num_cross_position_embeddings,
|
286 |
+
normalize_keys=self.normalize_keys,
|
287 |
+
pre_attn_dropout=self.pre_attn_dropout,
|
288 |
+
dropout_rate=self.dropout_rate,
|
289 |
+
dtype=self.dtype)
|
290 |
+
elif self.cross_attention_q:
|
291 |
+
# No separate keys,values for cross-attention, but we may still need
|
292 |
+
# cross-attention-scale, so we create our own.
|
293 |
+
assert self.num_cross_position_embeddings == 0
|
294 |
+
if self.normalize_keys:
|
295 |
+
self.attention_scale2 = self.param("attention_scale2",
|
296 |
+
jax.nn.initializers.ones,
|
297 |
+
(self.num_heads,), jnp.float32)
|
298 |
+
|
299 |
+
# Post-attention linear projection.
|
300 |
+
if not self.single_gate:
|
301 |
+
self.post_attn_mlp = self.attn_mlp_factory(
|
302 |
+
self.embedding_size,
|
303 |
+
gate_type=self.gate_type,
|
304 |
+
final_activation=None,
|
305 |
+
dtype=self.dtype) # pytype: disable=wrong-keyword-args # trace-all-classes
|
306 |
+
|
307 |
+
# Final FNN.
|
308 |
+
if not self.skip_ffn:
|
309 |
+
self.ffn = self.ffn_factory(
|
310 |
+
self.embedding_size,
|
311 |
+
gate_type=self.gate_type,
|
312 |
+
final_activation=("tanh" if self.single_gate else None),
|
313 |
+
dtype=self.dtype) # pytype: disable=wrong-keyword-args # trace-all-classes
|
314 |
+
|
315 |
+
# Layernorm.
|
316 |
+
self.pre_ffn_layernorm = nn_components.LayerNorm()
|
317 |
+
|
318 |
+
def force_init(self, xs: Array):
|
319 |
+
"""Force flax initialization of self, prior to use with lax.scan.
|
320 |
+
|
321 |
+
Args:
|
322 |
+
xs: The input sequence that the module will be called with.
|
323 |
+
"""
|
324 |
+
logging.info("tbase: Begin forced initialization.")
|
325 |
+
_ = self.kvq(xs)
|
326 |
+
batch_size = xs.shape[0]
|
327 |
+
seq_len = xs.shape[1]
|
328 |
+
attn_ys_shape = (batch_size, seq_len, self.num_heads, self.head_size)
|
329 |
+
dummy_attn_ys = jnp.zeros(attn_ys_shape, dtype=self.dtype)
|
330 |
+
if self.cross_attention_kv or self.cross_attention_q:
|
331 |
+
dummy_cross_attn_ys = dummy_attn_ys
|
332 |
+
else:
|
333 |
+
dummy_cross_attn_ys = None
|
334 |
+
_ = self.post_attn_ffn(xs, dummy_attn_ys, dummy_cross_attn_ys)
|
335 |
+
logging.info("tbase: End forced initialization.")
|
336 |
+
|
337 |
+
def attention_scale_factors(self) -> AttnScaleTuple:
|
338 |
+
"""Returns the attention scales, when keys and queries are normalized.
|
339 |
+
|
340 |
+
Returns: (scale for kv (i.e. queries), scale for cross_kv (i.e queries2))
|
341 |
+
"""
|
342 |
+
sfactor = self._kvq.attention_scale_factor()
|
343 |
+
if self.cross_attention_kv:
|
344 |
+
cross_sfactor = self._cross_kv.attention_scale_factor()
|
345 |
+
elif self.cross_attention_q and self.normalize_keys:
|
346 |
+
cross_sfactor = jnp.asarray(self.attention_scale2, dtype=self.dtype)
|
347 |
+
else:
|
348 |
+
cross_sfactor = None
|
349 |
+
return (sfactor, cross_sfactor)
|
350 |
+
|
351 |
+
def kvq(self, xs: Array) -> KVQTuple:
|
352 |
+
enable_dropout = self.pre_attn_dropout and self.is_training()
|
353 |
+
return self._kvq(xs, deterministic=not enable_dropout)
|
354 |
+
|
355 |
+
def cross_kv(self, xs: Array) -> Tuple[Array, Array]:
|
356 |
+
assert self.cross_attention_kv
|
357 |
+
enable_dropout = self.pre_attn_dropout and self.is_training()
|
358 |
+
(k, v, _, _) = self._cross_kv(xs, deterministic=not enable_dropout)
|
359 |
+
return (k, v)
|
360 |
+
|
361 |
+
def post_attn_ffn(self, xs: Array, attn_ys: Array,
|
362 |
+
cross_attn_ys: Optional[Array]) -> Array:
|
363 |
+
"""Combines the output of attention with the original input sequence.
|
364 |
+
|
365 |
+
Post-attn MLP on attn_ys, followed by resnet/gate.
|
366 |
+
Pre-FFN layernorm and dropout, then the FFN layer, followed by resnet/gate.
|
367 |
+
|
368 |
+
Args:
|
369 |
+
xs: Original input sequence of shape
|
370 |
+
(batch_size, sequence_length, embedding_size)
|
371 |
+
attn_ys: Output of the self-attention module, of shape
|
372 |
+
(batch_size, sequence_length, num_heads, head_size)
|
373 |
+
cross_attn_ys: Output of the cross-attention module, of shape
|
374 |
+
(batch_size, sequence_length, num_heads, head_size)
|
375 |
+
|
376 |
+
Returns:
|
377 |
+
Array of shape (batch_size, sequence_length, embedding_size)
|
378 |
+
"""
|
379 |
+
|
380 |
+
(batch_size, sequence_length, _) = xs.shape
|
381 |
+
assert attn_ys.shape == (batch_size, sequence_length,
|
382 |
+
self.num_heads, self.head_size)
|
383 |
+
no_dropout = not self.is_training()
|
384 |
+
drop_tile_shape = (1, 128, self.embedding_size)
|
385 |
+
|
386 |
+
# Concatenate cross-attention and self-attention results.
|
387 |
+
if cross_attn_ys is not None:
|
388 |
+
# Concatenate self-attention and cross-attention results, before
|
389 |
+
# applying the projection layer.
|
390 |
+
logging.info("tbase: using cross-attention.")
|
391 |
+
assert attn_ys.shape == (batch_size, sequence_length,
|
392 |
+
self.num_heads, self.head_size)
|
393 |
+
attn_ys = jnp.concatenate([attn_ys, cross_attn_ys], axis=2)
|
394 |
+
att_ys_num_heads = self.num_heads * 2
|
395 |
+
else:
|
396 |
+
# Only use self-attention.
|
397 |
+
att_ys_num_heads = self.num_heads
|
398 |
+
|
399 |
+
logging.info("tbase: attn_ys = %r", attn_ys)
|
400 |
+
attn_ys = attn_ys.reshape(
|
401 |
+
(batch_size, sequence_length, att_ys_num_heads * self.head_size))
|
402 |
+
|
403 |
+
if self.single_gate:
|
404 |
+
logging.info("tbase: single gate.")
|
405 |
+
assert not self.skip_ffn
|
406 |
+
# Skip post-attention linear projection and residual connection.
|
407 |
+
ys_hidden = xs # The FFN (below) will be gated onto xs (the input).
|
408 |
+
ffn_in = attn_ys # The input to the FFN is the output of attention.
|
409 |
+
else:
|
410 |
+
logging.info("tbase: post-attention MLP.")
|
411 |
+
# Standard transformer archicture.
|
412 |
+
# The post-attention MLP applies a linear projection to project attn_ys
|
413 |
+
# to embedding space. It then uses a residual connection or gate to
|
414 |
+
# combine the projection with xs. Post-attention dropout is applied
|
415 |
+
# before the residual/gate.
|
416 |
+
post_attn_ys = self.post_attn_mlp(
|
417 |
+
attn_ys, xs,
|
418 |
+
apply_dropout=self.post_attn_dropout and not no_dropout,
|
419 |
+
dropout_rate=self.dropout_rate,
|
420 |
+
drop_tile_shape=drop_tile_shape,
|
421 |
+
rng_function=self._get_dropout_rng)
|
422 |
+
|
423 |
+
# The FFN (below) will be gated onto post_attn_ys (which gates onto xs).
|
424 |
+
ys_hidden = post_attn_ys
|
425 |
+
if self.skip_ffn:
|
426 |
+
logging.info("tbase: skip final FFN. ys = %r", ys_hidden)
|
427 |
+
return ys_hidden
|
428 |
+
|
429 |
+
# The input to the FFN; Layernorm is applied before the FFN.
|
430 |
+
ffn_in = self.pre_ffn_layernorm(ys_hidden)
|
431 |
+
logging.info("tbase: pre-FFN layernorm = %r", ffn_in)
|
432 |
+
|
433 |
+
# Pre-FFN dropout.
|
434 |
+
if self.pre_ffn_dropout:
|
435 |
+
logging.info("tbase: pre-FFN dropout.")
|
436 |
+
ffn_in = nn_components.tiled_dropout(
|
437 |
+
ffn_in, drop_tile_shape, self.dropout_rate,
|
438 |
+
rng_function=self._get_dropout_rng, deterministic=no_dropout)
|
439 |
+
|
440 |
+
# FFN layer.
|
441 |
+
# Large MLP with hidden layers followed by residual connection or gate.
|
442 |
+
# The MLP will apply post-ffn dropout before the gate.
|
443 |
+
logging.info("tbase: final FFN")
|
444 |
+
ys = self.ffn(ffn_in, ys_hidden,
|
445 |
+
apply_dropout=self.post_ffn_dropout and not no_dropout,
|
446 |
+
dropout_rate=self.dropout_rate,
|
447 |
+
drop_tile_shape=drop_tile_shape,
|
448 |
+
rng_function=self._get_dropout_rng)
|
449 |
+
|
450 |
+
logging.info("tbase: ys = %r", ys)
|
451 |
+
return ys
|
aglib/meliad/transformer/transformer_layer.py
ADDED
@@ -0,0 +1,817 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 Google.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
"""A single transformer layer."""
|
16 |
+
|
17 |
+
from typing import Any, Mapping, NewType, Optional, Sequence, Tuple
|
18 |
+
|
19 |
+
from absl import logging
|
20 |
+
|
21 |
+
from flax import linen as nn
|
22 |
+
import gin
|
23 |
+
|
24 |
+
import jax
|
25 |
+
import jax.numpy as jnp
|
26 |
+
|
27 |
+
from transformer import attention
|
28 |
+
from transformer import memory_factory
|
29 |
+
from transformer import nn_components
|
30 |
+
from transformer import position
|
31 |
+
from transformer import position_fourier
|
32 |
+
from transformer import position_t5
|
33 |
+
from transformer import transformer_base
|
34 |
+
|
35 |
+
|
36 |
+
Array = jnp.ndarray
|
37 |
+
DecoderState = NewType("DecoderState", Mapping[str, Array])
|
38 |
+
WindowState = Optional[Tuple[attention.KVITuple, Array]]
|
39 |
+
KVITuple = attention.KVITuple
|
40 |
+
|
41 |
+
|
42 |
+
@gin.configurable
|
43 |
+
class TransformerLayer(nn.Module):
|
44 |
+
"""Full transformer layer, with attention."""
|
45 |
+
|
46 |
+
# Set by DecoderStack
|
47 |
+
mode: str
|
48 |
+
batch_size: int
|
49 |
+
embedding_size: int
|
50 |
+
cross_attention: bool = False
|
51 |
+
recurrent_attention: bool = False
|
52 |
+
memory: Optional[memory_factory.MemoryManager] = None
|
53 |
+
|
54 |
+
# Configurable hyper-parameters
|
55 |
+
num_heads: int = gin.REQUIRED
|
56 |
+
head_size: int = gin.REQUIRED
|
57 |
+
|
58 |
+
window_length: int = gin.REQUIRED
|
59 |
+
use_long_xl_architecture: bool = True
|
60 |
+
max_unrolled_windows: int = -1 # Always unroll.
|
61 |
+
relative_position_type: Optional[str] = "fourier" # {None, "fourier", "t5"}
|
62 |
+
use_causal_mask: bool = True
|
63 |
+
attn_dropout_rate: float = 0.0
|
64 |
+
|
65 |
+
recurrent_num_states: int = 0
|
66 |
+
recurrent_gate_type: str = "bias"
|
67 |
+
recurrent_single_gate: bool = False
|
68 |
+
recurrent_skip_ffn: bool = False
|
69 |
+
|
70 |
+
compute_importance: bool = False
|
71 |
+
memory_num_neighbors: int = 0
|
72 |
+
memory_reset_on_new_doc: bool = True
|
73 |
+
|
74 |
+
dtype: Any = jnp.float32
|
75 |
+
|
76 |
+
# Modes which support caching of previous keys and values.
|
77 |
+
supported_modes_for_cache: Sequence[str] = ("train", "test")
|
78 |
+
update_memory_modes: Sequence[str] = ("train", "test")
|
79 |
+
|
80 |
+
def supports_generate(self) -> bool:
|
81 |
+
return self.use_long_xl_architecture
|
82 |
+
|
83 |
+
def _get_cache_name_from_mode(self, mode: str) -> Tuple[str, bool, bool]:
|
84 |
+
"""Get the name of the cache, and whether to update the cache, from mode."""
|
85 |
+
# This is a hack to ensure that "generate" steps generate text as a
|
86 |
+
# continuation of the text that is stored in the "test" cache,
|
87 |
+
# but it does not update the "test" cache.
|
88 |
+
if mode == "generate":
|
89 |
+
assert "test" in self.supported_modes_for_cache
|
90 |
+
return ("test", False, False) # Use test cache, but don't update it.
|
91 |
+
elif mode == "init":
|
92 |
+
return ("train", False, False) # Use training cache for initialization.
|
93 |
+
else:
|
94 |
+
return (mode, True, mode in self.update_memory_modes)
|
95 |
+
|
96 |
+
def _allocate_cached_kvi(self, mode: str) -> KVITuple:
|
97 |
+
"""Allocate (keys, values, importance) which can be cached between steps."""
|
98 |
+
|
99 |
+
kv_shape = [self.batch_size, self.window_length,
|
100 |
+
self.num_heads, self.head_size]
|
101 |
+
imp_shape = [self.batch_size, self.window_length]
|
102 |
+
|
103 |
+
def kv_initializer(shape):
|
104 |
+
return jnp.zeros(shape, dtype=self.dtype)
|
105 |
+
|
106 |
+
def imp_initializer(shape):
|
107 |
+
return jnp.zeros(shape, dtype=self.dtype)
|
108 |
+
|
109 |
+
pkeys = self.variable("state", "previous_keys_" + mode,
|
110 |
+
kv_initializer, kv_shape)
|
111 |
+
pvals = self.variable("state", "previous_values_" + mode,
|
112 |
+
kv_initializer, kv_shape)
|
113 |
+
if self.compute_importance:
|
114 |
+
pimportance = self.variable("state", "previous_importance_" + mode,
|
115 |
+
imp_initializer, imp_shape)
|
116 |
+
else:
|
117 |
+
pimportance = None
|
118 |
+
return (pkeys, pvals, pimportance)
|
119 |
+
|
120 |
+
def _allocate_cached_recurrent_state(self, mode: str):
|
121 |
+
rec_num_states = self.recurrent_num_states
|
122 |
+
st_shape = [self.batch_size, rec_num_states, self.embedding_size]
|
123 |
+
|
124 |
+
def st_initializer(shape):
|
125 |
+
return jnp.zeros(shape, dtype=self.dtype)
|
126 |
+
|
127 |
+
return self.variable("state", "recurrent_state_" + mode,
|
128 |
+
st_initializer, st_shape)
|
129 |
+
|
130 |
+
def setup(self):
|
131 |
+
# Basic transformer functionality: everything except attention.
|
132 |
+
|
133 |
+
self.tbase = transformer_base.TransformerBase(
|
134 |
+
mode=self.mode,
|
135 |
+
embedding_size=self.embedding_size,
|
136 |
+
num_heads=self.num_heads,
|
137 |
+
head_size=self.head_size,
|
138 |
+
cross_attention_q=self.recurrent_attention or self.cross_attention,
|
139 |
+
cross_attention_kv=False, # or True to use separate k,v.
|
140 |
+
num_position_embeddings=0,
|
141 |
+
num_cross_position_embeddings=0, # or self.recurrent_num_states w/ k,v.
|
142 |
+
dtype=self.dtype)
|
143 |
+
|
144 |
+
# Recurrent transformer functionality.
|
145 |
+
self.recurrent_tbase = None
|
146 |
+
if self.recurrent_attention:
|
147 |
+
# Recurrent transformer layer.
|
148 |
+
# We use a learned position embedding so that each element of the state
|
149 |
+
# can learn to query and compute different summaries.
|
150 |
+
self.recurrent_tbase = transformer_base.TransformerBase(
|
151 |
+
mode="pure", # Disable dropout, which breaks jax.lax.scan.
|
152 |
+
embedding_size=self.embedding_size,
|
153 |
+
num_heads=self.num_heads,
|
154 |
+
head_size=self.head_size,
|
155 |
+
cross_attention_q=True,
|
156 |
+
cross_attention_kv=False, # or True to use separate k,v.
|
157 |
+
num_position_embeddings=self.recurrent_num_states,
|
158 |
+
num_cross_position_embeddings=0, # or self.window_length w/ k,v.
|
159 |
+
gate_type=self.recurrent_gate_type,
|
160 |
+
single_gate=self.recurrent_single_gate,
|
161 |
+
skip_ffn=self.recurrent_skip_ffn,
|
162 |
+
dtype=self.dtype)
|
163 |
+
|
164 |
+
# Initial state at start of document.
|
165 |
+
# We want this to be initially small, but large enough that adafactor
|
166 |
+
# will scale updates to a reasonable value.
|
167 |
+
self.recurrent_initial_state = self.param(
|
168 |
+
"recurrent_initial_state",
|
169 |
+
jax.nn.initializers.normal(stddev=0.1),
|
170 |
+
(self.recurrent_num_states, self.embedding_size), jnp.float32)
|
171 |
+
|
172 |
+
# Cached state from previous step for BPTT.
|
173 |
+
rec_state = {}
|
174 |
+
for mkey in self.supported_modes_for_cache:
|
175 |
+
rec_state[mkey] = self._allocate_cached_recurrent_state(mkey)
|
176 |
+
self.cached_recurrent_state = rec_state
|
177 |
+
|
178 |
+
# Set up relative position encoding.
|
179 |
+
if self.relative_position_type == "fourier":
|
180 |
+
self.relative_positions = position_fourier.RelativeFourierPositions(
|
181 |
+
num_heads=self.num_heads,
|
182 |
+
max_number_of_keys=self.window_length,
|
183 |
+
dtype=self.dtype)
|
184 |
+
elif self.relative_position_type == "t5":
|
185 |
+
self.relative_positions = position_t5.T5RelativePositionBiases(
|
186 |
+
num_buckets=32, # TODO(delesley): Let Gin configure these.
|
187 |
+
max_distance=128,
|
188 |
+
num_heads=self.num_heads,
|
189 |
+
dtype=self.dtype)
|
190 |
+
elif self.relative_position_type == "rotary":
|
191 |
+
# Rotary position encodings (RoPE). No learned bias parameters.
|
192 |
+
self.relative_positions = None
|
193 |
+
else:
|
194 |
+
assert self.relative_position_type is None
|
195 |
+
self.relative_positions = None
|
196 |
+
|
197 |
+
# Set up cache for Transformer-XL style architectures.
|
198 |
+
# A separate cache is created for each each mode (e.g. train, test)
|
199 |
+
cached_kvi = {}
|
200 |
+
if self.use_long_xl_architecture:
|
201 |
+
for mkey in self.supported_modes_for_cache:
|
202 |
+
cached_kvi[mkey] = self._allocate_cached_kvi(mkey)
|
203 |
+
self.cached_kvi = cached_kvi
|
204 |
+
|
205 |
+
# Set up external memory.
|
206 |
+
# A separate memory will be created for each mode (e.g. train, test)
|
207 |
+
mem_layers = {}
|
208 |
+
if self.memory is not None:
|
209 |
+
self.memory_bias = self.param("external_memory_bias", nn.zeros,
|
210 |
+
(self.num_heads,), "float32")
|
211 |
+
for mkey in self.supported_modes_for_cache:
|
212 |
+
mlayer = self.memory.create_memory_layer()
|
213 |
+
# Use setattr to setup the name and module containership hierarchy.
|
214 |
+
setattr(self, "mem_layer_" + mkey, mlayer)
|
215 |
+
mem_layers[mkey] = mlayer
|
216 |
+
self.mem_layers = mem_layers
|
217 |
+
|
218 |
+
def _get_cached_kvi(self, start_of_sequence: Array,
|
219 |
+
mode: str) -> Optional[KVITuple]:
|
220 |
+
"""Returns cached (keys, values, importance) from the previous step."""
|
221 |
+
if not self.use_long_xl_architecture:
|
222 |
+
return None
|
223 |
+
if mode not in self.cached_kvi:
|
224 |
+
# No cache, but we're using XL / sliding window, so return zeros.
|
225 |
+
logging.info("tlayer: using zero as initial XL cache value.")
|
226 |
+
kvi_shape = (self.batch_size, self.window_length,
|
227 |
+
self.num_heads, self.head_size)
|
228 |
+
return attention.initial_kvi(kvi_shape,
|
229 |
+
self.compute_importance, dtype=self.dtype)
|
230 |
+
|
231 |
+
# New documents start with zero_kv.
|
232 |
+
# Continuing the same document will attend to previous keys/vals.
|
233 |
+
(pkeys, pvals, pimportance) = self.cached_kvi[mode]
|
234 |
+
(zkeys, zvals, zimportance) = attention.initial_kvi(
|
235 |
+
pkeys.value.shape, self.compute_importance, dtype=self.dtype)
|
236 |
+
|
237 |
+
# Broadcast start_of_sequence over non-batch dims.
|
238 |
+
b = self.batch_size
|
239 |
+
start_of_sequence_kv = jnp.reshape(start_of_sequence, [b, 1, 1, 1])
|
240 |
+
prev_keys = jnp.where(start_of_sequence_kv, zkeys, pkeys.value)
|
241 |
+
prev_vals = jnp.where(start_of_sequence_kv, zvals, pvals.value)
|
242 |
+
if self.compute_importance:
|
243 |
+
start_of_sequence_imp = jnp.reshape(start_of_sequence, [b, 1])
|
244 |
+
prev_importance = jnp.where(start_of_sequence_imp, zimportance,
|
245 |
+
pimportance.value)
|
246 |
+
else:
|
247 |
+
prev_importance = None
|
248 |
+
logging.debug("tlayer: start_of_sequence = %r", start_of_sequence)
|
249 |
+
logging.info("tlayer: prev_keys[%r] = %r", mode, prev_keys)
|
250 |
+
logging.debug("tlayer: prev_importance[%r] = %r", mode, prev_importance)
|
251 |
+
return (prev_keys, prev_vals, prev_importance)
|
252 |
+
|
253 |
+
def _set_cached_kvi(self, next_kvi: KVITuple, mode: str):
|
254 |
+
"""Caches the last (keys, values, importance) from the current step."""
|
255 |
+
if not self.use_long_xl_architecture:
|
256 |
+
return
|
257 |
+
if mode not in self.cached_kvi:
|
258 |
+
return
|
259 |
+
|
260 |
+
(pkeys, pvals, pimportance) = self.cached_kvi[mode]
|
261 |
+
(nkeys, nvals, nimportance) = next_kvi # From last window
|
262 |
+
logging.info("tlayer: next_keys[%r] = %r", mode, nkeys)
|
263 |
+
pkeys.value = nkeys
|
264 |
+
pvals.value = nvals
|
265 |
+
if self.compute_importance:
|
266 |
+
logging.info("tlayer: next_importance[%r] = %r", mode, nimportance)
|
267 |
+
pimportance.value = nimportance
|
268 |
+
|
269 |
+
def _get_cached_recurrent_state(self, start_of_sequence: Array,
|
270 |
+
mode: str) -> Optional[Array]:
|
271 |
+
"""Returns cached recurrent state from the previous step."""
|
272 |
+
if not self.recurrent_attention:
|
273 |
+
return None
|
274 |
+
if mode not in self.cached_recurrent_state:
|
275 |
+
return None
|
276 |
+
|
277 |
+
b = self.batch_size
|
278 |
+
rstate = self.cached_recurrent_state[mode].value
|
279 |
+
istate = jnp.asarray(self.recurrent_initial_state, dtype=self.dtype)
|
280 |
+
istate = istate[jnp.newaxis, :, :] # Add batch dimension for broadcast.
|
281 |
+
logging.info("tlayer: get_cached_recurrent_state %r, %r", istate, rstate)
|
282 |
+
|
283 |
+
start_of_sequence_st = jnp.reshape(start_of_sequence, (b, 1, 1))
|
284 |
+
return jnp.where(start_of_sequence_st, istate, rstate)
|
285 |
+
|
286 |
+
def _set_cached_recurrent_state(self, next_state: Array, mode: str):
|
287 |
+
"""Store the next recurrent state in the cache."""
|
288 |
+
if not self.recurrent_attention:
|
289 |
+
return
|
290 |
+
if mode not in self.cached_recurrent_state:
|
291 |
+
return
|
292 |
+
|
293 |
+
logging.info("tlayer: set_cached_recurrent_state %r", next_state)
|
294 |
+
rstate = self.cached_recurrent_state[mode]
|
295 |
+
rstate.value = next_state
|
296 |
+
|
297 |
+
def _query_external_memory(self, keys: Array, values: Array, queries: Array,
|
298 |
+
start_of_sequence: Array,
|
299 |
+
mode: str, update_memory: bool):
|
300 |
+
"""Query and update external memory."""
|
301 |
+
if self.memory is None:
|
302 |
+
return None
|
303 |
+
|
304 |
+
# Make sure we initialize (allocate) the external memories for all modes.
|
305 |
+
# Per the flax lazy module initialization scheme, setup() will not be
|
306 |
+
# invoked on a submodule until that module is actually used.
|
307 |
+
if mode == "init":
|
308 |
+
for (_, mlayer) in self.mem_layers.items():
|
309 |
+
(_, _) = mlayer.topk_retrieval(queries, self.memory_num_neighbors)
|
310 |
+
mode = "train" # Pretend we're in training mode during initialization.
|
311 |
+
|
312 |
+
if mode not in self.mem_layers:
|
313 |
+
return None
|
314 |
+
if self.memory_num_neighbors == 0:
|
315 |
+
raise ValueError("Using memory, but num_neighbors == 0")
|
316 |
+
|
317 |
+
# Grab the appropriate memory layer for the current mode.
|
318 |
+
memory_layer = self.mem_layers[mode]
|
319 |
+
|
320 |
+
# Clear the relevant memories at the start of each new document.
|
321 |
+
if update_memory and self.memory_reset_on_new_doc:
|
322 |
+
# The number of "datasets" is batch_dim * num_heads.
|
323 |
+
# jnp.repeat will "broadcast" start_of_sequence over num_heads.
|
324 |
+
# E.g. if start_of_sequence = [True, False] and 4 heads,
|
325 |
+
# jnp.repeat will yield [T, T, T, T, F, F, F, F]
|
326 |
+
memory_layer.reset(jnp.repeat(start_of_sequence, self.num_heads))
|
327 |
+
|
328 |
+
# Query external memory, with queries.
|
329 |
+
(rkeys, rvals) = memory_layer.topk_retrieval(queries,
|
330 |
+
self.memory_num_neighbors)
|
331 |
+
logging.info("tlayer: query external memory (%r): rvals = %r", mode, rvals)
|
332 |
+
|
333 |
+
# Sanity check all dimensions are as expected.
|
334 |
+
assert rkeys.ndim == 5 # (b, seq_len, num_heads, num_neigh, head_dim)
|
335 |
+
assert rvals.ndim == 5
|
336 |
+
assert rkeys.shape == rvals.shape
|
337 |
+
assert rkeys.shape[0] == queries.shape[0] # batch size
|
338 |
+
assert rkeys.shape[1] == queries.shape[1] # sequence length
|
339 |
+
assert rkeys.shape[2] == self.num_heads
|
340 |
+
assert rkeys.shape[3] == self.memory_num_neighbors
|
341 |
+
assert rkeys.shape[4] == self.head_size
|
342 |
+
|
343 |
+
# Update external memory, with (keys, values).
|
344 |
+
if update_memory:
|
345 |
+
memory_layer.update(keys, values)
|
346 |
+
return (rkeys, rvals)
|
347 |
+
|
348 |
+
def __call__(self, xs: Array, start_of_sequence: Array,
|
349 |
+
*,
|
350 |
+
importance: Optional[Array] = None,
|
351 |
+
cross_attention_kv: Optional[Tuple[Array, Array]] = None,
|
352 |
+
window_state: Optional[WindowState] = None,
|
353 |
+
decoder_state: Optional[DecoderState] = None) -> (
|
354 |
+
Tuple[Array, Optional[Array], Optional[WindowState],
|
355 |
+
Optional[DecoderState], Any]):
|
356 |
+
"""Computes attention over a sequence of inputs.
|
357 |
+
|
358 |
+
Args:
|
359 |
+
xs: input sequence of shape (batch_size, sequence_length, num_hidden)
|
360 |
+
start_of_sequence: An input array of shape (batch_size)
|
361 |
+
|
362 |
+
--- The following must be passed by keyword only. ---
|
363 |
+
importance: Array of shape (batch_size, sequence_length).
|
364 |
+
An importance bias for attention.
|
365 |
+
cross_attention_kv: Keys and values from encoder for cross-attention.
|
366 |
+
window_state: State object which contains context from the prior
|
367 |
+
window when using a transformer-XL or sliding window.
|
368 |
+
Initially created with load_window_state().
|
369 |
+
decoder_state: State object for autoregressive decoding, initially
|
370 |
+
created with from init_decoder_state().
|
371 |
+
|
372 |
+
Returns:
|
373 |
+
(ys: outputs of shape (batch_size, sequence_length, num_hidden),
|
374 |
+
importance: importance values for the next layer,
|
375 |
+
next_window_state: state to pass to the next window,
|
376 |
+
next_decoder_state: next decoder state for autoregressive decoding,
|
377 |
+
viz_dict: dictionary of visualizations
|
378 |
+
)
|
379 |
+
"""
|
380 |
+
|
381 |
+
xs = jnp.asarray(xs, dtype=self.dtype)
|
382 |
+
logging.info("tlayer: xs = %r", xs)
|
383 |
+
logging.info("tlayer: recurrent = %r", self.recurrent_attention)
|
384 |
+
logging.info("tlayer: cross-attention = %r", cross_attention_kv is not None)
|
385 |
+
|
386 |
+
is_training = (self.mode == "train")
|
387 |
+
|
388 |
+
# Compute keys, values and queries.
|
389 |
+
# ---------------------------------
|
390 |
+
logging.info("tlayer: compute keys,values,queries.")
|
391 |
+
(keys, values, queries, queries2) = self.tbase.kvq(xs)
|
392 |
+
attention_scale_factors = self.tbase.attention_scale_factors()
|
393 |
+
(_, sequence_length, num_heads, _) = queries.shape # (b, k, h, d)
|
394 |
+
|
395 |
+
# Get biases and masks that are shared across windows.
|
396 |
+
# ----------------------------------------------------
|
397 |
+
if decoder_state is not None:
|
398 |
+
logging.info("tlayer: using autoregressive decoder.")
|
399 |
+
# When decoding, prior keys,values are loaded from the decoder state.
|
400 |
+
# Other values are precomputed, and loaded from the decoder state.
|
401 |
+
# The decoder state will be updated with the current token.
|
402 |
+
assert window_state is None
|
403 |
+
|
404 |
+
prev_kvi = None
|
405 |
+
recurrent_state = None # Use precomputed recurrent_kvq.
|
406 |
+
cross_attention_kv = None
|
407 |
+
rel_position_bias = decoder_state["relative_position_bias"]
|
408 |
+
causal_mask = None
|
409 |
+
dropout_multiplier = None
|
410 |
+
|
411 |
+
# Reuse cached recurrent keys,values for each token.
|
412 |
+
cached_recurrent_kvq = decoder_state["recurrent_kvq"]
|
413 |
+
if cached_recurrent_kvq is not None:
|
414 |
+
assert cross_attention_kv is None
|
415 |
+
cross_attention_kv = (cached_recurrent_kvq[0], cached_recurrent_kvq[1])
|
416 |
+
del cached_recurrent_kvq
|
417 |
+
|
418 |
+
# Get a full window of keys,values and update decoder state.
|
419 |
+
(decoder_state, keys, values) = self._next_decoder_state(
|
420 |
+
decoder_state, keys, values)
|
421 |
+
|
422 |
+
# Each query attends to window_length prior keys.
|
423 |
+
assert keys.shape[1] == self.window_length
|
424 |
+
kq_relative_offset = self.window_length
|
425 |
+
else:
|
426 |
+
logging.info("tlayer: windowed attention.")
|
427 |
+
# When training, attention is done using windows or chunks, and prior
|
428 |
+
# context (e.g. keys,values from the previous window) is stored in the
|
429 |
+
# window_state object.
|
430 |
+
(prev_kvi, recurrent_state) = window_state # pytype: disable=attribute-error
|
431 |
+
|
432 |
+
# Get the size of the sliding window for pos bias, dropout, & causal mask.
|
433 |
+
(num_queries, num_keys) = attention.sliding_attention_window_shape(
|
434 |
+
(keys, values, importance), prev_kvi, queries,
|
435 |
+
window_length=self.window_length)
|
436 |
+
kq_relative_offset = num_keys - num_queries
|
437 |
+
|
438 |
+
# Get the relative position bias.
|
439 |
+
# The bias doesn't depend on the query content, and so can be precomputed.
|
440 |
+
if self.relative_positions is not None:
|
441 |
+
rel_position_bias = self.relative_positions(num_queries, num_keys,
|
442 |
+
bidirectional=False)
|
443 |
+
logging.info("tlayer: %s relative bias = %r",
|
444 |
+
self.relative_position_type, rel_position_bias)
|
445 |
+
else:
|
446 |
+
rel_position_bias = None
|
447 |
+
|
448 |
+
# Get causal mask.
|
449 |
+
if self.use_causal_mask:
|
450 |
+
causal_mask = position.causal_mask(num_queries, num_keys,
|
451 |
+
window_length=self.window_length)
|
452 |
+
logging.info("tlayer: causal mask = %r", causal_mask)
|
453 |
+
else:
|
454 |
+
causal_mask = None
|
455 |
+
|
456 |
+
# Apply dropout to the attention matrix.
|
457 |
+
# The mask will be broadcast across batches and windows.
|
458 |
+
if self.attn_dropout_rate > 0.0 and is_training:
|
459 |
+
dropout_rng = self.make_rng("dropout")
|
460 |
+
attn_shape = (self.num_heads, num_queries, num_keys)
|
461 |
+
dropout_multiplier = nn_components.dropout_multiplier_mask(
|
462 |
+
dropout_rng, self.attn_dropout_rate, attn_shape, self.dtype)
|
463 |
+
logging.info("tlayer: attn_dropout = %r", dropout_multiplier)
|
464 |
+
else:
|
465 |
+
dropout_multiplier = None
|
466 |
+
|
467 |
+
# Load and store values into external memory, if memory is not None.
|
468 |
+
# ------------------------------------------------------------------
|
469 |
+
(mode, _, update_memory) = self._get_cache_name_from_mode(self.mode)
|
470 |
+
external_kv = self._query_external_memory(
|
471 |
+
keys, values, queries,
|
472 |
+
start_of_sequence=start_of_sequence, mode=mode,
|
473 |
+
update_memory=decoder_state is None and update_memory)
|
474 |
+
|
475 |
+
if self.memory is not None:
|
476 |
+
external_memory_bias = jnp.asarray(self.memory_bias, dtype=self.dtype)
|
477 |
+
external_memory_bias = jnp.reshape(external_memory_bias,
|
478 |
+
(1, 1, num_heads, 1))
|
479 |
+
external_memory_bias = jax.nn.sigmoid(external_memory_bias)
|
480 |
+
else:
|
481 |
+
external_memory_bias = None
|
482 |
+
|
483 |
+
# Compute the number of windows.
|
484 |
+
# ------------------------------
|
485 |
+
if sequence_length < self.window_length:
|
486 |
+
num_windows = 1 # Happens with autoregressive decoding.
|
487 |
+
elif sequence_length == self.window_length:
|
488 |
+
num_windows = 1
|
489 |
+
if self.use_long_xl_architecture:
|
490 |
+
assert prev_kvi is not None
|
491 |
+
else:
|
492 |
+
if not self.use_long_xl_architecture:
|
493 |
+
raise ValueError("Can only use sliding window with Transformer XL.")
|
494 |
+
num_windows = sequence_length // self.window_length
|
495 |
+
if (num_windows * self.window_length) != sequence_length:
|
496 |
+
raise ValueError(f"Window length {self.window_length} must be a " +
|
497 |
+
f"multiple of sequence length {sequence_length}")
|
498 |
+
logging.info("tlayer: num_windows = %d.", num_windows)
|
499 |
+
|
500 |
+
# Define the function to do attention within a single window.
|
501 |
+
# ---------------------------------------------------------
|
502 |
+
def single_window_attention(carry, inputs_w):
|
503 |
+
# This function uses the following variables from the outer scope.
|
504 |
+
# They are listed here for clarity.
|
505 |
+
nonlocal rel_position_bias
|
506 |
+
nonlocal causal_mask
|
507 |
+
nonlocal kq_relative_offset
|
508 |
+
nonlocal dropout_multiplier
|
509 |
+
nonlocal attention_scale_factors
|
510 |
+
nonlocal external_memory_bias
|
511 |
+
nonlocal cross_attention_kv # externally supplied.
|
512 |
+
|
513 |
+
# keys,values,queries over the whole sequence will be split into chunks.
|
514 |
+
# xs_w, kvqi_w, etc. are the chunk for the current window.
|
515 |
+
(prev_kvi_w, rec_state) = carry # carried from one window to the next.
|
516 |
+
(kvqi_w, external_kv_w) = inputs_w # inputs to the current window.
|
517 |
+
# (keys_curr_w, values_curr_w, _, _, importance_curr_w) = kvqi_w
|
518 |
+
|
519 |
+
# Concatenate keys,values from the previous window with the current
|
520 |
+
# window to implement sliding window attention.
|
521 |
+
(kvqi_w, next_kvi_w) = attention.concat_kvqi(kvqi_w, prev_kvi_w)
|
522 |
+
(keys_w, values_w, queries_w, queries2_w, importance_w) = kvqi_w
|
523 |
+
|
524 |
+
# Perform recurrent attention within the current window to get the next
|
525 |
+
# recurrent state, and set up cross attention.
|
526 |
+
if rec_state is not None:
|
527 |
+
logging.info("tlayer: recurrent attention.")
|
528 |
+
|
529 |
+
# NOTE -- recurrent states and input tokens are handled separately,
|
530 |
+
# because they have separate learned positional embeddings. Due to
|
531 |
+
# the way TransformerBase does cross-attention, this means that we use
|
532 |
+
# separate key,value layers for rec_state and tokens_w.
|
533 |
+
|
534 |
+
# Keys, values, queries from recurrent state.
|
535 |
+
logging.info("tlayer: recurrent kvq.")
|
536 |
+
rec_kvq = self.recurrent_tbase.kvq(rec_state)
|
537 |
+
r_scale_factors = self.recurrent_tbase.attention_scale_factors()
|
538 |
+
(r_keys, r_values, r_queries, r_queries2) = rec_kvq
|
539 |
+
|
540 |
+
# Joint attention over both recurrent states and input tokens.
|
541 |
+
logging.info("tlayer: recurrent self-attention.")
|
542 |
+
r_attn_ys = attention.simple_attention(
|
543 |
+
r_keys, r_values, r_queries, None,
|
544 |
+
scale_factor=r_scale_factors[0],
|
545 |
+
dtype=self.dtype)
|
546 |
+
|
547 |
+
logging.info("tlayer: recurrent cross-attention.")
|
548 |
+
r_cross_attn_ys = attention.simple_attention(
|
549 |
+
keys_w, values_w, r_queries2, importance_w,
|
550 |
+
scale_factor=r_scale_factors[1],
|
551 |
+
dtype=self.dtype)
|
552 |
+
|
553 |
+
# Recurrent post-attention FFN.
|
554 |
+
logging.info("tlayer: recurrent ffn.")
|
555 |
+
next_rec_state = self.recurrent_tbase.post_attn_ffn(
|
556 |
+
rec_state, r_attn_ys, r_cross_attn_ys)
|
557 |
+
|
558 |
+
# Get keys and values for cross-attention from recurrent state.
|
559 |
+
assert cross_attention_kv is None
|
560 |
+
local_cross_attention_kv = (r_keys, r_values)
|
561 |
+
else:
|
562 |
+
# Get keys and values for cross-attention from external argument.
|
563 |
+
next_rec_state = None
|
564 |
+
local_cross_attention_kv = cross_attention_kv
|
565 |
+
|
566 |
+
# If using RoPE, keys and queries are rotated before self-attention.
|
567 |
+
if self.relative_position_type == "rotary":
|
568 |
+
logging.info("Using rotary position encodings (RoPE), offset = %d",
|
569 |
+
kq_relative_offset)
|
570 |
+
(keys_w, queries_w) = position.rotate_kq(keys_w, queries_w,
|
571 |
+
max_wavelength=10_000,
|
572 |
+
offset=kq_relative_offset)
|
573 |
+
|
574 |
+
# Self-attention over input tokens.
|
575 |
+
logging.info("tlayer: self-attention.")
|
576 |
+
attn_ys_w = attention.simple_attention(
|
577 |
+
keys_w, values_w, queries_w, importance_w,
|
578 |
+
relative_position_bias=rel_position_bias,
|
579 |
+
scale_factor=attention_scale_factors[0],
|
580 |
+
causal_mask=causal_mask,
|
581 |
+
dropout_multiplier=dropout_multiplier,
|
582 |
+
dtype=self.dtype)
|
583 |
+
|
584 |
+
# Attention over external memory.
|
585 |
+
if external_kv_w is not None:
|
586 |
+
(external_keys_w, external_values_w) = external_kv_w
|
587 |
+
y_ext = attention.external_attention(
|
588 |
+
external_keys_w, external_values_w, queries_w,
|
589 |
+
scale_factor=attention_scale_factors[0])
|
590 |
+
if external_memory_bias is not None:
|
591 |
+
ebias = external_memory_bias
|
592 |
+
logging.info("tlayer: using external memory bias = %r", ebias)
|
593 |
+
attn_ys_w = (attn_ys_w * (1 - ebias)) + (y_ext * ebias)
|
594 |
+
else:
|
595 |
+
attn_ys_w += y_ext
|
596 |
+
|
597 |
+
# Cross attention from input tokens to encoder or recurrent state.
|
598 |
+
if local_cross_attention_kv is not None:
|
599 |
+
logging.info("tlayer: cross-attention.")
|
600 |
+
(c_keys, c_values) = local_cross_attention_kv
|
601 |
+
|
602 |
+
# Cross-attention using queries2.
|
603 |
+
cross_attn_ys_w = attention.simple_attention(
|
604 |
+
c_keys, c_values, queries2_w, None,
|
605 |
+
scale_factor=attention_scale_factors[1],
|
606 |
+
dtype=self.dtype)
|
607 |
+
else:
|
608 |
+
cross_attn_ys_w = None
|
609 |
+
|
610 |
+
# End function single_window_attention(...)
|
611 |
+
return ((next_kvi_w, next_rec_state),
|
612 |
+
(attn_ys_w, cross_attn_ys_w))
|
613 |
+
|
614 |
+
# Initialize recurrent_tbase before calling jax.lax.scan.
|
615 |
+
# Otherwise flax will throw a tantrum.
|
616 |
+
if (self.recurrent_attention and 0 <= self.max_unrolled_windows and
|
617 |
+
self.max_unrolled_windows < num_windows):
|
618 |
+
logging.info("tlayer: force initialization of recurrent_tbase.")
|
619 |
+
self.recurrent_tbase.force_init(recurrent_state)
|
620 |
+
|
621 |
+
# Perform sliding window attention over all keys,values,queries.
|
622 |
+
# --------------------------------------------------------------
|
623 |
+
initial_carry = (prev_kvi, recurrent_state) # window state.
|
624 |
+
kvqi = (keys, values, queries, queries2, importance)
|
625 |
+
attn_inputs = (kvqi, external_kv)
|
626 |
+
(next_carry, attn_outputs) = attention.split_and_scan(
|
627 |
+
single_window_attention,
|
628 |
+
initial_carry,
|
629 |
+
attn_inputs,
|
630 |
+
sections=num_windows,
|
631 |
+
axis=1,
|
632 |
+
max_unrolled_windows=self.max_unrolled_windows)
|
633 |
+
(attn_ys, cross_attn_ys) = attn_outputs
|
634 |
+
|
635 |
+
logging.info("tlayer: End windows.")
|
636 |
+
|
637 |
+
# Post-attention MLP, resnet, and FFN.
|
638 |
+
# ------------------------------------
|
639 |
+
logging.info("tlayer: final FFN.")
|
640 |
+
ys = self.tbase.post_attn_ffn(xs, attn_ys, cross_attn_ys)
|
641 |
+
|
642 |
+
importance_output = None
|
643 |
+
next_window_state = next_carry if window_state is not None else None
|
644 |
+
viz_dict = {} # Visualizations, not currently enabled.
|
645 |
+
return (ys, importance_output, next_window_state, decoder_state, viz_dict)
|
646 |
+
|
647 |
+
def load_window_state(self, start_of_sequence: Array) -> WindowState:
|
648 |
+
"""Load cached state that is passed from one window to the next."""
|
649 |
+
|
650 |
+
(mode, _, _) = self._get_cache_name_from_mode(self.mode)
|
651 |
+
prev_kvi = self._get_cached_kvi(start_of_sequence, mode)
|
652 |
+
rec_state = self._get_cached_recurrent_state(start_of_sequence, mode)
|
653 |
+
if prev_kvi is not None:
|
654 |
+
logging.info("tlayer: Loaded keys,values for mode %s from cache %s",
|
655 |
+
self.mode, mode)
|
656 |
+
else:
|
657 |
+
logging.info("tlayer: Skipping XL cache for mode %s.", self.mode)
|
658 |
+
if rec_state is not None:
|
659 |
+
logging.info("tlayer: Loaded recurrent state for mode %s from cache %s.",
|
660 |
+
self.mode, mode)
|
661 |
+
return (prev_kvi, rec_state)
|
662 |
+
|
663 |
+
def store_window_state(self, window_state: WindowState):
|
664 |
+
"""Write window state to the cache."""
|
665 |
+
|
666 |
+
(mode, update_cache, _) = self._get_cache_name_from_mode(self.mode)
|
667 |
+
(next_kvi, next_rec_state) = window_state # pytype: disable=attribute-error
|
668 |
+
if update_cache and next_kvi is not None:
|
669 |
+
logging.info("tlayer: Storing keys,values for mode %s in cache %s.",
|
670 |
+
self.mode, mode)
|
671 |
+
self._set_cached_kvi(next_kvi, mode)
|
672 |
+
else:
|
673 |
+
logging.info("tlayer: Skipping XL cache update for mode %s.", self.mode)
|
674 |
+
if update_cache and next_rec_state is not None:
|
675 |
+
logging.info("tlayer: Storing recurrent state for mode %s in cache %s.",
|
676 |
+
self.mode, mode)
|
677 |
+
self._set_cached_recurrent_state(next_rec_state, mode)
|
678 |
+
|
679 |
+
def get_recurrent_kv(self, window_state: WindowState):
|
680 |
+
"""Get the recurrent keys,values from window_state."""
|
681 |
+
|
682 |
+
# TODO(delesley): optimize.
|
683 |
+
# This isn't ideal, because we wind up computing the recurrent keys,values
|
684 |
+
# twice -- once within the sliding window above, and again in the
|
685 |
+
# DecoderStack, so they can be passed to other layers. However, the
|
686 |
+
# plumbing is a lot simpler this way.
|
687 |
+
if window_state is None:
|
688 |
+
return None
|
689 |
+
(_, rec_state) = window_state
|
690 |
+
if rec_state is None:
|
691 |
+
return None
|
692 |
+
logging.info("tlayer: get_recurrent_kv.")
|
693 |
+
(r_keys, r_values, _, _) = self.recurrent_tbase.kvq(rec_state)
|
694 |
+
return (r_keys, r_values)
|
695 |
+
|
696 |
+
def init_decoder_state(self, sequence_length: int,
|
697 |
+
start_of_sequence: Array) -> DecoderState:
|
698 |
+
"""Initialize decoder state for autoregressive generation.
|
699 |
+
|
700 |
+
Args:
|
701 |
+
sequence_length: The maximum length of the sequence to generate.
|
702 |
+
start_of_sequence: Array of boolean of shape (batch_size,)
|
703 |
+
True if starting a new sequence (with no prefix).
|
704 |
+
|
705 |
+
Returns:
|
706 |
+
A state object that can be passed to __call__.
|
707 |
+
"""
|
708 |
+
|
709 |
+
# Note that generate always uses a local context of size window_length.
|
710 |
+
# Training should be set up appropriately.
|
711 |
+
if not self.use_long_xl_architecture:
|
712 |
+
raise ValueError("Generation is only supported for transformer XL.")
|
713 |
+
if not self.use_causal_mask:
|
714 |
+
raise ValueError("Generator must have been trained with a causal mask.")
|
715 |
+
|
716 |
+
(mode, _, _) = self._get_cache_name_from_mode(self.mode)
|
717 |
+
|
718 |
+
# Get relative position bias.
|
719 |
+
if self.relative_positions is not None:
|
720 |
+
# Relative positions for all tokens *prior* to the current token.
|
721 |
+
# The causal mask prevents each token from attending to itself.
|
722 |
+
rel_position_bias = self.relative_positions(1, self.window_length,
|
723 |
+
offset=self.window_length,
|
724 |
+
bidirectional=False)
|
725 |
+
else:
|
726 |
+
rel_position_bias = None
|
727 |
+
|
728 |
+
# Initialize autoregressive storage for (key, value) pairs.
|
729 |
+
# Include space for a prefix of window_length tokens.
|
730 |
+
num_keys = sequence_length + self.window_length
|
731 |
+
stored_shape = (self.batch_size, num_keys, self.num_heads, self.head_size)
|
732 |
+
stored_keys = jnp.zeros(stored_shape, dtype=self.dtype)
|
733 |
+
stored_values = jnp.zeros(stored_shape, dtype=self.dtype)
|
734 |
+
start_index = self.window_length
|
735 |
+
|
736 |
+
# Copy keys,values from cache into storage, for use as a prefix.
|
737 |
+
prev_kvi = self._get_cached_kvi(start_of_sequence, mode)
|
738 |
+
if prev_kvi is not None:
|
739 |
+
(pkeys, pvals, prev_imps) = prev_kvi
|
740 |
+
assert prev_imps is None # Not yet supported.
|
741 |
+
assert pkeys.ndim == 4
|
742 |
+
assert pkeys.shape[1] == self.window_length # (b, wlen, num_heads, d)
|
743 |
+
|
744 |
+
stored_keys = jax.lax.dynamic_update_slice_in_dim(
|
745 |
+
stored_keys, pkeys, 0, axis=1)
|
746 |
+
stored_values = jax.lax.dynamic_update_slice_in_dim(
|
747 |
+
stored_values, pvals, 0, axis=1)
|
748 |
+
|
749 |
+
# Grab the current recurrent_state, and precompute keys,values,queries.
|
750 |
+
rstate = self._get_cached_recurrent_state(start_of_sequence, mode)
|
751 |
+
if rstate is not None:
|
752 |
+
recurrent_kvq = self.recurrent_tbase.kvq(rstate)
|
753 |
+
else:
|
754 |
+
recurrent_kvq = None
|
755 |
+
|
756 |
+
decoder_state_dict = {
|
757 |
+
"keys": stored_keys,
|
758 |
+
"values": stored_values,
|
759 |
+
"current_index": start_index,
|
760 |
+
"relative_position_bias": rel_position_bias,
|
761 |
+
"recurrent_kvq": recurrent_kvq
|
762 |
+
}
|
763 |
+
return DecoderState(decoder_state_dict)
|
764 |
+
|
765 |
+
def _next_decoder_state(self, decoder_state: DecoderState,
|
766 |
+
keys: Array, values: Array) -> Tuple[
|
767 |
+
DecoderState, Array, Array]:
|
768 |
+
"""Compute the next decoder state, and return keys,values to attend to.
|
769 |
+
|
770 |
+
The keys,values returned from this function are drawn from the prior
|
771 |
+
decoding state, and comprise a full window of local context.
|
772 |
+
|
773 |
+
Args:
|
774 |
+
decoder_state: The current decoder state, initially created using
|
775 |
+
init_decoder_state().
|
776 |
+
keys: The key for the current token, of shape (batch_size, 1, dim)
|
777 |
+
values: The value for the current token of shape (batch_size, 1, dim)
|
778 |
+
|
779 |
+
Returns:
|
780 |
+
(next_decoder_state,
|
781 |
+
window of keys of shape (batch_size, window_length, dim),
|
782 |
+
window of values of shape (batch_size, window_length, dim))
|
783 |
+
"""
|
784 |
+
|
785 |
+
assert keys.shape[1] == 1 # single-token autoregressive decoding.
|
786 |
+
|
787 |
+
logging.info("attn_layer: next decoder state; key = %r", keys)
|
788 |
+
|
789 |
+
# Unpack decoder_state
|
790 |
+
stored_keys = decoder_state["keys"]
|
791 |
+
stored_values = decoder_state["values"]
|
792 |
+
curr_index = decoder_state["current_index"]
|
793 |
+
|
794 |
+
# Slice to get window_length-sized chunk of previous keys,values.
|
795 |
+
out_decoder_state = {}
|
796 |
+
curr_win_index = curr_index - self.window_length
|
797 |
+
out_keys = jax.lax.dynamic_slice_in_dim(
|
798 |
+
stored_keys, curr_win_index, self.window_length, axis=1)
|
799 |
+
out_values = jax.lax.dynamic_slice_in_dim(
|
800 |
+
stored_values, curr_win_index, self.window_length, axis=1)
|
801 |
+
|
802 |
+
# Write current keys,values to stored keys, values.
|
803 |
+
stored_keys = jax.lax.dynamic_update_slice_in_dim(
|
804 |
+
stored_keys, keys, curr_index, axis=1)
|
805 |
+
stored_values = jax.lax.dynamic_update_slice_in_dim(
|
806 |
+
stored_values, values, curr_index, axis=1)
|
807 |
+
curr_index = curr_index + 1
|
808 |
+
|
809 |
+
# Pack a new decoder_state object.
|
810 |
+
out_decoder_state["keys"] = stored_keys
|
811 |
+
out_decoder_state["values"] = stored_values
|
812 |
+
out_decoder_state["current_index"] = curr_index
|
813 |
+
out_decoder_state["relative_position_bias"] = (
|
814 |
+
decoder_state["relative_position_bias"])
|
815 |
+
out_decoder_state["recurrent_kvq"] = decoder_state["recurrent_kvq"]
|
816 |
+
|
817 |
+
return (DecoderState(out_decoder_state), out_keys, out_values)
|