efederici commited on
Commit
2da0cf7
1 Parent(s): 4779b72

Update attention.py

Browse files
Files changed (1) hide show
  1. attention.py +1 -1
attention.py CHANGED
@@ -87,7 +87,7 @@ def flash_attn_fn(query, key, value, n_heads, softmax_scale=None, attn_bias=None
87
 
88
  def triton_flash_attn_fn(query, key, value, n_heads, softmax_scale=None, attn_bias=None, key_padding_mask=None, is_causal=False, dropout_p=0.0, training=False, needs_weights=False, multiquery=False):
89
  try:
90
- from .flash_attn_triton import flash_attn_func
91
  except:
92
  raise RuntimeError('Please install flash-attn==1.0.3.post0 and triton==2.0.0.dev20221202')
93
  check_valid_inputs(query, key, value)
 
87
 
88
  def triton_flash_attn_fn(query, key, value, n_heads, softmax_scale=None, attn_bias=None, key_padding_mask=None, is_causal=False, dropout_p=0.0, training=False, needs_weights=False, multiquery=False):
89
  try:
90
+ from flash_attn import flash_attn_triton
91
  except:
92
  raise RuntimeError('Please install flash-attn==1.0.3.post0 and triton==2.0.0.dev20221202')
93
  check_valid_inputs(query, key, value)