Update modeling_intern_vit.py
Browse filesFixes https://huggingface.co/OpenGVLab/InternVL2-1B/discussions/2
- modeling_intern_vit.py +5 -4
modeling_intern_vit.py
CHANGED
@@ -15,17 +15,18 @@ from transformers.activations import ACT2FN
|
|
15 |
from transformers.modeling_outputs import (BaseModelOutput,
|
16 |
BaseModelOutputWithPooling)
|
17 |
from transformers.modeling_utils import PreTrainedModel
|
|
|
18 |
from transformers.utils import logging
|
19 |
|
20 |
from .configuration_intern_vit import InternVisionConfig
|
21 |
|
22 |
try:
|
23 |
-
|
24 |
-
from flash_attn.flash_attn_interface import \
|
25 |
-
flash_attn_unpadded_qkvpacked_func
|
26 |
-
except: # v2
|
27 |
from flash_attn.flash_attn_interface import \
|
28 |
flash_attn_varlen_qkvpacked_func as flash_attn_unpadded_qkvpacked_func
|
|
|
|
|
|
|
29 |
|
30 |
from flash_attn.bert_padding import pad_input, unpad_input
|
31 |
|
|
|
15 |
from transformers.modeling_outputs import (BaseModelOutput,
|
16 |
BaseModelOutputWithPooling)
|
17 |
from transformers.modeling_utils import PreTrainedModel
|
18 |
+
from transformers.utils.import_utils import is_flash_attn_greater_or_equal
|
19 |
from transformers.utils import logging
|
20 |
|
21 |
from .configuration_intern_vit import InternVisionConfig
|
22 |
|
23 |
try:
|
24 |
+
if is_flash_attn_greater_or_equal("2.0.0"):
|
|
|
|
|
|
|
25 |
from flash_attn.flash_attn_interface import \
|
26 |
flash_attn_varlen_qkvpacked_func as flash_attn_unpadded_qkvpacked_func
|
27 |
+
else:
|
28 |
+
from flash_attn.flash_attn_interface import \
|
29 |
+
flash_attn_unpadded_qkvpacked_func
|
30 |
|
31 |
from flash_attn.bert_padding import pad_input, unpad_input
|
32 |
|