Spaces:
Runtime error
Runtime error
File size: 22,126 Bytes
581eeac |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 |
# Copyright 2023 The Flax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# CODE IS HEAVILY INSPIRED FROM https://github.com/huggingface/transformers/blob/v4.40.1/src/transformers/models/deprecated/transfo_xl/modeling_transfo_xl.py
# MOST OF THE TIME JUST A CONVERSION IN JAX
"""Relative Attention HEAVILY INSPIRED FROM https://github.com/huggingface/transformers/blob/v4.40.1/src/transformers/models/deprecated/transfo_xl/modeling_transfo_xl.py
, flax attention, https://github.com/kimiyoung/transformer-xl/blob/master/pytorch/mem_transformer.py#L143, most of the time just a flax/jax conversion """
import functools
from typing import Any, Callable, Optional, Tuple
from flax.linen.dtypes import promote_dtype
from flax.linen import initializers
from flax.linen.linear import default_kernel_init
from flax.linen.linear import DenseGeneral
from flax.linen.linear import DotGeneralT
from flax.linen.linear import PrecisionLike
from flax.linen.module import compact
from flax.linen.module import merge_param
from flax.linen.module import Module
import jax
from jax import lax
from jax import random
import jax.numpy as jnp
PRNGKey = Any
Shape = Tuple[int, ...]
Dtype = Any
Array = Any
roll_vmap = jax.vmap(jnp.roll, in_axes=(-2, 0, None), out_axes=-2)
def _rel_shift(x):
zero_pad_shape = x.shape[:-2] + (x.shape[-2], 1)
zero_pad = jnp.zeros(zero_pad_shape, dtype=x.dtype)
x_padded = jnp.concatenate([zero_pad, x], axis=-1)
x_padded_shape = x.shape[:-2] + (x.shape[-1] + 1, x.shape[-2])
x_padded = x_padded.reshape(x_padded_shape)
# x_padded=jnp.swapaxes(x_padded,0,1)
x = jnp.take(x_padded, jnp.arange(1, x_padded.shape[-2]), axis=-2).reshape(x.shape)
return x
def dot_product_attention_weights(
query: Array,
key: Array,
r_pos_embed,
r_r_bias,
r_w_bias,
bias: Optional[Array] = None,
mask: Optional[Array] = None,
broadcast_dropout: bool = True,
dropout_rng: Optional[PRNGKey] = None,
dropout_rate: float = 0.0,
deterministic: bool = False,
dtype: Optional[Dtype] = None,
precision: PrecisionLike = None,
):
"""Computes dot-product attention weights given query and key.
Used by :func:`dot_product_attention`, which is what you'll most likely use.
But if you want access to the attention weights for introspection, then
you can directly call this function and call einsum yourself.
Args:
query: queries for calculating attention with shape of
`[batch..., q_length, num_heads, qk_depth_per_head]`.
key: keys for calculating attention with shape of
`[batch..., kv_length, num_heads, qk_depth_per_head]`.
bias: bias for the attention weights. This should be broadcastable to the
shape `[batch..., num_heads, q_length, kv_length]`.
This can be used for incorporating causal masks, padding masks,
proximity bias, etc.
mask: mask for the attention weights. This should be broadcastable to the
shape `[batch..., num_heads, q_length, kv_length]`.
This can be used for incorporating causal masks.
Attention weights are masked out if their corresponding mask value
is `False`.
broadcast_dropout: bool: use a broadcasted dropout along batch dims.
dropout_rng: JAX PRNGKey: to be used for dropout
dropout_rate: dropout rate
deterministic: bool, deterministic or not (to apply dropout)
dtype: the dtype of the computation (default: infer from inputs and params)
precision: numerical precision of the computation see `jax.lax.Precision`
for details.
Returns:
Output of shape `[batch..., num_heads, q_length, kv_length]`.
"""
query, key = promote_dtype(query, key, dtype=dtype)
dtype = query.dtype
assert query.ndim == key.ndim, "q, k must have same rank."
assert query.shape[:-3] == key.shape[:-3], "q, k batch dims must match."
assert query.shape[-2] == key.shape[-2], "q, k num_heads must match."
assert query.shape[-1] == key.shape[-1], "q, k depths must match."
# calculate attention matrix
depth = query.shape[-1]
# query = query
# attn weight shape is (batch..., num_heads, q_length, kv_length)
attn_weights = jnp.einsum("...qhd,...khd->...hqk", query + r_w_bias, key, precision=precision)
attn_weights_r = jnp.einsum("...qhd,khd->...hqk", query + r_r_bias, r_pos_embed, precision=precision)
attn_weights_r = roll_vmap(attn_weights_r, jnp.arange(0, query.shape[-3]) - (query.shape[-3] - 1), -1)
# attn_weights_r=_rel_shift(attn_weights_r)
attn_weights = attn_weights + attn_weights_r
attn_weights = attn_weights / jnp.sqrt(depth).astype(dtype)
# apply attention bias: masking, dropout, proximity bias, etc.
if bias is not None:
attn_weights = attn_weights + bias
# apply attention mask
if mask is not None:
big_neg = jnp.finfo(dtype).min
attn_weights = jnp.where(mask, attn_weights, big_neg)
# normalize the attention weights
attn_weights = jax.nn.softmax(attn_weights).astype(dtype)
# apply attention dropout
if not deterministic and dropout_rate > 0.0:
keep_prob = 1.0 - dropout_rate
if broadcast_dropout:
# dropout is broadcast across the batch + head dimensions
dropout_shape = tuple([1] * (key.ndim - 2)) + attn_weights.shape[-2:]
keep = random.bernoulli(dropout_rng, keep_prob, dropout_shape) # type: ignore
else:
keep = random.bernoulli(dropout_rng, keep_prob, attn_weights.shape) # type: ignore
multiplier = keep.astype(dtype) / jnp.asarray(keep_prob, dtype=dtype)
attn_weights = attn_weights * multiplier
return attn_weights
def dot_product_attention(
query: Array,
key: Array,
value: Array,
r_pos_embed,
r_r_bias,
r_w_bias,
bias: Optional[Array] = None,
mask: Optional[Array] = None,
broadcast_dropout: bool = True,
dropout_rng: Optional[PRNGKey] = None,
dropout_rate: float = 0.0,
deterministic: bool = False,
dtype: Optional[Dtype] = None,
precision: PrecisionLike = None,
):
"""Computes dot-product attention given query, key, and value.
This is the core function for applying attention based on
https://arxiv.org/abs/1706.03762. It calculates the attention weights given
query and key and combines the values using the attention weights.
Note: query, key, value needn't have any batch dimensions.
Args:
query: queries for calculating attention with shape of
`[batch..., q_length, num_heads, qk_depth_per_head]`.
key: keys for calculating attention with shape of
`[batch..., kv_length, num_heads, qk_depth_per_head]`.
value: values to be used in attention with shape of
`[batch..., kv_length, num_heads, v_depth_per_head]`.
bias: bias for the attention weights. This should be broadcastable to the
shape `[batch..., num_heads, q_length, kv_length]`.
This can be used for incorporating causal masks, padding masks,
proximity bias, etc.
mask: mask for the attention weights. This should be broadcastable to the
shape `[batch..., num_heads, q_length, kv_length]`.
This can be used for incorporating causal masks.
Attention weights are masked out if their corresponding mask value
is `False`.
broadcast_dropout: bool: use a broadcasted dropout along batch dims.
dropout_rng: JAX PRNGKey: to be used for dropout
dropout_rate: dropout rate
deterministic: bool, deterministic or not (to apply dropout)
dtype: the dtype of the computation (default: infer from inputs)
precision: numerical precision of the computation see `jax.lax.Precision`
for details.
Returns:
Output of shape `[batch..., q_length, num_heads, v_depth_per_head]`.
"""
query, key, value = promote_dtype(query, key, value, dtype=dtype)
dtype = query.dtype
assert key.ndim == query.ndim == value.ndim, "q, k, v must have same rank."
assert query.shape[:-3] == key.shape[:-3] == value.shape[:-3], "q, k, v batch dims must match."
assert query.shape[-2] == key.shape[-2] == value.shape[-2], "q, k, v num_heads must match."
assert key.shape[-3] == value.shape[-3], "k, v lengths must match."
# compute attention weights
attn_weights = dot_product_attention_weights(
query,
key,
r_pos_embed,
r_r_bias,
r_w_bias,
bias,
mask,
broadcast_dropout,
dropout_rng,
dropout_rate,
deterministic,
dtype,
precision,
)
# return weighted sum over values for each query position
return jnp.einsum("...hqk,...khd->...qhd", attn_weights, value, precision=precision)
class RelMultiHeadDotProductAttention(Module):
"""Multi-head dot-product attention.
Attributes:
num_heads: number of attention heads. Features (i.e. inputs_q.shape[-1])
should be divisible by the number of heads.
dtype: the dtype of the computation (default: infer from inputs and params)
param_dtype: the dtype passed to parameter initializers (default: float32)
qkv_features: dimension of the key, query, and value.
out_features: dimension of the last projection
broadcast_dropout: bool: use a broadcasted dropout along batch dims.
dropout_rate: dropout rate
deterministic: if false, the attention weight is masked randomly using
dropout, whereas if true, the attention weights are deterministic.
precision: numerical precision of the computation see `jax.lax.Precision`
for details.
kernel_init: initializer for the kernel of the Dense layers.
bias_init: initializer for the bias of the Dense layers.
use_bias: bool: whether pointwise QKVO dense transforms use bias.
attention_fn: dot_product_attention or compatible function. Accepts query,
key, value, and returns output of shape `[bs, dim1, dim2, ..., dimN,,
num_heads, value_channels]``
decode: whether to prepare and use an autoregressive cache.
"""
num_heads: int
dtype: Optional[Dtype] = None
param_dtype: Dtype = jnp.float32
qkv_features: Optional[int] = None
out_features: Optional[int] = None
broadcast_dropout: bool = True
dropout_rate: float = 0.0
deterministic: Optional[bool] = None
precision: PrecisionLike = None
kernel_init: Callable[[PRNGKey, Shape, Dtype], Array] = default_kernel_init
bias_init: Callable[[PRNGKey, Shape, Dtype], Array] = initializers.zeros_init()
use_bias: bool = True
attention_fn: Callable[..., Array] = dot_product_attention
decode: bool = False
qkv_dot_general: DotGeneralT = lax.dot_general
out_dot_general: DotGeneralT = lax.dot_general
@compact
def __call__(
self,
inputs_q: Array,
inputs_kv: Array,
pos_embed: Array,
mask: Optional[Array] = None,
deterministic: Optional[bool] = None,
):
"""Applies multi-head dot product attention on the input data.
Projects the inputs into multi-headed query, key, and value vectors,
applies dot-product attention and project the results to an output vector.
Args:
inputs_q: input queries of shape
`[batch_sizes..., length, features]`.
inputs_kv: key/values of shape
`[batch_sizes..., length, features]`.
mask: attention mask of shape
`[batch_sizes..., num_heads, query_length, key/value_length]`.
Attention weights are masked out if their corresponding mask value
is `False`.
deterministic: if false, the attention weight is masked randomly
using dropout, whereas if true, the attention weights
are deterministic.
Returns:
output of shape `[batch_sizes..., length, features]`.
"""
features = self.out_features or inputs_q.shape[-1]
qkv_features = self.qkv_features or inputs_q.shape[-1]
assert qkv_features % self.num_heads == 0, (
f"Memory dimension ({qkv_features}) must be divisible by number of" f" heads ({self.num_heads})."
)
head_dim = qkv_features // self.num_heads
dense = functools.partial(
DenseGeneral,
axis=-1,
dtype=self.dtype,
param_dtype=self.param_dtype,
features=(self.num_heads, head_dim),
kernel_init=self.kernel_init,
bias_init=self.bias_init,
use_bias=self.use_bias,
precision=self.precision,
dot_general=self.qkv_dot_general,
)
# project inputs_q to multi-headed q/k/v
# dimensions are then [batch..., length, n_heads, n_features_per_head]
query, key, value = (
dense(name="query")(inputs_q),
dense(name="key")(inputs_kv),
dense(name="value")(inputs_kv),
)
# different bc no bias
dense_relpos = functools.partial(
DenseGeneral,
axis=-1,
dtype=self.dtype,
param_dtype=self.param_dtype,
features=(self.num_heads, head_dim),
kernel_init=self.kernel_init,
use_bias=False,
precision=self.precision,
dot_general=self.qkv_dot_general,
)
r_pos_embed = dense_relpos(name="pos_embed_mat")(pos_embed)
r_r_bias = self.param("r_r_bias", self.bias_init, (self.num_heads, head_dim)) # Initialization function
r_w_bias = self.param("r_w_bias", self.bias_init, (self.num_heads, head_dim)) # Initialization function
# During fast autoregressive decoding, we feed one position at a time,
# and cache the keys and values step by step.
if self.decode:
# detect if we're initializing by absence of existing cache data.
is_initialized = self.has_variable("cache", "cached_key")
cached_key = self.variable("cache", "cached_key", jnp.zeros, key.shape, key.dtype)
cached_value = self.variable("cache", "cached_value", jnp.zeros, value.shape, value.dtype)
cache_index = self.variable("cache", "cache_index", lambda: jnp.array(0, dtype=jnp.int32))
if is_initialized:
(
*batch_dims,
max_length,
num_heads,
depth_per_head,
) = cached_key.value.shape
# shape check of cached keys against query input
expected_shape = tuple(batch_dims) + (1, num_heads, depth_per_head)
if expected_shape != query.shape:
raise ValueError(
"Autoregressive cache shape error, "
"expected query shape %s instead got %s." % (expected_shape, query.shape)
)
# update key, value caches with our new 1d spatial slices
cur_index = cache_index.value
indices = (0,) * len(batch_dims) + (cur_index, 0, 0)
key = lax.dynamic_update_slice(cached_key.value, key, indices)
value = lax.dynamic_update_slice(cached_value.value, value, indices)
cached_key.value = key
cached_value.value = value
cache_index.value = cache_index.value + 1
# causal mask for cached decoder self-attention:
# our single query position should only attend to those key
# positions that have already been generated and cached,
# not the remaining zero elements.
mask = combine_masks(
mask,
jnp.broadcast_to(
jnp.arange(max_length) <= cur_index,
tuple(batch_dims) + (1, 1, max_length),
),
)
dropout_rng = None
if self.dropout_rate > 0.0: # Require `deterministic` only if using dropout.
m_deterministic = merge_param("deterministic", self.deterministic, deterministic)
if not m_deterministic:
dropout_rng = self.make_rng("dropout")
else:
m_deterministic = True
# apply attention
x = self.attention_fn(
query,
key,
value,
r_pos_embed,
r_r_bias,
r_w_bias,
mask=mask,
dropout_rng=dropout_rng,
dropout_rate=self.dropout_rate,
broadcast_dropout=self.broadcast_dropout,
deterministic=m_deterministic,
dtype=self.dtype,
precision=self.precision,
) # pytype: disable=wrong-keyword-args
# back to the original inputs dimensions
out = DenseGeneral(
features=features,
axis=(-2, -1),
kernel_init=self.kernel_init,
bias_init=self.bias_init,
use_bias=self.use_bias,
dtype=self.dtype,
param_dtype=self.param_dtype,
precision=self.precision,
dot_general=self.out_dot_general,
name="out", # type: ignore[call-arg]
)(x)
return out
class SelfAttention(RelMultiHeadDotProductAttention):
"""Self-attention special case of multi-head dot-product attention."""
@compact
def __call__( # type: ignore
self,
inputs_q: Array,
mask: Optional[Array] = None,
deterministic: Optional[bool] = None,
):
"""Applies multi-head dot product self-attention on the input data.
Projects the inputs into multi-headed query, key, and value vectors,
applies dot-product attention and project the results to an output vector.
Args:
inputs_q: input queries of shape
`[batch_sizes..., length, features]`.
mask: attention mask of shape
`[batch_sizes..., num_heads, query_length, key/value_length]`.
Attention weights are masked out if their corresponding mask value
is `False`.
deterministic: if false, the attention weight is masked randomly
using dropout, whereas if true, the attention weights
are deterministic.
Returns:
output of shape `[batch_sizes..., length, features]`.
"""
return super().__call__(inputs_q, inputs_q, mask, deterministic=deterministic)
# mask-making utility functions
def make_attention_mask(
query_input: Array,
key_input: Array,
pairwise_fn: Callable[..., Any] = jnp.multiply,
extra_batch_dims: int = 0,
dtype: Dtype = jnp.float32,
):
"""Mask-making helper for attention weights.
In case of 1d inputs (i.e., `[batch..., len_q]`, `[batch..., len_kv]`, the
attention weights will be `[batch..., heads, len_q, len_kv]` and this
function will produce `[batch..., 1, len_q, len_kv]`.
Args:
query_input: a batched, flat input of query_length size
key_input: a batched, flat input of key_length size
pairwise_fn: broadcasting elementwise comparison function
extra_batch_dims: number of extra batch dims to add singleton
axes for, none by default
dtype: mask return dtype
Returns:
A `[batch..., 1, len_q, len_kv]` shaped mask for 1d attention.
"""
mask = pairwise_fn(jnp.expand_dims(query_input, axis=-1), jnp.expand_dims(key_input, axis=-2))
mask = jnp.expand_dims(mask, axis=-3)
mask = jnp.expand_dims(mask, axis=tuple(range(extra_batch_dims)))
return mask.astype(dtype)
def make_causal_mask(x: Array, extra_batch_dims: int = 0, dtype: Dtype = jnp.float32) -> Array:
"""Make a causal mask for self-attention.
In case of 1d inputs (i.e., `[batch..., len]`, the self-attention weights
will be `[batch..., heads, len, len]` and this function will produce a
causal mask of shape `[batch..., 1, len, len]`.
Args:
x: input array of shape `[batch..., len]`
extra_batch_dims: number of batch dims to add singleton axes for,
none by default
dtype: mask return dtype
Returns:
A `[batch..., 1, len, len]` shaped causal mask for 1d attention.
"""
idxs = jnp.broadcast_to(jnp.arange(x.shape[-1], dtype=jnp.int32), x.shape)
return make_attention_mask(
idxs,
idxs,
jnp.greater_equal,
extra_batch_dims=extra_batch_dims,
dtype=dtype,
)
def combine_masks(*masks: Optional[Array], dtype: Dtype = jnp.float32) -> Array:
"""Combine attention masks.
Args:
*masks: set of attention mask arguments to combine, some can be None.
dtype: dtype for the returned mask.
Returns:
Combined mask, reduced by logical and, returns None if no masks given.
"""
masks_list = [m for m in masks if m is not None]
if not masks_list:
return None
assert all(
map(lambda x: x.ndim == masks_list[0].ndim, masks_list)
), f"masks must have same rank: {tuple(map(lambda x: x.ndim, masks_list))}"
mask, *other_masks = masks_list
for other_mask in other_masks:
mask = jnp.logical_and(mask, other_mask)
return mask.astype(dtype)
|