Sentence Similarity
Transformers
Safetensors
English
mistral
feature-extraction
text-embedding
embeddings
information-retrieval
beir
text-classification
language-model
text-clustering
text-semantic-similarity
text-evaluation
text-reranking
Sentence Similarity
natural_questions
ms_marco
fever
hotpot_qa
mteb
custom_code
text-generation-inference
text-embeddings-inference
Inference Endpoints
Update attn_mask_utils.py
Browse files- attn_mask_utils.py +29 -7
attn_mask_utils.py
CHANGED
@@ -1,7 +1,19 @@
|
|
1 |
from typing import List, Optional, Tuple, Union
|
2 |
import torch
|
|
|
|
|
3 |
from transformers.modeling_attn_mask_utils import AttentionMaskConverter
|
4 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
5 |
def _prepare_4d_attention_mask_for_sdpa(
|
6 |
attention_mask: Optional[torch.Tensor],
|
7 |
input_shape: Union[torch.Size, Tuple, List],
|
@@ -59,9 +71,14 @@ def _prepare_4d_attention_mask_for_sdpa(
|
|
59 |
# From PyTorch 2.1 onwards, F.scaled_dot_product_attention with the memory-efficient attention backend
|
60 |
# produces nans if sequences are completely unattended in the attention mask. Details: https://github.com/pytorch/pytorch/issues/110213
|
61 |
if query_length > 1:
|
62 |
-
|
63 |
-
expanded_4d_mask
|
64 |
-
|
|
|
|
|
|
|
|
|
|
|
65 |
|
66 |
return expanded_4d_mask
|
67 |
|
@@ -195,8 +212,13 @@ def _prepare_4d_causal_attention_mask_for_sdpa(
|
|
195 |
# controlflow that can not be captured properly.
|
196 |
# TODO: _unmask_unattended does not work either with torch.compile when using fullgraph=True. We should find a way to detect this case.
|
197 |
if query_length > 1 and not is_tracing:
|
198 |
-
|
199 |
-
expanded_4d_mask
|
200 |
-
|
|
|
|
|
|
|
|
|
|
|
201 |
|
202 |
-
return expanded_4d_mask
|
|
|
1 |
from typing import List, Optional, Tuple, Union
|
2 |
import torch
|
3 |
+
from packaging import version
|
4 |
+
import importlib.metadata
|
5 |
from transformers.modeling_attn_mask_utils import AttentionMaskConverter
|
6 |
|
7 |
+
from transformers.utils.import_utils import _is_package_available
|
8 |
+
|
9 |
+
def is_transformers_attn_greater_or_equal_4_39():
|
10 |
+
if not _is_package_available("transformers"):
|
11 |
+
return False
|
12 |
+
|
13 |
+
return version.parse(importlib.metadata.version("transformers")) >= version.parse(
|
14 |
+
"4.39.0"
|
15 |
+
)
|
16 |
+
|
17 |
def _prepare_4d_attention_mask_for_sdpa(
|
18 |
attention_mask: Optional[torch.Tensor],
|
19 |
input_shape: Union[torch.Size, Tuple, List],
|
|
|
71 |
# From PyTorch 2.1 onwards, F.scaled_dot_product_attention with the memory-efficient attention backend
|
72 |
# produces nans if sequences are completely unattended in the attention mask. Details: https://github.com/pytorch/pytorch/issues/110213
|
73 |
if query_length > 1:
|
74 |
+
if is_transformers_attn_greater_or_equal_4_39():
|
75 |
+
expanded_4d_mask = AttentionMaskConverter._unmask_unattended(
|
76 |
+
expanded_4d_mask, min_dtype=torch.finfo(inputs_embeds.dtype).min
|
77 |
+
)
|
78 |
+
else:
|
79 |
+
expanded_4d_mask = AttentionMaskConverter._unmask_unattended(
|
80 |
+
expanded_4d_mask, attention_mask, unmasked_value=0.0
|
81 |
+
)
|
82 |
|
83 |
return expanded_4d_mask
|
84 |
|
|
|
212 |
# controlflow that can not be captured properly.
|
213 |
# TODO: _unmask_unattended does not work either with torch.compile when using fullgraph=True. We should find a way to detect this case.
|
214 |
if query_length > 1 and not is_tracing:
|
215 |
+
if is_transformers_attn_greater_or_equal_4_39():
|
216 |
+
expanded_4d_mask = AttentionMaskConverter._unmask_unattended(
|
217 |
+
expanded_4d_mask, min_dtype=torch.finfo(inputs_embeds.dtype).min
|
218 |
+
)
|
219 |
+
else:
|
220 |
+
expanded_4d_mask = AttentionMaskConverter._unmask_unattended(
|
221 |
+
expanded_4d_mask, attention_mask, unmasked_value=0.0
|
222 |
+
)
|
223 |
|
224 |
+
return expanded_4d_mask
|