Update modeling_llava_qwen2.py
Browse files- modeling_llava_qwen2.py +4 -4
modeling_llava_qwen2.py
CHANGED
@@ -863,11 +863,11 @@ from transformers.utils import (
|
|
863 |
from configuration_llava_qwen2 import Qwen2Config
|
864 |
|
865 |
|
866 |
-
if is_flash_attn_2_available():
|
867 |
-
from flash_attn import flash_attn_func, flash_attn_varlen_func
|
868 |
-
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
|
869 |
|
870 |
-
|
|
|
|
|
|
|
871 |
|
872 |
|
873 |
logger = logging.get_logger(__name__)
|
|
|
863 |
from configuration_llava_qwen2 import Qwen2Config
|
864 |
|
865 |
|
|
|
|
|
|
|
866 |
|
867 |
+
from flash_attn import flash_attn_func, flash_attn_varlen_func
|
868 |
+
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
|
869 |
+
|
870 |
+
_flash_supports_window_size = "window_size" in list(inspect.signature(flash_attn_func).parameters)
|
871 |
|
872 |
|
873 |
logger = logging.get_logger(__name__)
|