Switch import mechanism for flash_attn

#51
by nvwilliamz - opened
Files changed (1) hide show
  1. modeling_phimoe.py +4 -3
modeling_phimoe.py CHANGED
@@ -50,14 +50,15 @@ from transformers.utils.import_utils import is_torch_fx_available
50
  from .configuration_phimoe import PhiMoEConfig
51
 
52
  from einops import rearrange
53
- from flash_attn.layers.rotary import RotaryEmbedding as FlashRotaryEmbedding
54
 
55
-
56
- if is_flash_attn_2_available():
57
  from flash_attn import flash_attn_func, flash_attn_varlen_func
58
  from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
59
 
60
  _flash_supports_window_size = "window_size" in list(inspect.signature(flash_attn_func).parameters)
 
 
61
 
62
  # This makes `_prepare_4d_causal_attention_mask` a leaf function in the FX graph.
63
  # It means that the function will not be traced through and simply appear as a node in the graph.
 
50
  from .configuration_phimoe import PhiMoEConfig
51
 
52
  from einops import rearrange
 
53
 
54
+ try:
55
+ from flash_attn.layers.rotary import RotaryEmbedding as FlashRotaryEmbedding
56
  from flash_attn import flash_attn_func, flash_attn_varlen_func
57
  from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
58
 
59
  _flash_supports_window_size = "window_size" in list(inspect.signature(flash_attn_func).parameters)
60
+ except ImportError:
61
+ pass
62
 
63
  # This makes `_prepare_4d_causal_attention_mask` a leaf function in the FX graph.
64
  # It means that the function will not be traced through and simply appear as a node in the graph.