keyfan commited on
Commit
33fac84
1 Parent(s): 7de2158

Update modeling_grok.py

Browse files
Files changed (1) hide show
  1. modeling_grok.py +4 -4
modeling_grok.py CHANGED
@@ -84,7 +84,7 @@ class GrokRMSNorm(nn.Module):
84
  GrokRMSNorm is equivalent to T5LayerNorm
85
  """
86
  super().__init__()
87
- self.weight = nn.Parameter(torch.ones(hidden_size))
88
  self.variance_epsilon = eps
89
 
90
  def forward(self, hidden_states):
@@ -92,7 +92,7 @@ class GrokRMSNorm(nn.Module):
92
  hidden_states = hidden_states.to(torch.float32)
93
  variance = hidden_states.pow(2).mean(-1, keepdim=True)
94
  hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
95
- return self.weight * hidden_states.to(input_dtype)
96
 
97
 
98
  # Copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding with Mistral->Grok
@@ -338,7 +338,7 @@ class GrokDecoderLayer(nn.Module):
338
  self.top_k = config.num_experts_per_tok
339
 
340
  self.multi_head_attention = GrokAttention(config, layer_idx)
341
- self.router = nn.Linear(self.hidden_size, self.num_experts, bias=False)
342
  self.moe = nn.ModuleList([GrokBlockSparseTop2MLP(config) for _ in range(self.num_experts)])
343
 
344
  self.rms_norm = GrokRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
@@ -400,7 +400,7 @@ class GrokDecoderLayer(nn.Module):
400
  batch_size, sequence_length, hidden_dim = hidden_states.shape
401
  hidden_states = hidden_states.view(-1, hidden_dim)
402
  # router_logits: (batch * sequence_length, n_experts)
403
- router_logits = self.router(hidden_states)
404
 
405
  routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
406
  routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
 
84
  GrokRMSNorm is equivalent to T5LayerNorm
85
  """
86
  super().__init__()
87
+ self.weight = nn.Parameter(torch.ones(hidden_size, dtype=torch.float32))
88
  self.variance_epsilon = eps
89
 
90
  def forward(self, hidden_states):
 
92
  hidden_states = hidden_states.to(torch.float32)
93
  variance = hidden_states.pow(2).mean(-1, keepdim=True)
94
  hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
95
+ return (self.weight * hidden_states).to(input_dtype)
96
 
97
 
98
  # Copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding with Mistral->Grok
 
338
  self.top_k = config.num_experts_per_tok
339
 
340
  self.multi_head_attention = GrokAttention(config, layer_idx)
341
+ self.router = nn.Linear(self.hidden_size, self.num_experts, dtype=torch.float32, bias=False)
342
  self.moe = nn.ModuleList([GrokBlockSparseTop2MLP(config) for _ in range(self.num_experts)])
343
 
344
  self.rms_norm = GrokRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
 
400
  batch_size, sequence_length, hidden_dim = hidden_states.shape
401
  hidden_states = hidden_states.view(-1, hidden_dim)
402
  # router_logits: (batch * sequence_length, n_experts)
403
+ router_logits = self.router(hidden_states.to(torch.float))
404
 
405
  routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
406
  routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)