yangapku commited on
Commit
e3edce3
1 Parent(s): 2db302e

update support for flash attn

Browse files
Files changed (1) hide show
  1. modeling_qwen.py +6 -3
modeling_qwen.py CHANGED
@@ -87,10 +87,13 @@ def _import_flash_attn():
87
 
88
  try:
89
  import flash_attn
90
- if int(flash_attn.__version__.split(".")[0]) >= 2:
91
- from flash_attn.flash_attn_interface import flash_attn_varlen_func as __flash_attn_unpadded_func
92
- else:
93
  from flash_attn.flash_attn_interface import flash_attn_unpadded_func as __flash_attn_unpadded_func
 
 
 
 
 
94
  flash_attn_unpadded_func = __flash_attn_unpadded_func
95
  except ImportError:
96
  logger.warn(
 
87
 
88
  try:
89
  import flash_attn
90
+ if not hasattr(flash_attn, '__version__'):
 
 
91
  from flash_attn.flash_attn_interface import flash_attn_unpadded_func as __flash_attn_unpadded_func
92
+ else:
93
+ if int(flash_attn.__version__.split(".")[0]) >= 2:
94
+ from flash_attn.flash_attn_interface import flash_attn_varlen_func as __flash_attn_unpadded_func
95
+ else:
96
+ from flash_attn.flash_attn_interface import flash_attn_unpadded_func as __flash_attn_unpadded_func
97
  flash_attn_unpadded_func = __flash_attn_unpadded_func
98
  except ImportError:
99
  logger.warn(