fixed flash_attention backward_compat
Browse files
modeling_decilm.py
CHANGED
@@ -385,7 +385,6 @@ class DeciLMAttention(nn.Module):
|
|
385 |
**kwargs,
|
386 |
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
387 |
bsz, q_len, _ = hidden_states.size()
|
388 |
-
|
389 |
if self.config.pretraining_tp > 1:
|
390 |
key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp
|
391 |
query_slices = self.q_proj.weight.split(
|
@@ -497,7 +496,6 @@ class DeciLMFlashAttention2(DeciLMAttention):
|
|
497 |
"`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` "
|
498 |
"make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers"
|
499 |
)
|
500 |
-
|
501 |
output_attentions = False
|
502 |
|
503 |
bsz, q_len, _ = hidden_states.size()
|
|
|
385 |
**kwargs,
|
386 |
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
387 |
bsz, q_len, _ = hidden_states.size()
|
|
|
388 |
if self.config.pretraining_tp > 1:
|
389 |
key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp
|
390 |
query_slices = self.q_proj.weight.split(
|
|
|
496 |
"`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` "
|
497 |
"make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers"
|
498 |
)
|
|
|
499 |
output_attentions = False
|
500 |
|
501 |
bsz, q_len, _ = hidden_states.size()
|
transformers_4_44_2__modeling_flash_attention_utils_backward_compat.py
CHANGED
@@ -15,12 +15,18 @@
|
|
15 |
|
16 |
import inspect
|
17 |
import os
|
18 |
-
from typing import Optional, Tuple
|
|
|
19 |
|
20 |
import torch
|
21 |
import torch.nn.functional as F
|
22 |
|
23 |
-
from
|
|
|
|
|
|
|
|
|
|
|
24 |
|
25 |
|
26 |
if is_flash_attn_2_available():
|
@@ -32,6 +38,46 @@ if is_flash_attn_2_available():
|
|
32 |
raise "Unable to import flash_attn"
|
33 |
|
34 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
35 |
def _get_unpad_data(attention_mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, int]:
|
36 |
"""
|
37 |
Retrieves indexing data required to repad unpadded (ragged) tensors.
|
|
|
15 |
|
16 |
import inspect
|
17 |
import os
|
18 |
+
from typing import Optional, Tuple, Union
|
19 |
+
|
20 |
|
21 |
import torch
|
22 |
import torch.nn.functional as F
|
23 |
|
24 |
+
from functools import lru_cache
|
25 |
+
import importlib.metadata
|
26 |
+
import importlib.util
|
27 |
+
from packaging import version
|
28 |
+
|
29 |
+
from transformers.utils import is_flash_attn_2_available
|
30 |
|
31 |
|
32 |
if is_flash_attn_2_available():
|
|
|
38 |
raise "Unable to import flash_attn"
|
39 |
|
40 |
|
41 |
+
def _is_package_available(pkg_name: str, return_version: bool = False) -> Union[Tuple[bool, str], bool]:
|
42 |
+
# Check if the package spec exists and grab its version to avoid importing a local directory
|
43 |
+
package_exists = importlib.util.find_spec(pkg_name) is not None
|
44 |
+
package_version = "N/A"
|
45 |
+
if package_exists:
|
46 |
+
try:
|
47 |
+
# Primary method to get the package version
|
48 |
+
package_version = importlib.metadata.version(pkg_name)
|
49 |
+
except importlib.metadata.PackageNotFoundError:
|
50 |
+
# Fallback method: Only for "torch" and versions containing "dev"
|
51 |
+
if pkg_name == "torch":
|
52 |
+
try:
|
53 |
+
package = importlib.import_module(pkg_name)
|
54 |
+
temp_version = getattr(package, "__version__", "N/A")
|
55 |
+
# Check if the version contains "dev"
|
56 |
+
if "dev" in temp_version:
|
57 |
+
package_version = temp_version
|
58 |
+
package_exists = True
|
59 |
+
else:
|
60 |
+
package_exists = False
|
61 |
+
except ImportError:
|
62 |
+
# If the package can't be imported, it's not available
|
63 |
+
package_exists = False
|
64 |
+
else:
|
65 |
+
# For packages other than "torch", don't attempt the fallback and set as not available
|
66 |
+
package_exists = False
|
67 |
+
if return_version:
|
68 |
+
return package_exists, package_version
|
69 |
+
else:
|
70 |
+
return package_exists
|
71 |
+
|
72 |
+
|
73 |
+
@lru_cache()
|
74 |
+
def is_flash_attn_greater_or_equal(library_version: str):
|
75 |
+
if not _is_package_available("flash_attn"):
|
76 |
+
return False
|
77 |
+
|
78 |
+
return version.parse(importlib.metadata.version("flash_attn")) >= version.parse(library_version)
|
79 |
+
|
80 |
+
|
81 |
def _get_unpad_data(attention_mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, int]:
|
82 |
"""
|
83 |
Retrieves indexing data required to repad unpadded (ragged) tensors.
|