remove flash_attn imports and usage (#1)
Browse files- remove flash_attn imports and usage (edc7c55e9ce3fe0ccd6c1a2af7a500f898e7faf4)
- Update modeling_minicpm.py (9fc8a74191cf8fd01d98d9cd988f0cdfc994ee18)
- modeling_minicpm.py +6 -6
modeling_minicpm.py
CHANGED
@@ -51,11 +51,11 @@ from transformers.utils.import_utils import is_torch_fx_available
|
|
51 |
from .configuration_minicpm import MiniCPMConfig
|
52 |
import re
|
53 |
|
54 |
-
try:
|
55 |
-
from flash_attn import flash_attn_func, flash_attn_varlen_func
|
56 |
-
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
|
57 |
-
except:
|
58 |
-
pass
|
59 |
|
60 |
|
61 |
# This makes `_prepare_4d_causal_attention_mask` a leaf function in the FX graph.
|
@@ -755,7 +755,7 @@ class MiniCPMSdpaAttention(MiniCPMAttention):
|
|
755 |
|
756 |
MINICPM_ATTENTION_CLASSES = {
|
757 |
"eager": MiniCPMAttention,
|
758 |
-
"flash_attention_2": MiniCPMFlashAttention2,
|
759 |
"sdpa": MiniCPMSdpaAttention,
|
760 |
}
|
761 |
|
|
|
51 |
from .configuration_minicpm import MiniCPMConfig
|
52 |
import re
|
53 |
|
54 |
+
#try:
|
55 |
+
# from flash_attn import flash_attn_func, flash_attn_varlen_func
|
56 |
+
# from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
|
57 |
+
#except:
|
58 |
+
# pass
|
59 |
|
60 |
|
61 |
# This makes `_prepare_4d_causal_attention_mask` a leaf function in the FX graph.
|
|
|
755 |
|
756 |
MINICPM_ATTENTION_CLASSES = {
|
757 |
"eager": MiniCPMAttention,
|
758 |
+
#"flash_attention_2": MiniCPMFlashAttention2,
|
759 |
"sdpa": MiniCPMSdpaAttention,
|
760 |
}
|
761 |
|