Crystalcareai commited on
Commit
6551767
1 Parent(s): 44b539c

Update modeling_quiet.py

Browse files
Files changed (1) hide show
  1. modeling_quiet.py +4 -3
modeling_quiet.py CHANGED
@@ -449,9 +449,10 @@ class QuietFlashAttention2(QuietAttention):
449
  key_states = key_states.to(target_dtype)
450
  value_states = value_states.to(target_dtype)
451
  # Reshape to the expected shape for Flash Attention
452
- query_states = query_states.view(bsz, -1, self.num_heads, self.head_dim)
453
- key_states = key_states.view(bsz, -1, self.num_key_value_heads, self.head_dim)
454
- value_states = value_states.view(bsz, -1, self.num_key_value_heads, self.head_dim)
 
455
 
456
  attn_output = self._flash_attention_forward(
457
  query_states,
 
449
  key_states = key_states.to(target_dtype)
450
  value_states = value_states.to(target_dtype)
451
  # Reshape to the expected shape for Flash Attention
452
+ query_states = query_states.reshape(bsz, -1, self.num_heads, self.head_dim)
453
+ key_states = key_states.reshape(bsz, -1, self.num_key_value_heads, self.head_dim)
454
+ value_states = value_states.reshape(bsz, -1, self.num_key_value_heads, self.head_dim)
455
+
456
 
457
  attn_output = self._flash_attention_forward(
458
  query_states,