itlevy commited on
Commit
c7f5725
·
verified ·
1 Parent(s): 186a08a

fixed flash_attention backward_compat

Browse files
modeling_decilm.py CHANGED
@@ -385,7 +385,6 @@ class DeciLMAttention(nn.Module):
385
  **kwargs,
386
  ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
387
  bsz, q_len, _ = hidden_states.size()
388
-
389
  if self.config.pretraining_tp > 1:
390
  key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp
391
  query_slices = self.q_proj.weight.split(
@@ -497,7 +496,6 @@ class DeciLMFlashAttention2(DeciLMAttention):
497
  "`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` "
498
  "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers"
499
  )
500
-
501
  output_attentions = False
502
 
503
  bsz, q_len, _ = hidden_states.size()
 
385
  **kwargs,
386
  ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
387
  bsz, q_len, _ = hidden_states.size()
 
388
  if self.config.pretraining_tp > 1:
389
  key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp
390
  query_slices = self.q_proj.weight.split(
 
496
  "`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` "
497
  "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers"
498
  )
 
499
  output_attentions = False
500
 
501
  bsz, q_len, _ = hidden_states.size()
transformers_4_44_2__modeling_flash_attention_utils_backward_compat.py CHANGED
@@ -15,12 +15,18 @@
15
 
16
  import inspect
17
  import os
18
- from typing import Optional, Tuple
 
19
 
20
  import torch
21
  import torch.nn.functional as F
22
 
23
- from transformers.utils import is_flash_attn_2_available, is_flash_attn_greater_or_equal
 
 
 
 
 
24
 
25
 
26
  if is_flash_attn_2_available():
@@ -32,6 +38,46 @@ if is_flash_attn_2_available():
32
  raise "Unable to import flash_attn"
33
 
34
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
  def _get_unpad_data(attention_mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, int]:
36
  """
37
  Retrieves indexing data required to repad unpadded (ragged) tensors.
 
15
 
16
  import inspect
17
  import os
18
+ from typing import Optional, Tuple, Union
19
+
20
 
21
  import torch
22
  import torch.nn.functional as F
23
 
24
+ from functools import lru_cache
25
+ import importlib.metadata
26
+ import importlib.util
27
+ from packaging import version
28
+
29
+ from transformers.utils import is_flash_attn_2_available
30
 
31
 
32
  if is_flash_attn_2_available():
 
38
  raise "Unable to import flash_attn"
39
 
40
 
41
+ def _is_package_available(pkg_name: str, return_version: bool = False) -> Union[Tuple[bool, str], bool]:
42
+ # Check if the package spec exists and grab its version to avoid importing a local directory
43
+ package_exists = importlib.util.find_spec(pkg_name) is not None
44
+ package_version = "N/A"
45
+ if package_exists:
46
+ try:
47
+ # Primary method to get the package version
48
+ package_version = importlib.metadata.version(pkg_name)
49
+ except importlib.metadata.PackageNotFoundError:
50
+ # Fallback method: Only for "torch" and versions containing "dev"
51
+ if pkg_name == "torch":
52
+ try:
53
+ package = importlib.import_module(pkg_name)
54
+ temp_version = getattr(package, "__version__", "N/A")
55
+ # Check if the version contains "dev"
56
+ if "dev" in temp_version:
57
+ package_version = temp_version
58
+ package_exists = True
59
+ else:
60
+ package_exists = False
61
+ except ImportError:
62
+ # If the package can't be imported, it's not available
63
+ package_exists = False
64
+ else:
65
+ # For packages other than "torch", don't attempt the fallback and set as not available
66
+ package_exists = False
67
+ if return_version:
68
+ return package_exists, package_version
69
+ else:
70
+ return package_exists
71
+
72
+
73
+ @lru_cache()
74
+ def is_flash_attn_greater_or_equal(library_version: str):
75
+ if not _is_package_available("flash_attn"):
76
+ return False
77
+
78
+ return version.parse(importlib.metadata.version("flash_attn")) >= version.parse(library_version)
79
+
80
+
81
  def _get_unpad_data(attention_mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, int]:
82
  """
83
  Retrieves indexing data required to repad unpadded (ragged) tensors.