Crystalcareai
commited on
Commit
•
706595d
1
Parent(s):
e6d7d0e
Update modeling_gemmoe.py
Browse files- modeling_gemmoe.py +2 -1
modeling_gemmoe.py
CHANGED
@@ -689,7 +689,8 @@ class GemmoeSparseMoeBlock(nn.Module):
|
|
689 |
flat_topk_idx = topk_idx.view(-1)
|
690 |
for i in range(self.num_experts):
|
691 |
expert = self.experts[i]
|
692 |
-
|
|
|
693 |
|
694 |
y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1)
|
695 |
|
|
|
689 |
flat_topk_idx = topk_idx.view(-1)
|
690 |
for i in range(self.num_experts):
|
691 |
expert = self.experts[i]
|
692 |
+
expert_output = expert(hidden_states[flat_topk_idx == i])
|
693 |
+
y[flat_topk_idx == i] = expert_output.to(y.dtype) # Cast expert_output to the same dtype as y
|
694 |
|
695 |
y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1)
|
696 |
|