rogerxfeng8 commited on
Commit
e477000
1 Parent(s): bef256d

Moved flash_attn assert close to the caller

Browse files

Moved the flash_attn package availability assert close to the caller _apply_dense_attention. This allows the non cuda devices run the model.

Files changed (1) hide show
  1. modeling_phi3_small.py +2 -0
modeling_phi3_small.py CHANGED
@@ -418,6 +418,8 @@ class Phi3SmallSelfAttention(nn.Module):
418
  avoid doing that.
419
 
420
  """
 
 
421
  attention_dropout_prob = self.attention_dropout_rate if self.training else 0.0
422
  # Get into the correct shape for the Flash Attention API
423
  # shape: (bs, seq_len, nqp, hn)
 
418
  avoid doing that.
419
 
420
  """
421
+ assert is_flash_attention_available, "Flash Attention is not available, but is needed for dense attention"
422
+
423
  attention_dropout_prob = self.attention_dropout_rate if self.training else 0.0
424
  # Get into the correct shape for the Flash Attention API
425
  # shape: (bs, seq_len, nqp, hn)