Crystalcareai commited on
Commit
cfc4ccd
1 Parent(s): 45f7601

Update modeling_gemmoe.py

Browse files
Files changed (1) hide show
  1. modeling_gemmoe.py +18 -48
modeling_gemmoe.py CHANGED
@@ -221,33 +221,16 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None):
221
  k_embed = (k * cos) + (rotate_half(k) * sin)
222
  return q_embed, k_embed
223
 
224
- class GemmoeMLP(nn.Module):
225
- def __init__(self, config):
226
- super().__init__()
227
- self.config = config
228
- self.hidden_size = config.hidden_size
229
- self.intermediate_size = config.intermediate_size
230
-
231
- self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
232
- self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
233
- self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
234
-
235
- self.act_fn = approx_gelu
236
-
237
- def forward(self, x):
238
- return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
239
-
240
-
241
  def repeat_kv(self, hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
242
- """
243
- This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
244
- num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
245
- """
246
- batch, num_key_value_heads, slen, head_dim = hidden_states.shape
247
- if n_rep == 1:
248
- return hidden_states
249
- hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
250
- return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
251
 
252
  class GemmoeAttention(nn.Module):
253
  """
@@ -569,17 +552,7 @@ class GemmoeSdpaAttention(GemmoeAttention):
569
  GemmoeAttention as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
570
  SDPA API.
571
  """
572
- def repeat_kv(self, hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
573
- """
574
- This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
575
- num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
576
- """
577
- batch, num_key_value_heads, slen, head_dim = hidden_states.shape
578
- if n_rep == 1:
579
- return hidden_states
580
- hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
581
- return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
582
-
583
  def forward(
584
  self,
585
  hidden_states: torch.Tensor,
@@ -670,10 +643,12 @@ class GemmoeBlockSparseTop2MLP(nn.Module):
670
  super().__init__()
671
  self.ffn_dim = config.intermediate_size
672
  self.hidden_dim = config.hidden_size
 
673
  self.w1 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False)
674
  self.w2 = nn.Linear(self.ffn_dim, self.hidden_dim, bias=False)
675
  self.w3 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False)
676
- self.act_fn = ACT2FN[config.hidden_act]
 
677
 
678
  def forward(self, hidden_states):
679
  current_hidden_states = self.act_fn(self.w1(hidden_states)) * self.w3(hidden_states)
@@ -734,20 +709,14 @@ class GemmoeSparseMoeBlock(nn.Module):
734
 
735
 
736
  class GemmoeDecoderLayer(nn.Module):
737
- """
738
- Decoder layer for the Gemmoe model.
739
-
740
- Args:
741
- config (GemmoeConfig): The configuration object for the Gemmoe model.
742
- layer_idx (int): The index of the layer.
743
- """
744
  def __init__(self, config: GemmoeConfig, layer_idx: int):
745
  super().__init__()
746
  self.hidden_size = config.hidden_size
747
- self.self_attn = GEMMOE_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx)
748
- self.mlp = GemmoeMLP(config)
 
749
  self.block_sparse_moe = GemmoeSparseMoeBlock(config)
750
- self.input_layernorm = GemmoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
751
  self.post_attention_layernorm = GemmoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
752
 
753
  def forward(
@@ -901,6 +870,7 @@ class GemmoeModel(GemmoePreTrainedModel):
901
  self.layers = nn.ModuleList(
902
  [GemmoeDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
903
  )
 
904
  self.norm = GemmoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
905
 
906
  self.gradient_checkpointing = False
 
221
  k_embed = (k * cos) + (rotate_half(k) * sin)
222
  return q_embed, k_embed
223
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
224
  def repeat_kv(self, hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
225
+ """
226
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
227
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
228
+ """
229
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
230
+ if n_rep == 1:
231
+ return hidden_states
232
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
233
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
234
 
235
  class GemmoeAttention(nn.Module):
236
  """
 
552
  GemmoeAttention as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
553
  SDPA API.
554
  """
555
+
 
 
 
 
 
 
 
 
 
 
556
  def forward(
557
  self,
558
  hidden_states: torch.Tensor,
 
643
  super().__init__()
644
  self.ffn_dim = config.intermediate_size
645
  self.hidden_dim = config.hidden_size
646
+
647
  self.w1 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False)
648
  self.w2 = nn.Linear(self.ffn_dim, self.hidden_dim, bias=False)
649
  self.w3 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False)
650
+
651
+ self.act_fn = approx_gelu
652
 
653
  def forward(self, hidden_states):
654
  current_hidden_states = self.act_fn(self.w1(hidden_states)) * self.w3(hidden_states)
 
709
 
710
 
711
  class GemmoeDecoderLayer(nn.Module):
 
 
 
 
 
 
 
712
  def __init__(self, config: GemmoeConfig, layer_idx: int):
713
  super().__init__()
714
  self.hidden_size = config.hidden_size
715
+
716
+ self.self_attn = GEMMOE_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx)
717
+
718
  self.block_sparse_moe = GemmoeSparseMoeBlock(config)
719
+ self.input_layernorm = GemmoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
720
  self.post_attention_layernorm = GemmoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
721
 
722
  def forward(
 
870
  self.layers = nn.ModuleList(
871
  [GemmoeDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
872
  )
873
+
874
  self.norm = GemmoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
875
 
876
  self.gradient_checkpointing = False