|
from typing import List, Optional, Tuple, Union |
|
import torch |
|
from transformers.modeling_attn_mask_utils import AttentionMaskConverter |
|
|
|
|
|
def _prepare_4d_causal_attention_mask( |
|
attention_mask: Optional[torch.Tensor], |
|
input_shape: Union[torch.Size, Tuple, List], |
|
inputs_embeds: torch.Tensor, |
|
past_key_values_length: int, |
|
sliding_window: Optional[int] = None, |
|
): |
|
""" |
|
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape |
|
`(batch_size, key_value_length)` |
|
|
|
Args: |
|
attention_mask (`torch.Tensor` or `None`): |
|
A 2D attention mask of shape `(batch_size, key_value_length)` |
|
input_shape (`tuple(int)` or `list(int)` or `torch.Size`): |
|
The input shape should be a tuple that defines `(batch_size, query_length)`. |
|
inputs_embeds (`torch.Tensor`): |
|
The embedded inputs as a torch Tensor. |
|
past_key_values_length (`int`): |
|
The length of the key value cache. |
|
sliding_window (`int`, *optional*): |
|
If the model uses windowed attention, a sliding window should be passed. |
|
""" |
|
attn_mask_converter = AttentionMaskConverter( |
|
is_causal=False, sliding_window=sliding_window |
|
) |
|
|
|
key_value_length = input_shape[-1] + past_key_values_length |
|
|
|
|
|
if attention_mask is not None and len(attention_mask.shape) == 2: |
|
attention_mask = attn_mask_converter.to_4d( |
|
attention_mask, |
|
input_shape[-1], |
|
key_value_length=key_value_length, |
|
dtype=inputs_embeds.dtype, |
|
) |
|
elif attention_mask is not None and len(attention_mask.shape) == 4: |
|
expected_shape = (input_shape[0], 1, input_shape[1], key_value_length) |
|
if tuple(attention_mask.shape) != expected_shape: |
|
raise ValueError( |
|
f"Incorrect 4D attention_mask shape: {tuple(attention_mask.shape)}; expected: {expected_shape}." |
|
) |
|
else: |
|
|
|
inverted_mask = 1.0 - attention_mask |
|
attention_mask = inverted_mask.masked_fill( |
|
inverted_mask.to(torch.bool), torch.finfo(inputs_embeds.dtype).min |
|
) |
|
else: |
|
attention_mask = attn_mask_converter.to_causal_4d( |
|
input_shape[0], |
|
input_shape[-1], |
|
key_value_length, |
|
dtype=inputs_embeds.dtype, |
|
device=inputs_embeds.device, |
|
) |
|
|
|
return attention_mask |
|
|
|
|
|
|
|
def _prepare_4d_causal_attention_mask_for_sdpa( |
|
attention_mask: Optional[torch.Tensor], |
|
input_shape: Union[torch.Size, Tuple, List], |
|
inputs_embeds: torch.Tensor, |
|
past_key_values_length: int, |
|
sliding_window: Optional[int] = None, |
|
): |
|
""" |
|
Prepares the correct `attn_mask` argument to be used by `torch.nn.functional.scaled_dot_product_attention`. |
|
|
|
In case no token is masked in the `attention_mask` argument, we simply set it to `None` for the cases `query_length == 1` and |
|
`key_value_length == query_length`, and rely instead on SDPA `is_causal` argument to use causal/non-causal masks, |
|
allowing to dispatch to the flash attention kernel (that can otherwise not be used if a custom `attn_mask` is passed). |
|
""" |
|
attn_mask_converter = AttentionMaskConverter( |
|
is_causal=False, sliding_window=sliding_window |
|
) |
|
|
|
key_value_length = input_shape[-1] + past_key_values_length |
|
batch_size, query_length = input_shape |
|
|
|
|
|
|
|
|
|
is_tracing = ( |
|
torch.jit.is_tracing() |
|
or isinstance(inputs_embeds, torch.fx.Proxy) |
|
or (hasattr(torch, "_dynamo") and torch._dynamo.is_compiling()) |
|
) |
|
|
|
if attention_mask is not None: |
|
|
|
if len(attention_mask.shape) == 4: |
|
expected_shape = (input_shape[0], 1, input_shape[1], key_value_length) |
|
if tuple(attention_mask.shape) != expected_shape: |
|
raise ValueError( |
|
f"Incorrect 4D attention_mask shape: {tuple(attention_mask.shape)}; expected: {expected_shape}." |
|
) |
|
else: |
|
|
|
inverted_mask = 1.0 - attention_mask.to(inputs_embeds.dtype) |
|
attention_mask = inverted_mask.masked_fill( |
|
inverted_mask.to(torch.bool), torch.finfo(inputs_embeds.dtype).min |
|
) |
|
return attention_mask |
|
|
|
elif not is_tracing and torch.all(attention_mask == 1): |
|
if query_length == 1: |
|
|
|
attention_mask = None |
|
elif key_value_length == query_length: |
|
attention_mask = None |
|
else: |
|
|
|
|
|
|
|
pass |
|
elif query_length > 1 and key_value_length != query_length: |
|
|
|
|
|
attention_mask = True |
|
elif is_tracing: |
|
raise ValueError( |
|
'Attention using SDPA can not be traced with torch.jit.trace when no attention_mask is provided. To solve this issue, please either load your model with the argument `attn_implementation="eager"` or pass an attention_mask input when tracing the model.' |
|
) |
|
|
|
if attention_mask is None: |
|
expanded_4d_mask = None |
|
elif attention_mask is True: |
|
expanded_4d_mask = attn_mask_converter.to_causal_4d( |
|
input_shape[0], |
|
input_shape[-1], |
|
key_value_length, |
|
dtype=inputs_embeds.dtype, |
|
device=inputs_embeds.device, |
|
) |
|
else: |
|
expanded_4d_mask = attn_mask_converter.to_4d( |
|
attention_mask, |
|
input_shape[-1], |
|
dtype=inputs_embeds.dtype, |
|
key_value_length=key_value_length, |
|
) |
|
|
|
|
|
|
|
|
|
if not is_tracing and expanded_4d_mask.device.type == "cuda": |
|
expanded_4d_mask = AttentionMaskConverter._unmask_unattended( |
|
expanded_4d_mask, min_dtype=torch.finfo(inputs_embeds.dtype).min |
|
) |
|
|
|
return expanded_4d_mask |
|
|