HugoVoxx commited on
Commit
15bcbe6
·
verified ·
1 Parent(s): f18cde5

Upload 20 files

Browse files
aglib/meliad/transformer/__init__.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 Google.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
aglib/meliad/transformer/attention.py ADDED
@@ -0,0 +1,443 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 Google.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Transformer attention functions."""
16
+
17
+ import typing
18
+ from typing import Any, Callable, Mapping, NewType, Optional, Sequence, Tuple, Union
19
+
20
+ from absl import logging
21
+ from flax import linen as nn
22
+ import jax
23
+ import jax.numpy as jnp
24
+
25
+ from transformer import nn_components
26
+ from transformer import position
27
+
28
+
29
+ Array = jnp.ndarray
30
+ ArrayTree = Union[Array, Tuple["ArrayTree", ...]]
31
+ DecoderState = NewType("DecoderState", Mapping[str, Array])
32
+
33
+ # Tuple of keys, values, importance.
34
+ KVITuple = Tuple[Array, Array, Optional[Array]]
35
+
36
+ # Tuple of keys, values, queries, queries2, importance.
37
+ KVQITuple = Tuple[Array, Array, Array, Optional[Array], Optional[Array]]
38
+
39
+ # Tuple of scale factors. See TransformerBase.attention_scale_factors().
40
+ AttnScaleTuple = Tuple[Optional[Array], Optional[Array]]
41
+
42
+
43
+ def initial_kvi(shape: Sequence[int], use_importance: bool, dtype: Any):
44
+ """Returns initial (zero) keys/values/i that can be passed to prev_kvi."""
45
+ z = jnp.zeros(shape, dtype=dtype)
46
+ if use_importance:
47
+ i = jnp.zeros((shape[0], shape[1]), dtype=dtype) # (bsize, window_length)
48
+ else:
49
+ i = None
50
+ return (z, z, i)
51
+
52
+
53
+ def concat_kvqi(kvqi: KVQITuple, prev_kvi: Optional[KVITuple]) -> (
54
+ Tuple[KVQITuple, Optional[KVITuple]]):
55
+ """Concatenate previous keys,values with current keys,values.
56
+
57
+ Args:
58
+ kvqi: Current keys, values, queries, quieres2, importance.
59
+ prev_kvi: Previous keys, values, importance.
60
+
61
+ Returns:
62
+ (kvqi: Concatenated (keys, values, queries, importance),
63
+ next_kvi: Next (keys, values, importance)) (from kvqi)
64
+ """
65
+
66
+ (keys, values, queries, queries2, importance) = kvqi
67
+ # The current keys,values,importance will be passed to the next window.
68
+ next_kvi = (keys, values, importance)
69
+ (batch_size, _, num_heads, head_dim) = keys.shape # (b, _, h, d)
70
+
71
+ if prev_kvi is None:
72
+ return (kvqi, None) # If prev_kvi is None, next_kvi should be None.
73
+
74
+ # Unpack prev_kvi and check shapes.
75
+ (pkeys, pvalues, pimportance) = prev_kvi
76
+ num_pkeys = pkeys.shape[1]
77
+ assert pkeys.shape == (batch_size, num_pkeys, num_heads, head_dim)
78
+ assert pkeys.shape == pvalues.shape
79
+ if pimportance is not None:
80
+ assert pimportance.shape == (batch_size, num_pkeys)
81
+
82
+ # Concatenate keys and values.
83
+ keys = jnp.concatenate([pkeys, keys], axis=1) # (b, k, h, d)
84
+ values = jnp.concatenate([pvalues, values], axis=1) # (b, k, h, d)
85
+ if importance is not None:
86
+ assert pimportance is not None
87
+ importance = jnp.concatenate([pimportance, importance], axis=1) # (b, k)
88
+ logging.info("attn: importance = %r", importance)
89
+
90
+ return ((keys, values, queries, queries2, importance), next_kvi)
91
+
92
+
93
+ def simple_attention(keys: Array,
94
+ values: Array,
95
+ queries: Array,
96
+ importance: Optional[Array],
97
+ *,
98
+ relative_position_bias: Optional[Array] = None,
99
+ scale_factor: Optional[Array] = None,
100
+ causal_mask: Optional[Array] = None,
101
+ dropout_multiplier: Optional[Array] = None,
102
+ dtype: Any = jnp.float32) -> Array:
103
+ """Simple attention from a set of queries to a set of keys,values.
104
+
105
+ Args:
106
+ keys: of shape [batch_size, num_keys, num_heads, head_dim].
107
+ values: of shape [batch_size, num_keys, num_heads, head_dim].
108
+ queries: of shape [batch_size, num_queries, num_heads, head_dim].
109
+ importance: of shape [batch_size, num_keys].
110
+
111
+ *: ---- the following arguments are passed by keyword only ----
112
+ relative_position_bias: A positional attention matrix of shape
113
+ [num_heads, num_queries, num_keys]
114
+ scale_factor: Learned scale factor for use with normalized keys,queries
115
+ of shape [num_heads]
116
+ causal_mask: A boolean array of shape [num_heads, num_queries, num_keys]
117
+ dropout_multiplier: A random mask of either 0.0 or 1.0/keep_prob,
118
+ of shape [num_heads, num_queries, num_keys]
119
+ dtype: data type to perform attention at.
120
+
121
+ Returns:
122
+ Attention outputs of shape [batch_size, num_queries, num_heads, head_size]
123
+ """
124
+
125
+ # (batch_size, num_keys, num_heads, head_dim)
126
+ (batch_size, num_keys, num_heads, head_dim) = keys.shape # (b, k, h, d)
127
+ num_queries = queries.shape[1]
128
+ assert keys.shape == values.shape
129
+ assert queries.shape == (batch_size, num_queries, num_heads, head_dim)
130
+ if importance is not None:
131
+ assert importance.shape == (batch_size, num_keys)
132
+
133
+ logging.info("attn: keys = %r", keys)
134
+ logging.info("attn: queries = %r", queries)
135
+
136
+ # Compute attention matrix.
137
+ attn = jnp.einsum("...qhd,...khd->...hqk", queries, keys) # (b, h, q, k)
138
+
139
+ logging.info("attn: content attn = %r", attn)
140
+
141
+ # Apply relative position bias.
142
+ if relative_position_bias is not None:
143
+ logging.info("attn: pbias = %r", relative_position_bias)
144
+ relative_position_bias = jnp.asarray(relative_position_bias, dtype=dtype)
145
+ pbias = position.broadcast_mask(relative_position_bias, attn)
146
+ attn = attn + pbias
147
+
148
+ # Apply learned attention scale.
149
+ if scale_factor is not None:
150
+ logging.info("attn: learned attention scale: %s", scale_factor)
151
+ # Broadcast scale over batch/keys/queries.
152
+ scale_factor = jnp.asarray(scale_factor, dtype=dtype)
153
+ scale_factor = scale_factor.reshape((1, num_heads, 1, 1))
154
+ attn = attn * scale_factor
155
+
156
+ # Apply causal mask.
157
+ if causal_mask is not None:
158
+ causal_mask = position.broadcast_mask(causal_mask, attn)
159
+ attn = jnp.where(causal_mask, attn, jnp.asarray(-1_000_000.0, dtype=dtype))
160
+
161
+ logging.info("attn: pre-softmax attn = %r", attn)
162
+
163
+ # Normalize attention matrix with softmax.
164
+ # min_x should be much smaller than minimum expected values in attn, but
165
+ # much larger than the masked_out values created by the causal mask. That
166
+ # way, if all tokens are masked out, then softmax will attend to nothing,
167
+ # rather than attend to everything equally.
168
+ min_x = jnp.asarray(-1000.0, dtype=dtype)
169
+ attn = nn_components.safe_softmax(attn, axis=-1, min_x=min_x) # (b, h, q, k)
170
+
171
+ # Apply dropout to attention matrix.
172
+ if dropout_multiplier is not None:
173
+ logging.debug("attn: drop = %r", dropout_multiplier)
174
+ dropout_multiplier = jnp.asarray(dropout_multiplier, dtype=dtype)
175
+ attn = attn * dropout_multiplier
176
+
177
+ logging.info("attn: final attn = %r", attn)
178
+
179
+ # Compute output -- values weighted by attention matrix.
180
+ y = jnp.einsum("...hqk,...khd->...qhd", attn, values) # (b, q, h, d)
181
+
182
+ logging.info("attn: y = %r", y)
183
+ return y
184
+
185
+
186
+ def external_attention(external_keys: Array,
187
+ external_values: Array,
188
+ queries: Array,
189
+ *,
190
+ scale_factor: Optional[Array] = None,
191
+ dtype: Any = jnp.float32) -> Array:
192
+ """Attention over (keys, values) retrieved from external memory.
193
+
194
+ Args:
195
+ external_keys: per-query keys from external memory, of shape
196
+ [batch_size, num_queries, num_heads, num_neighbors, head_size]
197
+ external_values: per-query values from external memory, of shape
198
+ [batch_size, num_queries, num_heads, num_neighbors, head_size]
199
+ queries: current queries, of shape:
200
+ [batch_size, num_queries, num_heads, head_size]
201
+
202
+ *: ---- the following arguments are passed by keyword only. ---
203
+ scale_factor: Learned scale factor for use with normalized keys,queries
204
+ of shape [num_heads]
205
+ dtype: data type to perform attention at.
206
+
207
+ Returns:
208
+ Attention outputs of shape [batch_size, num_queries, num_heads, head_size]
209
+ """
210
+
211
+ (batch_size, num_queries, num_heads, _, head_dim) = external_keys.shape
212
+ assert queries.shape == (batch_size, num_queries, num_heads, head_dim)
213
+ assert external_values.shape == external_keys.shape
214
+
215
+ # Build attention matrix.
216
+ logging.info("ext_attn: external keys = %r", external_keys)
217
+ ext_attn = jnp.einsum("...qhd,...qhid->...hqi", queries, external_keys)
218
+
219
+ logging.info("ext_attn: external_mem_attn: %s", ext_attn)
220
+ if scale_factor is not None:
221
+ scale_factor = jnp.asarray(scale_factor, dtype=dtype)
222
+ scale_factor = scale_factor.reshape((1, num_heads, 1, 1))
223
+ logging.info("ext_attn: scaling external_mem_attn by %s", scale_factor)
224
+ ext_attn = ext_attn * scale_factor
225
+
226
+ ext_attn = nn.softmax(ext_attn, axis=-1)
227
+
228
+ # Compute weighted sum of values.
229
+ ext_y = jnp.einsum("...hqi,...qhid->...qhd", ext_attn, external_values)
230
+ logging.info("ext_attn: ext_y = %r", ext_y)
231
+ return ext_y
232
+
233
+
234
+ def sliding_attention_window_shape(kvi: KVITuple,
235
+ prev_kvi: Optional[KVITuple],
236
+ queries: Array,
237
+ window_length: int) -> Tuple[int, int]:
238
+ """Return (num_queries, num_keys) for the sliding attention window."""
239
+
240
+ # Do error checking here.
241
+ (keys, values, importance) = kvi
242
+ assert keys.shape == queries.shape
243
+ assert values.shape == queries.shape
244
+
245
+ # Get sizes...
246
+ (batch_size, sequence_length, _, _) = queries.shape
247
+
248
+ if importance is not None:
249
+ assert importance.ndim == 2
250
+ assert importance.shape == (batch_size, sequence_length)
251
+
252
+ assert window_length > 0
253
+ if window_length >= sequence_length:
254
+ # No sliding window.
255
+ num_queries = sequence_length
256
+ num_keys = sequence_length
257
+ if prev_kvi is not None:
258
+ num_keys += prev_kvi[0].shape[1]
259
+ else:
260
+ # Sliding window.
261
+ if prev_kvi is not None:
262
+ assert prev_kvi[0].shape[1] == window_length
263
+ num_queries = window_length
264
+ num_keys = window_length * 2
265
+
266
+ return (num_queries, num_keys)
267
+
268
+
269
+ def split_tree(tree: ArrayTree, sections: int, axis: int = 0) -> (
270
+ Sequence[ArrayTree]):
271
+ """Recursively splits a possibly nested tuple of arrays along the given axis.
272
+
273
+ Args:
274
+ tree: A nested tree of tuples and arrays.
275
+ sections: The number of sections to split the tree into.
276
+ axis: The axis to do the split on arrays.
277
+
278
+ Returns:
279
+ A list of trees, of length sections, where each has the same shape as the
280
+ original, but with arrays of size 1/sections.
281
+ """
282
+
283
+ if tree is None:
284
+ return [None] * sections
285
+ elif isinstance(tree, jnp.ndarray):
286
+ return jnp.split(tree, sections, axis=axis)
287
+ elif isinstance(tree, tuple):
288
+ # Recursively split each element of the tuple into a list.
289
+ branch_lists = [split_tree(tree_i, sections, axis=axis) for tree_i in tree]
290
+ # Rearrange the tuple of lists into a list of tuples.
291
+ return [tuple([brs[i] for brs in branch_lists]) for i in range(sections)]
292
+ else:
293
+ raise ValueError("Argument %r must be an ndarray or tuple." % tree)
294
+
295
+
296
+ def concat_trees(tree_list: Sequence[ArrayTree], axis: int = 0) -> ArrayTree:
297
+ """Merges a list of trees into a single tree by concatenating their elements.
298
+
299
+ Args:
300
+ tree_list: A list of trees, all of the same shape.
301
+ axis: The axis to concatenate arrays on.
302
+
303
+ Returns:
304
+ A single tree, with the same shape as the trees in tree_list.
305
+ """
306
+
307
+ # All trees in the list are required to have the same shape.
308
+ # We return a tree with the same shape as each of the trees in the list,
309
+ first_tree = tree_list[0]
310
+ if first_tree is None:
311
+ # Merge a list of None into a single None.
312
+ for tree_i in tree_list:
313
+ assert tree_i is None
314
+ return None
315
+ elif isinstance(first_tree, jnp.ndarray):
316
+ # Concatenate a list of arrays.
317
+ for tree_i in tree_list:
318
+ assert isinstance(tree_i, jnp.ndarray)
319
+ return jnp.concatenate(tree_list, axis=axis)
320
+ elif isinstance(first_tree, tuple):
321
+ # Reshape a list of tuples into a tuple of concatenated lists.
322
+ for tree_i in tree_list:
323
+ assert isinstance(tree_i, tuple) and len(tree_i) == len(first_tree)
324
+ num_branches = len(first_tree)
325
+ return tuple([concat_trees([tree[b] for tree in tree_list], axis=axis)
326
+ for b in range(num_branches)])
327
+ else:
328
+ raise ValueError("Argument %r must be an ndarray or tuple." % first_tree)
329
+
330
+
331
+ def reshape_transpose_tree(tree: ArrayTree, sections: int, axis: int = 0) -> (
332
+ ArrayTree):
333
+ """Reshape and transpose arrays so that the window is dimension 0."""
334
+
335
+ # We could use jax tree utils for this, but we do it the hard way so the
336
+ # implementaiton can be compared with split_tree.
337
+ if tree is None:
338
+ return None
339
+ elif isinstance(tree, jnp.ndarray):
340
+ tree = typing.cast(Array, tree) # Tell type-checker about isinstance
341
+ ndim = tree.ndim
342
+ wlen = tree.shape[axis] // sections
343
+ assert sections * wlen == tree.shape[axis] # Must be evenly divisible.
344
+
345
+ # Break the axis dimension into sections * window_size
346
+ arr = tree
347
+ sh = list(arr.shape)
348
+ nshape = sh[0:axis] + [sections, wlen] + sh[axis + 1:]
349
+ arr = jnp.reshape(arr, nshape)
350
+
351
+ # Transpose sections to be dimension 0.
352
+ tdims = [axis] + list(range(0, axis)) + list(range(axis + 1, ndim + 1))
353
+ arr = jnp.transpose(arr, tdims)
354
+ return arr
355
+ elif isinstance(tree, tuple):
356
+ return tuple([reshape_transpose_tree(b, sections, axis) for b in tree])
357
+ else:
358
+ raise ValueError("Argument %r must be an ndarray or tuple." % tree)
359
+
360
+
361
+ def transpose_reshape_tree(tree: ArrayTree, sections: int, axis: int = 0) -> (
362
+ ArrayTree):
363
+ """Reshape and transpose arrays so that the window is dimension 0."""
364
+
365
+ # We could use jax tree utils for this, but we do it the hard way so the
366
+ # implementaiton can be compared with split_tree.
367
+ if tree is None:
368
+ return None
369
+ elif isinstance(tree, jnp.ndarray):
370
+ tree = typing.cast(Array, tree) # Tell type-checker about isinstance
371
+ ndim = tree.ndim - 1 # Input tree has 1 extra dimension on front.
372
+ assert axis < ndim
373
+ wlen = tree.shape[axis + 1] # Window length.
374
+
375
+ # Transpose dimension 0 back to its proper place.
376
+ arr = tree
377
+ tdims = list(range(1, axis + 1)) + [0] + list(range(axis + 1, ndim + 1))
378
+ arr = jnp.transpose(arr, tdims)
379
+
380
+ # Combine the sections and window_size dimensions.
381
+ sh = list(arr.shape)
382
+ nshape = sh[0:axis] + [sections * wlen] + sh[axis + 2:]
383
+ arr = jnp.reshape(arr, nshape)
384
+ return arr
385
+ elif isinstance(tree, tuple):
386
+ return tuple([transpose_reshape_tree(b, sections, axis) for b in tree])
387
+ else:
388
+ raise ValueError("Argument %r must be an ndarray or tuple." % tree)
389
+
390
+
391
+ def split_and_scan(fn: Callable[[ArrayTree, ArrayTree],
392
+ Tuple[ArrayTree, ArrayTree]],
393
+ carry: ArrayTree, input_arrays: ArrayTree,
394
+ sections: int, axis: int = 0,
395
+ max_unrolled_windows: int = -1) -> (
396
+ Tuple[ArrayTree, ArrayTree]):
397
+ """Scan over a set of input arrays in chunks.
398
+
399
+ Splits each array in 'input_arrays' into the number of chunks given by
400
+ 'sections', and then loops over the chunks using a scan operation.
401
+ Returns a concatenation of the results.
402
+
403
+ Args:
404
+ fn: A function from (carry, input_i) -> (carry, output_i).
405
+ carry: The initial state for the scan, that will be passed from one
406
+ iteration to the next.
407
+ input_arrays: A nested tree of tuples of arrays.
408
+ sections: The number of sections or chunks for the split.
409
+ axis: The axis to split each array along.
410
+ max_unrolled_windows: If 0 <= max_unrolled_windows < sections,
411
+ use jax.lax.scan rather than unrolling the windows with a python loop.
412
+
413
+ Returns:
414
+ (carry, output)
415
+ """
416
+
417
+ if sections == 1:
418
+ logging.info("Single window, no scan.")
419
+ return fn(carry, input_arrays)
420
+
421
+ if axis < 0:
422
+ raise ValueError(f"Axis must be positive. Got {axis}")
423
+
424
+ logging.info("Scanning over %d windows", sections)
425
+
426
+ if 0 <= max_unrolled_windows and max_unrolled_windows < sections:
427
+ logging.info("Using jax.lax.scan.")
428
+ in_arrs = reshape_transpose_tree(input_arrays, sections, axis)
429
+ (carry, out_arrs) = jax.lax.scan(fn, carry, in_arrs)
430
+ output_arrays = transpose_reshape_tree(out_arrs, sections, axis)
431
+ return (carry, output_arrays)
432
+
433
+ logging.info("Using unrolled for-loop.")
434
+ in_list = split_tree(input_arrays, sections, axis=axis)
435
+ out_list = []
436
+
437
+ for (k, in_chunk) in enumerate(in_list):
438
+ logging.info("Processing window %d", k)
439
+ (carry, out_chunk) = fn(carry, in_chunk)
440
+ out_list.append(out_chunk)
441
+
442
+ output_arrays = concat_trees(out_list, axis=axis)
443
+ return (carry, output_arrays)
aglib/meliad/transformer/decoder_stack.py ADDED
@@ -0,0 +1,426 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 Google.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Hierarchical transformer."""
16
+
17
+ import functools
18
+ from typing import Any, Callable, Optional, Sequence, Tuple
19
+
20
+ from absl import logging
21
+
22
+ from flax import linen as nn
23
+ from flax import struct
24
+ import gin
25
+ import jax.numpy as jnp
26
+ from transformer import attention
27
+ from transformer import metric_utils
28
+ from transformer import nn_components
29
+ from transformer import position
30
+ from transformer import transformer_layer
31
+
32
+
33
+ Array = Any
34
+
35
+
36
+ # Basic task options are shared among multiple classes.
37
+ @gin.configurable
38
+ @struct.dataclass
39
+ class TransformerTaskConfig:
40
+ """Configuration hyperparameters for sequence-to-sequence tasks."""
41
+
42
+ dataset_name: str = "synthetic"
43
+ train_split: str = "train"
44
+ test_split: str = "test"
45
+ sequential_chunks: bool = True # Process chunks of text in sequential order.
46
+
47
+ sequence_length: int = 4096
48
+ batch_size: int = 1 # per device batch size
49
+ vocab_size: int = 256
50
+
51
+
52
+ DStackDecoderState = Tuple[transformer_layer.DecoderState, ...]
53
+ DStackWindowState = Tuple[transformer_layer.WindowState, ...]
54
+
55
+
56
+ @gin.configurable
57
+ class DecoderStack(nn.Module):
58
+ """Stack of transformer decoder layers."""
59
+
60
+ mode: str
61
+ task_config: TransformerTaskConfig = gin.REQUIRED
62
+
63
+ # Configurable hyperparameters.
64
+ num_layers: int = gin.REQUIRED
65
+ embedding_size: int = gin.REQUIRED
66
+ embedding_stddev: float = 1.0
67
+
68
+ # The class to use for an individual transformer layer.
69
+ layer_factory: Any = gin.REQUIRED
70
+
71
+ # Window length to use for the decoder stack.
72
+ # If nonzero, use this instead of TransformerLayer.window_length.
73
+ dstack_window_length: int = 0
74
+ use_absolute_positions: bool = False
75
+ use_final_layernorm: bool = True
76
+ final_dropout_rate: float = 0.0
77
+ final_mlp_factory: Optional[Callable[[int], nn.Module]] = None
78
+
79
+ # Enable recurrence on particular layers.
80
+ recurrent_layer_indices: Sequence[int] = ()
81
+ feedback_recurrence: bool = True
82
+
83
+ # The factory function which creates a MemoryManager, or None.
84
+ memory_factory: Any = None
85
+ # Layers to equip with external memory.
86
+ memory_layer_indices: Sequence[int] = ()
87
+
88
+ dtype: Any = jnp.float32
89
+
90
+ def is_training(self):
91
+ return self.mode == "train"
92
+
93
+ def supports_generate(self) -> bool:
94
+ return all([lyr.supports_generate() for lyr in self.transformer_layers])
95
+
96
+ def setup(self):
97
+ task_config = self.task_config
98
+
99
+ embed_init = nn.initializers.normal(stddev=self.embedding_stddev,
100
+ dtype=jnp.float32)
101
+ self.embed = nn.Embed(num_embeddings=task_config.vocab_size,
102
+ features=self.embedding_size,
103
+ embedding_init=embed_init)
104
+
105
+ # Create a memory_factory.MemoryManager object, which is shared among
106
+ # all transformer layers. Each layer will use the MemoryManager object
107
+ # to instantiate a block of memory for that layer.
108
+ memory = None
109
+ if self.memory_factory is not None:
110
+ if self.memory_layer_indices:
111
+ memory = self.memory_factory(batch_size=task_config.batch_size,
112
+ mode=self.mode)
113
+ else:
114
+ logging.warning(
115
+ "Memory factory specified, but memory_layer_indices is empty.")
116
+
117
+ # Allow negative numbers in memory_layer_indices.
118
+ # Negative numbers refer to layers at the top of the stack.
119
+ for k in self.memory_layer_indices:
120
+ if k < -self.num_layers or k >= self.num_layers:
121
+ raise ValueError(f"Invalid memory layer index {k}")
122
+ # The % operator will convert negative k to self.num_layers + k.
123
+ mem_layer_indices = [
124
+ idx % self.num_layers for idx in self.memory_layer_indices
125
+ ]
126
+
127
+ # Allow negative numbers in recurrent_layer_indices.
128
+ for k in self.recurrent_layer_indices:
129
+ if k < -self.num_layers or k >= self.num_layers:
130
+ raise ValueError(f"Invalid recurrent layer index {k}")
131
+ recurrent_layer_indices = [
132
+ idx % self.num_layers for idx in self.recurrent_layer_indices
133
+ ]
134
+ # Turn on cross attention if there are recurrent layers with feedback.
135
+ enable_cross_attn = (self.feedback_recurrence and
136
+ self.recurrent_layer_indices and
137
+ self.dstack_window_length > 0)
138
+
139
+ layers = []
140
+ for i in range(0, self.num_layers):
141
+ mem = memory if (i in mem_layer_indices) else None
142
+ rec_i = i in recurrent_layer_indices
143
+ layer_fn = functools.partial(
144
+ self.layer_factory,
145
+ mode=self.mode,
146
+ batch_size=self.task_config.batch_size,
147
+ embedding_size=self.embedding_size,
148
+ name=f"transformer{i}",
149
+ recurrent_attention=rec_i,
150
+ cross_attention=enable_cross_attn and not rec_i)
151
+ if mem:
152
+ logging.info("Using external memory with transformer layer %d.", i)
153
+ layer_fn = functools.partial(
154
+ layer_fn,
155
+ memory=mem,
156
+ # We use partial function applications here only to avoid
157
+ # overwriting the head size unless memory is involved.
158
+ head_size=mem.key_size,
159
+ num_heads=mem.num_heads)
160
+ layers.append(layer_fn())
161
+ self.transformer_layers = layers
162
+
163
+ if self.use_final_layernorm:
164
+ self.final_layernorm = nn_components.LayerNorm()
165
+
166
+ if self.final_mlp_factory is not None:
167
+ self.final_mlp = self.final_mlp_factory(self.embedding_size)
168
+
169
+ def init_decoder_state(self, sequence_length: int,
170
+ start_of_sequence: Array) -> DStackDecoderState:
171
+ """Return initial state for autoregressive generation."""
172
+ return tuple([
173
+ layer.init_decoder_state(sequence_length, start_of_sequence)
174
+ for layer in self.transformer_layers
175
+ ])
176
+
177
+ def load_window_state(self, start_of_sequence: Array) -> DStackWindowState:
178
+ """Load cached state that is passed from one window to the next."""
179
+ return tuple([
180
+ layer.load_window_state(start_of_sequence)
181
+ for layer in self.transformer_layers
182
+ ])
183
+
184
+ def store_window_state(self, window_state: DStackWindowState):
185
+ """Write window state to the cache."""
186
+ for (layer, wstate) in zip(self.transformer_layers, window_state):
187
+ layer.store_window_state(wstate)
188
+
189
+ def _eval_layer_stack(self, xs: Array, start_of_sequence: Array,
190
+ window_state: Optional[DStackWindowState],
191
+ decoder_state: Optional[DStackDecoderState]) -> (
192
+ Tuple[Array, Optional[DStackWindowState],
193
+ Optional[DStackDecoderState], Any]):
194
+ """Evaluate a stack of transformer layers on an input."""
195
+
196
+ ys = xs # (batch_size, seq_len, num_hidden)
197
+ importance = None # (batch_size, sequence_length)
198
+ next_window_states = []
199
+ next_decoder_states = []
200
+ attn_viz_dicts = []
201
+
202
+ # If we have a recurrent layer, grab the keys and values from it.
203
+ # All other layers can then cross-attend to the recurrent keys and values.
204
+ recurrent_kv = None
205
+ enable_cross_attn = (self.feedback_recurrence and
206
+ self.recurrent_layer_indices and
207
+ self.dstack_window_length > 0)
208
+ if enable_cross_attn and window_state is not None:
209
+ # TODO(delesley): fix this so it works with the autoregressive decoder.
210
+ assert decoder_state is None
211
+ logging.info("dstack: using recurrent cross attention on all layers.")
212
+ for (layer, wstate_i) in zip(self.transformer_layers, window_state):
213
+ rkv = layer.get_recurrent_kv(wstate_i)
214
+ if rkv is not None:
215
+ recurrent_kv = rkv
216
+
217
+ # Apply transformer layers.
218
+ for (i, layer) in enumerate(self.transformer_layers):
219
+ if layer.recurrent_attention:
220
+ cross_kv = None # The recurrent layer handles rkv internally.
221
+ else:
222
+ cross_kv = recurrent_kv # Other layers cross-attend to recurrent one.
223
+
224
+ logging.info("dstack: ---- Layer %d ----", i)
225
+ wstate_i = None if window_state is None else window_state[i]
226
+ dstate_i = None if decoder_state is None else decoder_state[i]
227
+ (ys, importance, n_wstate_i, n_dstate_i, viz_dict) = layer(
228
+ ys, start_of_sequence,
229
+ importance=importance,
230
+ cross_attention_kv=cross_kv, # cross-attend to recurrent_kv.
231
+ window_state=wstate_i,
232
+ decoder_state=dstate_i)
233
+ next_window_states.append(n_wstate_i)
234
+ next_decoder_states.append(n_dstate_i)
235
+ attn_viz_dicts.append(viz_dict)
236
+
237
+ window_state = tuple(next_window_states)
238
+ decoder_state = tuple(next_decoder_states)
239
+ return (ys, window_state, decoder_state, attn_viz_dicts)
240
+
241
+ def __call__(self,
242
+ input_tokens: Array,
243
+ target_tokens: Array,
244
+ start_of_sequence: Array,
245
+ decoder_state: Optional[DStackDecoderState] = None) -> (
246
+ Tuple[Array, Optional[DStackDecoderState], Any]):
247
+ """Call the decoder stack.
248
+
249
+ This function will embed tokens, run the embeddings through a stack of
250
+ decoder layers, and then compute logits for the target tokens using the
251
+ transpose of the embeddings. It returns un-normalized (pre-softmax)
252
+ logits.
253
+
254
+ Args:
255
+ input_tokens: Integer array of shape [batch_size, sequence_length]
256
+ target_tokens: For compatibility. Ignored by this class.
257
+ start_of_sequence: Boolean array of shape [batch_size],
258
+ which indicates whether a sequence is at the start of sequence.
259
+ decoder_state: State object for autoregressive decoding,
260
+ created from init_decoder_state.
261
+
262
+ Returns:
263
+ (logits, of shape [batch_size, sequence_length, vocab_size],
264
+ next_decoder_state: for autoregressive decoding,
265
+ viz_dict: dictionary of visualizations,
266
+ )
267
+ """
268
+ del target_tokens
269
+ task_config = self.task_config
270
+
271
+ # Embed tokens.
272
+ embeddings = self.embed(input_tokens) # (batch_size, seq_len, num_hidden)
273
+ embeddings = embeddings.astype(self.dtype)
274
+ sequence_length = embeddings.shape[1]
275
+ logging.info("dstack: embeddings = %r", embeddings)
276
+
277
+ # Add absolute position encodings if necessary.
278
+ if self.use_absolute_positions:
279
+ # Use a large max_wavelength so that only part of the input vector
280
+ # is used for positions.
281
+ positions = position.position_encoding(
282
+ num_positions=task_config.sequence_length,
283
+ input_dim=self.embedding_size,
284
+ max_wavelength=10_000)
285
+ positions = jnp.asarray(positions, dtype=self.dtype)
286
+ positions = jnp.expand_dims(positions, 0) # Add batch dimension.
287
+ logging.info("dstack: absolute positions = %r", positions)
288
+ embeddings = embeddings + positions
289
+
290
+ # Function to run the whole transformer stack on a single window.
291
+ # ---------------------------------------------------------------
292
+ def single_window_stack(carry, inputs_w):
293
+ (window_state_w, start_of_seq_w) = carry
294
+ (outputs_w, window_state_w, _, _) = self._eval_layer_stack(
295
+ inputs_w, start_of_seq_w,
296
+ window_state=window_state_w, decoder_state=None)
297
+
298
+ # start_of_sequence is false after the first window.
299
+ bsize = self.task_config.batch_size
300
+ next_start_of_seq = jnp.asarray([False] * bsize, dtype=jnp.bool_)
301
+ return ((window_state_w, next_start_of_seq), outputs_w)
302
+
303
+ # Find the number of windows. A sequence may be split into multiple
304
+ # windows here, or alternatively, it may be split (or further split) within
305
+ # TransformerLayer, depending on configuration.
306
+ if (self.dstack_window_length == 0 or
307
+ self.dstack_window_length >= sequence_length):
308
+ num_windows = 1
309
+ else:
310
+ num_windows = sequence_length // self.dstack_window_length
311
+ assert (num_windows * self.dstack_window_length) == sequence_length
312
+
313
+ # Evaluate the stack of layers, scanning over windows if configured.
314
+ # ------------------------------------------------------------------
315
+ if decoder_state is None:
316
+ logging.info("dstack: scanning over %d windows.", num_windows)
317
+ # Load cached state from the previous training step, for truncated BPTT.
318
+ window_state = self.load_window_state(start_of_sequence)
319
+
320
+ # Scan single_window_stack over the sequence.
321
+ cstate = (window_state, start_of_sequence)
322
+ (cstate, ys) = attention.split_and_scan(single_window_stack,
323
+ cstate,
324
+ embeddings,
325
+ sections=num_windows,
326
+ axis=1)
327
+ (window_state, _) = cstate
328
+
329
+ # Cache state for the next training step, for truncated BPTT.
330
+ self.store_window_state(window_state)
331
+ attn_viz_dicts = {} # Temporarily disabled.
332
+ else:
333
+ logging.info("dstack: autoregressive generator.")
334
+ # Run as an autoregressive decoder: evaluate the whole stack on a token.
335
+ # Do not load or store window_state; decoder_state is used instead.
336
+ (ys, _, decoder_state, _) = self._eval_layer_stack(
337
+ embeddings, start_of_sequence,
338
+ window_state=None, decoder_state=decoder_state)
339
+ attn_viz_dicts = {}
340
+
341
+ # Apply layernorm to the final output, before calculating logits.
342
+ # With a pre-layernorm architecture, this has to be done here.
343
+ if self.use_final_layernorm:
344
+ logging.info("dstack: Final layernorm.")
345
+ ys = self.final_layernorm(ys)
346
+
347
+ # Final dropout before token prediction.
348
+ drop_tile_shape = (1, 128, self.embedding_size)
349
+ get_dropout_rng = lambda: self.make_rng("dropout")
350
+ ys = nn_components.tiled_dropout(ys, drop_tile_shape,
351
+ self.final_dropout_rate,
352
+ rng_function=get_dropout_rng,
353
+ deterministic=not self.is_training())
354
+
355
+ # Apply an MLP at the very end to convert the output of the transformer
356
+ # into a vector to look up target tokens in the embedding table.
357
+ # This final layer allows the NN to distinguish between the "input context",
358
+ # which is returned by the transformer resnet, and the "predicted target".
359
+ if self.final_mlp_factory is not None:
360
+ logging.info("dstack: Final MLP layer.")
361
+ ys = self.final_mlp(ys)
362
+
363
+ # Reverse embedding to generate logits which predict the output tokens.
364
+ logits = self.embed.attend(ys) # (..., seq_len, vocab_size)
365
+ logging.info("dstack: logits = %r", logits)
366
+
367
+ # Normalize so that the range of logits is reasonable.
368
+ logits = logits / jnp.sqrt(logits.shape[-1]).astype(self.dtype)
369
+
370
+ # Produce various visualizations in generate mode.
371
+ # TODO(delesley): Too many visualizations crashes the summary writer.
372
+ if self.mode == "generate":
373
+ img_dict = self._make_images(attn_viz_dicts, [])
374
+ hist_dict = {} # metric_utils.make_histograms(attn_viz_dicts)
375
+ info_dict = {**img_dict, **hist_dict}
376
+ else:
377
+ info_dict = {} # Don't output any visualizations.
378
+
379
+ return (logits, decoder_state, info_dict)
380
+
381
+ def _make_importance_image(self, importance_list, scaled=True) -> Array:
382
+ rows = []
383
+ for imp in importance_list:
384
+ rows += [imp] * 8 # Rows are 8 pixels high for better visability.
385
+ image = jnp.stack(rows)
386
+ if scaled:
387
+ image = jnp.exp(image)
388
+ image = metric_utils.normalize_image(image, True)
389
+ return metric_utils.reshape_image(image)
390
+
391
+ def _make_images(self, viz_dicts, importance_list):
392
+ image_dict = {}
393
+ for (i, viz_dict) in enumerate(viz_dicts):
394
+ if "attn_importance_gate" in viz_dict:
395
+ imp_gate = viz_dict["attn_importance_gate"][0] # First item in batch.
396
+ imp_strip = metric_utils.normalize_image(imp_gate[:, 0:8, :], True)
397
+ else:
398
+ imp_strip = None
399
+
400
+ for (k, attn_images) in viz_dict.items():
401
+ if k not in {"attn_content",
402
+ "attn_pre_softmax",
403
+ "attn_log",
404
+ "attn",
405
+ "attn_position_bias",
406
+ "attn_importance_bias",
407
+ "attn_importance_gate"}:
408
+ continue
409
+
410
+ attn_img = attn_images[0] # Grab the first item in the batch.
411
+ attn_img = metric_utils.normalize_image(attn_img,
412
+ as_group=(k != "attn"))
413
+ if imp_strip is not None and k in {"attn_log", "attn"}:
414
+ # Show importance bias in a strip at the bottom of the image.
415
+ attn_img = metric_utils.overlay_images(attn_img, imp_strip)
416
+ attn_img = metric_utils.reshape_image(attn_img) # Returns None on fail.
417
+ if attn_img is not None:
418
+ image_dict[k + "_" + str(i)] = attn_img
419
+
420
+ if importance_list:
421
+ # Create an image out of the importance for each layer.
422
+ image_dict["importance_gate"] = self._make_importance_image(
423
+ importance_list, scaled=True)
424
+ image_dict["importance_raw"] = self._make_importance_image(
425
+ importance_list, scaled=False)
426
+ return image_dict
aglib/meliad/transformer/ht_main.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 Google.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ r"""Main program to train htransformer models.
16
+
17
+ """
18
+
19
+ from typing import Sequence
20
+
21
+ from absl import app
22
+ from absl import flags
23
+ from clu import platform
24
+ import jax
25
+ from transformer import launcher
26
+ import tensorflow.compat.v2 as tf
27
+
28
+
29
+ FLAGS = flags.FLAGS
30
+
31
+
32
+ def main(argv: Sequence[str]) -> None:
33
+ if len(argv) > 1:
34
+ raise app.UsageError("Too many command-line arguments.")
35
+
36
+ launcher.parse_gin_configuration()
37
+
38
+ # Hide any GPUs from TensorFlow. Otherwise TF might reserve memory and make
39
+ # it unavailable to JAX.
40
+ tf.config.experimental.set_visible_devices([], "GPU")
41
+
42
+ # Set global seed for datasets.
43
+ # tf.random.set_seed(1234)
44
+
45
+ # Add a note so that we can tell which task is which JAX host.
46
+ # (Depending on the platform task 0 is not guaranteed to be host 0)
47
+ platform.work_unit().set_task_status(f"process_index: {jax.process_index()}, "
48
+ f"process_count: {jax.process_count()}")
49
+ platform.work_unit().create_artifact(platform.ArtifactType.DIRECTORY,
50
+ FLAGS.workdir, "workdir")
51
+
52
+ launcher.run_training_loop(testing=False)
53
+
54
+
55
+ if __name__ == "__main__":
56
+ app.run(main)
aglib/meliad/transformer/ht_main_inference.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 Google.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ r"""Program to run a transformer model over a single article.
16
+
17
+ """
18
+
19
+ # This program is currently a template, which can be expanded to do more
20
+ # sophisticated analysis.
21
+
22
+ from typing import Sequence
23
+
24
+ from absl import app
25
+ from absl import flags
26
+ from clu import platform
27
+ import jax
28
+ from transformer import inference_utils
29
+ from transformer import tasks # pylint: disable=unused-import
30
+ import tensorflow.compat.v2 as tf
31
+
32
+
33
+ flags.DEFINE_string("workdir", "", "Directory to save model checkpoints.")
34
+ flags.DEFINE_string("load_dir", "", "Directory to load pre-trained model.")
35
+ flags.DEFINE_integer("num_steps", 110, "Number of steps.")
36
+
37
+ flags.DEFINE_list(
38
+ "gin_search_paths",
39
+ ["transformer/configs"],
40
+ "List of paths where the Gin config files are located.")
41
+ flags.DEFINE_multi_string(
42
+ "gin_file", ["base_htrans.gin"], "List of Gin config files.")
43
+ flags.DEFINE_multi_string(
44
+ "gin_param", None, "Newline separated list of Gin parameter bindings.")
45
+
46
+ FLAGS = flags.FLAGS
47
+
48
+
49
+ def main(argv: Sequence[str]) -> None:
50
+ if len(argv) > 1:
51
+ raise app.UsageError("Too many command-line arguments.")
52
+
53
+ # Hide any GPUs from TensorFlow. Otherwise TF might reserve memory and make
54
+ # it unavailable to JAX.
55
+ tf.config.experimental.set_visible_devices([], "GPU")
56
+
57
+ # Add a note so that we can tell which task is which JAX host.
58
+ # (Depending on the platform task 0 is not guaranteed to be host 0)
59
+ platform.work_unit().set_task_status(f"process_index: {jax.process_index()}, "
60
+ f"process_count: {jax.process_count()}")
61
+ platform.work_unit().create_artifact(platform.ArtifactType.DIRECTORY,
62
+ FLAGS.workdir, "workdir")
63
+
64
+ inference_utils.parse_gin_configuration(FLAGS.gin_file, FLAGS.gin_param,
65
+ gin_paths=FLAGS.gin_search_paths)
66
+
67
+ article_data = inference_utils.read_article(True)
68
+ (_, vocab) = article_data
69
+ (task, task_state, _) = inference_utils.create_model_and_task(
70
+ vocab, load_dir=FLAGS.load_dir)
71
+ outs = inference_utils.run_model(task, task_state, article_data,
72
+ verbose=True)
73
+ inference_utils.get_token_losses(outs)
74
+
75
+ if __name__ == "__main__":
76
+ app.run(main)
aglib/meliad/transformer/inference_utils.py ADDED
@@ -0,0 +1,271 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 Google.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ r"""Various utility functions for doing inference on data.
16
+
17
+ This file provides a simple procedural API for loading a model, loading data,
18
+ and running the model over data. It is intended for use in, e.g., colabs.
19
+ """
20
+
21
+ from typing import Any, Dict, Optional, Sequence, Tuple
22
+
23
+ from absl import logging
24
+ import gin
25
+ import jax
26
+ import training_loop
27
+ from transformer import decoder_stack
28
+ from transformer import models
29
+ from transformer import text_dataset
30
+ import numpy as np
31
+ import seqio
32
+
33
+
34
+ Trainer = training_loop.Trainer
35
+ TrainState = training_loop.TrainState
36
+ TrainingTask = training_loop.TrainingTask
37
+ PRNGKeys = training_loop.PRNGKeys
38
+
39
+ ModelInput = Dict[str, Any] # Input to model.
40
+ MetricsOutput = Dict[str, Any] # Metrics output by model.
41
+ ArticleData = Tuple[Sequence[ModelInput], seqio.Vocabulary]
42
+ TaskState = Tuple[TrainState, int]
43
+
44
+
45
+ DEFAULT_GIN_PATHS = [
46
+ "transformer/configs"
47
+ ]
48
+
49
+
50
+ def parse_gin_configuration(gin_files: Optional[Sequence[str]],
51
+ gin_params: Optional[Sequence[str]],
52
+ gin_paths: Optional[Sequence[str]] = None):
53
+ """Load gin configuration options.
54
+
55
+ Args:
56
+ gin_files: A list of gin file names with the configuration to load.
57
+ gin_params: A list of additional parameter overrides.
58
+ gin_paths: A list of paths to search for gin_files.
59
+ """
60
+
61
+ # We allow None values to more easily handle command-line flags.
62
+ if gin_files is None:
63
+ gin_files = []
64
+ if gin_params is None:
65
+ gin_params = []
66
+ if gin_paths is None:
67
+ gin_paths = DEFAULT_GIN_PATHS
68
+
69
+ logging.info("Parsing gin configuration.")
70
+ for path in gin_paths:
71
+ logging.info("Added Gin search path %s", path)
72
+ gin.add_config_file_search_path(path)
73
+ for file_name in gin_files:
74
+ logging.info("Loading Gin config file %s", file_name)
75
+ for param in gin_params:
76
+ logging.info("Overriding Gin param %s", param)
77
+ gin.parse_config_files_and_bindings(gin_files, gin_params)
78
+
79
+
80
+ def read_article(split: Optional[str] = None,
81
+ verbose: bool = False) -> ArticleData:
82
+ """Read a single article from the dataset and save it as a list of blocks.
83
+
84
+ This routine will return blocks for a single article; so the tokens will
85
+ have a batch size of 1. The blocks can be fed to the model directly as input.
86
+
87
+ Args:
88
+ split: The dataset split to load from. Defaults to the test split.
89
+ verbose: If True, will dump the contents of the article to the log.
90
+
91
+ Returns:
92
+ A pair of (list_of_blocks, vocabulary)
93
+ """
94
+
95
+ logging.info("Reading article.")
96
+
97
+ text_dataset.set_default_data_directory()
98
+ task_config = decoder_stack.TransformerTaskConfig()
99
+ batch_size = 1
100
+
101
+ if split is None:
102
+ split = task_config.test_split
103
+
104
+ (test_ds, vocab) = text_dataset.load_text_dataset(
105
+ name=task_config.dataset_name,
106
+ split=split,
107
+ sequence_length=task_config.sequence_length,
108
+ batch_size=batch_size,
109
+ sequential=task_config.sequential_chunks,
110
+ shard_dataset=False)
111
+
112
+ logging.info("Configured vocab_size = %d", task_config.vocab_size)
113
+ logging.info("Task vocabulary size = %d", vocab.vocab_size)
114
+ if task_config.vocab_size < vocab.vocab_size:
115
+ raise ValueError(
116
+ "Task vocabulary size does not match configured vocab_size: " +
117
+ f"{task_config.vocab_size} < {vocab.vocab_size}")
118
+
119
+ article_segments = []
120
+ ds_iter = test_ds.as_numpy_iterator()
121
+ vocab_map = {"targets": vocab}
122
+
123
+ segment_num = 0
124
+ while True:
125
+ try:
126
+ x = next(ds_iter)
127
+ except StopIteration:
128
+ logging.info("End of epoch? Something went wrong.")
129
+ break
130
+
131
+ # Make sure we've started reading, otherwise it immediately quits...
132
+ if article_segments:
133
+ if x["start_of_sequence"][0]:
134
+ break
135
+
136
+ if verbose:
137
+ logging.info("Segment %d = %s", segment_num,
138
+ text_dataset.pretty_print_article(x, vocab_map,
139
+ max_length=10_000))
140
+ article_segments.append(x)
141
+ segment_num += 1
142
+
143
+ logging.info("Done reading article: %d segments.", segment_num)
144
+ logging.info("Num tokens = %d", segment_num * task_config.sequence_length)
145
+ return (article_segments, vocab)
146
+
147
+
148
+ def create_model_and_task(vocab: seqio.Vocabulary,
149
+ load_dir: Optional[str] = None) -> (
150
+ Tuple[TrainingTask, TaskState, Trainer]):
151
+ """Initialize the model and get a task for inference.
152
+
153
+ The task will be configured to take test (inference) steps with the model.
154
+ The task will also be configured to run on a single replica, at batch size 1.
155
+
156
+ Args:
157
+ vocab: The vocabulary for the training data, used for logging and decoding.
158
+ load_dir: A directory which contains a pre-trained model.
159
+
160
+ Returns:
161
+ (task -- has a run_step method to take individual steps with the model,
162
+ state -- contains trainable parameters and other state,
163
+ trainer -- a Trainer object (see training_loop.py))
164
+ """
165
+
166
+ logging.info("JAX process: %d / %d", jax.process_index(), jax.process_count())
167
+ logging.info("JAX local devices: %r", jax.local_devices())
168
+
169
+ # This task won't be pulling from a dataset.
170
+ def null_iter_fn():
171
+ return None
172
+
173
+ trainer = training_loop.Trainer(
174
+ get_training_dataset_iterator=null_iter_fn,
175
+ get_test_dataset_iterator=None,
176
+ pretty_print_input_function=None,
177
+ process_summaries_function=models.process_summaries_function(vocab),
178
+ load_dir=load_dir,
179
+ workdir="", # Don't log or save checkpoints.
180
+ replicate_mode=False) # Run on a single device at batch size 1.
181
+
182
+ # Create and initialize the model.
183
+ (tstate, start_step, imodel, prngs) = trainer.initialize_model()
184
+
185
+ # Create an inference task.
186
+ writers = {}
187
+ task = trainer.create_training_task("test", imodel, prngs, writers)
188
+
189
+ # Register any additional actions.
190
+ # Actions are cleared first for use with colab.
191
+ training_loop.clear_interstep_callbacks()
192
+ training_loop.register_interstep_callbacks()
193
+
194
+ task_state = (tstate, start_step)
195
+ return (task, task_state, trainer)
196
+
197
+
198
+ def run_model(task: TrainingTask, task_state: TaskState,
199
+ article_data: ArticleData, verbose: bool = False) -> (
200
+ Sequence[MetricsOutput]):
201
+ """Run the model on an article, and return the outputs for each segment.
202
+
203
+ Args:
204
+ task: The task to run, from create_model_and_task.
205
+ task_state: The state of the model, from create_model_and_task.
206
+ article_data: The article and vocabulary, from read_article.
207
+ verbose: If True, will send input and output to the log.
208
+
209
+ Returns:
210
+ A sequence of model outputs for each block.
211
+ """
212
+
213
+ logging.info("Running the model.")
214
+
215
+ (article_segments, vocab) = article_data
216
+ (tstate, start_step) = task_state
217
+ vocab_map = {"targets": vocab}
218
+
219
+ # Ignore the iterator for the test task, and loop over the article.
220
+ step = start_step
221
+ segment_num = 0
222
+
223
+ # Loop over the article, and run the model on each segment.
224
+ segment_outputs = []
225
+ for x in article_segments:
226
+ if verbose:
227
+ logging.info("Segment [%d] = %s", segment_num,
228
+ text_dataset.pretty_print_article(x, vocab_map,
229
+ max_length=10_000))
230
+ else:
231
+ logging.info("Segment %d, step %d.", segment_num, step)
232
+
233
+ (tstate, metrics_np) = task.run_step(tstate, x, step)
234
+ training_loop.run_interstep_callbacks("test", step)
235
+ segment_outputs.append(metrics_np)
236
+
237
+ if verbose:
238
+ logging.info("Output [%d] = %s", segment_num, metrics_np)
239
+
240
+ del x
241
+ segment_num += 1
242
+ step += 1
243
+
244
+ logging.info("Done running the model: %d segments.", segment_num)
245
+ return segment_outputs
246
+
247
+
248
+ def get_token_losses(segment_outputs: Sequence[Any]) -> np.ndarray:
249
+ """Return the loss for each token in a sequence.
250
+
251
+ Given a list of model outputs, extract the token losses from each output
252
+ and concatenate them together.
253
+
254
+ Args:
255
+ segment_outputs: the outputs from run_model().
256
+
257
+ Returns:
258
+ An array of shape (batch_size, sequence_length), of float.
259
+ """
260
+
261
+ block_token_losses = []
262
+ for seg in segment_outputs:
263
+ if "token_losses" in seg:
264
+ block_token_losses.append(seg["token_losses"])
265
+ else:
266
+ raise ValueError("Token losses were not recorded.")
267
+
268
+ logging.info("Got token losses for %d segments", len(block_token_losses))
269
+ token_losses = np.concatenate(block_token_losses, axis=-1)
270
+ logging.info("token_losses.shape = %r", token_losses.shape)
271
+ return token_losses
aglib/meliad/transformer/launcher.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 Google.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Setup the data pipeline and launch the main training loop."""
16
+
17
+ from absl import flags
18
+ from absl import logging
19
+
20
+ import gin
21
+ import jax
22
+
23
+ import training_loop
24
+ from transformer import decoder_stack
25
+ from transformer import models
26
+ from transformer import tasks # pylint: disable=unused-import
27
+ from transformer import text_dataset
28
+
29
+
30
+ flags.DEFINE_string("workdir", "", "Directory to save model checkpoints.")
31
+ flags.DEFINE_string("load_dir", "", "Directory to load pre-trained model.")
32
+ flags.DEFINE_integer("num_steps", 110, "Number of steps.")
33
+
34
+ flags.DEFINE_list(
35
+ "gin_search_paths",
36
+ ["transformer/configs"],
37
+ "List of paths where the Gin config files are located.")
38
+ flags.DEFINE_multi_string(
39
+ "gin_file", ["base_htrans.gin"], "List of Gin config files.")
40
+ flags.DEFINE_multi_string(
41
+ "gin_param", None, "Newline separated list of Gin parameter bindings.")
42
+
43
+ FLAGS = flags.FLAGS
44
+
45
+
46
+ def parse_gin_configuration():
47
+ """Load and parse Gin configuration from command-line flags."""
48
+ for gin_file_path in FLAGS.gin_search_paths:
49
+ logging.info("Added Gin search path %s", gin_file_path)
50
+ gin.add_config_file_search_path(gin_file_path)
51
+ for gin_file in FLAGS.gin_file:
52
+ logging.info("Loading Gin config file %s", gin_file)
53
+ if FLAGS.gin_param:
54
+ for gin_param in FLAGS.gin_param:
55
+ logging.info("Overriding Gin param %s", gin_param)
56
+ gin.parse_config_files_and_bindings(FLAGS.gin_file, FLAGS.gin_param)
57
+
58
+
59
+ def run_training_loop(testing: bool = False):
60
+ """Setup data pipeline and launch the main training loop."""
61
+
62
+ logging.info("JAX process: %d / %d", jax.process_index(), jax.process_count())
63
+ logging.info("JAX local devices: %r", jax.local_devices())
64
+
65
+ text_dataset.set_default_data_directory()
66
+ task_config = decoder_stack.TransformerTaskConfig()
67
+ batch_size = task_config.batch_size * jax.local_device_count()
68
+
69
+ (train_ds, vocab) = text_dataset.load_text_dataset(
70
+ name=task_config.dataset_name,
71
+ split=task_config.train_split, # train
72
+ sequence_length=task_config.sequence_length,
73
+ batch_size=batch_size,
74
+ sequential=task_config.sequential_chunks,
75
+ shard_dataset=True)
76
+
77
+ (test_ds, test_vocab) = text_dataset.load_text_dataset(
78
+ name=task_config.dataset_name,
79
+ split=task_config.test_split, # test
80
+ sequence_length=task_config.sequence_length,
81
+ batch_size=batch_size,
82
+ sequential=task_config.sequential_chunks,
83
+ shard_dataset=False)
84
+
85
+ logging.info("Configured vocab_size = %d", task_config.vocab_size)
86
+ logging.info("Task vocabulary size = %d", vocab.vocab_size)
87
+ assert vocab.vocab_size == test_vocab.vocab_size # Sanity check.
88
+ if task_config.vocab_size < vocab.vocab_size:
89
+ raise ValueError(
90
+ "Task vocabulary size does not match configured vocab_size: " +
91
+ f"{task_config.vocab_size} < {vocab.vocab_size}")
92
+
93
+ # Pretty printing depends on the vocabulary object.
94
+ def pretty_print_article_fn(article) -> str:
95
+ return text_dataset.pretty_print_article(article, {"targets": vocab}, 32768)
96
+
97
+ train_ds_iter_fn = text_dataset.get_iterator_function(train_ds)
98
+ test_ds_iter_fn = text_dataset.get_iterator_function(test_ds)
99
+
100
+ if testing:
101
+ # Build trainer, which is configurable by Gin, and run training loop.
102
+ trainer = training_loop.Trainer(
103
+ get_training_dataset_iterator=train_ds_iter_fn,
104
+ get_test_dataset_iterator=test_ds_iter_fn,
105
+ pretty_print_input_function=pretty_print_article_fn,
106
+ process_summaries_function=models.process_summaries_function(vocab),
107
+ num_steps=FLAGS.num_steps, # Ignore Gin config for these options.
108
+ load_dir=FLAGS.load_dir,
109
+ workdir=FLAGS.workdir)
110
+ else:
111
+ trainer = training_loop.Trainer(
112
+ get_training_dataset_iterator=train_ds_iter_fn,
113
+ get_test_dataset_iterator=test_ds_iter_fn,
114
+ pretty_print_input_function=pretty_print_article_fn,
115
+ process_summaries_function=models.process_summaries_function(vocab),
116
+ load_dir=FLAGS.load_dir,
117
+ workdir=FLAGS.workdir)
118
+
119
+ trainer.train()
aglib/meliad/transformer/memory_factory.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 Google.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Flax modules and functions for using external memory."""
16
+
17
+ from typing import Any, Optional, Tuple
18
+
19
+ from absl import logging
20
+ from flax import linen
21
+ import gin
22
+ import jax
23
+ from transformer import memory_layer
24
+
25
+
26
+
27
+ PRNGKey = Any
28
+ Shape = Tuple[int]
29
+ Dtype = Any
30
+ Array = Any
31
+ MemoryResource = Any
32
+
33
+
34
+ class MemoryManager:
35
+ """Manages any external resources that may be required by external memory.
36
+
37
+ MemoryManager also functions as a factory, to create Flax modules that will
38
+ read and write to whatever external memory has been configured.
39
+ """
40
+
41
+ def __init__(self,
42
+ batch_size: int,
43
+ mode: str,
44
+ num_heads: int,
45
+ key_size: int,
46
+ value_size: int,
47
+ database_size: Optional[int] = None,
48
+ dtype: Dtype = "float32",
49
+ off_device_memory: Optional[MemoryResource] = None):
50
+ """Create a MemoryManager object.
51
+
52
+ A MemoryManager configures external memory, and is used as a factory to
53
+ construct flax modules that read or write to the memory.
54
+
55
+ Args:
56
+ batch_size: The number of separate documents in a batch.
57
+ mode: e.g. ("train", or "test")
58
+ num_heads: The number of transformer heads.
59
+ key_size: The length of the key vectors.
60
+ value_size: The length of the value vectors.
61
+ database_size: The total number of tokens in the database.
62
+ dtype: The datatype used for keys and values.
63
+ off_device_memory: An object which manages underlying SCAM memory.
64
+ If None, then the model will use on-device memory.
65
+ """
66
+ self.batch_size = batch_size
67
+ self.mode = mode
68
+ self.num_heads = num_heads
69
+ self.key_size = key_size
70
+ self.value_size = value_size
71
+ self.database_size = database_size
72
+ self.dtype = dtype
73
+ self.off_device_memory = off_device_memory
74
+
75
+ def create_memory_layer(self) -> linen.Module:
76
+ """Create a flax Module that implements external memory."""
77
+
78
+ num_datasets = (
79
+ self.batch_size * self.num_heads #
80
+ if self.off_device_memory is None #
81
+ else self.num_heads)
82
+ if self.off_device_memory is not None:
83
+ mem_layer = None
84
+ if mem_layer is None:
85
+ raise ValueError("Off-device memory is not supported at this time.")
86
+ return memory_layer.BatchedMemory(
87
+ mem_layer,
88
+ split_dimensions=(-2,),
89
+ )
90
+ else:
91
+ assert self.database_size is not None
92
+ mem_layer = memory_layer.MemoryOnTpu(num_datasets=num_datasets,
93
+ key_features=self.key_size,
94
+ value_features=self.value_size,
95
+ database_size=self.database_size,
96
+ dtype=self.dtype)
97
+ # Handle queries of shape [batch_size, seq_len, num_heads, kv_features]
98
+ return memory_layer.BatchedMemory(mem_layer,
99
+ split_dimensions=(0, -2))
100
+
101
+
102
+ @gin.configurable
103
+ def memory_on_tpu_factory(batch_size: int,
104
+ mode: str,
105
+ num_heads: int = gin.REQUIRED,
106
+ key_size: int = gin.REQUIRED,
107
+ value_size: int = gin.REQUIRED,
108
+ database_size: int = gin.REQUIRED,
109
+ dtype: Dtype = gin.REQUIRED) -> MemoryManager:
110
+ """Implement SCAM memory on device."""
111
+ return MemoryManager(batch_size=batch_size,
112
+ mode=mode,
113
+ num_heads=num_heads,
114
+ key_size=key_size,
115
+ value_size=value_size,
116
+ database_size=database_size,
117
+ dtype=dtype,
118
+ off_device_memory=None)
119
+
120
+
aglib/meliad/transformer/memory_layer.py ADDED
@@ -0,0 +1,431 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 Google.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """FLAX layers for on-TPU memory."""
16
+
17
+ import abc
18
+ import functools
19
+ from typing import Callable, Sequence, Tuple, TypeVar, Union
20
+
21
+ from absl import logging
22
+ from flax import linen
23
+ import gin
24
+ import jax
25
+ from jax import lax
26
+ import jax.numpy as jnp
27
+ import numpy as np # use with care!
28
+
29
+ Shape = Sequence[int]
30
+ Dtype = jnp.dtype
31
+ Array = jnp.ndarray
32
+
33
+ Axes = Union[int, Tuple[int, ...]]
34
+ F = TypeVar('F', bound=Callable)
35
+
36
+
37
+ class MemoryLayer(linen.Module, metaclass=abc.ABCMeta):
38
+ """Internal interface for memory layers without batch dim.
39
+
40
+ See BatchedMemory for a layer that can be used in Flax models.
41
+ """
42
+ num_datasets: int
43
+
44
+ @abc.abstractmethod
45
+ def update(self, key: Array, value: Array) -> int:
46
+ """Adds key/value pairs to memory.
47
+
48
+ Args:
49
+ key: of shape (num_kv, num_datasets, k_features)
50
+ value: of shape (num_kv, num_datasets, v_features)
51
+
52
+ Returns:
53
+ Dummy value so that TPU operations can wait for the update to finish if
54
+ desired.
55
+ """
56
+ raise NotImplementedError()
57
+
58
+ @abc.abstractmethod
59
+ def topk_retrieval(self, query: Array,
60
+ num_neighbors: int) -> Tuple[Array, Array]:
61
+ """Retrieves the nearest neighbors for each query.
62
+
63
+ Args:
64
+ query: of shape (num_queries, num_datasets, k_features)
65
+ num_neighbors: int indicating the number of neighbors to retrieve
66
+
67
+ Returns:
68
+ Tuple of selected keys and selected values of shapes
69
+ (num_queries, num_datasets, num_neighbors, k_features), and
70
+ (num_queries, num_datasets, num_neighbors, v_features)
71
+ """
72
+ raise NotImplementedError()
73
+
74
+ @abc.abstractmethod
75
+ def reset(self, datasets: Array) -> int:
76
+ """Reset some or all of the datasets in the memory.
77
+
78
+ Args:
79
+ datasets: A vector of shape (num_datasets) of type bool. Each position
80
+ indicates whether the dataset with the same index should be reset.
81
+
82
+ Returns:
83
+ Dummy value so that TPU operations can wait for the update to finish if
84
+ desired.
85
+ """
86
+ raise NotImplementedError()
87
+
88
+ def __call__(self, query, num_neighbors):
89
+ self.topk_retrieval(query, num_neighbors)
90
+
91
+
92
+ def _target_dimensions(shape: Shape,
93
+ source_dimensions: Sequence[int]) -> Sequence[int]:
94
+ target_dimensions = range(-2, -2 - len(source_dimensions), -1)
95
+ assert len(source_dimensions) == len(target_dimensions)
96
+ return sorted(d % len(shape) for d in target_dimensions)
97
+
98
+
99
+ def _rearrange_dimensions_shapes(
100
+ shape: Shape, split_dimensions: Sequence[int]) -> Tuple[Shape, Shape]:
101
+ split_shape = tuple(shape[d] for d in split_dimensions)
102
+ remaining_shape = tuple(
103
+ shape[d] for d in range(len(shape)) if d not in split_dimensions)
104
+ batch_shape = remaining_shape[:-1]
105
+ return split_shape, batch_shape
106
+
107
+
108
+ def _rearrange_dimensions(x: Array, split_dimensions: Sequence[int]) -> Array:
109
+ """Rearrange array so that we can split by a single dimension.
110
+
111
+ Turns an array of shape [d1, ..., dn, features] and a list of dimensions to
112
+ split by into [prod(remaining_dimensions), prod(split_dimensions),
113
+ features]
114
+
115
+ Args:
116
+ x: array of shape [d1, ..., dn, features]
117
+ split_dimensions: list of dimensions that should end up in dimension -2.
118
+
119
+ Returns:
120
+ Rearranged array as described above.
121
+ """
122
+ split_dimensions = [d % len(x.shape) for d in split_dimensions]
123
+ split_dimensions = sorted(split_dimensions)
124
+ split_shape, batch_shape = _rearrange_dimensions_shapes(
125
+ x.shape, split_dimensions)
126
+
127
+ target_dimensions = _target_dimensions(x.shape, split_dimensions)
128
+ x = jnp.moveaxis(x, split_dimensions, target_dimensions)
129
+ assert len(x.shape) > len(split_dimensions)
130
+ assert all(isinstance(d, int) and d >= 0 for d in batch_shape)
131
+ assert all(isinstance(d, int) and d >= 0 for d in split_shape)
132
+ new_shape = [
133
+ # The use of numpy is okay here, since shapes are concrete at jit time.
134
+ np.prod(batch_shape),
135
+ np.prod(split_shape),
136
+ x.shape[-1] # features dimension
137
+ ]
138
+ res = x.reshape(new_shape)
139
+ assert res.ndim == 3
140
+ return res
141
+
142
+
143
+ def _restore_dimensions(x: Array, original_shape: Shape,
144
+ split_dimensions: Sequence[int]) -> Array:
145
+ """Restores arrays encoded with _rearrange_dimensions.
146
+
147
+ Args:
148
+ x: Array of shape [prod(batch_shape), prod(split_shape), feature...]
149
+ original_shape: Shape of the array to restore to.
150
+ split_dimensions: Dimensions that were multiplied into dimension 2.
151
+
152
+ Returns:
153
+ Array of the original shape and axis order for all dimensions in batch_shape
154
+ and split_shape. Feature dimensions may have changed (can include additional
155
+ dimensions for neighbors, for example).
156
+ """
157
+ split_dimensions = [d % len(original_shape) for d in split_dimensions]
158
+ split_dimensions = sorted(split_dimensions)
159
+ split_shape, batch_shape = _rearrange_dimensions_shapes(
160
+ original_shape, split_dimensions)
161
+
162
+ features_shape = x.shape[2:]
163
+ x = x.reshape((*batch_shape, *split_shape, *features_shape))
164
+
165
+ # rearrange
166
+ target_dimensions = _target_dimensions(original_shape, split_dimensions)
167
+ x = jnp.moveaxis(x, target_dimensions, split_dimensions)
168
+ return x
169
+
170
+
171
+ @gin.configurable
172
+ class BatchedMemory(linen.Module):
173
+ """Equips a memory module with a batch dimension."""
174
+
175
+ # We wrap this linen.Module:
176
+ wrapped: MemoryLayer
177
+
178
+ # `split_dimensions` indicates the dimensions of the query and update tensors
179
+ # that will go to separate databases. By default, we use a separate database
180
+ # for each head.
181
+ # Note that some implementations of the memory share memory across all hosts
182
+ # and devices (memory_on_borg, unless configured otherwise) or just across
183
+ # devices of each host (memory_on_host).
184
+ # Default is (-2,) to split by head only; use (0, -2) to also slit by batch
185
+ # dimensions.
186
+ split_dimensions: Tuple[int, ...] = (-2,)
187
+
188
+ query_stride: int = 1
189
+ update_stride: int = 1
190
+
191
+ def update(self, key: Array, value: Array):
192
+ """Adds key/value pairs to memory.
193
+
194
+ Args:
195
+ key: typically of shape (batch, kv_len, num_heads, k_features). This
196
+ tensor is split up into datasets according to `split_dimensions`.
197
+ value: typically of shape (batch, kv_len, num_heads, v_features). This
198
+ tensor is split up into datasets according to `split_dimensions`.
199
+
200
+ Returns:
201
+ A dummy value 0, once the operation has completed.
202
+ """
203
+ if key.ndim != 4 or value.ndim != 4:
204
+ raise ValueError('Expected batched inputs; got shapes: %s and %s.' %
205
+ (key.shape, value.shape))
206
+ key = _rearrange_dimensions(key, self.split_dimensions)
207
+ value = _rearrange_dimensions(value, self.split_dimensions)
208
+ update_stride = self.update_stride
209
+ if update_stride == 1:
210
+ return self.wrapped.update(key, value)
211
+ return self.wrapped.update(key[update_stride - 1::update_stride, ...],
212
+ value[update_stride - 1::update_stride, ...])
213
+
214
+ def topk_retrieval(self, query: Array, num_neighbors: int):
215
+ """Retrieves the nearest neighbors for each query.
216
+
217
+ Args:
218
+ query: typically of shape (batch, q_len, num_heads, k_features). This
219
+ tensor is split up into datasets according to `split_dimensions`.
220
+ num_neighbors: number of neighbors to retrieve
221
+
222
+ Returns:
223
+ Tuple of tensors with the retrieved keys and value of the same shape as
224
+ query, but with an extra dimension of length num_neighbors - typically:
225
+ (batch, q_len, num_heads, num_neighbors, k_features)
226
+ """
227
+ if query.ndim != 4:
228
+ raise ValueError('Expected batched inputs; got shape: %s.' % query.shape)
229
+ query_stride = self.query_stride
230
+ original_shape = query.shape
231
+ query = _rearrange_dimensions(query, self.split_dimensions)
232
+ if query_stride == 1:
233
+ key, value = self.wrapped.topk_retrieval(query, num_neighbors)
234
+ else:
235
+ num_queries, num_heads, k_features = query.shape
236
+ throttled_query = query[0::query_stride, ...]
237
+ key = jnp.zeros(
238
+ shape=(num_queries, num_heads, num_neighbors, k_features),
239
+ dtype=query.dtype)
240
+ throttled_key, throttled_value = (
241
+ self.wrapped.topk_retrieval(throttled_query, num_neighbors))
242
+ _, _, _, v_features = throttled_value.shape
243
+ value = jnp.zeros(
244
+ shape=(num_queries, num_heads, num_neighbors, v_features),
245
+ dtype=query.dtype)
246
+ key = key.at[0::query_stride, ...].set(throttled_key)
247
+ value = value.at[0::query_stride, ...].set(throttled_value)
248
+ key = _restore_dimensions(key, original_shape, self.split_dimensions)
249
+ # Note that `original_shape` here may have the wrong feature dimension (if
250
+ # k_features != v_features. But `_restore_dimensions` does not depend on
251
+ # that dimension and the tests cover this case.
252
+ value = _restore_dimensions(value, original_shape, self.split_dimensions)
253
+ assert key.ndim == len(original_shape) + 1
254
+ return key, value
255
+
256
+ def reset(self, datasets: Array) -> int:
257
+ """Resets the memory.
258
+
259
+ Args:
260
+ datasets: of shape (num_datasets,), typically the same as (num_heads,).
261
+
262
+ Returns:
263
+ A dummy value 0, once the operation has completed.
264
+ """
265
+ return self.wrapped.reset(datasets)
266
+
267
+
268
+ @functools.partial(jax.jit, static_argnames=('num_buckets', 'bucket_size'))
269
+ def _chunking_sparsify(query: Array, key: Array, num_buckets: int,
270
+ bucket_size: int) -> Tuple[Array, Array, Array]:
271
+ """Approximate top k operation for a single head."""
272
+ # q = q_length, f = qk features, d = database_size
273
+ scores = jnp.einsum('qf,df->qd', query, key)
274
+ mask = (key.sum(-1) == 0).astype(jnp.bfloat16) * -1e6
275
+ scores += mask
276
+
277
+ num_queries, _ = scores.shape
278
+ reshaped_scores = jnp.reshape(scores, (num_queries, bucket_size, num_buckets))
279
+
280
+ sparse_scores = linen.softmax(reshaped_scores * 1e6, axis=1)
281
+
282
+ # topk_scores and topk_indices will only be computed if we depend on their
283
+ # results.
284
+ topk_scores = jnp.max(reshaped_scores, axis=1)
285
+ local_indices = jnp.argmax(reshaped_scores, axis=1)
286
+ topk_indices = (
287
+ local_indices * num_buckets + jnp.arange(num_buckets).reshape(
288
+ (1, num_buckets)))
289
+ return sparse_scores, topk_scores, topk_indices
290
+
291
+
292
+ def _retrieve_topk_gatherless(
293
+ query: Array, key: Array, value: Array,
294
+ num_neighbors: int) -> Tuple[Array, Array, Array, Array]:
295
+ """Retrieves for a single head - used to simplify array accesses."""
296
+ num_kv, query_features = query.shape
297
+ database_size, key_features = key.shape
298
+ _, value_features = value.shape
299
+ assert query_features == key_features
300
+ num_buckets = num_neighbors
301
+ if num_buckets > database_size:
302
+ raise ValueError('More buckets than items in database. %s > %s' %
303
+ (num_buckets, database_size))
304
+ if database_size % num_buckets:
305
+ raise ValueError('Buckets must divide database: %s %% %s.' %
306
+ (database_size, num_buckets))
307
+ bucket_size = database_size // num_buckets
308
+
309
+ sparse_scores, topk_scores, topk_indices = _chunking_sparsify(
310
+ query, key, num_buckets, bucket_size)
311
+ key = key.reshape(bucket_size, num_buckets, key_features)
312
+ value = value.reshape(bucket_size, num_buckets, value_features)
313
+ selected_keys = jnp.einsum('qbn,bnd->qnd', sparse_scores, key)
314
+ selected_values = jnp.einsum('qbn,bnd->qnd', sparse_scores, value)
315
+
316
+ assert selected_keys.shape == (num_kv, num_neighbors, key_features)
317
+ assert selected_values.shape == (num_kv, num_neighbors, value_features)
318
+ return selected_keys, selected_values, topk_scores, topk_indices
319
+
320
+
321
+ class MemoryOnTpu(MemoryLayer):
322
+ """Approximate top K search on TPU."""
323
+ # database_size must be integer multiple of prod(batch_dims) * num_neighbors.
324
+ database_size: int
325
+ dtype: Dtype = jnp.float32 # pylint: disable=g-bare-generic
326
+ key_features: int = 64
327
+ value_features: int = 64
328
+ report_scores_and_indices: bool = False
329
+
330
+ def setup(self):
331
+ self.db_index = self.variable('database', 'database_index',
332
+ functools.partial(jnp.zeros, dtype=jnp.int32),
333
+ (self.num_datasets,))
334
+ self.key_db = self.variable(
335
+ 'database', 'key_db', functools.partial(jnp.zeros, dtype=self.dtype),
336
+ (self.num_datasets, self.database_size, self.key_features))
337
+ self.value_db = self.variable(
338
+ 'database', 'value_db', functools.partial(jnp.zeros, dtype=self.dtype),
339
+ (self.num_datasets, self.database_size, self.value_features))
340
+
341
+ self.retrieved_indices = self.variable(
342
+ 'database', 'retrieved_indices',
343
+ functools.partial(jnp.zeros, dtype=jnp.int32), (0, 0, 0))
344
+ self.retrieved_indices_scores = self.variable(
345
+ 'database', 'retrieved_indices_scores',
346
+ functools.partial(jnp.zeros, dtype=jnp.float32), (0, 0, 0))
347
+
348
+ def _update_kv_database(self, database, new_values, start_index):
349
+ num_datasets, database_size, _ = database.shape
350
+ assert database_size == self.database_size, f'{database_size} vs {self.database_size}'
351
+ assert num_datasets == self.num_datasets
352
+ assert new_values.ndim == 3
353
+ assert start_index.shape == (self.num_datasets,)
354
+
355
+ def _update(database, new_values, start_index):
356
+ return lax.dynamic_update_slice(
357
+ database, new_values, start_indices=(start_index, 0))
358
+
359
+ return jax.vmap(
360
+ _update, in_axes=(0, 0, 0), out_axes=0)(database, new_values,
361
+ start_index)
362
+
363
+ def update(self, key: Array, value: Array) -> int:
364
+ """Add keys and values to the memory; overwrite oldest if memory is full."""
365
+ key = lax.stop_gradient(key)
366
+ value = lax.stop_gradient(value)
367
+ assert len(key.shape) == len(value.shape)
368
+ assert key.shape[:-1] == value.shape[:-1]
369
+ num_kv, num_datasets, key_features = key.shape
370
+ assert num_datasets == self.num_datasets
371
+ assert key_features == self.key_features
372
+ assert value.shape[-1] == self.value_features
373
+ assert self.database_size % num_kv == 0, (
374
+ 'Database size must be integer multiple of num_kv.')
375
+ key = jnp.moveaxis(key, source=1, destination=0) # split by dataset
376
+ value = jnp.moveaxis(value, source=1, destination=0) # split by dataset
377
+
378
+ # start_index can be larger than DB - we use that to detect which entries
379
+ # are not written to yet
380
+ start_index = self.db_index.value % self.database_size
381
+ self.key_db.value = self._update_kv_database(self.key_db.value, key,
382
+ start_index)
383
+ self.value_db.value = self._update_kv_database(self.value_db.value, value,
384
+ start_index)
385
+ self.db_index.value = self.db_index.value + num_kv
386
+ return 0
387
+
388
+ def topk_retrieval(self, query: Array,
389
+ num_neighbors: int) -> Tuple[Array, Array]:
390
+ """Nearest neighbors by full multiplication and approximate top k on TPU."""
391
+ query = lax.stop_gradient(query)
392
+ unused_num_kv, num_datasets, query_features = query.shape
393
+ assert num_datasets == self.num_datasets
394
+ assert query_features == self.key_features
395
+ query = jnp.moveaxis(query, source=1, destination=0)
396
+
397
+ # Process different heads sequentially
398
+ selected_keys, selected_values, topk_scores, topk_indices = lax.map(
399
+ lambda x: _retrieve_topk_gatherless(*x, num_neighbors),
400
+ (query, self.key_db.value, self.value_db.value))
401
+
402
+ if self.report_scores_and_indices:
403
+ # TODO(mrabe): These variable updates may not work perfectly yet. Find out
404
+ # why Flax does not like them.
405
+ self.retrieved_indices.value = topk_indices
406
+ self.retrieved_indices_scores.value = topk_scores
407
+
408
+ assert selected_keys.ndim == selected_values.ndim == 4
409
+ selected_keys = jnp.moveaxis(selected_keys, source=0, destination=1)
410
+ selected_values = jnp.moveaxis(selected_values, source=0, destination=1)
411
+ return selected_keys, selected_values
412
+
413
+ def reset(self, datasets: Array) -> int:
414
+ """Resets specified datasets."""
415
+ datasets = lax.stop_gradient(datasets)
416
+ assert datasets.shape == (self.num_datasets,)
417
+ assert datasets.dtype == jnp.bool_
418
+
419
+ def _reset_single_dataset(input_tuple):
420
+ """Resets a single head; reset is a single bool."""
421
+ database, reset = input_tuple
422
+ assert reset.shape == tuple(), reset.shape
423
+ assert reset.dtype == jnp.bool_
424
+ return database * (1 - reset)
425
+
426
+ self.db_index.value = self.db_index.value * (1 - datasets)
427
+ self.key_db.value = lax.map(
428
+ _reset_single_dataset, xs=(self.key_db.value, datasets))
429
+ self.value_db.value = lax.map(
430
+ _reset_single_dataset, xs=(self.value_db.value, datasets))
431
+ return 0
aglib/meliad/transformer/metric_utils.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 Google.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Helper routines for recording various training metrics."""
16
+
17
+ from typing import Any
18
+ import jax.numpy as jnp
19
+
20
+
21
+ Array = Any
22
+
23
+
24
+ def compute_accuracy_sum(logits, targets, valid_loss_mask=None):
25
+ """Compute accuracy for logits and targets.
26
+
27
+ Args:
28
+ logits: [batch, length, num_classes] float array.
29
+ targets: categorical targets [batch, length] int array.
30
+ valid_loss_mask: None or array of shape bool[batch, length]
31
+
32
+ Returns:
33
+ The number of correct tokens in the output.
34
+ """
35
+ if logits.shape[:-1] != targets.shape:
36
+ raise ValueError("Incorrect shapes. Got shape %s logits and %s targets" %
37
+ logits.shape, targets.shape)
38
+ if valid_loss_mask is not None and valid_loss_mask.shape != targets.shape:
39
+ raise ValueError("Incorrect shapes. Got shape %s targets and %s mask" %
40
+ targets.shape, valid_loss_mask.shape)
41
+
42
+ accuracy = jnp.equal(jnp.argmax(logits, axis=-1), targets)
43
+ if valid_loss_mask is not None:
44
+ accuracy = jnp.logical_and(accuracy, valid_loss_mask)
45
+ return jnp.sum(accuracy) # Sum of the number of True values.
46
+
47
+
48
+ def reshape_image(image):
49
+ """Reshape image to something that tensorboard recognizes.
50
+
51
+ Args:
52
+ image: Array of shape [xsize, size] or [num_images, xsize, ysize]
53
+
54
+ Returns:
55
+ Array of shape [num_images, xsize, ysize, 1]
56
+ """
57
+
58
+ # Reshape to [num_images, xdim, ydim, rgb] for tensorboard.
59
+ sh = image.shape
60
+ if image.ndim == 2:
61
+ return jnp.reshape(image, [1, sh[0], sh[1], 1]).astype(jnp.float32)
62
+ elif image.ndim == 3:
63
+ return jnp.reshape(image, [sh[0], sh[1], sh[2], 1]).astype(jnp.float32)
64
+ else:
65
+ return None # Not an image.
66
+
67
+
68
+ def normalize_image(images: Array, as_group: bool = False) -> Array:
69
+ """Rescale the values in images to between 0.0 and 1.0.
70
+
71
+ Args:
72
+ images: Array of size [batch_size, xsize, ysize]
73
+ as_group: Scale all images in the batch by the same amount if True.
74
+
75
+ Returns:
76
+ A rescaled image of the same shape.
77
+ """
78
+
79
+ images = images.astype(jnp.float32) # Return images as float32.
80
+ if as_group:
81
+ # Normalize the batch of images as a group.
82
+ min_img = jnp.min(images)
83
+ max_img = jnp.max(images)
84
+ else:
85
+ # Normalize each image in the batch individually.
86
+ min_img = jnp.min(images, axis=(-2, -1), keepdims=True)
87
+ max_img = jnp.max(images, axis=(-2, -1), keepdims=True)
88
+ norm_image = (images - min_img) / (max_img - min_img + 1e-6)
89
+ return jnp.where(jnp.isfinite(norm_image), norm_image, 0.0)
90
+
91
+
92
+ def overlay_images(image1: Array, image2: Array) -> Array:
93
+ """Place image1 on top of image2, broadcasting image2 if necessary.
94
+
95
+ Args:
96
+ image1: array of shape [num_images, xsize, ysize]
97
+ image2: array of shape [num_images, xsize, ysize]
98
+
99
+ Returns:
100
+ A combined image.
101
+ """
102
+
103
+ assert image1.ndim == 3 # (num_images, xsize, ysize)
104
+ assert image2.ndim == 3
105
+ image2 = jnp.broadcast_to(image2, image1.shape)
106
+ return jnp.concatenate([image1, image2], axis=1)
107
+
108
+
109
+ def make_histograms(viz_dicts):
110
+ """Generate image histograms."""
111
+ hist_dict = {}
112
+ for (i, viz_dict) in enumerate(viz_dicts):
113
+ for (k, images) in viz_dict.items():
114
+ hist_dict["h_" + k + "_" + str(i)] = images
115
+ return hist_dict
aglib/meliad/transformer/models.py ADDED
@@ -0,0 +1,317 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 Google.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Sequence to sequence model."""
16
+
17
+ from typing import Any, Callable, Dict, Tuple
18
+
19
+ from absl import logging
20
+ from flax import linen as nn
21
+ from flax.training import common_utils
22
+ import gin
23
+ import jax
24
+ import jax.numpy as jnp
25
+ import metrics_summary
26
+ from transformer import decoder_stack
27
+ from transformer import metric_utils
28
+ from transformer import text_dataset
29
+ import numpy as np
30
+ import seqio
31
+
32
+
33
+ Array = jnp.ndarray
34
+ MetricsSummary = metrics_summary.MetricsSummary
35
+
36
+
37
+ # TODO(mrabe): Remove this function and find a better way to turn text metrics
38
+ # into text on tensorboard.
39
+ def process_summaries(vocab: seqio.Vocabulary,
40
+ met_summary: MetricsSummary,
41
+ mode: str) -> MetricsSummary:
42
+ """Compute some additional summaries, and convert tokens to text.
43
+
44
+ Args:
45
+ vocab: The vocabulary to detokenize generated text.
46
+ met_summary: The summary object to process.
47
+ mode: The mode of the summary (e.g. "test", "train")
48
+
49
+ Returns:
50
+ The modified summary dictionary.
51
+ """
52
+
53
+ mdict = met_summary.current_metric_dict()
54
+
55
+ # Calculate perplexity from the average nats_per_token over all replicas.
56
+ # This has to be done here, because the perplexities themselves can't be
57
+ # averaged in the usual way.
58
+ if "nats_per_token" in mdict:
59
+ nats_per_token = mdict["nats_per_token"].to_value()
60
+ met_summary.add({"perplexity": np.exp(nats_per_token)})
61
+
62
+ if mode == "generate" and "gen_tokens" in mdict:
63
+ # Convert output tokens to example output text.
64
+ # Write text to both the summary, and pretty-print to the log file.
65
+ gen_toks = mdict["gen_tokens"].to_value()
66
+ if np.ndim(gen_toks) != 2:
67
+ raise ValueError("Unsupported shape for gen_tokens: %s" % gen_toks.shape)
68
+
69
+ ntoks = gen_toks.shape[-1]
70
+ gen_text = text_dataset.decode_tokens(gen_toks, vocab, max_length=ntoks)
71
+ logging.info("Generated text = %s", gen_text)
72
+ met_summary.add_text({"gen_text": gen_text})
73
+ del mdict["gen_tokens"] # Otherwise it will turn into a histogram.
74
+
75
+ return met_summary
76
+
77
+
78
+ @gin.configurable
79
+ def process_summaries_function(vocab: seqio.Vocabulary) -> Callable[
80
+ [MetricsSummary, str], MetricsSummary]:
81
+ """Return a function that processes summaries with the given vocabulary."""
82
+ # For use with training_loop.process_summaries_function
83
+ def process_fn(met_summary: MetricsSummary, mode: str):
84
+ return process_summaries(vocab, met_summary, mode)
85
+ return process_fn
86
+
87
+
88
+ @gin.configurable
89
+ class DecoderOnlyLanguageModel(nn.Module):
90
+ """Decoder only language modeling."""
91
+
92
+ mode: str
93
+ task_config: decoder_stack.TransformerTaskConfig = gin.REQUIRED
94
+ decoder_factory: Callable[[], Any] = gin.REQUIRED
95
+
96
+ sample_method: str = "sample" # Can be {"sample", "greedy"}
97
+ output_token_losses: bool = False
98
+
99
+ def get_fake_input(self):
100
+ """Returns a fake input for initialization of the appropriate shape."""
101
+ b = self.task_config.batch_size
102
+ fake_input_dict = {
103
+ "targets": jnp.ones([b, self.task_config.sequence_length],
104
+ dtype=jnp.int32),
105
+ "start_of_sequence": jnp.ones([b], dtype=jnp.bool_),
106
+ "epoch": jnp.ones([b], dtype=jnp.int32),
107
+ }
108
+ if text_dataset.get_loss_mask_tokens(split=self.mode) != (None, None):
109
+ # We are not adding the loss mask to the dummy input by default as it can
110
+ # cause a slowdown during evaluation and perhaps inference.
111
+ fake_input_dict["loss_mask"] = jnp.ones(
112
+ [b, self.task_config.sequence_length], dtype=jnp.bool_)
113
+ return fake_input_dict
114
+
115
+ def metrics_summary_operations(self, aggregate_over: str) -> Dict[str, str]:
116
+ """Summary operation to use for recorded metrics."""
117
+ metric_ops = {
118
+ "loss": "mean",
119
+ "nats_per_token": "mean",
120
+ "bits_per_token": "mean",
121
+ "bits_per_char": "mean",
122
+ "accuracy": "mean",
123
+ "num_tokens": "mean",
124
+ "num_chars_per_device": "mean",
125
+ "num_chars_per_batch": "mean",
126
+ "nonzero_tokens": "mean",
127
+ "num_tokens_per_device": "mean",
128
+ "num_tokens_per_batch": "mean",
129
+ "epoch": "mean",
130
+ }
131
+ if aggregate_over == "steps":
132
+ return metric_ops
133
+ elif aggregate_over == "devices":
134
+ # Ensure that statistics that refer to the total batch size stay constant
135
+ # as TPU topologies change. For those we have to sum over devices, but
136
+ # compute the mean over steps.
137
+ metric_ops.update({
138
+ "num_tokens_per_batch": "sum",
139
+ "num_chars_per_batch": "sum",
140
+ "loss": "sum"})
141
+ return metric_ops
142
+ else:
143
+ raise ValueError("Don't know how to aggregate over: %s" % aggregate_over)
144
+
145
+ def setup(self):
146
+ self.decoder = self.decoder_factory(mode=self.mode,
147
+ task_config=self.task_config) # pytype: disable=wrong-keyword-args # trace-all-classes
148
+
149
+ def __call__(self, inputs: ...):
150
+ task_config = self.task_config
151
+
152
+ input_tokens = inputs["targets"] # [b, seq_len]
153
+ start_of_sequence = inputs["start_of_sequence"] # [b]
154
+ epochs = inputs["epoch"] # [b]
155
+ if "loss_mask" in inputs:
156
+ loss_mask = inputs["loss_mask"] # [b, seq_len]
157
+ else:
158
+ loss_mask = jnp.ones((1, 1), dtype=jnp.bool_)
159
+
160
+ input_tokens = jnp.asarray(input_tokens)
161
+ assert input_tokens.ndim == 2
162
+ assert input_tokens.shape[0] == task_config.batch_size
163
+ assert input_tokens.shape[1] == task_config.sequence_length
164
+ assert start_of_sequence.shape[0] == task_config.batch_size
165
+
166
+ # Sanity check to avoid out-of-bounds on token lookup.
167
+ input_tokens = input_tokens % task_config.vocab_size
168
+
169
+ logging.info("langmodel: Compiling model for mode %s", self.mode)
170
+ logging.info("langmodel: input_tokens = %r", input_tokens)
171
+ logging.info("langmodel: start_of_sequece = %r", start_of_sequence)
172
+ logging.info("langmodel: epochs = %r", epochs)
173
+
174
+ # The target outputs are the next character in each sequence.
175
+ # Shift tokens left and pad with a zero at the end.
176
+ # TODO(delesley): We don't predict the first token of each sequence.
177
+ target_tokens = jnp.pad(input_tokens[:, 1:], [(0, 0), (0, 1)])
178
+ logging.info("langmodel: target_tokens = %r", target_tokens)
179
+
180
+ # Invoke the decoder stack.
181
+ # The decoder will return pre-softmax logits for the predicted targets.
182
+ (logits, _, d_metrics) = self.decoder(input_tokens=input_tokens,
183
+ target_tokens=target_tokens,
184
+ start_of_sequence=start_of_sequence)
185
+
186
+ # Softmax cross-entropy loss on target tokens.
187
+ logits = nn.log_softmax(logits, axis=-1) # (b, seq_len, vocab_size)
188
+ logging.info("langmodel: logits = %r", logits)
189
+ soft_targets = common_utils.onehot(target_tokens, task_config.vocab_size)
190
+ logging.info("langmodel: soft_targets = %r", soft_targets)
191
+
192
+ losses = -jnp.sum(soft_targets * logits, axis=-1) # (b, seq_len)
193
+ logging.info("langmodel: losses = %r", losses)
194
+
195
+ # Don't predict null tokens which are past the end-of-sequence.
196
+ # Also don't predict the 0 at the end of the sequence.
197
+ # TODO(delesley): Predict the final end-of-sequence marker.
198
+ loss_mask = jnp.logical_and(
199
+ loss_mask,
200
+ input_tokens > 0)
201
+ loss_mask = jnp.logical_and(
202
+ loss_mask,
203
+ target_tokens > 0)
204
+ logging.info("langmodel: loss_mask = %r", loss_mask)
205
+
206
+ losses = jnp.where(loss_mask, losses, 0.0) # (batch_size, seq_len)
207
+ loss = jnp.sum(losses) # total loss on device
208
+
209
+ token_count = jnp.sum(loss_mask) # tokens on device
210
+ token_count_nz = token_count + 1.0e-6
211
+ loss_per_token = loss / token_count_nz
212
+ bits_per_token = loss_per_token * 1.442695 # log(e)/log(2)
213
+ accuracy = metric_utils.compute_accuracy_sum(logits, target_tokens,
214
+ loss_mask)
215
+ accuracy = accuracy / token_count_nz # Percent correct.
216
+ epoch = jnp.mean(epochs)
217
+
218
+ if self.mode == "generate" and self.decoder.supports_generate():
219
+ # Generate example text.
220
+ logging.info("lang_model: text inference.")
221
+ gen_tokens = self.generate(inputs, task_config.sequence_length)
222
+
223
+ # Return generated text, along with vizualizations and histograms.
224
+ metrics = {"gen_tokens": gen_tokens, **d_metrics}
225
+ return (loss, metrics)
226
+
227
+ # Just return metrics related to the loss.
228
+ metrics = {
229
+ "loss": loss, # will be summed over devices
230
+ "nats_per_token": (loss_per_token, token_count),
231
+ "bits_per_token": (bits_per_token, token_count),
232
+ "accuracy": (accuracy, token_count),
233
+ "num_tokens_per_device": token_count,
234
+ "num_tokens_per_batch": token_count, # will be summed over devices
235
+ "epoch": epoch,
236
+ }
237
+
238
+ # Compute bits per character if we have the number of characters.
239
+ if "num_chars" in inputs:
240
+ num_chars = jnp.sum(inputs["num_chars"])
241
+ bits_per_char = loss / (num_chars + 1e-6) * 1.442695
242
+ metrics["num_chars_per_device"] = num_chars
243
+ metrics["num_chars_per_batch"] = num_chars # will be summed over devices
244
+ metrics["bits_per_char"] = (bits_per_char, num_chars)
245
+
246
+ # Provided to make sure that the data pipeline and the the model agree
247
+ # on the number of tokens with a loss.
248
+ if "nonzero_tokens" in inputs:
249
+ nonzero_tokens = jnp.sum(inputs["nonzero_tokens"])
250
+ metrics["nonzero_tokens"] = nonzero_tokens
251
+
252
+ if self.output_token_losses:
253
+ metrics["token_losses"] = losses
254
+
255
+ return (loss, metrics)
256
+
257
+ def generate(self, inputs: ..., sequence_length: int) -> Array:
258
+ """Generate an output sequence.
259
+
260
+ Args:
261
+ inputs: the same as argument to _call_.
262
+ sequence_length: the length of sequence to generate.
263
+
264
+ Returns:
265
+ An array of generated tokens of shape (batch_size, sequence_length).
266
+ """
267
+ # TODO(delesley): Add support for passing the prefix as an argument.
268
+ # TODO(delesley): Add support for temperature, gumbel softmax, beam search.
269
+
270
+ batch_size = self.task_config.batch_size
271
+ input_tokens = inputs["targets"] # [b,seq_len]
272
+ start_of_sequence = inputs["start_of_sequence"] # [b]
273
+
274
+ # Initialize decoder.
275
+ dstate = self.decoder.init_decoder_state(sequence_length,
276
+ start_of_sequence)
277
+
278
+ # TODO(delesley): Handle start-of-sequence in a better way.
279
+ # There is no special token for start of sequence, so we grab the first
280
+ # one from the ground-truth input data.
281
+ first_token = input_tokens[:, 0:1]
282
+ no_start_of_seq = jnp.array([False] * batch_size, dtype=jnp.bool_)
283
+ sample_method = self.sample_method
284
+ sample_prng = self.make_rng("sample")
285
+
286
+ # Greedy autoregressive decoder function.
287
+ def loop_fn(scan_state: Any, i: Array) -> Tuple[Any, Array]:
288
+ prng = jax.random.fold_in(sample_prng, i)
289
+ (dstate, input_token) = scan_state
290
+ del i
291
+ (logits, dstate, _) = self.decoder(input_tokens=input_token,
292
+ target_tokens=None,
293
+ start_of_sequence=no_start_of_seq,
294
+ decoder_state=dstate)
295
+ if sample_method == "sample":
296
+ logging.info("Using categorical sampling.")
297
+ output_token = jax.random.categorical(prng, logits, axis=-1)
298
+ elif sample_method == "greedy":
299
+ logging.info("Using greedy sampling.")
300
+ output_token = jnp.argmax(logits, axis=-1)
301
+ else:
302
+ raise ValueError(f"Invalid sampling method: {sample_method}")
303
+ logging.info("generate_loop_fn: output_token = %r", output_token)
304
+ return ((dstate, output_token), output_token)
305
+
306
+ # Scan over the sequence length.
307
+ iterations = jnp.arange(sequence_length)
308
+ initial_scan_state = (dstate, first_token)
309
+ (_, output_tokens) = jax.lax.scan(loop_fn, initial_scan_state, iterations)
310
+ logging.info("generate: output_tokens = %r", output_tokens)
311
+
312
+ # Output_tokens has shape (sequence_length, batch_size, 1)
313
+ assert output_tokens.shape == (sequence_length, batch_size, 1)
314
+ output_tokens = jnp.reshape(
315
+ output_tokens, (sequence_length, self.task_config.batch_size))
316
+ output_tokens = output_tokens.transpose([1, 0])
317
+ return output_tokens
aglib/meliad/transformer/nn_components.py ADDED
@@ -0,0 +1,437 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 Google.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Core NN components used in models.
16
+ """
17
+
18
+ from typing import Any, Callable, Optional, Tuple, Union
19
+
20
+ from absl import logging
21
+ from flax import linen as nn
22
+ import gin
23
+ import jax
24
+ from jax import lax
25
+ from jax.nn import initializers
26
+ import jax.numpy as jnp
27
+
28
+
29
+ PRNGKey = Any
30
+ Array = jnp.ndarray
31
+ Shape = Tuple[int, ...]
32
+ Dtype = Union[jnp.dtype, str]
33
+
34
+
35
+ def scalar_initializer(x):
36
+ """Like linen.zeros, but initializes a parameter to a scalar value."""
37
+ def init_fun(key, shape, dtype):
38
+ del key
39
+ return jnp.broadcast_to(jnp.array(x, dtype=dtype), shape)
40
+ return init_fun
41
+
42
+
43
+ def swish(x: Array) -> Array:
44
+ """Swish function, which is very similar to gelu."""
45
+ return x * nn.sigmoid(x)
46
+
47
+
48
+ def soft_abs(x: Array) -> Array:
49
+ """Soft version of absolute value, that is smoothly differentiable."""
50
+ return jnp.sqrt(jnp.square(x) + 1) - 1
51
+
52
+
53
+ def get_activation_function(fname: Optional[str]) -> Callable[[Array], Array]:
54
+ """Get activation function from the specified string."""
55
+ if fname is None:
56
+ return lambda x: x
57
+ elif fname == "relu":
58
+ return nn.relu
59
+ elif fname == "swish":
60
+ return swish
61
+ elif fname == "sigmoid":
62
+ return nn.sigmoid
63
+ elif fname == "tanh":
64
+ return nn.tanh
65
+ else:
66
+ raise ValueError("Unknown activation function %s" % fname)
67
+
68
+
69
+ # Adapted from flax.linen.softmax.
70
+ def safe_softmax(x: Array,
71
+ axis: Optional[Union[int, Tuple[int, ...]]] = -1,
72
+ min_x: Optional[Array] = None) -> Array:
73
+ r"""Softmax function.
74
+
75
+ Computes the function which rescales elements to the range :math:`[0, 1]`
76
+ such that the elements along :code:`axis` sum to :math:`1`.
77
+
78
+ This version of softmax is intended for use with causal attention masks, and
79
+ safely covers the situation where all elements are masked out. If min_x is
80
+ not None, then probabability will be distributed between the values in x, and
81
+ min_x. If x >> min_x, then the probability allocated to min_x will be zero,
82
+ and this function will be the same as the usual softmax. However, if
83
+ x << min_x, (because all the values in x are masked out) then probability
84
+ will be allocated to min_x instead, and the probability allocated to x will
85
+ be 0. I.e., attention will attend to nothing if everything is masked out.
86
+
87
+ .. math ::
88
+ \mathrm{softmax}(x) = \frac{\exp(x_i)}{\sum_j \exp(x_j)}
89
+
90
+ Args:
91
+ x: input array
92
+ axis: the axis or axes along which the softmax should be computed. The
93
+ softmax output summed across these dimensions should sum to :math:`1`.
94
+ Either an integer or a tuple of integers.
95
+ min_x: the value of a minimum element which will be included in the
96
+ softmax sum. The value of min_x should be small when compared to the
97
+ expected values in x. If all of the values in x are smaller than
98
+ min_x, then probability will be allocated to the minimum element
99
+ instead, and the result of softmax will sum to less than 1.
100
+
101
+ Returns:
102
+ An array of the same shape as x.
103
+ """
104
+ # Subtract maximum value in x for numerical stability, so that the exponent
105
+ # never exceeds numerical precision.
106
+ x_max = lax.stop_gradient(jnp.max(x, axis, initial=min_x, keepdims=True))
107
+ if min_x is not None:
108
+ min_x = jnp.asarray(min_x, dtype=x.dtype)
109
+ x_max = jnp.maximum(x_max, min_x)
110
+ unnormalized = jnp.exp(x - x_max)
111
+ x_sum = jnp.sum(unnormalized, axis=axis, keepdims=True)
112
+ if min_x is not None:
113
+ x_sum = x_sum + jnp.exp(min_x - x_max)
114
+ return unnormalized / x_sum
115
+
116
+
117
+ def dropout_multiplier_mask(rng, dropout_rate: float, shape: Shape,
118
+ dtype: Dtype):
119
+ """Returns an array which can be multiplied by an input to perform dropout.
120
+
121
+ Args:
122
+ rng: A random number generator.
123
+ dropout_rate: The rate at which to drop.
124
+ shape: The shape of the output array.
125
+ dtype: The type of the output array.
126
+
127
+ Returns:
128
+ An array of given shape, where values are { 0.0, 1.0/keep_probibility. }.
129
+ """
130
+ if dropout_rate <= 0.0:
131
+ return jnp.ones(shape, dtype=dtype)
132
+
133
+ logging.info("dropout mask: %s", shape)
134
+ keep_prob = 1.0 - dropout_rate
135
+ keep = jax.random.bernoulli(rng, keep_prob, shape)
136
+ dropout_multiplier = (keep.astype(dtype) / jnp.asarray(keep_prob, dtype))
137
+ return dropout_multiplier
138
+
139
+
140
+ def tiled_dropout(x: Array, shape: Shape, dropout_rate: float,
141
+ rng_function: Callable[[], jax.random.KeyArray],
142
+ deterministic: bool) -> Array:
143
+ """Tiles a dropout mask over a larger array.
144
+
145
+ This will generate a smaller dropout mask of the given shape, and tile it
146
+ over a larger array, which reduces the computational cost and memory
147
+ associated with generating a large dropout mask.
148
+
149
+ Args:
150
+ x: The input array.
151
+ shape: The shape of the dropout mask to tile.
152
+ dropout_rate: The rate at which to drop.
153
+ rng_function: A function which returns a random number generator, e.g.
154
+ lambda. self.make_rng("dropout"). The function will not
155
+ be called if dropout is not enabled.
156
+ deterministic: If True, don't do dropout.
157
+
158
+ Returns:
159
+ An array of the same shape as x, with some values dropped out.
160
+ """
161
+ if deterministic or dropout_rate <= 0.0:
162
+ return x
163
+
164
+ if x.ndim != len(shape):
165
+ raise ValueError("Shapes must have same number of dimensions %r, %r." %
166
+ (x.shape, shape))
167
+ for (xd, sd) in zip(x.shape, shape):
168
+ if (xd % sd) != 0:
169
+ raise ValueError("Incompatible shapes %r, %r" % (x.shape, shape))
170
+
171
+ # Get random number generator for dropout.
172
+ rng = rng_function()
173
+
174
+ repeats = [(1 if sd == 1 else xd // sd) for (xd, sd) in zip(x.shape, shape)]
175
+ logging.info("tiled dropout %r, tile: %r", x.shape, shape)
176
+
177
+ dtype = x.dtype
178
+ keep_prob = 1.0 - dropout_rate
179
+ keep = jax.random.bernoulli(rng, keep_prob, shape)
180
+ keep = jnp.tile(keep, repeats)
181
+ keep = jnp.broadcast_to(keep, x.shape)
182
+ x_scaled = x / jnp.asarray(keep_prob, dtype=dtype)
183
+ return lax.select(keep, x_scaled, jnp.zeros_like(x, dtype=dtype))
184
+
185
+
186
+ @gin.configurable
187
+ class MLP(nn.Module):
188
+ """Implements a multi-layer perceptron, with optional resnet or gate."""
189
+
190
+ # Arguments to module.
191
+ num_output_features: int # Length of output vectors.
192
+
193
+ # Gin configurable parameters.
194
+ num_layers: int = gin.REQUIRED # Number of layers in the MLP.
195
+ num_hidden_units: int = gin.REQUIRED # Length of hidden unit vectors.
196
+ hidden_activation: Optional[str] = "relu" # Hidden layer activation fn.
197
+ final_activation: Optional[str] = None # Final layer activation fn.
198
+ use_bias: bool = True # Use a bias in each dense layer.
199
+ gate_type: Optional[str] = None # { "residual", "bias", "full" }
200
+ initializer_scale: float = 1.0 # Scale of initial values.
201
+ dtype: Any = jnp.float32
202
+
203
+ def setup(self):
204
+ kernel_init = jax.nn.initializers.variance_scaling(
205
+ scale=self.initializer_scale, mode="fan_in",
206
+ distribution="truncated_normal")
207
+
208
+ assert self.num_layers > 0
209
+ hlayers = []
210
+ for i in range(0, self.num_layers - 1):
211
+ assert self.num_hidden_units > 0
212
+ hlayer = nn.Dense(self.num_hidden_units,
213
+ use_bias=self.use_bias,
214
+ kernel_init=kernel_init,
215
+ dtype=self.dtype,
216
+ name=f"hidden{i}")
217
+ hlayers.append(hlayer)
218
+ self.hidden_layers = hlayers
219
+ self.output_layer = nn.Dense(self.num_output_features,
220
+ use_bias=self.use_bias,
221
+ kernel_init=kernel_init,
222
+ dtype=self.dtype)
223
+
224
+ if self.gate_type is None or self.gate_type == "residual":
225
+ return
226
+
227
+ # We use a low but non-zero bias so that adafactor knows how to scale it.
228
+ gate_bias_init = jax.nn.initializers.normal(stddev=0.1)
229
+ # Also use a lower than normal kernel.
230
+ gate_kernel_init = jax.nn.initializers.variance_scaling(
231
+ scale=0.1, mode="fan_in", distribution="truncated_normal")
232
+
233
+ if self.gate_type == "bias":
234
+ self.gate_bias = self.param("gate_bias", gate_bias_init,
235
+ (self.num_output_features,), jnp.float32)
236
+ elif self.gate_type == "full":
237
+ self.gate_layer = nn.Dense(self.num_output_features,
238
+ use_bias=True,
239
+ bias_init=gate_bias_init,
240
+ kernel_init=gate_kernel_init,
241
+ dtype=self.dtype)
242
+ elif self.gate_type == "lstm":
243
+ self.input_gate = nn.Dense(self.num_output_features,
244
+ use_bias=True,
245
+ bias_init=gate_bias_init,
246
+ kernel_init=gate_kernel_init,
247
+ dtype=self.dtype)
248
+ self.forget_gate = nn.Dense(self.num_output_features,
249
+ use_bias=True,
250
+ bias_init=gate_bias_init,
251
+ kernel_init=gate_kernel_init,
252
+ dtype=self.dtype)
253
+ else:
254
+ raise ValueError("Unsupported gate_type: %s" % self.gate_type)
255
+
256
+ def _gate(self, y_hidden: Array, state: Array, y_out: Array) -> Array:
257
+ """Compute the value to use for the gate."""
258
+
259
+ if self.gate_type == "residual":
260
+ # Residual connection: just add y_out to the state.
261
+ logging.info("mlp: residual")
262
+ return state + y_out
263
+
264
+ elif self.gate_type == "bias":
265
+ # Simple gate: use a gru_style gate with a learned bias (no kernel).
266
+ bias = jnp.asarray(self.gate_bias, dtype=self.dtype)
267
+ bias = jnp.reshape(bias, (1,) * (y_out.ndim - 1) + (-1,)) # batch dims.
268
+ g = jax.nn.sigmoid(bias)
269
+ logging.info("mlp: gate bias = %r", g)
270
+ return (state * g) + (y_out * (1 - g))
271
+
272
+ elif self.gate_type == "full":
273
+ # Normal GRU style gate -- compute g using both a kernel and bias.
274
+ g = jax.nn.sigmoid(self.gate_layer(y_hidden) + 1) # biased to remember
275
+ logging.info("mlp: gate full = %r", g)
276
+ return (state * g) + (y_out * (1 - g))
277
+
278
+ elif self.gate_type == "lstm":
279
+ # LSTM style gate with input and forget gates.
280
+ fg = jax.nn.sigmoid(self.forget_gate(y_hidden) + 1) # biased to remember
281
+ ig = jax.nn.sigmoid(self.input_gate(y_hidden) - 1)
282
+ logging.info("mlp: gate lstm = %r, %r", ig, fg)
283
+ return (state * fg) + (y_out * ig)
284
+
285
+ else:
286
+ raise ValueError("Unsupported gate type %s" % self.gate_type)
287
+
288
+ def __call__(self, x: Array, state: Optional[Array],
289
+ apply_dropout: bool = False,
290
+ dropout_rate: float = 0.0,
291
+ drop_tile_shape: Optional[Shape] = None,
292
+ rng_function: Optional[Callable[[], Any]] = None) -> Array:
293
+ """Apply the multi-layer perceptron to the input x.
294
+
295
+ For simple MLPs, returns f(x), where f is the MLP function.
296
+ For resnets and gated architectures, it returns
297
+ state + f(x) -- for resnet.
298
+ g*state + (1-g)*f(x) -- for gated architecture, where g is the gate.
299
+
300
+ Args:
301
+ x: The input to the MLP.
302
+ state: The prior value, if this MLP is used as part of a resnet or gated
303
+ architecture.
304
+ apply_dropout: If true, applies dropout to the result.
305
+ dropout_rate: The dropout rate to use.
306
+ drop_tile_shape: The dropout tile shape.
307
+ rng_function: Gets a random number seed for dropout.
308
+
309
+ Returns:
310
+ The combination of f(x) and the (optional) prior state.
311
+ """
312
+
313
+ x = jnp.asarray(x, self.dtype)
314
+ hidden_act_fun = get_activation_function(self.hidden_activation)
315
+ final_act_fun = get_activation_function(self.final_activation)
316
+ if self.hidden_layers:
317
+ # Apply some number of hidden layers.
318
+ y = x
319
+ for layer in self.hidden_layers:
320
+ logging.info("mlp: hidden %d, %s", self.num_hidden_units,
321
+ self.hidden_activation)
322
+ y = hidden_act_fun(layer(y))
323
+ else:
324
+ # Apply the hidden activation function to the input.
325
+ logging.info("mlp: activation = %s", self.hidden_activation)
326
+ y = hidden_act_fun(x)
327
+
328
+ y_hidden = y # The hidden layer right before the output.
329
+ logging.info("mlp: final activation = %s", self.final_activation)
330
+ y_out = self.output_layer(y_hidden) # The MLP final output.
331
+ y_out = final_act_fun(y_out) # Apply final activation function.
332
+ logging.info("mlp: final = %r", y_out)
333
+
334
+ # Optionally apply dropout to the output.
335
+ if apply_dropout:
336
+ if drop_tile_shape is None:
337
+ raise ValueError("drop_tile_shape must be specified for dropout.")
338
+ if rng_function is None:
339
+ raise ValueError("rng_function must be specified for dropout.")
340
+ logging.info("mlp: dropout rate = %s", dropout_rate)
341
+ y_out = tiled_dropout(
342
+ y_out, shape=drop_tile_shape, dropout_rate=dropout_rate,
343
+ rng_function=rng_function, deterministic=False)
344
+
345
+ if state is None:
346
+ # Simple MLP. No gate to combine y_out with the state.
347
+ assert self.gate_type is None
348
+ logging.info("mlp: gate type = None.")
349
+ return y_out
350
+
351
+ # When using state, gate_type must be specified.
352
+ assert self.gate_type is not None
353
+ return self._gate(y_hidden, state, y_out)
354
+
355
+
356
+ # Modified slightly from the flax implementation.
357
+ @gin.configurable
358
+ class LayerNorm(nn.Module):
359
+ """Layer normalization (https://arxiv.org/abs/1607.06450).
360
+
361
+ Operates on the last axis of the input data.
362
+
363
+ It normalizes the activations of the layer for each given example in a
364
+ batch independently, rather than across a batch like Batch Normalization.
365
+ i.e. applies a transformation that maintains the mean activation within
366
+ each example close to 0 and the activation standard deviation close to 1.
367
+
368
+ Attributes:
369
+ epsilon: A small float added to variance to avoid dividing by zero.
370
+ dtype: the dtype of the computation (default: float32).
371
+ use_bias: If True, bias (beta) is added.
372
+ use_scale: If True, multiply by scale (gamma).
373
+ use_mean: If True, compute and adjust for the mean.
374
+ Note that that T5X layernorm does not use the mean.
375
+ Empirically, ignoring the mean can stabilize learning in transformers.
376
+ use_scalar_scale_bias: If True, using a single scalar for scale & bias.
377
+ enable_layernorm: If False, does not perform layernorm.
378
+ bias_init: Initializer for bias, by default, zero.
379
+ scale_init: Initializer for scale, by default, one.
380
+ """
381
+ epsilon: float = 1e-6
382
+ dtype: Any = jnp.float32
383
+ use_scale: bool = True # Apply a learned scale.
384
+ use_bias: bool = False # Apply a learned bias.
385
+ use_mean: bool = False # Calculate and adjust for the mean.
386
+ use_scalar_scale_bias: bool = False # Learn a single scalar scale & bias.
387
+ enable_layernorm: bool = True # Turn off layernorm if false.
388
+ bias_init: Callable[[PRNGKey, Shape, Dtype], Array] = initializers.zeros
389
+ scale_init: Callable[[PRNGKey, Shape, Dtype], Array] = initializers.ones
390
+
391
+ @nn.compact
392
+ def __call__(self, x):
393
+ """Applies layer normalization on the input.
394
+
395
+ Args:
396
+ x: the inputs
397
+
398
+ Returns:
399
+ Normalized inputs (the same shape as inputs).
400
+ """
401
+ if not self.enable_layernorm:
402
+ return x
403
+ x = jnp.asarray(x)
404
+
405
+ # Calculate mean and variance at higher precision.
406
+ xf = jnp.asarray(x, jnp.float32)
407
+ if self.use_mean:
408
+ mean = jnp.mean(xf, axis=-1, keepdims=True)
409
+ xf = xf - mean
410
+ var = jnp.mean(lax.square(xf), axis=-1, keepdims=True)
411
+ mul = lax.rsqrt(var + self.epsilon)
412
+
413
+ # Rescale x
414
+ # if not use_mean, then rescale around zero instead. (A simplification.)
415
+ if self.use_mean:
416
+ y = (x - mean) * mul
417
+ else:
418
+ y = x * mul
419
+
420
+ if self.use_scalar_scale_bias:
421
+ # Learn a single scalar value for bias and scale.
422
+ # (Which mirrors the single value for mean and stddev above.)
423
+ num_scale_bias_features = 1
424
+ else:
425
+ # Learn a different value per neuron/feature for bias and scale.
426
+ num_scale_bias_features = x.shape[-1]
427
+
428
+ # Apply learned scale and bias.
429
+ if self.use_scale:
430
+ y = y * jnp.asarray(
431
+ self.param("scale", self.scale_init, (num_scale_bias_features,)),
432
+ dtype=self.dtype)
433
+ if self.use_bias:
434
+ y = y + jnp.asarray(
435
+ self.param("bias", self.bias_init, (num_scale_bias_features,)),
436
+ dtype=self.dtype)
437
+ return y.astype(self.dtype)
aglib/meliad/transformer/position.py ADDED
@@ -0,0 +1,242 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 Google.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Functions for dealing with relative and absolute positions, and masks."""
16
+
17
+ from typing import Optional, Tuple, Union
18
+
19
+ import jax.numpy as jnp
20
+ import numpy as np
21
+
22
+
23
+ Array = jnp.ndarray
24
+ NpArray = np.ndarray
25
+ Dtype = Union[jnp.dtype, str]
26
+
27
+
28
+ def relative_positions(num_queries: int, num_keys: int,
29
+ offset: Optional[int] = None):
30
+ """Returns an jax array of relative positions between query and key.
31
+
32
+ If num_keys >= num_queries, e.g. for transformer XL or sliding window,
33
+ then offset should be (num_keys - num_queries) to make the last N queries
34
+ line up with the last N keys. This is the default if offset is None.
35
+
36
+ Args:
37
+ num_queries: Number of queries.
38
+ num_keys: Number of keys.
39
+ offset: Offset of the first query wrt. the first key.
40
+
41
+ Returns:
42
+ A /jax/ array of shape [num_queries, num_keys] with the signed distance
43
+ from each query to each key.
44
+ """
45
+
46
+ # Get the offset of each query wrt. to each key.
47
+ # If not specified, assume the last N queries line up with the last N keys.
48
+ if offset is None:
49
+ if num_keys < num_queries:
50
+ raise ValueError("Number of keys %d must be greater than queries %d" %
51
+ (num_keys, num_queries))
52
+ offset = num_keys - num_queries
53
+ qidx = jnp.arange(0, num_queries, dtype=jnp.int32).reshape(num_queries, 1)
54
+ kidx = jnp.arange(0, num_keys, dtype=jnp.int32).reshape(1, num_keys)
55
+ return kidx - (qidx + offset)
56
+
57
+
58
+ def relative_positions_np(num_queries: int, num_keys: int,
59
+ offset: Optional[int] = None):
60
+ """Returns a numpy array of relative positions between query and key.
61
+
62
+ If num_keys >= num_queries, e.g. for transformer XL or sliding window,
63
+ then offset should be (num_keys - num_queries) to make the last N queries
64
+ line up with the last N keys. This is the default if offset is None.
65
+
66
+ Args:
67
+ num_queries: Number of queries.
68
+ num_keys: Number of keys.
69
+ offset: Offset of the first query wrt. to the first key.
70
+
71
+ Returns:
72
+ A /numpy/ array of shape [num_queries, num_keys] with the signed distance
73
+ from each query to each key.
74
+ """
75
+
76
+ # Get the offset of each query wrt. to each key.
77
+ # If not specified, assume the last N queries line up with the last N keys.
78
+ if offset is None:
79
+ if num_keys < num_queries:
80
+ raise ValueError("Number of keys %d must be greater than queries %d" %
81
+ (num_keys, num_queries))
82
+ offset = num_keys - num_queries
83
+ qidx = np.arange(0, num_queries, dtype=np.int32).reshape(num_queries, 1)
84
+ kidx = np.arange(0, num_keys, dtype=np.int32).reshape(1, num_keys)
85
+ return kidx - (qidx + offset)
86
+
87
+
88
+ def broadcast_mask(mask: Array, attn: Array):
89
+ """Broadcast a mask or bias over all the dimensions of attn."""
90
+
91
+ # Add leading dimensions for batch_size, num_heads if necessary.
92
+ if mask.ndim < attn.ndim:
93
+ mask = jnp.expand_dims(mask, axis=tuple(range(0, attn.ndim - mask.ndim)))
94
+ return mask
95
+
96
+
97
+ def causal_mask(num_queries: int, num_keys: int, window_length: int = 0):
98
+ """Returns a causal mask of the same shape as attn."""
99
+
100
+ # The mask ranges over the window_length positions prior to each query.
101
+ if window_length == 0:
102
+ window_length = num_queries
103
+
104
+ kqpos = relative_positions(num_queries, num_keys) # 2D mask
105
+
106
+ # The causal mask includes only those tokens *before* the current token.
107
+ # This slightly improves perplexity in practice, and simplifies generation.
108
+ # Each token attends to exactly window_length prior tokens.
109
+ mask = (kqpos < 0) & (kqpos >= -window_length)
110
+ return mask
111
+
112
+
113
+ def position_encoding(num_positions: int,
114
+ input_dim: int,
115
+ *,
116
+ offset: int = 0,
117
+ max_wavelength: float = 0) -> NpArray:
118
+ """Returns a position encoding of shape (num_positions, input_dim).
119
+
120
+ Positions are encoded as sin/cos pairs at geometrically increasing
121
+ wavelengths.
122
+
123
+ The length of a half-wave (peak to trough) increases geometrically from 1 to
124
+ max_wavelength. (Technically, it's slightly less; the last sin/cos pair has
125
+ a wavelength of max_wavelength**((d-1)/d), where d = input_dim/2.)
126
+
127
+ NOTE: unlike prior published position encodings, we multiply the position of
128
+ each token by pi to convert from fractions of a wave (position/wavelength)
129
+ to radians. Thus, the highest frequency wave alternates between -1 and 1 on
130
+ every token, whereas in prior published work the highest frequency alternates
131
+ between -1 and 1 every pi tokens. The max_wavelength is also effectively
132
+ 1/pi times as long, so a prior published factor of 10,000
133
+ (e.g. https://arxiv.org/abs/1706.03762) would equate to a max_wavelength
134
+ of 31,416.
135
+
136
+ This encoding also does not alternate between sin/cos values, but puts all of
137
+ the cos values on one side, and the sin values on the other. That makes it
138
+ easier to split the sin,cos values to construct or apply a rotation matrix.
139
+
140
+ The default value for max_wavelength is 2 * num_positions.
141
+
142
+ Args:
143
+ num_positions: The number of positions.
144
+ input_dim: The dimension of the position vector.
145
+ *: --- The following are keyword arguments only. ---
146
+ offset: Positions count from offset to (offset + num_positions).
147
+ max_wavelength: The maximum length of a half-wave (peak to trough)
148
+
149
+ Returns:
150
+ Numpy matrix of shape (num_positions, input_dim) containing the encodings.
151
+ Position encodings are packed as concat(cos_values, sin_values, axis=1).
152
+ """
153
+
154
+ if max_wavelength == 0:
155
+ max_wavelength = 2 * num_positions
156
+ assert max_wavelength > 1
157
+
158
+ assert (input_dim % 2) == 0
159
+ idim2 = input_dim // 2
160
+
161
+ # t ranges from 0 <= t < 1
162
+ t = np.arange(0, idim2, dtype=np.float32) / idim2
163
+
164
+ # wavelength (columns)
165
+ # The length of a half-wave (trough to peak) increases geometrically
166
+ # 1 <= wavelength < max_wavelength
167
+ wavelength = float(max_wavelength)**t
168
+ wavelength = np.reshape(wavelength, (1, idim2)) # broadcast over rows
169
+
170
+ # k is the position in the sequence (rows)
171
+ k = np.arange(offset, num_positions + offset, dtype=np.float32)
172
+ k = np.reshape(k, (num_positions, 1)) # broadcast over columns
173
+
174
+ # For each position (row) compute an angle (column) at various wavelengths.
175
+ # NOTE: unlike prior published work, we multiply by pi to convert to radians.
176
+ pi_f = np.array(np.pi, dtype=np.float32)
177
+ angles = pi_f * k / wavelength # shape (num_positions, idim2)
178
+ posx = np.cos(angles, dtype=np.float32)
179
+ posy = np.sin(angles, dtype=np.float32)
180
+ return np.concatenate([posx, posy], axis=1) # shape (num_positions, idim)
181
+
182
+
183
+ def rotate_kq(keys: Array, queries: Array,
184
+ *, # the following args must be passed by keyword.
185
+ max_wavelength: float,
186
+ offset: Optional[int] = None,
187
+ dtype: Optional[Dtype] = None) -> Tuple[Array, Array]:
188
+ """Rotate keys and queries by the relative distance between query and key.
189
+
190
+ Implements rotary position embeddings (RoPE) https://arxiv.org/abs/2104.09864.
191
+
192
+ Args:
193
+ keys: array of shape (batch_size, num_keys, num_heads, head_size)
194
+ queries: aray of shape (batch_size, num_queries, num_heads, head_size)
195
+ max_wavelength: The maximum length of a half-wave (peak to trough)
196
+ offset: The relative positional offset from keys[i] to queries[i].
197
+ Defaults to num_keys - num_queries if not specified.
198
+ dtype: The precision to perform the rotation at.
199
+ Defaults to keys.dtype.
200
+
201
+ Returns:
202
+ (keys, queries) after rotation.
203
+ """
204
+
205
+ (batch_size, num_keys, num_heads, head_size) = keys.shape
206
+ (_, num_queries, _, _) = queries.shape
207
+ assert queries.shape == (batch_size, num_queries, num_heads, head_size)
208
+
209
+ if offset is None:
210
+ assert num_keys >= num_queries
211
+ offset = num_keys - num_queries
212
+
213
+ if dtype is None:
214
+ dtype = keys.dtype
215
+
216
+ def rotate_k_or_q(kq: Array, num_kq: int, kq_offset: int) -> Array:
217
+ nonlocal max_wavelength
218
+ nonlocal dtype
219
+
220
+ # Get position encodings, which can be used to do a rotation.
221
+ kq_pos = position_encoding(num_kq, head_size, offset=kq_offset,
222
+ max_wavelength=max_wavelength)
223
+ # Broadcast over batch_size and num_heads.
224
+ kq_pos = np.reshape(kq_pos, (1, num_kq, 1, head_size))
225
+ # Split position encoding into separate sin/cos values in order to
226
+ # construct a rotation matrix.
227
+ (cosa, sina) = np.split(kq_pos, 2, axis=-1)
228
+ cosa = jnp.asarray(cosa, dtype=dtype) # convert from numpy -> jax
229
+ sina = jnp.asarray(sina, dtype=dtype) # convert from numpy -> jax
230
+
231
+ # Split keys/queries into real & imaginary (i.e. x & y) parts.
232
+ (kqx, kqy) = jnp.split(kq, 2, axis=-1)
233
+ # Apply rotation matrix.
234
+ kqx_rot = (kqx * cosa) - (kqy * sina)
235
+ kqy_rot = (kqx * sina) + (kqy * cosa)
236
+ # Concatenate back into keys/queries.
237
+ return jnp.concatenate([kqx_rot, kqy_rot], axis=-1)
238
+
239
+ keys = rotate_k_or_q(keys, num_keys, -offset) # pylint: disable=invalid-unary-operand-type
240
+ queries = rotate_k_or_q(queries, num_queries, 0)
241
+ return (keys, queries)
242
+
aglib/meliad/transformer/position_fourier.py ADDED
@@ -0,0 +1,218 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 Google.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Class for Fourier relative position biases.
16
+
17
+ This implementation uses the same Fourier position encodings that are used
18
+ in the absolute position encoding. However, instead of adding the positions
19
+ to the input, where the position vector and content vectors become entangled,
20
+ the relative encoding computes a relative position bias matrix, which is then
21
+ added to the content-based attention matrix before applying softmax.
22
+
23
+ The bias matrix is computed as follows. First, a learned transformation is
24
+ applied to each query position, which transforms it so that it matches a set
25
+ of key positions. The relative position bias between query 'i' and key 'j' is
26
+ the dot product between the transformed position 'i', and position 'j'.
27
+
28
+ The learned transformation is designed so that the match between query and key
29
+ is a function of the relative distance between the two. Although absolute
30
+ positions are fed as inputs, the rest of the network can't "see" the absolute
31
+ positions; it can only transform them by some relative amount.
32
+
33
+ A position vector consists of a sequence of (sin, cos) pairs, which have
34
+ geometrically increasing wavelengths that span from 2 (for the first pair
35
+ in each vector) to twice the length of the token sequence (for the last pair).
36
+ Each sin/cos pair encodes the (x, y) value of a 2D unit vector at a particular
37
+ angle. For each sin/cos pair in the query position vector, we apply a learned
38
+ 2x2 rotation matrix, which will rotate and scale the pair by some amount.
39
+
40
+ The dot product of two (sin, cos) pairs is the cosine of the angle between them.
41
+ The dot product of the query position and key position vectors is thus the sum
42
+ of such cosines. By rotating and scaling the query position, it is possible to
43
+ approximate any function over relative position as a Fourier series: a sum of
44
+ cosine waves at different wavelengths. The rotation provides phase, and the
45
+ scale provides magnitude.
46
+
47
+ Put another way, rotating the (sin, cos) pairs of a query position will compute
48
+ a relative offset from the /query/ position to some target /key/ position.
49
+ """
50
+
51
+ from typing import Any, Optional
52
+
53
+ from flax import linen as nn
54
+ import gin
55
+ import jax.numpy as jnp
56
+ from transformer import position
57
+ import numpy as np
58
+
59
+
60
+ Array = jnp.ndarray
61
+
62
+
63
+ def _initialize_frel_rotation_matrix(rng, num_heads, vec_size):
64
+ """Intialize the rotation matrices."""
65
+ # Initialize each rotation matrix to the identity * scale.
66
+ #
67
+ # Initially scale by 1 / number of sine waves = 1/2 the position vector size.
68
+ # With this initialization, the initial position bias terms should be
69
+ # between -1.0 and 1.0 after the rotation matrix has been applied.
70
+ del rng # required for init function but unused
71
+ scale = float(2.0 / vec_size)
72
+ tmat_a = jnp.ones([num_heads, vec_size // 2], dtype=jnp.float32) * scale
73
+ tmat_b = jnp.zeros([num_heads, vec_size // 2], dtype=jnp.float32)
74
+ return jnp.concatenate([tmat_a, tmat_b], axis=1)
75
+
76
+
77
+ @gin.configurable
78
+ class RelativeFourierPositions(nn.Module):
79
+ """A implementation of Fourier relative positions."""
80
+
81
+ # The number of attention heads.
82
+ num_heads: int = 8
83
+
84
+ # The maximum number of keys to attend to.
85
+ # The sin/cos wavelengths of the position vectors will be tuned to this max.
86
+ max_number_of_keys: int = 1024
87
+
88
+ # Size of the position vector. Needs to be large enough to address the keys.
89
+ position_vector_size: int = 128
90
+
91
+ # Data type to use for the rotation matrices.
92
+ dtype: Any = jnp.float32
93
+
94
+ @nn.compact
95
+ def __call__(self, num_queries: int, num_keys: int,
96
+ offset: Optional[int] = None,
97
+ bidirectional: bool = True) -> Array:
98
+ """Returns relative positional attention matrix.
99
+
100
+ If num_keys >= num_queries, e.g. for transformer XL or sliding window,
101
+ then offset should be (num_keys - num_queries) to make the last N queries
102
+ line up with the last N keys. This is the default if offset is None.
103
+
104
+ Args:
105
+ num_queries: Number of queries.
106
+ num_keys: Number of keys.
107
+ offset: Offset of the first query with respect to the first key.
108
+ (See position.relative_positions() for more info.)
109
+ bidirectional: Unused, included for compatibility.
110
+ Relative positions are always bidirectional.
111
+ Returns:
112
+ Attention matrix of shape (num_heads, num_queries, num_keys)
113
+ """
114
+
115
+ # Get the offset of each query with respect to each key.
116
+ # If not specified, the last N queries line up with the last N keys.
117
+ if offset is None:
118
+ assert num_keys >= num_queries
119
+ offset = num_keys - num_queries
120
+ max_wavelength = 2 * self.max_number_of_keys
121
+
122
+ # Compute absolute position vectors for keys.
123
+ # Use numpy to compute these arrays statically.
124
+ # ks : (num_keys, pvec_size)
125
+ ks = position.position_encoding(num_keys,
126
+ self.position_vector_size,
127
+ offset=0, # offset of queries wrt. keys
128
+ max_wavelength=max_wavelength)
129
+
130
+ # Compute absolute position vectors for queries.
131
+ # qs : (num_queries, pvec_size)
132
+ if offset >= 0 and offset + num_queries <= num_keys:
133
+ # Query positions are a subset of the key positions.
134
+ qs = ks[offset:offset + num_queries]
135
+ else:
136
+ # Query positions must be computed separately.
137
+ qs = position.position_encoding(num_queries,
138
+ self.position_vector_size,
139
+ offset=offset,
140
+ max_wavelength=max_wavelength)
141
+
142
+ # Split qs into x and y coordinates for rotation.
143
+ (qx, qy) = np.split(qs, 2, axis=-1)
144
+ qs_xs = np.concatenate([qx, qx], axis=-1)
145
+ qs_ys = np.concatenate([qy, qy], axis=-1)
146
+ del qs
147
+
148
+ # Convert from numpy to jax.
149
+ ks = jnp.asarray(ks, dtype=self.dtype)
150
+ qs_xs = jnp.asarray(qs_xs, dtype=self.dtype)
151
+ qs_ys = jnp.asarray(qs_ys, dtype=self.dtype)
152
+
153
+ # Initialize the rotation matrices to the identity.
154
+ rotation_matrix = self.param("rotation_matrix",
155
+ _initialize_frel_rotation_matrix,
156
+ self.num_heads,
157
+ self.position_vector_size)
158
+
159
+ rotation_matrix = jnp.asarray(rotation_matrix, dtype=self.dtype)
160
+
161
+ # Unpack rotatation_matrix to a set of 2x2 matrices.
162
+ rmat1 = rotation_matrix # [rm_a, rm_b]
163
+ (rm_a, rm_b) = jnp.split(rotation_matrix, 2, axis=-1)
164
+ rmat2 = jnp.concatenate([-rm_b, rm_a], axis=-1)
165
+
166
+ # Vectors in qs consist of a set of (x,y) (e.g. sin,cos) pairs.
167
+ # We transform each (x,y) pair with a 2D rotation matrix:
168
+ #
169
+ # x' = a*x + -b*y
170
+ # y' = b*x + a*y
171
+ #
172
+ # or equivalently, x' + y'i = (a + bi)(x + yi) where i = sqrt(-1).
173
+ #
174
+ # For an angle theta, and scale s, a = cos(theta)*s, b = sin(theta)*s,
175
+ # and a + bi = s*exp(i*theta). We avoid computing sin,cos by training a,b
176
+ # directly.
177
+ #
178
+ # qs_xs = [x0 .. xn; x0 .. xn] -- layout of qs_xs
179
+ # qs_ys = [y0 .. yn; y0 .. yn]
180
+ # rmat1 = [a0 .. an; b0 .. bn] -- layout of (a,b) values in rmat1
181
+ # rmat2 = [-b0 .. -bn; a0 .. an]
182
+ #
183
+ # rot_qs: (num_heads, num_queries, pvec_size)
184
+
185
+ # Broadcast qs over the number of heads.
186
+ # Broadcast rmat over the number of queries.
187
+ qs_xs = qs_xs[jnp.newaxis, ...] # (1, num_queries, pvec_size)
188
+ qs_ys = qs_ys[jnp.newaxis, ...]
189
+ rmat1 = rmat1[:, jnp.newaxis, ...] # (num_heads, 1, pvec_size)
190
+ rmat2 = rmat2[:, jnp.newaxis, ...]
191
+ rot_qs = ((rmat1 * qs_xs) + (rmat2 * qs_ys))
192
+
193
+ # Compute the dot product of each position vector in ks by the rotated qs.
194
+ #
195
+ # The dot product of each (x, y) pair in ks, and each (x', y') in rot_qs,
196
+ # is equal to the cosine of the angle between them, times the length
197
+ # of (x', y').
198
+ #
199
+ # The angle of the cosine for each pair depends on:
200
+ # - The distance between the key and the query, divided by the wavelength.
201
+ # (From the initial position encoding for ks and qs).
202
+ # - The rotation performed by (a,b).
203
+ #
204
+ # The length of (x', y') is equal to the scale of (a, b).
205
+ #
206
+ # The dot product of two complete position vectors is the sum of the
207
+ # cosines for all pairs. The cosines form a progression of geometrically
208
+ # increasing wavelengths, and each wave has a scale and phase provided by
209
+ # the rotation matrix. The sum of such waves can thus approximate any
210
+ # function of position.
211
+ #
212
+ # pbias: (num_heads, num_queries, num_keys)
213
+ pbias = jnp.einsum("hqd,kd->hqk", rot_qs, ks)
214
+
215
+ # Add batch dimension; --> shape (1, num_heads, num_queries, num_keys)
216
+ pbias = jnp.expand_dims(pbias, 0)
217
+ return pbias.astype(self.dtype)
218
+
aglib/meliad/transformer/position_t5.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 Google.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Class for T5 relative position biases.
16
+
17
+ Adapted from flaxformer.components.relative_position_biases.py
18
+ """
19
+
20
+ from typing import Any, Callable, Optional
21
+
22
+ from flax import linen as nn
23
+ import gin
24
+ from jax import lax
25
+ import jax.numpy as jnp
26
+ from transformer import position
27
+ import numpy as np
28
+
29
+
30
+ Array = Any
31
+
32
+
33
+ @gin.configurable
34
+ class T5RelativePositionBiases(nn.Module):
35
+ """Adds T5-style relative positional embeddings to the attention logits.
36
+
37
+ Attributes:
38
+ num_buckets: Number of buckets to bucket distances between key and query
39
+ positions into.
40
+ max_distance: Maximum distance before everything is lumped into the last
41
+ distance bucket.
42
+ num_heads: Number of heads in the attention layer. Each head will get a
43
+ different relative position weighting.
44
+ dtype: Type of arrays through this module.
45
+ embedding_init: initializer for relative embedding table.
46
+ """
47
+ num_buckets: int
48
+ max_distance: int
49
+ num_heads: int
50
+ dtype: Any
51
+ embedding_init: Callable[..., Array] = nn.linear.default_embed_init
52
+
53
+ @staticmethod
54
+ def _relative_position_bucket(relative_position,
55
+ bidirectional=True,
56
+ num_buckets=32,
57
+ max_distance=128):
58
+ """Translate relative position to a bucket number for relative attention.
59
+
60
+ The relative position is defined as memory_position - query_position, i.e.
61
+ the distance in tokens from the attending position to the attended-to
62
+ position. If bidirectional=False, then positive relative positions are
63
+ invalid.
64
+ We use smaller buckets for small absolute relative_position and larger
65
+ buckets for larger absolute relative_positions. All relative
66
+ positions >=max_distance map to the same bucket. All relative
67
+ positions <=-max_distance map to the same bucket. This should allow for
68
+ more graceful generalization to longer sequences than the model has been
69
+ trained on.
70
+
71
+ Args:
72
+ relative_position: an int32 array
73
+ bidirectional: a boolean - whether the attention is bidirectional
74
+ num_buckets: an integer
75
+ max_distance: an integer
76
+
77
+ Returns:
78
+ a Tensor with the same shape as relative_position, containing int32
79
+ values in the range [0, num_buckets)
80
+ """
81
+ ret = 0
82
+ n = -relative_position
83
+ if bidirectional:
84
+ num_buckets //= 2
85
+ ret += (n < 0).astype(np.int32) * num_buckets
86
+ n = np.abs(n)
87
+ else:
88
+ n = np.maximum(n, 0)
89
+ # now n is in the range [0, inf)
90
+ max_exact = num_buckets // 2
91
+ is_small = (n < max_exact)
92
+ val_if_large = max_exact + (
93
+ np.log(n.astype(np.float32) / max_exact + np.finfo(np.float32).eps) /
94
+ np.log(max_distance / max_exact) *
95
+ (num_buckets - max_exact)).astype(np.int32)
96
+ val_if_large = np.minimum(val_if_large, num_buckets - 1)
97
+ ret += np.where(is_small, n, val_if_large)
98
+ return ret
99
+
100
+ @nn.compact
101
+ def __call__(self, num_queries, num_keys, offset: Optional[int]=None,
102
+ bidirectional=True):
103
+ """Produce relative position embedding attention biases.
104
+
105
+ Args:
106
+ num_queries: Number of queries.
107
+ num_keys: Number of keys.
108
+ offset: Offset of the first query with respect to the first key.
109
+ (See position.relative_positions() for more info.)
110
+ bidirectional: whether to allow positive memory-query relative position
111
+ embeddings.
112
+
113
+ Returns:
114
+ output: `(1, num_heads, num_queries, num_keys)` attention bias
115
+ """
116
+
117
+ # Find the distance between each query and each key.
118
+ # This is where this implementation differs from the T5 implementation;
119
+ # this version lines the /last/ N queries up with the /last/ N keys,
120
+ # (which is appropriate for XL/sliding window) while the T5 version lines
121
+ # up the /first/ N queries with the first N keys, in cases where the
122
+ # number of keys and queries differ.
123
+ relative_position = position.relative_positions_np(
124
+ num_queries=num_queries, num_keys=num_keys, offset=offset)
125
+
126
+ rp_bucket = self._relative_position_bucket(
127
+ relative_position,
128
+ bidirectional=bidirectional,
129
+ num_buckets=self.num_buckets,
130
+ max_distance=self.max_distance)
131
+ relative_attention_bias = self.param('rel_embedding', self.embedding_init,
132
+ (self.num_heads, self.num_buckets),
133
+ jnp.float32)
134
+
135
+ relative_attention_bias = jnp.asarray(relative_attention_bias, self.dtype)
136
+ # Instead of using a slow gather, we create a leading-dimension one-hot
137
+ # array from rp_bucket and use it to perform the gather-equivalent via a
138
+ # contraction, i.e.:
139
+ # (num_head, num_buckets) x (num_buckets one-hot, num_queries, num_keys).
140
+ # This is equivalent to relative_attention_bias[:, rp_bucket]
141
+ bcast_iota = lax.broadcasted_iota(jnp.int32, (self.num_buckets, 1, 1), 0)
142
+ rp_bucket_one_hot = jnp.array(
143
+ rp_bucket[jnp.newaxis, ...] == bcast_iota, dtype=self.dtype)
144
+ # --> shape (num_queries, num_keys, num_heads)
145
+ values = lax.dot_general(
146
+ relative_attention_bias,
147
+ rp_bucket_one_hot,
148
+ (
149
+ ((1,), (0,)), # rhs, lhs contracting dims
150
+ ((), ()))) # no batched dims
151
+ # Add a singleton batch dimension.
152
+ # --> shape (1, num_heads, num_queries, num_keys)
153
+ out = values[jnp.newaxis, ...]
154
+
155
+ return out
aglib/meliad/transformer/synthetic_text_data.py ADDED
The diff for this file is too large to render. See raw diff
 
aglib/meliad/transformer/tasks.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 Google.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Add Tasks to registry."""
16
+
17
+ import functools
18
+
19
+ from transformer import text_dataset
20
+ import seqio
21
+ import t5.data
22
+ from t5.data import preprocessors
23
+ import tensorflow as tf
24
+
25
+
26
+ TaskRegistry = seqio.TaskRegistry
27
+
28
+
29
+ def define_pg19_task(name: str, vocab: seqio.Vocabulary):
30
+ seqio.TaskRegistry.add(
31
+ name,
32
+ seqio.TfdsDataSource(
33
+ tfds_name="pg19:0.1.1"
34
+ ),
35
+ preprocessors=[
36
+ functools.partial(text_dataset.rekey_articles,
37
+ rekey={"book_text": "targets"},
38
+ keep={"book_title", "book_id", "publication_date"}),
39
+ seqio.preprocessors.tokenize,
40
+ ],
41
+ output_features={
42
+ "targets": seqio.Feature(vocab,
43
+ add_eos=False, dtype=tf.int32),
44
+ }
45
+ )
46
+
47
+
48
+ T5_DEFAULT_VOCABULARY = t5.data.get_default_vocabulary()
49
+ define_pg19_task("pg19_bytes", seqio.ByteVocabulary())
50
+ define_pg19_task("pg19_tokens", T5_DEFAULT_VOCABULARY)
51
+
52
+
aglib/meliad/transformer/text_dataset.py ADDED
@@ -0,0 +1,775 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 Google.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Load text datasets for long-range transformer models."""
16
+
17
+ import os
18
+ import re
19
+ from typing import Any, Callable, Dict, Iterable, Mapping, Optional, Sequence, Set, Tuple, Union
20
+
21
+ from absl import flags
22
+ from absl import logging
23
+ import gin
24
+ import jax
25
+ from transformer import synthetic_text_data
26
+ import numpy as np
27
+ import seqio
28
+ import tensorflow.compat.v2 as tf
29
+
30
+
31
+
32
+ flags.DEFINE_string("default_data_dir", None,
33
+ "Default directory where data is stored.")
34
+ FLAGS = flags.FLAGS
35
+
36
+
37
+ _DEFAULT_DATA_DIRECTORY = None
38
+
39
+
40
+ @gin.configurable
41
+ def set_default_data_directory(directory_name=None):
42
+ """Set the default directory where training data is located."""
43
+ global _DEFAULT_DATA_DIRECTORY
44
+ # If the data directory has been overridden with a command-line flag, use it.
45
+ # If not, the see if directory_name has been configured by Gin.
46
+ # Otherwise, use the default tfds directory.
47
+ if FLAGS.default_data_dir:
48
+ directory_name = FLAGS.default_data_dir
49
+ if directory_name is not None:
50
+ seqio.set_tfds_data_dir_override(directory_name)
51
+ _DEFAULT_DATA_DIRECTORY = directory_name
52
+
53
+
54
+ def get_iterator_function(dataset: Optional[tf.data.Dataset]):
55
+ """Returns a function which gets an iterator over the given dataset."""
56
+ if dataset is None:
57
+ return None
58
+ else:
59
+ return dataset.as_numpy_iterator
60
+
61
+
62
+ @gin.configurable
63
+ def get_loss_mask_tokens(
64
+ split: str,
65
+ loss_mask_start_tokens: Sequence[int] = (),
66
+ loss_mask_end_tokens: Sequence[int] = (),
67
+ splits: Sequence[str] = ("all",)
68
+ ) -> Tuple[Sequence[int], Sequence[int]]:
69
+ """Returns two token sequences to indicate start and end of the loss.
70
+
71
+ Please configure loss_mask_start_tokens, loss_mask_end_tokens, and
72
+ split_filter via gin. Example gin config to only apply loss between tokens 2
73
+ and 1 for the test set (and everywhere for any other data split):
74
+
75
+ ```
76
+ text_dataset.get_loss_mask_tokens:
77
+ loss_mask_start_tokens=(2,)
78
+ loss_mask_end_tokens=(1,)
79
+ restrict_to_splits=("test",)
80
+ ```
81
+
82
+ Args:
83
+ split: The mode ("test", "train", ...)
84
+ loss_mask_start_tokens: token sequence to starts the loss
85
+ loss_mask_end_tokens: token sequence to stop the loss
86
+ splits: Only compute the loss mask for splits in this list.
87
+ By default it is 'all', which is a reserved split string that applies to
88
+ all splits.
89
+ """
90
+ if "all" in splits or split in splits:
91
+ return loss_mask_start_tokens, loss_mask_end_tokens
92
+ return (), ()
93
+
94
+
95
+ @gin.configurable
96
+ def load_text_dataset(name: str,
97
+ split: str,
98
+ sequence_length: int,
99
+ batch_size: int,
100
+ sequential: bool = True,
101
+ shard_dataset: bool = True,
102
+ verbose: bool = False,
103
+ ) -> Tuple[tf.data.Dataset, seqio.Vocabulary]:
104
+ """Load a text dataset of long articles or books, and split_and_batch them.
105
+
106
+ The input dataset must produce complete books or articles, where each article
107
+ is a dictionary containing a "tokens" field.
108
+ See split_and_batch for more information on the output dataset.
109
+
110
+ Args:
111
+ name: The name of the seqio task which produces the dataset.
112
+ split: The name of the split to use, e.g. "train" or "test".
113
+ sequence_length: Split text into sequences of this length.
114
+ batch_size: Draw from batch_size articles in each batch.
115
+ sequential: If True, return the chunks of each article in sequence.
116
+ shard_dataset: If True, split data set into shards.
117
+ verbose: Log (an excerpt) of every text example loaded from disk. If False,
118
+ will only print 1 excerpt every 60 seconds.
119
+
120
+ Returns:
121
+ (dataset, vocabulary)
122
+ where vocabulary is the seqio.Vocabulary which is used to encode "targets".
123
+ """
124
+
125
+ logging.info("Loading text data set %s, split=%s, shape=(%d, %d)",
126
+ name, split, batch_size, sequence_length)
127
+
128
+ if name == "synthetic":
129
+ ds = synthetic_data_long(split, sequence_length, batch_size)
130
+ return (ds, seqio.PassThroughVocabulary(256, 0))
131
+ elif name == "synthetic_short":
132
+ ds = synthetic_data_short(split, sequence_length, batch_size)
133
+ return (ds, seqio.PassThroughVocabulary(256, 0))
134
+ elif name == "enwik8":
135
+ # TODO(delesley): Encapsulate enwik8 into a Task.
136
+ ds = load_enwik8(split, sequence_length, batch_size,
137
+ data_dir=_DEFAULT_DATA_DIRECTORY)
138
+ return (ds, seqio.PassThroughVocabulary(256, 0))
139
+
140
+ # Bypass the seqio "feature converter", and get the task directly.
141
+ task = seqio.get_mixture_or_task(name)
142
+ vocab = task.output_features["targets"].vocabulary
143
+
144
+ # Create the task input pipeline.
145
+ if shard_dataset:
146
+ logging.info("Shards: %d of %d", jax.process_index(), jax.process_count())
147
+ shard_info = seqio.ShardInfo(index=jax.process_index(),
148
+ num_shards=jax.process_count())
149
+ else:
150
+ shard_info = None
151
+
152
+ if sequential:
153
+ task_seqlen = None # We do our own splitting.
154
+ shuffle_buffer_size = 1000 # Number of full-length books.
155
+ else:
156
+ task_seqlen = {"targets": sequence_length} # Ask the task to do splitting.
157
+ shuffle_buffer_size = 10_000 # Number of chunks.
158
+
159
+ ds = task.get_dataset(
160
+ sequence_length=task_seqlen,
161
+ split=split,
162
+ use_cached=False,
163
+ shuffle=True,
164
+ shuffle_buffer_size=shuffle_buffer_size,
165
+ seed=None,
166
+ shard_info=shard_info,
167
+ num_epochs=1)
168
+
169
+ if sequence_length == 0:
170
+ return (ds, vocab) # Don't chop into subsequences.
171
+
172
+ def extract_fn(article):
173
+ return article["targets"]
174
+
175
+ include_loss_mask = bool(get_loss_mask_tokens(split)[0])
176
+ ds = split_and_batch(ds,
177
+ split=split,
178
+ extract_fn=extract_fn,
179
+ sequence_length=sequence_length,
180
+ batch_size=batch_size,
181
+ auto_rewind=True,
182
+ vocab=vocab,
183
+ include_loss_mask=include_loss_mask,
184
+ verbose=verbose)
185
+ return (ds, vocab)
186
+
187
+
188
+ def rekey_articles(ds: tf.data.Dataset,
189
+ rekey: Mapping[str, str],
190
+ keep: Optional[Set[str]] = None) -> tf.data.Dataset:
191
+ """Rekey the articles in ds.
192
+
193
+ Fields in rekey will be renamed, field in keep will be kept, others will
194
+ be discarded. E.g., For PG19:
195
+
196
+ rekey_article(ds,
197
+ rekey={"book_text": "targets"},
198
+ keep={"book_title", "book_id"})
199
+ Args:
200
+ ds: The dataset to rekey.
201
+ rekey: Dictionary which contains fields to rename.
202
+ keep: Set of fields to keep.
203
+
204
+ Returns:
205
+ A rekeyed dataset.
206
+ """
207
+
208
+ def rekey_fn(article):
209
+ result_dict = {}
210
+ for (k, v) in article.items():
211
+ if k in rekey:
212
+ result_dict[rekey[k]] = v
213
+ elif k in keep:
214
+ result_dict[k] = v
215
+ return result_dict
216
+
217
+ return ds.map(rekey_fn)
218
+
219
+
220
+ def pretty_print_article(article,
221
+ vocab_map: Mapping[str, Optional[seqio.Vocabulary]],
222
+ max_length: int = 60) -> str:
223
+ """Convert the contents of a long article to a short string."""
224
+ if not hasattr(article, "items"):
225
+ return pretty_print_value(article, max_length) # Not a dictionary.
226
+ dstr = "{"
227
+ for (k, v) in article.items():
228
+ if vocab_map and k in vocab_map:
229
+ vstr = decode_tokens(v, vocab_map[k], max_length)
230
+ else:
231
+ vstr = pretty_print_value(v, max_length)
232
+ dstr += "\n " + k + ": " + vstr
233
+ return dstr + "\n}"
234
+
235
+
236
+ def pretty_print_value(value, max_length: int) -> str:
237
+ """Convert a possibly large value to a short string."""
238
+ if isinstance(value, bytes):
239
+ if len(value) <= max_length:
240
+ return str(value)
241
+ else:
242
+ return f"bytes[{len(value)}] " + str(value[:max_length]) + "..."
243
+ elif isinstance(value, str):
244
+ if len(value) <= max_length:
245
+ return value
246
+ else:
247
+ return f"str[{len(value)}] " + value[:max_length] + "..."
248
+ elif isinstance(value, np.ndarray):
249
+ vstr = f"ndarray({value.shape}, {value.dtype.str})"
250
+ if value.size <= (max_length / 4):
251
+ vstr += " = " + str(value)
252
+ return vstr
253
+ elif np.ndim(value) == 0:
254
+ return str(value) # Scalar data.
255
+ else:
256
+ return str(type(value))
257
+
258
+
259
+ def decode_tokens(tokens: Any, vocab: seqio.Vocabulary, max_length: int) -> str:
260
+ """Convert tokens to a human-readable string."""
261
+ if isinstance(tokens, np.ndarray):
262
+ tstr = f"ndarray({tokens.shape}, {tokens.dtype.str}) = "
263
+ else:
264
+ tstr = f"{str(type(tokens))} = "
265
+
266
+ if np.ndim(tokens) == 1:
267
+ tstr += decode_tokens_1d(tokens, vocab, max_length)
268
+ elif np.ndim(tokens) == 2:
269
+ jtstr = ",\n ".join([decode_tokens_1d(s, vocab, max_length)
270
+ for s in tokens])
271
+ tstr += f"[\n {jtstr}\n ]"
272
+ else:
273
+ tstr = pretty_print_value(tokens, max_length)
274
+ return tstr
275
+
276
+
277
+ def decode_tokens_1d(tokens: Any, vocab: Any, max_length: int,
278
+ raw_string: bool = False) -> Union[str, bytes]:
279
+ """Convert a 1D array of tokens to a human-readable string.
280
+
281
+ Args:
282
+ tokens: 1-dimensional array of integers.
283
+ vocab: The vocabulary to detokenize the array.
284
+ max_length: The maximum number of tokens to detokenize.
285
+ raw_string: If True, return the string as bytes.
286
+ If false, pretty print it (e.g. with "\n").
287
+
288
+ Returns:
289
+ The detokenized string.
290
+ """
291
+
292
+ assert np.ndim(tokens) == 1
293
+ # The type of tokens is np.ndarray((sequence_length,), "int32")
294
+ # We have to convert this to an actual list of python integers, NOT numpy
295
+ # integers, or decode will blow up, and fail to marshall the data to C++.
296
+ dtoks = [int(i) for i in tokens[:max_length]]
297
+ tstr = vocab.decode(dtoks)
298
+
299
+ # Convert the decoded string to a byte string.
300
+ # PassThroughVocabulary returns a list, not a string.
301
+ if isinstance(tstr, str):
302
+ tstr = bytes(tstr.encode("utf-8"))
303
+ else:
304
+ tstr = bytes(tstr)
305
+
306
+ # If raw_string, return immediately.
307
+ if raw_string:
308
+ return tstr
309
+
310
+ # Otherwise format it for pretty-printing.
311
+ # Converting bytes to str will convert, e.g., newlines as "\n".
312
+ tstr = str(tstr)
313
+ if len(tokens) > max_length:
314
+ tstr += "..."
315
+ return tstr
316
+
317
+
318
+ def bytes_to_tokens(s: str):
319
+ """Convert a byte string to an array of integers."""
320
+ return np.fromiter((char for char in s), count=len(s), dtype=np.int32)
321
+
322
+
323
+ def pad_chunk(s: Optional[np.ndarray], sequence_length: int):
324
+ """Pad an array s out to the given sequence_length."""
325
+ if s is None:
326
+ return np.zeros(sequence_length, dtype=np.int32)
327
+ assert np.ndim(s) == 1
328
+ chunk_len = len(s)
329
+ assert chunk_len <= sequence_length
330
+ if chunk_len == sequence_length:
331
+ return s
332
+ else:
333
+ return np.pad(s, (0, sequence_length - chunk_len),
334
+ mode="constant", constant_values=0)
335
+
336
+
337
+ def split_article(tokens: np.ndarray, sequence_length: int, split: str,
338
+ include_loss_mask: bool) -> (
339
+ Iterable[Tuple[np.ndarray, np.ndarray]]):
340
+ """Split an array into segments of length sequence_length."""
341
+ assert np.ndim(tokens) == 1
342
+ if include_loss_mask:
343
+ loss_mask = loss_mask_from_tokens(tokens, split)
344
+
345
+ for k in range(0, len(tokens), sequence_length):
346
+ segment = pad_chunk(tokens[k:k + sequence_length], sequence_length)
347
+ if include_loss_mask:
348
+ segment_loss_mask = pad_chunk(
349
+ loss_mask[k:k + sequence_length], sequence_length).astype(bool)
350
+ else:
351
+ segment_loss_mask = np.array(True, dtype=bool) # dummy mask
352
+ yield (segment, segment_loss_mask)
353
+
354
+
355
+ def nonzero_tokens(tokens: np.ndarray,
356
+ loss_mask: Optional[np.ndarray]) -> list[int]:
357
+ """Removes tokens that are not predicted by the model."""
358
+ # TODO(delesley): Fix the model so that it predicts the first token.
359
+ # The language model doesn't predict the first token.
360
+ toks = [int(tokens[i]) for i in range(1, len(tokens))
361
+ if (tokens[i] != 0 and (loss_mask is None or loss_mask[i]))]
362
+ return toks
363
+
364
+
365
+ def _find_subsequence_idxs(sequence: np.ndarray, subsequence: Sequence[int]):
366
+ """Returns the indices where `subsequence` occurs in `sequence`."""
367
+ subsequence = np.asarray(subsequence, dtype=np.int32)
368
+ # use np.where as an efficient way to iterate over the whole array; but we can
369
+ # only test for a single token, unfortunately.
370
+ potential_matches = np.where(sequence == subsequence[0])[0]
371
+ match_indices = []
372
+ for start_index in potential_matches:
373
+ if np.array_equal(sequence[start_index:start_index + len(subsequence)],
374
+ subsequence):
375
+ match_indices.append(start_index)
376
+ return match_indices
377
+
378
+
379
+ def loss_mask_from_tokens(tokens: np.ndarray, split: str) -> np.ndarray:
380
+ """Compute a mask for language modelling loss using start and end tokens."""
381
+ assert np.ndim(tokens) == 1
382
+ tokens = tokens.astype(np.int32)
383
+
384
+ # Position offset of loss mask and target positions. Typically -1, which
385
+ # indicates that targets are shifted 1 position left compared to inputs.
386
+ offset = -1
387
+
388
+ start_tokens, end_tokens = get_loss_mask_tokens(split=split)
389
+ if not start_tokens:
390
+ # default to not masking out any loss
391
+ return np.ones_like(tokens, dtype=bool)
392
+
393
+ start = 0
394
+ end = len(tokens) # include end_tokens
395
+ start_indices = _find_subsequence_idxs(tokens, start_tokens)
396
+ if start_indices:
397
+ if end_tokens:
398
+ end_indices = _find_subsequence_idxs(tokens, end_tokens)
399
+ else:
400
+ end_indices = []
401
+ if len(start_indices) > 1 or len(end_indices) > 1:
402
+ logging.error("Multiple start or end tokens for loss mask: %s, %s",
403
+ start_indices, end_indices)
404
+ start = start_indices[0]
405
+ if end_indices and end_indices[0] >= start:
406
+ end = end_indices[0]
407
+
408
+ # We include the start_tokens and the end_tokens, which represents that the
409
+ # model must predict the location, the content, and the end of the
410
+ # subsequence.
411
+ start += offset
412
+ start = max(0, start) # to prevent offset creating negative indices
413
+ end += len(end_tokens) + offset
414
+
415
+ # Create the actual mask. Roughly equivalent to
416
+ # mask = np.array([i >= start && i <= end for i in range(len(tokens))])
417
+ mask = np.concatenate([
418
+ np.zeros((start,), dtype=bool),
419
+ np.ones((end - start,), dtype=bool),
420
+ np.zeros((len(tokens) - end,), dtype=bool)
421
+ ])
422
+ return mask
423
+
424
+
425
+ def _batched_interleave_generator(
426
+ ds: tf.data.Dataset,
427
+ flat_map_func: Callable[[str], Iterable[Tuple[np.ndarray, np.ndarray]]],
428
+ post_map_func,
429
+ batch_size: int,
430
+ vocab: Optional[seqio.Vocabulary] = None,
431
+ include_loss_mask: bool = False,
432
+ auto_rewind: bool = False) -> Iterable[Dict[str, np.ndarray]]:
433
+ """Generator which combines the interleave and batch dataset operations.
434
+
435
+ Given a set of articles from ds, flat_map_func is mapped over the articles
436
+ to break each article up into an iterable of chunks and their loss masks.
437
+ The generator will return the examples from each article in sequential order,
438
+ for transformer-XL style models that process long articles over multiple
439
+ training steps.
440
+
441
+ Articles are combined into batches of size batch_size, where each example in
442
+ the batch is pulled from a different article. When one article ends, the
443
+ generator will start pulling examples from the next article. The overall
444
+ result is similar to tf.Data.Dataset.interleave, except that interleave does
445
+ not always maintain the same order of articles. If this generator starts
446
+ pulling from article "foo" as the 3rd item in the batch, then consecutive
447
+ examples from "foo" will remain as the 3rd item until the article ends. This
448
+ guarantee is necessary to pass state from one training step to the next.
449
+
450
+ If auto_rewind, then the generator will automatically grab a new iterator
451
+ from ds at the end of the epoch, and increment the epoch counter. Otherwise,
452
+ it will yield empty datasets until all articles in the batch have been
453
+ completed.
454
+
455
+ Args:
456
+ ds: A dataset of articles.
457
+ flat_map_func: A function which returns an iterator over chunks of tokens
458
+ and the loss masks associated with those tokens.
459
+ post_map_func: A function which post-processes each item to fixed size.
460
+ batch_size: The number of articles in a batch.
461
+ vocab: The vocabulary to detokenize strings and count characters.
462
+ include_loss_mask: If true, will return a loss mask with the tokens.
463
+ auto_rewind: Automatically rewind ds at end of epoch.
464
+
465
+ Yields:
466
+ Batches of consecutive examples from articles.
467
+ Each example has type: {
468
+ "targets": int32[batch_size, sequence_length],
469
+ "start_of_sequence": bool[batch_size],
470
+ "epoch": int32[batch_size],
471
+ "loss_mask": bool[batch_size, sequence_length],
472
+ }
473
+ """
474
+
475
+ ds_iter = ds.as_numpy_iterator()
476
+
477
+ document_start = [True] * batch_size # At start of each article.
478
+ readers = [None] * batch_size # Iterator for each article
479
+ still_reading = [True] * batch_size # End of current article?
480
+ item_epochs = [0] * batch_size # Epoch of the given item.
481
+ epoch = 0
482
+
483
+ # Main generator loop
484
+ while any(still_reading):
485
+ targets = [None] * batch_size
486
+ loss_mask = [None] * batch_size
487
+ for i in range(0, batch_size):
488
+ targets_i = None
489
+ loss_mask_i = None
490
+ while targets_i is None and still_reading[i]:
491
+ if readers[i] is not None:
492
+ try:
493
+ # Grab the next item from the article.
494
+ targets_i, loss_mask_i = next(readers[i])
495
+ except StopIteration:
496
+ # Article has ended; continue the while loop to grab a new one.
497
+ readers[i] = None
498
+ else:
499
+ # Grab the next article from ds if the current one has ended.
500
+ dsi = None
501
+ try:
502
+ dsi = iter(flat_map_func(next(ds_iter)))
503
+ except StopIteration:
504
+ logging.info("End of epoch %d.", epoch)
505
+ if auto_rewind:
506
+ epoch = epoch + 1
507
+ logging.info("Starting epoch %d.", epoch)
508
+ ds_iter = ds.as_numpy_iterator()
509
+ dsi = iter(flat_map_func(next(ds_iter)))
510
+ else:
511
+ still_reading[i] = False # No more articles on i
512
+ if dsi is not None:
513
+ # Start reading the new article.
514
+ # Continue while loop to grab the first chunk.
515
+ readers[i] = dsi
516
+ document_start[i] = True
517
+ item_epochs[i] = epoch
518
+
519
+ # post_map_func must handle None values, and return stackable np.arrays.
520
+ targets[i] = post_map_func(targets_i) # handles None
521
+ if include_loss_mask:
522
+ loss_mask[i] = post_map_func(loss_mask_i).astype(bool) # handles None
523
+
524
+ # If we've reached the end of all articles, stop immediately.
525
+ if not any(still_reading):
526
+ break
527
+
528
+ doc_start_orig = document_start.copy() # Return doc_start_orig.
529
+ for i in range(0, batch_size):
530
+ # Now that we've read an item, set /start/ to false for each reader.
531
+ document_start[i] = False
532
+
533
+ # Decode the tokenized segement back to characters, to count the number
534
+ # of characters for the bits-per-character computation.
535
+ num_chars = [0] * batch_size
536
+ nz_toks = [0] * batch_size
537
+ for i in range(0, batch_size):
538
+ lmask = loss_mask[i] if include_loss_mask else None
539
+ toks = nonzero_tokens(targets[i], lmask)
540
+ if vocab is not None:
541
+ bchars = decode_tokens_1d(toks, vocab, max_length=len(targets[i]),
542
+ raw_string=True)
543
+ num_chars[i] = len(bchars)
544
+ else:
545
+ num_chars[i] = len(toks)
546
+ nz_toks[i] = len(toks)
547
+
548
+ item = {
549
+ "targets": np.stack(targets),
550
+ "start_of_sequence": np.array(doc_start_orig),
551
+ "epoch": np.array(item_epochs),
552
+ "num_chars": np.stack(num_chars),
553
+ "nonzero_tokens": np.stack(nz_toks),
554
+ }
555
+ if include_loss_mask:
556
+ item["loss_mask"] = np.stack(loss_mask)
557
+ yield item
558
+
559
+
560
+ def split_and_batch(ds: tf.data.Dataset,
561
+ split: str,
562
+ extract_fn: Callable[[Any], Any],
563
+ sequence_length: int,
564
+ batch_size: int,
565
+ auto_rewind: bool = False,
566
+ vocab: Optional[seqio.Vocabulary] = None,
567
+ include_loss_mask: bool = False,
568
+ verbose: bool = False) -> tf.data.Dataset:
569
+ """Converts articles to tokens and chops and batches them.
570
+
571
+ See batched_interleave_generator for more details.
572
+
573
+ Args:
574
+ ds: A dataset of articles.
575
+ split: Which dataset split is to be computed, e.g. 'train'.
576
+ extract_fn: Return a sequence of tokens from article.
577
+ sequence_length: The number of tokens in each sequence.
578
+ batch_size: The number of examples in each batch.
579
+ auto_rewind: If True, will automatically rewind at end of epoch.
580
+ vocab: Vocabulary, used to count characters.
581
+ include_loss_mask: Return a loss mask for each batch.
582
+ verbose: Write article info to log as they are read.
583
+
584
+ Returns:
585
+ A dataset which yields examples of shape {
586
+ "targets": int32[batch_size, sequence_length],
587
+ "start_of_sequence": bool[batch_size],
588
+ "epoch": int32[batch_size],
589
+ "loss_mask": bool[batch_size, sequence_length],
590
+ "num_chars": A count of the number of detokenized characters.
591
+ "nonzero_tokens": A count of the number of nonzero predicted tokens.
592
+ }
593
+ """
594
+
595
+ # Tokenize article, compute loss mask, split into multiple chunks.
596
+ # The entire article must fit into memory.
597
+ def wrap_split_article(article):
598
+ if verbose:
599
+ logging.info("Reading article: %s", pretty_print_article(article, {}))
600
+ else:
601
+ logging.log_every_n_seconds(logging.INFO, "Reading article: %s", 60,
602
+ pretty_print_article(article, {}))
603
+ tokens = extract_fn(article)
604
+ if isinstance(tokens, str) or isinstance(tokens, bytes):
605
+ tokens = bytes_to_tokens(tokens)
606
+ elif isinstance(tokens, np.ndarray):
607
+ tokens = tokens.astype(np.int32)
608
+ else:
609
+ raise TypeError("Unusupported sequence type: %s" % str(type(tokens)))
610
+ return split_article(tokens, sequence_length, split=split,
611
+ include_loss_mask=include_loss_mask)
612
+
613
+ # Handle None values.
614
+ def wrap_pad_chunk(s):
615
+ return pad_chunk(s, sequence_length)
616
+
617
+ def wrap_batched_interleave_generator():
618
+ return _batched_interleave_generator(ds,
619
+ flat_map_func=wrap_split_article,
620
+ post_map_func=wrap_pad_chunk,
621
+ batch_size=batch_size,
622
+ vocab=vocab,
623
+ include_loss_mask=include_loss_mask,
624
+ auto_rewind=auto_rewind)
625
+
626
+ out_sig = {
627
+ "targets": tf.TensorSpec(shape=(batch_size, sequence_length),
628
+ dtype=tf.int32),
629
+ "start_of_sequence": tf.TensorSpec(shape=(batch_size,), dtype=tf.bool),
630
+ "epoch": tf.TensorSpec(shape=(batch_size,), dtype=tf.int32),
631
+ "num_chars": tf.TensorSpec(shape=(batch_size,), dtype=tf.int32),
632
+ "nonzero_tokens": tf.TensorSpec(shape=(batch_size,), dtype=tf.int32),
633
+ }
634
+ if include_loss_mask:
635
+ out_sig["loss_mask"] = tf.TensorSpec(shape=(batch_size, sequence_length),
636
+ dtype=tf.bool)
637
+
638
+ cds = tf.data.Dataset.from_generator(wrap_batched_interleave_generator,
639
+ output_signature=out_sig)
640
+ return cds
641
+
642
+
643
+ def merge_articles(article_starts_ends, sequence_length):
644
+ """Merge consecutive articles if their combined length < sequence_length."""
645
+ cs = 0
646
+ ce = 0
647
+ for (s, e) in article_starts_ends:
648
+ if ce == 0:
649
+ ce = s
650
+ if (e - cs) > sequence_length:
651
+ if ce > cs:
652
+ # print("Yield: ", cs, " to ", ce)
653
+ yield (cs, ce) # Yield prior merged articles
654
+ cs = s # Reset to start of current article
655
+ ce = e
656
+ else:
657
+ ce = e # Merge article with current set.
658
+ # print("Article: ", s, " to ", e)
659
+ if ce > 0:
660
+ # print("Yield: ", cs, " to ", ce)
661
+ yield (cs, ce) # Yield final merged set.
662
+
663
+
664
+ def _targets_to_tokens(article):
665
+ return bytes_to_tokens(article["targets"])
666
+
667
+
668
+ def _wrap_text_in_dict(text):
669
+ return {"targets": text}
670
+
671
+
672
+ # ---------------------
673
+
674
+ def load_enwik8(split: str,
675
+ sequence_length: int,
676
+ batch_size: int,
677
+ data_dir: str) -> tf.data.Dataset:
678
+ """Load the enwik8 dataset, partitioning into articles."""
679
+
680
+ if data_dir is None:
681
+ raise ValueError("Must specify a data directory for enwik8")
682
+
683
+ filename = os.path.join(data_dir, "enwik8")
684
+ filename = os.path.join(filename, "enwik8_" + split)
685
+
686
+ # Don't attempt to split the data, just shuffle it differently for
687
+ # each worker.
688
+ local_seed = 42 + jax.process_index()
689
+
690
+ logging.info("Enwik8: reading %s", filename)
691
+ with gfile.Open(filename, "r") as f:
692
+ text_data = f.read()
693
+
694
+ logging.info("Enwik8: parsing %s", filename)
695
+ article_starts = [m.start(0) for m in re.finditer("<page>", text_data)]
696
+ article_ends = article_starts[1:] + [len(text_data)]
697
+ logging.info("Enwik8: found %d articles.", len(article_starts))
698
+
699
+ merged_se = merge_articles(zip(article_starts, article_ends),
700
+ sequence_length)
701
+ articles = [text_data[s:e] for (s, e) in merged_se]
702
+ num_articles = len(articles)
703
+ logging.info("Enwik8: merged into %d articles.", num_articles)
704
+
705
+ logging.info("Building dataset.")
706
+ ds = tf.data.Dataset.from_tensor_slices(articles)
707
+ ds = ds.map(_wrap_text_in_dict)
708
+ ds = ds.shuffle(num_articles, reshuffle_each_iteration=True, seed=local_seed)
709
+ if sequence_length == 0:
710
+ return ds # Don't split and batch
711
+
712
+ return split_and_batch(ds,
713
+ split=split,
714
+ extract_fn=_targets_to_tokens,
715
+ sequence_length=sequence_length,
716
+ batch_size=batch_size,
717
+ auto_rewind=True,
718
+ verbose=False)
719
+
720
+ # ---------------------
721
+
722
+
723
+ def synthetic_data_short(split: str,
724
+ sequence_length: int,
725
+ batch_size: int,
726
+ auto_rewind: bool = True) -> tf.data.Dataset:
727
+ """Return a synthetic data set of sequences."""
728
+
729
+ strings = [
730
+ b"The quick brown fox jumped over the lazy dog.",
731
+ b"Humpty dumpty sat on a wall and had a great fall and went splat.",
732
+ b"She sells sea shells by the sea shore.",
733
+ b"Peter piper picked a peck of pickled peppercorns."
734
+ ]
735
+ logging.info("Building synthetic dataset (short).")
736
+ ds = tf.data.Dataset.from_tensor_slices(strings)
737
+ ds = ds.map(_wrap_text_in_dict)
738
+ ds = ds.shuffle(4, reshuffle_each_iteration=True, seed=42)
739
+ if sequence_length == 0:
740
+ return ds # Don't split and batch
741
+
742
+ return split_and_batch(ds,
743
+ split=split,
744
+ extract_fn=_targets_to_tokens,
745
+ sequence_length=sequence_length,
746
+ batch_size=batch_size,
747
+ auto_rewind=auto_rewind,
748
+ verbose=False)
749
+
750
+
751
+ def synthetic_data_long(split: str,
752
+ sequence_length: int,
753
+ batch_size: int,
754
+ auto_rewind: bool = True) -> tf.data.Dataset:
755
+ """Returns a synthetic data set with several long articles."""
756
+ articles = [
757
+ synthetic_text_data.text1_illiad_book1,
758
+ synthetic_text_data.text2_huckleberry_finn,
759
+ synthetic_text_data.text3_call_of_the_wild,
760
+ synthetic_text_data.text4_the_prince
761
+ ]
762
+ logging.info("Building synthetic dataset (long).")
763
+ ds = tf.data.Dataset.from_tensor_slices(articles)
764
+ ds = ds.map(_wrap_text_in_dict)
765
+ ds = ds.shuffle(4, reshuffle_each_iteration=True, seed=42)
766
+ if sequence_length == 0:
767
+ return ds # Don't split and batch
768
+
769
+ return split_and_batch(ds,
770
+ split=split,
771
+ extract_fn=_targets_to_tokens,
772
+ sequence_length=sequence_length,
773
+ batch_size=batch_size,
774
+ auto_rewind=auto_rewind,
775
+ verbose=False)
aglib/meliad/transformer/transformer_base.py ADDED
@@ -0,0 +1,451 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 Google.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Base class for transformer layers."""
16
+
17
+ from typing import Any, Callable, Optional, Tuple
18
+
19
+ from absl import logging
20
+
21
+ from flax import linen as nn
22
+ import gin
23
+ import jax
24
+ import jax.numpy as jnp
25
+
26
+
27
+ from transformer import nn_components
28
+
29
+
30
+ Array = Any
31
+
32
+ # Tuple of scale factors
33
+ AttnScaleTuple = Tuple[Optional[Array], Optional[Array]]
34
+
35
+ # Tuple of keys,values,queries
36
+ KVQTuple = Tuple[Array, Array, Optional[Array], Optional[Array]]
37
+
38
+
39
+ @gin.configurable
40
+ class KVQLayer(nn.Module):
41
+ """Generate keys, values, and queries for attention."""
42
+
43
+ embedding_size: int
44
+ num_heads: int
45
+ head_size: int
46
+ has_queries: bool = True
47
+ has_queries2: bool = False # For cross-attention, e.g. decoder or recurrence.
48
+
49
+ normalize_keys: bool = True # Normalize keys and queries.
50
+ num_position_embeddings: int = 0 # Learned absolute position embeddings.
51
+ pre_attn_dropout: bool = True
52
+ dropout_rate: float = 0.0
53
+ dtype: Any = jnp.float32
54
+
55
+ def setup(self):
56
+ kernel_init = nn.initializers.variance_scaling(
57
+ scale=1.0, mode="fan_in", distribution="truncated_normal")
58
+
59
+ # Project to keys,values,queries
60
+ # Disable bias. This prevents a failure mode whereby the attention matrix
61
+ # can become filled with very large uniform values, due to high bias.
62
+ self.keys_layer = nn.Dense(
63
+ features=self.num_heads * self.head_size,
64
+ use_bias=False, # No bias for keys.
65
+ kernel_init=kernel_init,
66
+ dtype=self.dtype)
67
+ self.values_layer = nn.Dense(
68
+ features=self.num_heads * self.head_size,
69
+ use_bias=False, # No bias for values.
70
+ kernel_init=kernel_init,
71
+ dtype=self.dtype)
72
+ if self.has_queries:
73
+ self.queries_layer = nn.Dense(
74
+ features=self.num_heads * self.head_size,
75
+ use_bias=False, # No bias for queries.
76
+ kernel_init=kernel_init,
77
+ dtype=self.dtype)
78
+ if self.has_queries2:
79
+ self.queries2_layer = nn.Dense(
80
+ features=self.num_heads * self.head_size,
81
+ use_bias=False, # No bias for queries.
82
+ kernel_init=kernel_init,
83
+ dtype=self.dtype)
84
+
85
+ # When normalizing keys and queries, attention must be scaled with
86
+ # learned parameters.
87
+ if self.normalize_keys:
88
+ self.attention_scale = self.param("attention_scale",
89
+ jax.nn.initializers.ones,
90
+ (self.num_heads,), jnp.float32)
91
+
92
+ # Learned position embeddings for absolute positions.
93
+ if self.num_position_embeddings > 0:
94
+ # Embeddings for query elements.
95
+ self.position_embeddings = self.param(
96
+ "position_embeddings",
97
+ jax.nn.initializers.normal(stddev=1.0),
98
+ (self.num_position_embeddings, self.embedding_size),
99
+ jnp.float32)
100
+
101
+ # Layernorm
102
+ self.pre_attn_layernorm = nn_components.LayerNorm()
103
+
104
+ def attention_scale_factor(self) -> Optional[Array]:
105
+ """Returns the attention scale, when keys and queries are normalized."""
106
+ if self.normalize_keys:
107
+ return jnp.asarray(self.attention_scale, dtype=self.dtype)
108
+ else:
109
+ return None
110
+
111
+ def _get_dropout_rng(self):
112
+ return self.make_rng("dropout")
113
+
114
+ def _normalize_kq(self, kq: Array) -> Array:
115
+ """Normalize function for keys and queries."""
116
+ epsilon = jnp.array(1.0e-6, dtype=self.dtype)
117
+ kq_sum_sqr = jnp.sum(jnp.square(kq), axis=-1, keepdims=True)
118
+ norm_kq = kq * jax.lax.rsqrt(kq_sum_sqr + epsilon)
119
+ return jnp.asarray(norm_kq, dtype=self.dtype)
120
+
121
+ def __call__(self, xs: Array, deterministic: bool = False) -> KVQTuple:
122
+ """Takes a sequence of embeddings as input, and returns keys,values,queries.
123
+
124
+ First apply pre_attn layernorm, and pre_attn dropout.
125
+ Then add learned positional embeddings, if any.
126
+ Return (keys, values, queries, queries2).
127
+
128
+ Args:
129
+ xs: input sequence of shape (batch_size, sequence_length, embedding_size)
130
+ deterministic: if False, apply dropout.
131
+
132
+ Returns:
133
+ (keys, values, queries, queries2) of shape
134
+ (batch_size, sequence_length, num_heads, head_size)
135
+ """
136
+
137
+ # Project inputs to (keys, values, queries).
138
+ (batch_size, num_keys, _) = xs.shape
139
+ drop_tile_shape = (1, 128, self.embedding_size)
140
+
141
+ # Apply layernorm to input, rather than the output.
142
+ # This provides better gradients through the resnet, and also avoids
143
+ # the need for a prolonged warmup phase (https://arxiv.org/abs/2002.04745)
144
+
145
+ # Layernorm for self-attention.
146
+ logging.info("kvq: pre_attn xs = %r", xs)
147
+ xs = jnp.asarray(xs, dtype=self.dtype)
148
+ xs = self.pre_attn_layernorm(xs)
149
+
150
+ # Add (optional) learned position embeddings.
151
+ if self.num_position_embeddings > 0:
152
+ assert xs.ndim == 3 # (b, sequence_length, embedding_size)
153
+ assert xs.shape[-2] == self.num_position_embeddings
154
+ logging.info("kvq: learned positions.")
155
+ xs_pos = jnp.asarray(self.position_embeddings, dtype=self.dtype)
156
+ xs_pos = jnp.expand_dims(xs_pos, 0) # Add batch dimension.
157
+ xs = xs + xs_pos
158
+
159
+ # Pre-attention dropout.
160
+ if self.pre_attn_dropout:
161
+ logging.info("kvq: pre_attn dropout.")
162
+ xs = nn_components.tiled_dropout(xs, drop_tile_shape, self.dropout_rate,
163
+ rng_function=self._get_dropout_rng,
164
+ deterministic=deterministic)
165
+
166
+ # Compute keys and values.
167
+ keys = self.keys_layer(xs) # (b, num_keys, num_heads * head_size)
168
+ values = self.values_layer(xs)
169
+
170
+ # Compute queries and cross-attention queries if necessary.
171
+ if self.has_queries:
172
+ queries = self.queries_layer(xs) # (b, num_keys, n_heads * head_size)
173
+ logging.info("kvq: queries = %r", queries)
174
+ else:
175
+ queries = None
176
+ if self.has_queries2:
177
+ queries2 = self.queries2_layer(xs) # (b, num_keys, n_heads * head_size)
178
+ logging.info("kvq: queries2 = %r", queries2)
179
+ else:
180
+ queries2 = None
181
+
182
+ # Reshape to split num_heads, head_size into separate dimensions.
183
+ kv_shape = (batch_size, num_keys, self.num_heads, self.head_size)
184
+ keys = jnp.reshape(keys, kv_shape)
185
+ values = jnp.reshape(values, kv_shape)
186
+ if queries is not None:
187
+ queries = jnp.reshape(queries, kv_shape)
188
+ if queries2 is not None:
189
+ queries2 = jnp.reshape(queries2, kv_shape)
190
+
191
+ if self.normalize_keys:
192
+ # Normalize both keys and queries.
193
+ # The learned attention_scale_factors() will return non-None.
194
+ logging.info("kvq: normalize keys, queries.")
195
+ keys = self._normalize_kq(keys)
196
+ if queries is not None:
197
+ queries = self._normalize_kq(queries)
198
+ if queries2 is not None:
199
+ queries2 = self._normalize_kq(queries2)
200
+ else:
201
+ # Scale queries by 1 / sqrt(d) when using unnormalized keys,queries.
202
+ d_scale = jax.lax.rsqrt(float(self.head_size)).astype(self.dtype)
203
+ logging.info("kvq: scale queries by 1/sqrt(d).")
204
+ if queries is not None:
205
+ queries = queries * d_scale
206
+ if queries2 is not None:
207
+ queries2 = queries2 * d_scale
208
+
209
+ # Return keys, values, and queries.
210
+ return (keys, values, queries, queries2)
211
+
212
+
213
+ @gin.configurable
214
+ class TransformerBase(nn.Module):
215
+ """TransformerBase implements everything except attention.
216
+
217
+ It handles:
218
+ - Projection to (keys, values, queries) before attention.
219
+ - Projection MLP back to embedding_size after attention.
220
+ - Final FFN layer.
221
+ - layernorm, dropout, and normalization of keys and queries.
222
+
223
+ This functionality is ecapsulated here so that it can be reused with more
224
+ complicated attention mechanisms.
225
+ """
226
+
227
+ # Options set by parent module.
228
+ mode: str
229
+ embedding_size: int
230
+ num_heads: int
231
+ head_size: int
232
+
233
+ cross_attention_q: bool = False # Additional q for cross-attention.
234
+ cross_attention_kv: bool = False # Additional kv for cross-attention.
235
+ num_position_embeddings: int = 0 # Learned position embeddings.
236
+ num_cross_position_embeddings: int = 0 # Learned position embeddings.
237
+
238
+ # Configurable hyperparameters.
239
+ attn_mlp_factory: Callable[[int], nn.Module] = gin.REQUIRED
240
+ ffn_factory: Callable[[int], nn.Module] = gin.REQUIRED
241
+ gate_type: str = "residual"
242
+ single_gate: bool = False
243
+ skip_ffn: bool = False
244
+
245
+ normalize_keys: bool = True
246
+ dropout_rate: float = 0.0
247
+ pre_attn_dropout: bool = True
248
+ post_attn_dropout: bool = False
249
+ pre_ffn_dropout: bool = False
250
+ post_ffn_dropout: bool = True
251
+
252
+ dtype: Any = jnp.float32
253
+
254
+ def is_training(self) -> bool:
255
+ return self.mode == "train"
256
+
257
+ def _get_dropout_rng(self):
258
+ return self.make_rng("dropout")
259
+
260
+ def _normalize_kq(self, kq: Array) -> Array:
261
+ """Normalize function for keys and queries."""
262
+ epsilon = jnp.array(1.0e-6, dtype=self.dtype)
263
+ kq_sum_sqr = jnp.sum(jnp.square(kq), axis=-1, keepdims=True)
264
+ norm_kq = kq * jax.lax.rsqrt(kq_sum_sqr + epsilon)
265
+ return jnp.asarray(norm_kq, dtype=self.dtype)
266
+
267
+ def setup(self):
268
+ # Keys,values,queries for self-attention; queries for cross-attention.
269
+ self._kvq = KVQLayer(self.embedding_size, self.num_heads, self.head_size,
270
+ has_queries=True,
271
+ has_queries2=self.cross_attention_q,
272
+ num_position_embeddings=self.num_position_embeddings,
273
+ normalize_keys=self.normalize_keys,
274
+ pre_attn_dropout=self.pre_attn_dropout,
275
+ dropout_rate=self.dropout_rate,
276
+ dtype=self.dtype)
277
+
278
+ # Keys,values, attention_scale for cross-attention.
279
+ if self.cross_attention_kv:
280
+ # Use a full kvq layer, with layernorm and attention scale.
281
+ self._cross_kv = KVQLayer(
282
+ self.embedding_size, self.num_heads, self.head_size,
283
+ has_queries=False,
284
+ has_queries2=False,
285
+ num_position_embeddings=self.num_cross_position_embeddings,
286
+ normalize_keys=self.normalize_keys,
287
+ pre_attn_dropout=self.pre_attn_dropout,
288
+ dropout_rate=self.dropout_rate,
289
+ dtype=self.dtype)
290
+ elif self.cross_attention_q:
291
+ # No separate keys,values for cross-attention, but we may still need
292
+ # cross-attention-scale, so we create our own.
293
+ assert self.num_cross_position_embeddings == 0
294
+ if self.normalize_keys:
295
+ self.attention_scale2 = self.param("attention_scale2",
296
+ jax.nn.initializers.ones,
297
+ (self.num_heads,), jnp.float32)
298
+
299
+ # Post-attention linear projection.
300
+ if not self.single_gate:
301
+ self.post_attn_mlp = self.attn_mlp_factory(
302
+ self.embedding_size,
303
+ gate_type=self.gate_type,
304
+ final_activation=None,
305
+ dtype=self.dtype) # pytype: disable=wrong-keyword-args # trace-all-classes
306
+
307
+ # Final FNN.
308
+ if not self.skip_ffn:
309
+ self.ffn = self.ffn_factory(
310
+ self.embedding_size,
311
+ gate_type=self.gate_type,
312
+ final_activation=("tanh" if self.single_gate else None),
313
+ dtype=self.dtype) # pytype: disable=wrong-keyword-args # trace-all-classes
314
+
315
+ # Layernorm.
316
+ self.pre_ffn_layernorm = nn_components.LayerNorm()
317
+
318
+ def force_init(self, xs: Array):
319
+ """Force flax initialization of self, prior to use with lax.scan.
320
+
321
+ Args:
322
+ xs: The input sequence that the module will be called with.
323
+ """
324
+ logging.info("tbase: Begin forced initialization.")
325
+ _ = self.kvq(xs)
326
+ batch_size = xs.shape[0]
327
+ seq_len = xs.shape[1]
328
+ attn_ys_shape = (batch_size, seq_len, self.num_heads, self.head_size)
329
+ dummy_attn_ys = jnp.zeros(attn_ys_shape, dtype=self.dtype)
330
+ if self.cross_attention_kv or self.cross_attention_q:
331
+ dummy_cross_attn_ys = dummy_attn_ys
332
+ else:
333
+ dummy_cross_attn_ys = None
334
+ _ = self.post_attn_ffn(xs, dummy_attn_ys, dummy_cross_attn_ys)
335
+ logging.info("tbase: End forced initialization.")
336
+
337
+ def attention_scale_factors(self) -> AttnScaleTuple:
338
+ """Returns the attention scales, when keys and queries are normalized.
339
+
340
+ Returns: (scale for kv (i.e. queries), scale for cross_kv (i.e queries2))
341
+ """
342
+ sfactor = self._kvq.attention_scale_factor()
343
+ if self.cross_attention_kv:
344
+ cross_sfactor = self._cross_kv.attention_scale_factor()
345
+ elif self.cross_attention_q and self.normalize_keys:
346
+ cross_sfactor = jnp.asarray(self.attention_scale2, dtype=self.dtype)
347
+ else:
348
+ cross_sfactor = None
349
+ return (sfactor, cross_sfactor)
350
+
351
+ def kvq(self, xs: Array) -> KVQTuple:
352
+ enable_dropout = self.pre_attn_dropout and self.is_training()
353
+ return self._kvq(xs, deterministic=not enable_dropout)
354
+
355
+ def cross_kv(self, xs: Array) -> Tuple[Array, Array]:
356
+ assert self.cross_attention_kv
357
+ enable_dropout = self.pre_attn_dropout and self.is_training()
358
+ (k, v, _, _) = self._cross_kv(xs, deterministic=not enable_dropout)
359
+ return (k, v)
360
+
361
+ def post_attn_ffn(self, xs: Array, attn_ys: Array,
362
+ cross_attn_ys: Optional[Array]) -> Array:
363
+ """Combines the output of attention with the original input sequence.
364
+
365
+ Post-attn MLP on attn_ys, followed by resnet/gate.
366
+ Pre-FFN layernorm and dropout, then the FFN layer, followed by resnet/gate.
367
+
368
+ Args:
369
+ xs: Original input sequence of shape
370
+ (batch_size, sequence_length, embedding_size)
371
+ attn_ys: Output of the self-attention module, of shape
372
+ (batch_size, sequence_length, num_heads, head_size)
373
+ cross_attn_ys: Output of the cross-attention module, of shape
374
+ (batch_size, sequence_length, num_heads, head_size)
375
+
376
+ Returns:
377
+ Array of shape (batch_size, sequence_length, embedding_size)
378
+ """
379
+
380
+ (batch_size, sequence_length, _) = xs.shape
381
+ assert attn_ys.shape == (batch_size, sequence_length,
382
+ self.num_heads, self.head_size)
383
+ no_dropout = not self.is_training()
384
+ drop_tile_shape = (1, 128, self.embedding_size)
385
+
386
+ # Concatenate cross-attention and self-attention results.
387
+ if cross_attn_ys is not None:
388
+ # Concatenate self-attention and cross-attention results, before
389
+ # applying the projection layer.
390
+ logging.info("tbase: using cross-attention.")
391
+ assert attn_ys.shape == (batch_size, sequence_length,
392
+ self.num_heads, self.head_size)
393
+ attn_ys = jnp.concatenate([attn_ys, cross_attn_ys], axis=2)
394
+ att_ys_num_heads = self.num_heads * 2
395
+ else:
396
+ # Only use self-attention.
397
+ att_ys_num_heads = self.num_heads
398
+
399
+ logging.info("tbase: attn_ys = %r", attn_ys)
400
+ attn_ys = attn_ys.reshape(
401
+ (batch_size, sequence_length, att_ys_num_heads * self.head_size))
402
+
403
+ if self.single_gate:
404
+ logging.info("tbase: single gate.")
405
+ assert not self.skip_ffn
406
+ # Skip post-attention linear projection and residual connection.
407
+ ys_hidden = xs # The FFN (below) will be gated onto xs (the input).
408
+ ffn_in = attn_ys # The input to the FFN is the output of attention.
409
+ else:
410
+ logging.info("tbase: post-attention MLP.")
411
+ # Standard transformer archicture.
412
+ # The post-attention MLP applies a linear projection to project attn_ys
413
+ # to embedding space. It then uses a residual connection or gate to
414
+ # combine the projection with xs. Post-attention dropout is applied
415
+ # before the residual/gate.
416
+ post_attn_ys = self.post_attn_mlp(
417
+ attn_ys, xs,
418
+ apply_dropout=self.post_attn_dropout and not no_dropout,
419
+ dropout_rate=self.dropout_rate,
420
+ drop_tile_shape=drop_tile_shape,
421
+ rng_function=self._get_dropout_rng)
422
+
423
+ # The FFN (below) will be gated onto post_attn_ys (which gates onto xs).
424
+ ys_hidden = post_attn_ys
425
+ if self.skip_ffn:
426
+ logging.info("tbase: skip final FFN. ys = %r", ys_hidden)
427
+ return ys_hidden
428
+
429
+ # The input to the FFN; Layernorm is applied before the FFN.
430
+ ffn_in = self.pre_ffn_layernorm(ys_hidden)
431
+ logging.info("tbase: pre-FFN layernorm = %r", ffn_in)
432
+
433
+ # Pre-FFN dropout.
434
+ if self.pre_ffn_dropout:
435
+ logging.info("tbase: pre-FFN dropout.")
436
+ ffn_in = nn_components.tiled_dropout(
437
+ ffn_in, drop_tile_shape, self.dropout_rate,
438
+ rng_function=self._get_dropout_rng, deterministic=no_dropout)
439
+
440
+ # FFN layer.
441
+ # Large MLP with hidden layers followed by residual connection or gate.
442
+ # The MLP will apply post-ffn dropout before the gate.
443
+ logging.info("tbase: final FFN")
444
+ ys = self.ffn(ffn_in, ys_hidden,
445
+ apply_dropout=self.post_ffn_dropout and not no_dropout,
446
+ dropout_rate=self.dropout_rate,
447
+ drop_tile_shape=drop_tile_shape,
448
+ rng_function=self._get_dropout_rng)
449
+
450
+ logging.info("tbase: ys = %r", ys)
451
+ return ys
aglib/meliad/transformer/transformer_layer.py ADDED
@@ -0,0 +1,817 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 Google.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """A single transformer layer."""
16
+
17
+ from typing import Any, Mapping, NewType, Optional, Sequence, Tuple
18
+
19
+ from absl import logging
20
+
21
+ from flax import linen as nn
22
+ import gin
23
+
24
+ import jax
25
+ import jax.numpy as jnp
26
+
27
+ from transformer import attention
28
+ from transformer import memory_factory
29
+ from transformer import nn_components
30
+ from transformer import position
31
+ from transformer import position_fourier
32
+ from transformer import position_t5
33
+ from transformer import transformer_base
34
+
35
+
36
+ Array = jnp.ndarray
37
+ DecoderState = NewType("DecoderState", Mapping[str, Array])
38
+ WindowState = Optional[Tuple[attention.KVITuple, Array]]
39
+ KVITuple = attention.KVITuple
40
+
41
+
42
+ @gin.configurable
43
+ class TransformerLayer(nn.Module):
44
+ """Full transformer layer, with attention."""
45
+
46
+ # Set by DecoderStack
47
+ mode: str
48
+ batch_size: int
49
+ embedding_size: int
50
+ cross_attention: bool = False
51
+ recurrent_attention: bool = False
52
+ memory: Optional[memory_factory.MemoryManager] = None
53
+
54
+ # Configurable hyper-parameters
55
+ num_heads: int = gin.REQUIRED
56
+ head_size: int = gin.REQUIRED
57
+
58
+ window_length: int = gin.REQUIRED
59
+ use_long_xl_architecture: bool = True
60
+ max_unrolled_windows: int = -1 # Always unroll.
61
+ relative_position_type: Optional[str] = "fourier" # {None, "fourier", "t5"}
62
+ use_causal_mask: bool = True
63
+ attn_dropout_rate: float = 0.0
64
+
65
+ recurrent_num_states: int = 0
66
+ recurrent_gate_type: str = "bias"
67
+ recurrent_single_gate: bool = False
68
+ recurrent_skip_ffn: bool = False
69
+
70
+ compute_importance: bool = False
71
+ memory_num_neighbors: int = 0
72
+ memory_reset_on_new_doc: bool = True
73
+
74
+ dtype: Any = jnp.float32
75
+
76
+ # Modes which support caching of previous keys and values.
77
+ supported_modes_for_cache: Sequence[str] = ("train", "test")
78
+ update_memory_modes: Sequence[str] = ("train", "test")
79
+
80
+ def supports_generate(self) -> bool:
81
+ return self.use_long_xl_architecture
82
+
83
+ def _get_cache_name_from_mode(self, mode: str) -> Tuple[str, bool, bool]:
84
+ """Get the name of the cache, and whether to update the cache, from mode."""
85
+ # This is a hack to ensure that "generate" steps generate text as a
86
+ # continuation of the text that is stored in the "test" cache,
87
+ # but it does not update the "test" cache.
88
+ if mode == "generate":
89
+ assert "test" in self.supported_modes_for_cache
90
+ return ("test", False, False) # Use test cache, but don't update it.
91
+ elif mode == "init":
92
+ return ("train", False, False) # Use training cache for initialization.
93
+ else:
94
+ return (mode, True, mode in self.update_memory_modes)
95
+
96
+ def _allocate_cached_kvi(self, mode: str) -> KVITuple:
97
+ """Allocate (keys, values, importance) which can be cached between steps."""
98
+
99
+ kv_shape = [self.batch_size, self.window_length,
100
+ self.num_heads, self.head_size]
101
+ imp_shape = [self.batch_size, self.window_length]
102
+
103
+ def kv_initializer(shape):
104
+ return jnp.zeros(shape, dtype=self.dtype)
105
+
106
+ def imp_initializer(shape):
107
+ return jnp.zeros(shape, dtype=self.dtype)
108
+
109
+ pkeys = self.variable("state", "previous_keys_" + mode,
110
+ kv_initializer, kv_shape)
111
+ pvals = self.variable("state", "previous_values_" + mode,
112
+ kv_initializer, kv_shape)
113
+ if self.compute_importance:
114
+ pimportance = self.variable("state", "previous_importance_" + mode,
115
+ imp_initializer, imp_shape)
116
+ else:
117
+ pimportance = None
118
+ return (pkeys, pvals, pimportance)
119
+
120
+ def _allocate_cached_recurrent_state(self, mode: str):
121
+ rec_num_states = self.recurrent_num_states
122
+ st_shape = [self.batch_size, rec_num_states, self.embedding_size]
123
+
124
+ def st_initializer(shape):
125
+ return jnp.zeros(shape, dtype=self.dtype)
126
+
127
+ return self.variable("state", "recurrent_state_" + mode,
128
+ st_initializer, st_shape)
129
+
130
+ def setup(self):
131
+ # Basic transformer functionality: everything except attention.
132
+
133
+ self.tbase = transformer_base.TransformerBase(
134
+ mode=self.mode,
135
+ embedding_size=self.embedding_size,
136
+ num_heads=self.num_heads,
137
+ head_size=self.head_size,
138
+ cross_attention_q=self.recurrent_attention or self.cross_attention,
139
+ cross_attention_kv=False, # or True to use separate k,v.
140
+ num_position_embeddings=0,
141
+ num_cross_position_embeddings=0, # or self.recurrent_num_states w/ k,v.
142
+ dtype=self.dtype)
143
+
144
+ # Recurrent transformer functionality.
145
+ self.recurrent_tbase = None
146
+ if self.recurrent_attention:
147
+ # Recurrent transformer layer.
148
+ # We use a learned position embedding so that each element of the state
149
+ # can learn to query and compute different summaries.
150
+ self.recurrent_tbase = transformer_base.TransformerBase(
151
+ mode="pure", # Disable dropout, which breaks jax.lax.scan.
152
+ embedding_size=self.embedding_size,
153
+ num_heads=self.num_heads,
154
+ head_size=self.head_size,
155
+ cross_attention_q=True,
156
+ cross_attention_kv=False, # or True to use separate k,v.
157
+ num_position_embeddings=self.recurrent_num_states,
158
+ num_cross_position_embeddings=0, # or self.window_length w/ k,v.
159
+ gate_type=self.recurrent_gate_type,
160
+ single_gate=self.recurrent_single_gate,
161
+ skip_ffn=self.recurrent_skip_ffn,
162
+ dtype=self.dtype)
163
+
164
+ # Initial state at start of document.
165
+ # We want this to be initially small, but large enough that adafactor
166
+ # will scale updates to a reasonable value.
167
+ self.recurrent_initial_state = self.param(
168
+ "recurrent_initial_state",
169
+ jax.nn.initializers.normal(stddev=0.1),
170
+ (self.recurrent_num_states, self.embedding_size), jnp.float32)
171
+
172
+ # Cached state from previous step for BPTT.
173
+ rec_state = {}
174
+ for mkey in self.supported_modes_for_cache:
175
+ rec_state[mkey] = self._allocate_cached_recurrent_state(mkey)
176
+ self.cached_recurrent_state = rec_state
177
+
178
+ # Set up relative position encoding.
179
+ if self.relative_position_type == "fourier":
180
+ self.relative_positions = position_fourier.RelativeFourierPositions(
181
+ num_heads=self.num_heads,
182
+ max_number_of_keys=self.window_length,
183
+ dtype=self.dtype)
184
+ elif self.relative_position_type == "t5":
185
+ self.relative_positions = position_t5.T5RelativePositionBiases(
186
+ num_buckets=32, # TODO(delesley): Let Gin configure these.
187
+ max_distance=128,
188
+ num_heads=self.num_heads,
189
+ dtype=self.dtype)
190
+ elif self.relative_position_type == "rotary":
191
+ # Rotary position encodings (RoPE). No learned bias parameters.
192
+ self.relative_positions = None
193
+ else:
194
+ assert self.relative_position_type is None
195
+ self.relative_positions = None
196
+
197
+ # Set up cache for Transformer-XL style architectures.
198
+ # A separate cache is created for each each mode (e.g. train, test)
199
+ cached_kvi = {}
200
+ if self.use_long_xl_architecture:
201
+ for mkey in self.supported_modes_for_cache:
202
+ cached_kvi[mkey] = self._allocate_cached_kvi(mkey)
203
+ self.cached_kvi = cached_kvi
204
+
205
+ # Set up external memory.
206
+ # A separate memory will be created for each mode (e.g. train, test)
207
+ mem_layers = {}
208
+ if self.memory is not None:
209
+ self.memory_bias = self.param("external_memory_bias", nn.zeros,
210
+ (self.num_heads,), "float32")
211
+ for mkey in self.supported_modes_for_cache:
212
+ mlayer = self.memory.create_memory_layer()
213
+ # Use setattr to setup the name and module containership hierarchy.
214
+ setattr(self, "mem_layer_" + mkey, mlayer)
215
+ mem_layers[mkey] = mlayer
216
+ self.mem_layers = mem_layers
217
+
218
+ def _get_cached_kvi(self, start_of_sequence: Array,
219
+ mode: str) -> Optional[KVITuple]:
220
+ """Returns cached (keys, values, importance) from the previous step."""
221
+ if not self.use_long_xl_architecture:
222
+ return None
223
+ if mode not in self.cached_kvi:
224
+ # No cache, but we're using XL / sliding window, so return zeros.
225
+ logging.info("tlayer: using zero as initial XL cache value.")
226
+ kvi_shape = (self.batch_size, self.window_length,
227
+ self.num_heads, self.head_size)
228
+ return attention.initial_kvi(kvi_shape,
229
+ self.compute_importance, dtype=self.dtype)
230
+
231
+ # New documents start with zero_kv.
232
+ # Continuing the same document will attend to previous keys/vals.
233
+ (pkeys, pvals, pimportance) = self.cached_kvi[mode]
234
+ (zkeys, zvals, zimportance) = attention.initial_kvi(
235
+ pkeys.value.shape, self.compute_importance, dtype=self.dtype)
236
+
237
+ # Broadcast start_of_sequence over non-batch dims.
238
+ b = self.batch_size
239
+ start_of_sequence_kv = jnp.reshape(start_of_sequence, [b, 1, 1, 1])
240
+ prev_keys = jnp.where(start_of_sequence_kv, zkeys, pkeys.value)
241
+ prev_vals = jnp.where(start_of_sequence_kv, zvals, pvals.value)
242
+ if self.compute_importance:
243
+ start_of_sequence_imp = jnp.reshape(start_of_sequence, [b, 1])
244
+ prev_importance = jnp.where(start_of_sequence_imp, zimportance,
245
+ pimportance.value)
246
+ else:
247
+ prev_importance = None
248
+ logging.debug("tlayer: start_of_sequence = %r", start_of_sequence)
249
+ logging.info("tlayer: prev_keys[%r] = %r", mode, prev_keys)
250
+ logging.debug("tlayer: prev_importance[%r] = %r", mode, prev_importance)
251
+ return (prev_keys, prev_vals, prev_importance)
252
+
253
+ def _set_cached_kvi(self, next_kvi: KVITuple, mode: str):
254
+ """Caches the last (keys, values, importance) from the current step."""
255
+ if not self.use_long_xl_architecture:
256
+ return
257
+ if mode not in self.cached_kvi:
258
+ return
259
+
260
+ (pkeys, pvals, pimportance) = self.cached_kvi[mode]
261
+ (nkeys, nvals, nimportance) = next_kvi # From last window
262
+ logging.info("tlayer: next_keys[%r] = %r", mode, nkeys)
263
+ pkeys.value = nkeys
264
+ pvals.value = nvals
265
+ if self.compute_importance:
266
+ logging.info("tlayer: next_importance[%r] = %r", mode, nimportance)
267
+ pimportance.value = nimportance
268
+
269
+ def _get_cached_recurrent_state(self, start_of_sequence: Array,
270
+ mode: str) -> Optional[Array]:
271
+ """Returns cached recurrent state from the previous step."""
272
+ if not self.recurrent_attention:
273
+ return None
274
+ if mode not in self.cached_recurrent_state:
275
+ return None
276
+
277
+ b = self.batch_size
278
+ rstate = self.cached_recurrent_state[mode].value
279
+ istate = jnp.asarray(self.recurrent_initial_state, dtype=self.dtype)
280
+ istate = istate[jnp.newaxis, :, :] # Add batch dimension for broadcast.
281
+ logging.info("tlayer: get_cached_recurrent_state %r, %r", istate, rstate)
282
+
283
+ start_of_sequence_st = jnp.reshape(start_of_sequence, (b, 1, 1))
284
+ return jnp.where(start_of_sequence_st, istate, rstate)
285
+
286
+ def _set_cached_recurrent_state(self, next_state: Array, mode: str):
287
+ """Store the next recurrent state in the cache."""
288
+ if not self.recurrent_attention:
289
+ return
290
+ if mode not in self.cached_recurrent_state:
291
+ return
292
+
293
+ logging.info("tlayer: set_cached_recurrent_state %r", next_state)
294
+ rstate = self.cached_recurrent_state[mode]
295
+ rstate.value = next_state
296
+
297
+ def _query_external_memory(self, keys: Array, values: Array, queries: Array,
298
+ start_of_sequence: Array,
299
+ mode: str, update_memory: bool):
300
+ """Query and update external memory."""
301
+ if self.memory is None:
302
+ return None
303
+
304
+ # Make sure we initialize (allocate) the external memories for all modes.
305
+ # Per the flax lazy module initialization scheme, setup() will not be
306
+ # invoked on a submodule until that module is actually used.
307
+ if mode == "init":
308
+ for (_, mlayer) in self.mem_layers.items():
309
+ (_, _) = mlayer.topk_retrieval(queries, self.memory_num_neighbors)
310
+ mode = "train" # Pretend we're in training mode during initialization.
311
+
312
+ if mode not in self.mem_layers:
313
+ return None
314
+ if self.memory_num_neighbors == 0:
315
+ raise ValueError("Using memory, but num_neighbors == 0")
316
+
317
+ # Grab the appropriate memory layer for the current mode.
318
+ memory_layer = self.mem_layers[mode]
319
+
320
+ # Clear the relevant memories at the start of each new document.
321
+ if update_memory and self.memory_reset_on_new_doc:
322
+ # The number of "datasets" is batch_dim * num_heads.
323
+ # jnp.repeat will "broadcast" start_of_sequence over num_heads.
324
+ # E.g. if start_of_sequence = [True, False] and 4 heads,
325
+ # jnp.repeat will yield [T, T, T, T, F, F, F, F]
326
+ memory_layer.reset(jnp.repeat(start_of_sequence, self.num_heads))
327
+
328
+ # Query external memory, with queries.
329
+ (rkeys, rvals) = memory_layer.topk_retrieval(queries,
330
+ self.memory_num_neighbors)
331
+ logging.info("tlayer: query external memory (%r): rvals = %r", mode, rvals)
332
+
333
+ # Sanity check all dimensions are as expected.
334
+ assert rkeys.ndim == 5 # (b, seq_len, num_heads, num_neigh, head_dim)
335
+ assert rvals.ndim == 5
336
+ assert rkeys.shape == rvals.shape
337
+ assert rkeys.shape[0] == queries.shape[0] # batch size
338
+ assert rkeys.shape[1] == queries.shape[1] # sequence length
339
+ assert rkeys.shape[2] == self.num_heads
340
+ assert rkeys.shape[3] == self.memory_num_neighbors
341
+ assert rkeys.shape[4] == self.head_size
342
+
343
+ # Update external memory, with (keys, values).
344
+ if update_memory:
345
+ memory_layer.update(keys, values)
346
+ return (rkeys, rvals)
347
+
348
+ def __call__(self, xs: Array, start_of_sequence: Array,
349
+ *,
350
+ importance: Optional[Array] = None,
351
+ cross_attention_kv: Optional[Tuple[Array, Array]] = None,
352
+ window_state: Optional[WindowState] = None,
353
+ decoder_state: Optional[DecoderState] = None) -> (
354
+ Tuple[Array, Optional[Array], Optional[WindowState],
355
+ Optional[DecoderState], Any]):
356
+ """Computes attention over a sequence of inputs.
357
+
358
+ Args:
359
+ xs: input sequence of shape (batch_size, sequence_length, num_hidden)
360
+ start_of_sequence: An input array of shape (batch_size)
361
+
362
+ --- The following must be passed by keyword only. ---
363
+ importance: Array of shape (batch_size, sequence_length).
364
+ An importance bias for attention.
365
+ cross_attention_kv: Keys and values from encoder for cross-attention.
366
+ window_state: State object which contains context from the prior
367
+ window when using a transformer-XL or sliding window.
368
+ Initially created with load_window_state().
369
+ decoder_state: State object for autoregressive decoding, initially
370
+ created with from init_decoder_state().
371
+
372
+ Returns:
373
+ (ys: outputs of shape (batch_size, sequence_length, num_hidden),
374
+ importance: importance values for the next layer,
375
+ next_window_state: state to pass to the next window,
376
+ next_decoder_state: next decoder state for autoregressive decoding,
377
+ viz_dict: dictionary of visualizations
378
+ )
379
+ """
380
+
381
+ xs = jnp.asarray(xs, dtype=self.dtype)
382
+ logging.info("tlayer: xs = %r", xs)
383
+ logging.info("tlayer: recurrent = %r", self.recurrent_attention)
384
+ logging.info("tlayer: cross-attention = %r", cross_attention_kv is not None)
385
+
386
+ is_training = (self.mode == "train")
387
+
388
+ # Compute keys, values and queries.
389
+ # ---------------------------------
390
+ logging.info("tlayer: compute keys,values,queries.")
391
+ (keys, values, queries, queries2) = self.tbase.kvq(xs)
392
+ attention_scale_factors = self.tbase.attention_scale_factors()
393
+ (_, sequence_length, num_heads, _) = queries.shape # (b, k, h, d)
394
+
395
+ # Get biases and masks that are shared across windows.
396
+ # ----------------------------------------------------
397
+ if decoder_state is not None:
398
+ logging.info("tlayer: using autoregressive decoder.")
399
+ # When decoding, prior keys,values are loaded from the decoder state.
400
+ # Other values are precomputed, and loaded from the decoder state.
401
+ # The decoder state will be updated with the current token.
402
+ assert window_state is None
403
+
404
+ prev_kvi = None
405
+ recurrent_state = None # Use precomputed recurrent_kvq.
406
+ cross_attention_kv = None
407
+ rel_position_bias = decoder_state["relative_position_bias"]
408
+ causal_mask = None
409
+ dropout_multiplier = None
410
+
411
+ # Reuse cached recurrent keys,values for each token.
412
+ cached_recurrent_kvq = decoder_state["recurrent_kvq"]
413
+ if cached_recurrent_kvq is not None:
414
+ assert cross_attention_kv is None
415
+ cross_attention_kv = (cached_recurrent_kvq[0], cached_recurrent_kvq[1])
416
+ del cached_recurrent_kvq
417
+
418
+ # Get a full window of keys,values and update decoder state.
419
+ (decoder_state, keys, values) = self._next_decoder_state(
420
+ decoder_state, keys, values)
421
+
422
+ # Each query attends to window_length prior keys.
423
+ assert keys.shape[1] == self.window_length
424
+ kq_relative_offset = self.window_length
425
+ else:
426
+ logging.info("tlayer: windowed attention.")
427
+ # When training, attention is done using windows or chunks, and prior
428
+ # context (e.g. keys,values from the previous window) is stored in the
429
+ # window_state object.
430
+ (prev_kvi, recurrent_state) = window_state # pytype: disable=attribute-error
431
+
432
+ # Get the size of the sliding window for pos bias, dropout, & causal mask.
433
+ (num_queries, num_keys) = attention.sliding_attention_window_shape(
434
+ (keys, values, importance), prev_kvi, queries,
435
+ window_length=self.window_length)
436
+ kq_relative_offset = num_keys - num_queries
437
+
438
+ # Get the relative position bias.
439
+ # The bias doesn't depend on the query content, and so can be precomputed.
440
+ if self.relative_positions is not None:
441
+ rel_position_bias = self.relative_positions(num_queries, num_keys,
442
+ bidirectional=False)
443
+ logging.info("tlayer: %s relative bias = %r",
444
+ self.relative_position_type, rel_position_bias)
445
+ else:
446
+ rel_position_bias = None
447
+
448
+ # Get causal mask.
449
+ if self.use_causal_mask:
450
+ causal_mask = position.causal_mask(num_queries, num_keys,
451
+ window_length=self.window_length)
452
+ logging.info("tlayer: causal mask = %r", causal_mask)
453
+ else:
454
+ causal_mask = None
455
+
456
+ # Apply dropout to the attention matrix.
457
+ # The mask will be broadcast across batches and windows.
458
+ if self.attn_dropout_rate > 0.0 and is_training:
459
+ dropout_rng = self.make_rng("dropout")
460
+ attn_shape = (self.num_heads, num_queries, num_keys)
461
+ dropout_multiplier = nn_components.dropout_multiplier_mask(
462
+ dropout_rng, self.attn_dropout_rate, attn_shape, self.dtype)
463
+ logging.info("tlayer: attn_dropout = %r", dropout_multiplier)
464
+ else:
465
+ dropout_multiplier = None
466
+
467
+ # Load and store values into external memory, if memory is not None.
468
+ # ------------------------------------------------------------------
469
+ (mode, _, update_memory) = self._get_cache_name_from_mode(self.mode)
470
+ external_kv = self._query_external_memory(
471
+ keys, values, queries,
472
+ start_of_sequence=start_of_sequence, mode=mode,
473
+ update_memory=decoder_state is None and update_memory)
474
+
475
+ if self.memory is not None:
476
+ external_memory_bias = jnp.asarray(self.memory_bias, dtype=self.dtype)
477
+ external_memory_bias = jnp.reshape(external_memory_bias,
478
+ (1, 1, num_heads, 1))
479
+ external_memory_bias = jax.nn.sigmoid(external_memory_bias)
480
+ else:
481
+ external_memory_bias = None
482
+
483
+ # Compute the number of windows.
484
+ # ------------------------------
485
+ if sequence_length < self.window_length:
486
+ num_windows = 1 # Happens with autoregressive decoding.
487
+ elif sequence_length == self.window_length:
488
+ num_windows = 1
489
+ if self.use_long_xl_architecture:
490
+ assert prev_kvi is not None
491
+ else:
492
+ if not self.use_long_xl_architecture:
493
+ raise ValueError("Can only use sliding window with Transformer XL.")
494
+ num_windows = sequence_length // self.window_length
495
+ if (num_windows * self.window_length) != sequence_length:
496
+ raise ValueError(f"Window length {self.window_length} must be a " +
497
+ f"multiple of sequence length {sequence_length}")
498
+ logging.info("tlayer: num_windows = %d.", num_windows)
499
+
500
+ # Define the function to do attention within a single window.
501
+ # ---------------------------------------------------------
502
+ def single_window_attention(carry, inputs_w):
503
+ # This function uses the following variables from the outer scope.
504
+ # They are listed here for clarity.
505
+ nonlocal rel_position_bias
506
+ nonlocal causal_mask
507
+ nonlocal kq_relative_offset
508
+ nonlocal dropout_multiplier
509
+ nonlocal attention_scale_factors
510
+ nonlocal external_memory_bias
511
+ nonlocal cross_attention_kv # externally supplied.
512
+
513
+ # keys,values,queries over the whole sequence will be split into chunks.
514
+ # xs_w, kvqi_w, etc. are the chunk for the current window.
515
+ (prev_kvi_w, rec_state) = carry # carried from one window to the next.
516
+ (kvqi_w, external_kv_w) = inputs_w # inputs to the current window.
517
+ # (keys_curr_w, values_curr_w, _, _, importance_curr_w) = kvqi_w
518
+
519
+ # Concatenate keys,values from the previous window with the current
520
+ # window to implement sliding window attention.
521
+ (kvqi_w, next_kvi_w) = attention.concat_kvqi(kvqi_w, prev_kvi_w)
522
+ (keys_w, values_w, queries_w, queries2_w, importance_w) = kvqi_w
523
+
524
+ # Perform recurrent attention within the current window to get the next
525
+ # recurrent state, and set up cross attention.
526
+ if rec_state is not None:
527
+ logging.info("tlayer: recurrent attention.")
528
+
529
+ # NOTE -- recurrent states and input tokens are handled separately,
530
+ # because they have separate learned positional embeddings. Due to
531
+ # the way TransformerBase does cross-attention, this means that we use
532
+ # separate key,value layers for rec_state and tokens_w.
533
+
534
+ # Keys, values, queries from recurrent state.
535
+ logging.info("tlayer: recurrent kvq.")
536
+ rec_kvq = self.recurrent_tbase.kvq(rec_state)
537
+ r_scale_factors = self.recurrent_tbase.attention_scale_factors()
538
+ (r_keys, r_values, r_queries, r_queries2) = rec_kvq
539
+
540
+ # Joint attention over both recurrent states and input tokens.
541
+ logging.info("tlayer: recurrent self-attention.")
542
+ r_attn_ys = attention.simple_attention(
543
+ r_keys, r_values, r_queries, None,
544
+ scale_factor=r_scale_factors[0],
545
+ dtype=self.dtype)
546
+
547
+ logging.info("tlayer: recurrent cross-attention.")
548
+ r_cross_attn_ys = attention.simple_attention(
549
+ keys_w, values_w, r_queries2, importance_w,
550
+ scale_factor=r_scale_factors[1],
551
+ dtype=self.dtype)
552
+
553
+ # Recurrent post-attention FFN.
554
+ logging.info("tlayer: recurrent ffn.")
555
+ next_rec_state = self.recurrent_tbase.post_attn_ffn(
556
+ rec_state, r_attn_ys, r_cross_attn_ys)
557
+
558
+ # Get keys and values for cross-attention from recurrent state.
559
+ assert cross_attention_kv is None
560
+ local_cross_attention_kv = (r_keys, r_values)
561
+ else:
562
+ # Get keys and values for cross-attention from external argument.
563
+ next_rec_state = None
564
+ local_cross_attention_kv = cross_attention_kv
565
+
566
+ # If using RoPE, keys and queries are rotated before self-attention.
567
+ if self.relative_position_type == "rotary":
568
+ logging.info("Using rotary position encodings (RoPE), offset = %d",
569
+ kq_relative_offset)
570
+ (keys_w, queries_w) = position.rotate_kq(keys_w, queries_w,
571
+ max_wavelength=10_000,
572
+ offset=kq_relative_offset)
573
+
574
+ # Self-attention over input tokens.
575
+ logging.info("tlayer: self-attention.")
576
+ attn_ys_w = attention.simple_attention(
577
+ keys_w, values_w, queries_w, importance_w,
578
+ relative_position_bias=rel_position_bias,
579
+ scale_factor=attention_scale_factors[0],
580
+ causal_mask=causal_mask,
581
+ dropout_multiplier=dropout_multiplier,
582
+ dtype=self.dtype)
583
+
584
+ # Attention over external memory.
585
+ if external_kv_w is not None:
586
+ (external_keys_w, external_values_w) = external_kv_w
587
+ y_ext = attention.external_attention(
588
+ external_keys_w, external_values_w, queries_w,
589
+ scale_factor=attention_scale_factors[0])
590
+ if external_memory_bias is not None:
591
+ ebias = external_memory_bias
592
+ logging.info("tlayer: using external memory bias = %r", ebias)
593
+ attn_ys_w = (attn_ys_w * (1 - ebias)) + (y_ext * ebias)
594
+ else:
595
+ attn_ys_w += y_ext
596
+
597
+ # Cross attention from input tokens to encoder or recurrent state.
598
+ if local_cross_attention_kv is not None:
599
+ logging.info("tlayer: cross-attention.")
600
+ (c_keys, c_values) = local_cross_attention_kv
601
+
602
+ # Cross-attention using queries2.
603
+ cross_attn_ys_w = attention.simple_attention(
604
+ c_keys, c_values, queries2_w, None,
605
+ scale_factor=attention_scale_factors[1],
606
+ dtype=self.dtype)
607
+ else:
608
+ cross_attn_ys_w = None
609
+
610
+ # End function single_window_attention(...)
611
+ return ((next_kvi_w, next_rec_state),
612
+ (attn_ys_w, cross_attn_ys_w))
613
+
614
+ # Initialize recurrent_tbase before calling jax.lax.scan.
615
+ # Otherwise flax will throw a tantrum.
616
+ if (self.recurrent_attention and 0 <= self.max_unrolled_windows and
617
+ self.max_unrolled_windows < num_windows):
618
+ logging.info("tlayer: force initialization of recurrent_tbase.")
619
+ self.recurrent_tbase.force_init(recurrent_state)
620
+
621
+ # Perform sliding window attention over all keys,values,queries.
622
+ # --------------------------------------------------------------
623
+ initial_carry = (prev_kvi, recurrent_state) # window state.
624
+ kvqi = (keys, values, queries, queries2, importance)
625
+ attn_inputs = (kvqi, external_kv)
626
+ (next_carry, attn_outputs) = attention.split_and_scan(
627
+ single_window_attention,
628
+ initial_carry,
629
+ attn_inputs,
630
+ sections=num_windows,
631
+ axis=1,
632
+ max_unrolled_windows=self.max_unrolled_windows)
633
+ (attn_ys, cross_attn_ys) = attn_outputs
634
+
635
+ logging.info("tlayer: End windows.")
636
+
637
+ # Post-attention MLP, resnet, and FFN.
638
+ # ------------------------------------
639
+ logging.info("tlayer: final FFN.")
640
+ ys = self.tbase.post_attn_ffn(xs, attn_ys, cross_attn_ys)
641
+
642
+ importance_output = None
643
+ next_window_state = next_carry if window_state is not None else None
644
+ viz_dict = {} # Visualizations, not currently enabled.
645
+ return (ys, importance_output, next_window_state, decoder_state, viz_dict)
646
+
647
+ def load_window_state(self, start_of_sequence: Array) -> WindowState:
648
+ """Load cached state that is passed from one window to the next."""
649
+
650
+ (mode, _, _) = self._get_cache_name_from_mode(self.mode)
651
+ prev_kvi = self._get_cached_kvi(start_of_sequence, mode)
652
+ rec_state = self._get_cached_recurrent_state(start_of_sequence, mode)
653
+ if prev_kvi is not None:
654
+ logging.info("tlayer: Loaded keys,values for mode %s from cache %s",
655
+ self.mode, mode)
656
+ else:
657
+ logging.info("tlayer: Skipping XL cache for mode %s.", self.mode)
658
+ if rec_state is not None:
659
+ logging.info("tlayer: Loaded recurrent state for mode %s from cache %s.",
660
+ self.mode, mode)
661
+ return (prev_kvi, rec_state)
662
+
663
+ def store_window_state(self, window_state: WindowState):
664
+ """Write window state to the cache."""
665
+
666
+ (mode, update_cache, _) = self._get_cache_name_from_mode(self.mode)
667
+ (next_kvi, next_rec_state) = window_state # pytype: disable=attribute-error
668
+ if update_cache and next_kvi is not None:
669
+ logging.info("tlayer: Storing keys,values for mode %s in cache %s.",
670
+ self.mode, mode)
671
+ self._set_cached_kvi(next_kvi, mode)
672
+ else:
673
+ logging.info("tlayer: Skipping XL cache update for mode %s.", self.mode)
674
+ if update_cache and next_rec_state is not None:
675
+ logging.info("tlayer: Storing recurrent state for mode %s in cache %s.",
676
+ self.mode, mode)
677
+ self._set_cached_recurrent_state(next_rec_state, mode)
678
+
679
+ def get_recurrent_kv(self, window_state: WindowState):
680
+ """Get the recurrent keys,values from window_state."""
681
+
682
+ # TODO(delesley): optimize.
683
+ # This isn't ideal, because we wind up computing the recurrent keys,values
684
+ # twice -- once within the sliding window above, and again in the
685
+ # DecoderStack, so they can be passed to other layers. However, the
686
+ # plumbing is a lot simpler this way.
687
+ if window_state is None:
688
+ return None
689
+ (_, rec_state) = window_state
690
+ if rec_state is None:
691
+ return None
692
+ logging.info("tlayer: get_recurrent_kv.")
693
+ (r_keys, r_values, _, _) = self.recurrent_tbase.kvq(rec_state)
694
+ return (r_keys, r_values)
695
+
696
+ def init_decoder_state(self, sequence_length: int,
697
+ start_of_sequence: Array) -> DecoderState:
698
+ """Initialize decoder state for autoregressive generation.
699
+
700
+ Args:
701
+ sequence_length: The maximum length of the sequence to generate.
702
+ start_of_sequence: Array of boolean of shape (batch_size,)
703
+ True if starting a new sequence (with no prefix).
704
+
705
+ Returns:
706
+ A state object that can be passed to __call__.
707
+ """
708
+
709
+ # Note that generate always uses a local context of size window_length.
710
+ # Training should be set up appropriately.
711
+ if not self.use_long_xl_architecture:
712
+ raise ValueError("Generation is only supported for transformer XL.")
713
+ if not self.use_causal_mask:
714
+ raise ValueError("Generator must have been trained with a causal mask.")
715
+
716
+ (mode, _, _) = self._get_cache_name_from_mode(self.mode)
717
+
718
+ # Get relative position bias.
719
+ if self.relative_positions is not None:
720
+ # Relative positions for all tokens *prior* to the current token.
721
+ # The causal mask prevents each token from attending to itself.
722
+ rel_position_bias = self.relative_positions(1, self.window_length,
723
+ offset=self.window_length,
724
+ bidirectional=False)
725
+ else:
726
+ rel_position_bias = None
727
+
728
+ # Initialize autoregressive storage for (key, value) pairs.
729
+ # Include space for a prefix of window_length tokens.
730
+ num_keys = sequence_length + self.window_length
731
+ stored_shape = (self.batch_size, num_keys, self.num_heads, self.head_size)
732
+ stored_keys = jnp.zeros(stored_shape, dtype=self.dtype)
733
+ stored_values = jnp.zeros(stored_shape, dtype=self.dtype)
734
+ start_index = self.window_length
735
+
736
+ # Copy keys,values from cache into storage, for use as a prefix.
737
+ prev_kvi = self._get_cached_kvi(start_of_sequence, mode)
738
+ if prev_kvi is not None:
739
+ (pkeys, pvals, prev_imps) = prev_kvi
740
+ assert prev_imps is None # Not yet supported.
741
+ assert pkeys.ndim == 4
742
+ assert pkeys.shape[1] == self.window_length # (b, wlen, num_heads, d)
743
+
744
+ stored_keys = jax.lax.dynamic_update_slice_in_dim(
745
+ stored_keys, pkeys, 0, axis=1)
746
+ stored_values = jax.lax.dynamic_update_slice_in_dim(
747
+ stored_values, pvals, 0, axis=1)
748
+
749
+ # Grab the current recurrent_state, and precompute keys,values,queries.
750
+ rstate = self._get_cached_recurrent_state(start_of_sequence, mode)
751
+ if rstate is not None:
752
+ recurrent_kvq = self.recurrent_tbase.kvq(rstate)
753
+ else:
754
+ recurrent_kvq = None
755
+
756
+ decoder_state_dict = {
757
+ "keys": stored_keys,
758
+ "values": stored_values,
759
+ "current_index": start_index,
760
+ "relative_position_bias": rel_position_bias,
761
+ "recurrent_kvq": recurrent_kvq
762
+ }
763
+ return DecoderState(decoder_state_dict)
764
+
765
+ def _next_decoder_state(self, decoder_state: DecoderState,
766
+ keys: Array, values: Array) -> Tuple[
767
+ DecoderState, Array, Array]:
768
+ """Compute the next decoder state, and return keys,values to attend to.
769
+
770
+ The keys,values returned from this function are drawn from the prior
771
+ decoding state, and comprise a full window of local context.
772
+
773
+ Args:
774
+ decoder_state: The current decoder state, initially created using
775
+ init_decoder_state().
776
+ keys: The key for the current token, of shape (batch_size, 1, dim)
777
+ values: The value for the current token of shape (batch_size, 1, dim)
778
+
779
+ Returns:
780
+ (next_decoder_state,
781
+ window of keys of shape (batch_size, window_length, dim),
782
+ window of values of shape (batch_size, window_length, dim))
783
+ """
784
+
785
+ assert keys.shape[1] == 1 # single-token autoregressive decoding.
786
+
787
+ logging.info("attn_layer: next decoder state; key = %r", keys)
788
+
789
+ # Unpack decoder_state
790
+ stored_keys = decoder_state["keys"]
791
+ stored_values = decoder_state["values"]
792
+ curr_index = decoder_state["current_index"]
793
+
794
+ # Slice to get window_length-sized chunk of previous keys,values.
795
+ out_decoder_state = {}
796
+ curr_win_index = curr_index - self.window_length
797
+ out_keys = jax.lax.dynamic_slice_in_dim(
798
+ stored_keys, curr_win_index, self.window_length, axis=1)
799
+ out_values = jax.lax.dynamic_slice_in_dim(
800
+ stored_values, curr_win_index, self.window_length, axis=1)
801
+
802
+ # Write current keys,values to stored keys, values.
803
+ stored_keys = jax.lax.dynamic_update_slice_in_dim(
804
+ stored_keys, keys, curr_index, axis=1)
805
+ stored_values = jax.lax.dynamic_update_slice_in_dim(
806
+ stored_values, values, curr_index, axis=1)
807
+ curr_index = curr_index + 1
808
+
809
+ # Pack a new decoder_state object.
810
+ out_decoder_state["keys"] = stored_keys
811
+ out_decoder_state["values"] = stored_values
812
+ out_decoder_state["current_index"] = curr_index
813
+ out_decoder_state["relative_position_bias"] = (
814
+ decoder_state["relative_position_bias"])
815
+ out_decoder_state["recurrent_kvq"] = decoder_state["recurrent_kvq"]
816
+
817
+ return (DecoderState(out_decoder_state), out_keys, out_values)