hvaara commited on
Commit
188f686
1 Parent(s): a9fc14a

Update modeling_intern_vit.py

Browse files

Fixes https://huggingface.co/OpenGVLab/InternVL2-1B/discussions/2

Files changed (1) hide show
  1. modeling_intern_vit.py +5 -4
modeling_intern_vit.py CHANGED
@@ -15,17 +15,18 @@ from transformers.activations import ACT2FN
15
  from transformers.modeling_outputs import (BaseModelOutput,
16
  BaseModelOutputWithPooling)
17
  from transformers.modeling_utils import PreTrainedModel
 
18
  from transformers.utils import logging
19
 
20
  from .configuration_intern_vit import InternVisionConfig
21
 
22
  try:
23
- try: # v1
24
- from flash_attn.flash_attn_interface import \
25
- flash_attn_unpadded_qkvpacked_func
26
- except: # v2
27
  from flash_attn.flash_attn_interface import \
28
  flash_attn_varlen_qkvpacked_func as flash_attn_unpadded_qkvpacked_func
 
 
 
29
 
30
  from flash_attn.bert_padding import pad_input, unpad_input
31
 
 
15
  from transformers.modeling_outputs import (BaseModelOutput,
16
  BaseModelOutputWithPooling)
17
  from transformers.modeling_utils import PreTrainedModel
18
+ from transformers.utils.import_utils import is_flash_attn_greater_or_equal
19
  from transformers.utils import logging
20
 
21
  from .configuration_intern_vit import InternVisionConfig
22
 
23
  try:
24
+ if is_flash_attn_greater_or_equal("2.0.0"):
 
 
 
25
  from flash_attn.flash_attn_interface import \
26
  flash_attn_varlen_qkvpacked_func as flash_attn_unpadded_qkvpacked_func
27
+ else:
28
+ from flash_attn.flash_attn_interface import \
29
+ flash_attn_unpadded_qkvpacked_func
30
 
31
  from flash_attn.bert_padding import pad_input, unpad_input
32