Switch import mechanism for flash_attn

#50
by nvwilliamz - opened
Files changed (1) hide show
  1. modeling_phimoe.py +3 -1
modeling_phimoe.py CHANGED
@@ -53,11 +53,13 @@ from einops import rearrange
53
  from flash_attn.layers.rotary import RotaryEmbedding as FlashRotaryEmbedding
54
 
55
 
56
- if is_flash_attn_2_available():
57
  from flash_attn import flash_attn_func, flash_attn_varlen_func
58
  from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
59
 
60
  _flash_supports_window_size = "window_size" in list(inspect.signature(flash_attn_func).parameters)
 
 
61
 
62
  # This makes `_prepare_4d_causal_attention_mask` a leaf function in the FX graph.
63
  # It means that the function will not be traced through and simply appear as a node in the graph.
 
53
  from flash_attn.layers.rotary import RotaryEmbedding as FlashRotaryEmbedding
54
 
55
 
56
+ try:
57
  from flash_attn import flash_attn_func, flash_attn_varlen_func
58
  from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
59
 
60
  _flash_supports_window_size = "window_size" in list(inspect.signature(flash_attn_func).parameters)
61
+ except ImportError:
62
+ pass
63
 
64
  # This makes `_prepare_4d_causal_attention_mask` a leaf function in the FX graph.
65
  # It means that the function will not be traced through and simply appear as a node in the graph.