Crystalcareai
commited on
Update modeling_gemmoe.py
Browse files- modeling_gemmoe.py +3 -4
modeling_gemmoe.py
CHANGED
@@ -667,8 +667,6 @@ class GemmoeSparseMoeBlock(nn.Module):
|
|
667 |
topk_weight, topk_idx = torch.topk(routing_weights, self.top_k, dim=-1, sorted=False)
|
668 |
topk_weight /= topk_weight.sum(dim=-1, keepdim=True)
|
669 |
|
670 |
-
hidden_states = hidden_states.repeat_interleave(self.top_k, dim=0)
|
671 |
-
|
672 |
expert_outputs = []
|
673 |
for i in range(self.num_experts):
|
674 |
expert_input = hidden_states[topk_idx[:, i]]
|
@@ -676,9 +674,10 @@ class GemmoeSparseMoeBlock(nn.Module):
|
|
676 |
expert_outputs.append(expert_output)
|
677 |
|
678 |
expert_outputs = torch.stack(expert_outputs, dim=1)
|
679 |
-
expert_outputs = expert_outputs.view(batch_size
|
|
|
680 |
|
681 |
-
final_hidden_states =
|
682 |
final_hidden_states = final_hidden_states.view(batch_size, sequence_length, hidden_dim)
|
683 |
|
684 |
return final_hidden_states.to(hidden_states.dtype), router_logits.to(hidden_states.dtype)
|
|
|
667 |
topk_weight, topk_idx = torch.topk(routing_weights, self.top_k, dim=-1, sorted=False)
|
668 |
topk_weight /= topk_weight.sum(dim=-1, keepdim=True)
|
669 |
|
|
|
|
|
670 |
expert_outputs = []
|
671 |
for i in range(self.num_experts):
|
672 |
expert_input = hidden_states[topk_idx[:, i]]
|
|
|
674 |
expert_outputs.append(expert_output)
|
675 |
|
676 |
expert_outputs = torch.stack(expert_outputs, dim=1)
|
677 |
+
expert_outputs = expert_outputs.view(batch_size, sequence_length, self.top_k, -1)
|
678 |
+
topk_weight = topk_weight.view(batch_size, sequence_length, self.top_k, 1)
|
679 |
|
680 |
+
final_hidden_states = (expert_outputs * topk_weight).sum(dim=2)
|
681 |
final_hidden_states = final_hidden_states.view(batch_size, sequence_length, hidden_dim)
|
682 |
|
683 |
return final_hidden_states.to(hidden_states.dtype), router_logits.to(hidden_states.dtype)
|