Crystalcareai commited on
Commit
44b539c
1 Parent(s): 7621f1c

Update modeling_quiet.py

Browse files
Files changed (1) hide show
  1. modeling_quiet.py +5 -6
modeling_quiet.py CHANGED
@@ -448,11 +448,10 @@ class QuietFlashAttention2(QuietAttention):
448
  query_states = query_states.to(target_dtype)
449
  key_states = key_states.to(target_dtype)
450
  value_states = value_states.to(target_dtype)
451
-
452
- # Reashape to the expected shape for Flash Attention
453
- query_states = query_states.transpose(1, 2)
454
- key_states = key_states.transpose(1, 2)
455
- value_states = value_states.transpose(1, 2)
456
 
457
  attn_output = self._flash_attention_forward(
458
  query_states,
@@ -462,7 +461,7 @@ class QuietFlashAttention2(QuietAttention):
462
  q_len,
463
  dropout=dropout_rate,
464
  use_sliding_windows=use_sliding_windows,
465
- )
466
 
467
  attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
468
  attn_output = self.o_proj(attn_output)
 
448
  query_states = query_states.to(target_dtype)
449
  key_states = key_states.to(target_dtype)
450
  value_states = value_states.to(target_dtype)
451
+ # Reshape to the expected shape for Flash Attention
452
+ query_states = query_states.view(bsz, -1, self.num_heads, self.head_dim)
453
+ key_states = key_states.view(bsz, -1, self.num_key_value_heads, self.head_dim)
454
+ value_states = value_states.view(bsz, -1, self.num_key_value_heads, self.head_dim)
 
455
 
456
  attn_output = self._flash_attention_forward(
457
  query_states,
 
461
  q_len,
462
  dropout=dropout_rate,
463
  use_sliding_windows=use_sliding_windows,
464
+ )
465
 
466
  attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
467
  attn_output = self.o_proj(attn_output)