Spaces:
Sleeping
Sleeping
from torch.nn.functional import * | |
from torch.nn.functional import ( | |
_mha_shape_check, | |
_canonical_mask, | |
_none_or_dtype, | |
_in_projection_packed, | |
) | |
from torch.nn import functional as F | |
import torch | |
# Tensor = torch.Tensor | |
# from typing import Callable, List, Optional, Tuple, Union | |
def multi_head_attention_forward_patched( | |
query: Tensor, | |
key: Tensor, | |
value: Tensor, | |
embed_dim_to_check: int, | |
num_heads: int, | |
in_proj_weight: Optional[Tensor], | |
in_proj_bias: Optional[Tensor], | |
bias_k: Optional[Tensor], | |
bias_v: Optional[Tensor], | |
add_zero_attn: bool, | |
dropout_p: float, | |
out_proj_weight: Tensor, | |
out_proj_bias: Optional[Tensor], | |
training: bool = True, | |
key_padding_mask: Optional[Tensor] = None, | |
need_weights: bool = True, | |
attn_mask: Optional[Tensor] = None, | |
use_separate_proj_weight: bool = False, | |
q_proj_weight: Optional[Tensor] = None, | |
k_proj_weight: Optional[Tensor] = None, | |
v_proj_weight: Optional[Tensor] = None, | |
static_k: Optional[Tensor] = None, | |
static_v: Optional[Tensor] = None, | |
average_attn_weights: bool = True, | |
is_causal: bool = False, | |
cache=None, | |
) -> Tuple[Tensor, Optional[Tensor]]: | |
r""" | |
Args: | |
query, key, value: map a query and a set of key-value pairs to an output. | |
See "Attention Is All You Need" for more details. | |
embed_dim_to_check: total dimension of the model. | |
num_heads: parallel attention heads. | |
in_proj_weight, in_proj_bias: input projection weight and bias. | |
bias_k, bias_v: bias of the key and value sequences to be added at dim=0. | |
add_zero_attn: add a new batch of zeros to the key and | |
value sequences at dim=1. | |
dropout_p: probability of an element to be zeroed. | |
out_proj_weight, out_proj_bias: the output projection weight and bias. | |
training: apply dropout if is ``True``. | |
key_padding_mask: if provided, specified padding elements in the key will | |
be ignored by the attention. This is an binary mask. When the value is True, | |
the corresponding value on the attention layer will be filled with -inf. | |
need_weights: output attn_output_weights. | |
Default: `True` | |
Note: `needs_weight` defaults to `True`, but should be set to `False` | |
For best performance when attention weights are not nedeeded. | |
*Setting needs_weights to `True` | |
leads to a significant performance degradation.* | |
attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all | |
the batches while a 3D mask allows to specify a different mask for the entries of each batch. | |
is_causal: If specified, applies a causal mask as attention mask, and ignores | |
attn_mask for computing scaled dot product attention. | |
Default: ``False``. | |
.. warning:: | |
is_causal is provides a hint that the attn_mask is the | |
causal mask.Providing incorrect hints can result in | |
incorrect execution, including forward and backward | |
compatibility. | |
use_separate_proj_weight: the function accept the proj. weights for query, key, | |
and value in different forms. If false, in_proj_weight will be used, which is | |
a combination of q_proj_weight, k_proj_weight, v_proj_weight. | |
q_proj_weight, k_proj_weight, v_proj_weight, in_proj_bias: input projection weight and bias. | |
static_k, static_v: static key and value used for attention operators. | |
average_attn_weights: If true, indicates that the returned ``attn_weights`` should be averaged across heads. | |
Otherwise, ``attn_weights`` are provided separately per head. Note that this flag only has an effect | |
when ``need_weights=True.``. Default: True | |
Shape: | |
Inputs: | |
- query: :math:`(L, E)` or :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is | |
the embedding dimension. | |
- key: :math:`(S, E)` or :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is | |
the embedding dimension. | |
- value: :math:`(S, E)` or :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is | |
the embedding dimension. | |
- key_padding_mask: :math:`(S)` or :math:`(N, S)` where N is the batch size, S is the source sequence length. | |
If a FloatTensor is provided, it will be directly added to the value. | |
If a BoolTensor is provided, the positions with the | |
value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged. | |
- attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length. | |
3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length, | |
S is the source sequence length. attn_mask ensures that position i is allowed to attend the unmasked | |
positions. If a BoolTensor is provided, positions with ``True`` | |
are not allowed to attend while ``False`` values will be unchanged. If a FloatTensor | |
is provided, it will be added to the attention weight. | |
- static_k: :math:`(N*num_heads, S, E/num_heads)`, where S is the source sequence length, | |
N is the batch size, E is the embedding dimension. E/num_heads is the head dimension. | |
- static_v: :math:`(N*num_heads, S, E/num_heads)`, where S is the source sequence length, | |
N is the batch size, E is the embedding dimension. E/num_heads is the head dimension. | |
Outputs: | |
- attn_output: :math:`(L, E)` or :math:`(L, N, E)` where L is the target sequence length, N is the batch size, | |
E is the embedding dimension. | |
- attn_output_weights: Only returned when ``need_weights=True``. If ``average_attn_weights=True``, returns | |
attention weights averaged across heads of shape :math:`(L, S)` when input is unbatched or | |
:math:`(N, L, S)`, where :math:`N` is the batch size, :math:`L` is the target sequence length, and | |
:math:`S` is the source sequence length. If ``average_attn_weights=False``, returns attention weights per | |
head of shape :math:`(num_heads, L, S)` when input is unbatched or :math:`(N, num_heads, L, S)`. | |
""" | |
tens_ops = ( | |
query, | |
key, | |
value, | |
in_proj_weight, | |
in_proj_bias, | |
bias_k, | |
bias_v, | |
out_proj_weight, | |
out_proj_bias, | |
) | |
if has_torch_function(tens_ops): | |
return handle_torch_function( | |
multi_head_attention_forward, | |
tens_ops, | |
query, | |
key, | |
value, | |
embed_dim_to_check, | |
num_heads, | |
in_proj_weight, | |
in_proj_bias, | |
bias_k, | |
bias_v, | |
add_zero_attn, | |
dropout_p, | |
out_proj_weight, | |
out_proj_bias, | |
training=training, | |
key_padding_mask=key_padding_mask, | |
need_weights=need_weights, | |
attn_mask=attn_mask, | |
is_causal=is_causal, | |
use_separate_proj_weight=use_separate_proj_weight, | |
q_proj_weight=q_proj_weight, | |
k_proj_weight=k_proj_weight, | |
v_proj_weight=v_proj_weight, | |
static_k=static_k, | |
static_v=static_v, | |
average_attn_weights=average_attn_weights, | |
cache=cache, | |
) | |
is_batched = _mha_shape_check( | |
query, key, value, key_padding_mask, attn_mask, num_heads | |
) | |
# For unbatched input, we unsqueeze at the expected batch-dim to pretend that the input | |
# is batched, run the computation and before returning squeeze the | |
# batch dimension so that the output doesn't carry this temporary batch dimension. | |
if not is_batched: | |
# unsqueeze if the input is unbatched | |
query = query.unsqueeze(1) | |
key = key.unsqueeze(1) | |
value = value.unsqueeze(1) | |
if key_padding_mask is not None: | |
key_padding_mask = key_padding_mask.unsqueeze(0) | |
# set up shape vars | |
tgt_len, bsz, embed_dim = query.shape | |
src_len, _, _ = key.shape | |
key_padding_mask = _canonical_mask( | |
mask=key_padding_mask, | |
mask_name="key_padding_mask", | |
other_type=_none_or_dtype(attn_mask), | |
other_name="attn_mask", | |
target_type=query.dtype, | |
) | |
if is_causal and attn_mask is None: | |
raise RuntimeError( | |
"Need attn_mask if specifying the is_causal hint. " | |
"You may use the Transformer module method " | |
"`generate_square_subsequent_mask` to create this mask." | |
) | |
if is_causal and key_padding_mask is None and not need_weights: | |
# when we have a kpm or need weights, we need attn_mask | |
# Otherwise, we use the is_causal hint go as is_causal | |
# indicator to SDPA. | |
attn_mask = None | |
else: | |
attn_mask = _canonical_mask( | |
mask=attn_mask, | |
mask_name="attn_mask", | |
other_type=None, | |
other_name="", | |
target_type=query.dtype, | |
check_other=False, | |
) | |
if key_padding_mask is not None: | |
# We have the attn_mask, and use that to merge kpm into it. | |
# Turn off use of is_causal hint, as the merged mask is no | |
# longer causal. | |
is_causal = False | |
assert ( | |
embed_dim == embed_dim_to_check | |
), f"was expecting embedding dimension of {embed_dim_to_check}, but got {embed_dim}" | |
if isinstance(embed_dim, torch.Tensor): | |
# embed_dim can be a tensor when JIT tracing | |
head_dim = embed_dim.div(num_heads, rounding_mode="trunc") | |
else: | |
head_dim = embed_dim // num_heads | |
assert ( | |
head_dim * num_heads == embed_dim | |
), f"embed_dim {embed_dim} not divisible by num_heads {num_heads}" | |
if use_separate_proj_weight: | |
# allow MHA to have different embedding dimensions when separate projection weights are used | |
assert ( | |
key.shape[:2] == value.shape[:2] | |
), f"key's sequence and batch dims {key.shape[:2]} do not match value's {value.shape[:2]}" | |
else: | |
assert ( | |
key.shape == value.shape | |
), f"key shape {key.shape} does not match value shape {value.shape}" | |
# | |
# compute in-projection | |
# | |
if not use_separate_proj_weight: | |
assert ( | |
in_proj_weight is not None | |
), "use_separate_proj_weight is False but in_proj_weight is None" | |
q, k, v = _in_projection_packed(query, key, value, in_proj_weight, in_proj_bias) | |
else: | |
assert ( | |
q_proj_weight is not None | |
), "use_separate_proj_weight is True but q_proj_weight is None" | |
assert ( | |
k_proj_weight is not None | |
), "use_separate_proj_weight is True but k_proj_weight is None" | |
assert ( | |
v_proj_weight is not None | |
), "use_separate_proj_weight is True but v_proj_weight is None" | |
if in_proj_bias is None: | |
b_q = b_k = b_v = None | |
else: | |
b_q, b_k, b_v = in_proj_bias.chunk(3) | |
q, k, v = _in_projection( | |
query, | |
key, | |
value, | |
q_proj_weight, | |
k_proj_weight, | |
v_proj_weight, | |
b_q, | |
b_k, | |
b_v, | |
) | |
if cache != None: | |
if cache["first_infer"] == 1: | |
cache["k"][cache["stage"]] = k | |
# print(0,cache["k"].shape) | |
cache["v"][cache["stage"]] = v | |
else: ###12个layer每个都要留自己的cache_kv | |
# print(1,cache["k"].shape) | |
cache["k"][cache["stage"]] = torch.cat( | |
[cache["k"][cache["stage"]], k], 0 | |
) ##本来时序是1,但是proj的时候可能transpose了所以时序到0维了 | |
cache["v"][cache["stage"]] = torch.cat([cache["v"][cache["stage"]], v], 0) | |
# print(2, cache["k"].shape) | |
src_len = cache["k"][cache["stage"]].shape[0] | |
k = cache["k"][cache["stage"]] | |
v = cache["v"][cache["stage"]] | |
# if attn_mask is not None: | |
# attn_mask=attn_mask[-1:,] | |
# print(attn_mask.shape,attn_mask) | |
cache["stage"] = (cache["stage"] + 1) % cache["all_stage"] | |
# print(2333,cache) | |
# prep attention mask | |
attn_mask = _canonical_mask( | |
mask=attn_mask, | |
mask_name="attn_mask", | |
other_type=None, | |
other_name="", | |
target_type=q.dtype, | |
check_other=False, | |
) | |
if attn_mask is not None: | |
# ensure attn_mask's dim is 3 | |
if attn_mask.dim() == 2: | |
correct_2d_size = (tgt_len, src_len) | |
if attn_mask.shape != correct_2d_size: | |
raise RuntimeError( | |
f"The shape of the 2D attn_mask is {attn_mask.shape}, but should be {correct_2d_size}." | |
) | |
attn_mask = attn_mask.unsqueeze(0) | |
elif attn_mask.dim() == 3: | |
correct_3d_size = (bsz * num_heads, tgt_len, src_len) | |
if attn_mask.shape != correct_3d_size: | |
raise RuntimeError( | |
f"The shape of the 3D attn_mask is {attn_mask.shape}, but should be {correct_3d_size}." | |
) | |
else: | |
raise RuntimeError( | |
f"attn_mask's dimension {attn_mask.dim()} is not supported" | |
) | |
# add bias along batch dimension (currently second) | |
if bias_k is not None and bias_v is not None: | |
assert static_k is None, "bias cannot be added to static key." | |
assert static_v is None, "bias cannot be added to static value." | |
k = torch.cat([k, bias_k.repeat(1, bsz, 1)]) | |
v = torch.cat([v, bias_v.repeat(1, bsz, 1)]) | |
if attn_mask is not None: | |
attn_mask = pad(attn_mask, (0, 1)) | |
if key_padding_mask is not None: | |
key_padding_mask = pad(key_padding_mask, (0, 1)) | |
else: | |
assert bias_k is None | |
assert bias_v is None | |
# | |
# reshape q, k, v for multihead attention and make em batch first | |
# | |
q = q.view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1) | |
if static_k is None: | |
k = k.view(k.shape[0], bsz * num_heads, head_dim).transpose(0, 1) | |
else: | |
# TODO finish disentangling control flow so we don't do in-projections when statics are passed | |
assert ( | |
static_k.size(0) == bsz * num_heads | |
), f"expecting static_k.size(0) of {bsz * num_heads}, but got {static_k.size(0)}" | |
assert ( | |
static_k.size(2) == head_dim | |
), f"expecting static_k.size(2) of {head_dim}, but got {static_k.size(2)}" | |
k = static_k | |
if static_v is None: | |
v = v.view(v.shape[0], bsz * num_heads, head_dim).transpose(0, 1) | |
else: | |
# TODO finish disentangling control flow so we don't do in-projections when statics are passed | |
assert ( | |
static_v.size(0) == bsz * num_heads | |
), f"expecting static_v.size(0) of {bsz * num_heads}, but got {static_v.size(0)}" | |
assert ( | |
static_v.size(2) == head_dim | |
), f"expecting static_v.size(2) of {head_dim}, but got {static_v.size(2)}" | |
v = static_v | |
# add zero attention along batch dimension (now first) | |
if add_zero_attn: | |
zero_attn_shape = (bsz * num_heads, 1, head_dim) | |
k = torch.cat( | |
[k, torch.zeros(zero_attn_shape, dtype=k.dtype, device=k.device)], dim=1 | |
) | |
v = torch.cat( | |
[v, torch.zeros(zero_attn_shape, dtype=v.dtype, device=v.device)], dim=1 | |
) | |
if attn_mask is not None: | |
attn_mask = pad(attn_mask, (0, 1)) | |
if key_padding_mask is not None: | |
key_padding_mask = pad(key_padding_mask, (0, 1)) | |
# update source sequence length after adjustments | |
src_len = k.size(1) | |
# merge key padding and attention masks | |
if key_padding_mask is not None: | |
assert key_padding_mask.shape == ( | |
bsz, | |
src_len, | |
), f"expecting key_padding_mask shape of {(bsz, src_len)}, but got {key_padding_mask.shape}" | |
key_padding_mask = ( | |
key_padding_mask.view(bsz, 1, 1, src_len) | |
.expand(-1, num_heads, -1, -1) | |
.reshape(bsz * num_heads, 1, src_len) | |
) | |
if attn_mask is None: | |
attn_mask = key_padding_mask | |
else: | |
attn_mask = attn_mask + key_padding_mask | |
# adjust dropout probability | |
if not training: | |
dropout_p = 0.0 | |
# | |
# (deep breath) calculate attention and out projection | |
# | |
if need_weights: | |
B, Nt, E = q.shape | |
q_scaled = q / math.sqrt(E) | |
assert not ( | |
is_causal and attn_mask is None | |
), "FIXME: is_causal not implemented for need_weights" | |
if attn_mask is not None: | |
attn_output_weights = torch.baddbmm( | |
attn_mask, q_scaled, k.transpose(-2, -1) | |
) | |
else: | |
attn_output_weights = torch.bmm(q_scaled, k.transpose(-2, -1)) | |
attn_output_weights = softmax(attn_output_weights, dim=-1) | |
if dropout_p > 0.0: | |
attn_output_weights = dropout(attn_output_weights, p=dropout_p) | |
attn_output = torch.bmm(attn_output_weights, v) | |
attn_output = ( | |
attn_output.transpose(0, 1).contiguous().view(tgt_len * bsz, embed_dim) | |
) | |
attn_output = linear(attn_output, out_proj_weight, out_proj_bias) | |
attn_output = attn_output.view(tgt_len, bsz, attn_output.size(1)) | |
# optionally average attention weights over heads | |
attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len) | |
if average_attn_weights: | |
attn_output_weights = attn_output_weights.mean(dim=1) | |
if not is_batched: | |
# squeeze the output if input was unbatched | |
attn_output = attn_output.squeeze(1) | |
attn_output_weights = attn_output_weights.squeeze(0) | |
return attn_output, attn_output_weights | |
else: | |
# attn_mask can be either (L,S) or (N*num_heads, L, S) | |
# if attn_mask's shape is (1, L, S) we need to unsqueeze to (1, 1, L, S) | |
# in order to match the input for SDPA of (N, num_heads, L, S) | |
if attn_mask is not None: | |
if attn_mask.size(0) == 1 and attn_mask.dim() == 3: | |
attn_mask = attn_mask.unsqueeze(0) | |
else: | |
attn_mask = attn_mask.view(bsz, num_heads, -1, src_len) | |
q = q.view(bsz, num_heads, tgt_len, head_dim) | |
k = k.view(bsz, num_heads, src_len, head_dim) | |
v = v.view(bsz, num_heads, src_len, head_dim) | |
# with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True, enable_mem_efficient=True): | |
attn_output = scaled_dot_product_attention( | |
q, k, v, attn_mask, dropout_p, is_causal | |
) | |
attn_output = ( | |
attn_output.permute(2, 0, 1, 3).contiguous().view(bsz * tgt_len, embed_dim) | |
) | |
attn_output = linear(attn_output, out_proj_weight, out_proj_bias) | |
attn_output = attn_output.view(tgt_len, bsz, attn_output.size(1)) | |
if not is_batched: | |
# squeeze the output if input was unbatched | |
attn_output = attn_output.squeeze(1) | |
return attn_output, None | |