Crystalcareai
commited on
Update modeling_gemmoe.py
Browse files- modeling_gemmoe.py +14 -11
modeling_gemmoe.py
CHANGED
@@ -655,31 +655,34 @@ class GemmoeSparseMoeBlock(nn.Module):
|
|
655 |
self.num_experts = config.num_local_experts
|
656 |
self.top_k = 2
|
657 |
|
|
|
658 |
self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False)
|
|
|
659 |
self.experts = nn.ModuleList([GemmoeBlockSparseTop2MLP(config) for _ in range(self.num_experts)])
|
660 |
|
661 |
def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
662 |
batch_size, sequence_length, hidden_dim = hidden_states.shape
|
663 |
hidden_states = hidden_states.view(-1, hidden_dim)
|
664 |
|
|
|
665 |
router_logits = self.gate(hidden_states)
|
666 |
routing_weights = F.softmax(router_logits, dim=1)
|
667 |
topk_weight, topk_idx = torch.topk(routing_weights, self.top_k, dim=-1, sorted=False)
|
668 |
topk_weight /= topk_weight.sum(dim=-1, keepdim=True)
|
669 |
|
670 |
-
|
|
|
|
|
|
|
|
|
671 |
for i in range(self.num_experts):
|
672 |
-
|
673 |
-
expert_output =
|
674 |
-
|
675 |
|
676 |
-
|
677 |
-
|
678 |
-
|
679 |
-
|
680 |
-
final_hidden_states = (expert_outputs * topk_weight).sum(dim=2)
|
681 |
-
final_hidden_states = final_hidden_states.view(batch_size, sequence_length, hidden_dim)
|
682 |
-
|
683 |
return final_hidden_states.to(hidden_states.dtype), router_logits.to(hidden_states.dtype)
|
684 |
|
685 |
|
|
|
655 |
self.num_experts = config.num_local_experts
|
656 |
self.top_k = 2
|
657 |
|
658 |
+
# gating
|
659 |
self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False)
|
660 |
+
|
661 |
self.experts = nn.ModuleList([GemmoeBlockSparseTop2MLP(config) for _ in range(self.num_experts)])
|
662 |
|
663 |
def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
664 |
batch_size, sequence_length, hidden_dim = hidden_states.shape
|
665 |
hidden_states = hidden_states.view(-1, hidden_dim)
|
666 |
|
667 |
+
# router_logits: (batch * sequence_length, n_experts)
|
668 |
router_logits = self.gate(hidden_states)
|
669 |
routing_weights = F.softmax(router_logits, dim=1)
|
670 |
topk_weight, topk_idx = torch.topk(routing_weights, self.top_k, dim=-1, sorted=False)
|
671 |
topk_weight /= topk_weight.sum(dim=-1, keepdim=True)
|
672 |
|
673 |
+
hidden_states = hidden_states.repeat_interleave(self.top_k, dim=0)
|
674 |
+
|
675 |
+
y = torch.empty_like(hidden_states)
|
676 |
+
|
677 |
+
flat_topk_idx = topk_idx.view(-1)
|
678 |
for i in range(self.num_experts):
|
679 |
+
expert = self.experts[i]
|
680 |
+
expert_output = expert(hidden_states[flat_topk_idx == i])
|
681 |
+
y[flat_topk_idx == i] = expert_output
|
682 |
|
683 |
+
y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1)
|
684 |
+
|
685 |
+
final_hidden_states = y.reshape(batch_size, sequence_length, hidden_dim)
|
|
|
|
|
|
|
|
|
686 |
return final_hidden_states.to(hidden_states.dtype), router_logits.to(hidden_states.dtype)
|
687 |
|
688 |
|