ver217 commited on
Commit
f27fd70
1 Parent(s): f7c6e7f

[hotfix] update gqa impl

Browse files
Files changed (1) hide show
  1. modeling_grok1.py +20 -0
modeling_grok1.py CHANGED
@@ -74,6 +74,21 @@ def load_balancing_loss_func(
74
  ) * (num_experts**2)
75
 
76
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
  class RMSNorm(nn.Module):
78
  def __init__(
79
  self,
@@ -194,6 +209,7 @@ class MultiHeadAttention(nn.Module):
194
  if num_key_value_heads is None:
195
  num_key_value_heads = num_heads
196
  self.num_key_value_heads = num_key_value_heads
 
197
  self.attn_output_multiplier = attn_output_multiplier
198
  self.max_attn_val = max_attn_val
199
 
@@ -259,6 +275,10 @@ class MultiHeadAttention(nn.Module):
259
 
260
  past_key_value = (key_states, value_states) if use_cache else None
261
 
 
 
 
 
262
  attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)).to(
263
  torch.float
264
  )
 
74
  ) * (num_experts**2)
75
 
76
 
77
+ # Copied from transformers.models.llama.modeling_llama.repeat_kv
78
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
79
+ """
80
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
81
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
82
+ """
83
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
84
+ if n_rep == 1:
85
+ return hidden_states
86
+ hidden_states = hidden_states[:, :, None, :, :].expand(
87
+ batch, num_key_value_heads, n_rep, slen, head_dim
88
+ )
89
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
90
+
91
+
92
  class RMSNorm(nn.Module):
93
  def __init__(
94
  self,
 
209
  if num_key_value_heads is None:
210
  num_key_value_heads = num_heads
211
  self.num_key_value_heads = num_key_value_heads
212
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
213
  self.attn_output_multiplier = attn_output_multiplier
214
  self.max_attn_val = max_attn_val
215
 
 
275
 
276
  past_key_value = (key_states, value_states) if use_cache else None
277
 
278
+ # repeat k/v heads if n_kv_heads < n_heads
279
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
280
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
281
+
282
  attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)).to(
283
  torch.float
284
  )