Crystalcareai commited on
Commit
706595d
1 Parent(s): e6d7d0e

Update modeling_gemmoe.py

Browse files
Files changed (1) hide show
  1. modeling_gemmoe.py +2 -1
modeling_gemmoe.py CHANGED
@@ -689,7 +689,8 @@ class GemmoeSparseMoeBlock(nn.Module):
689
  flat_topk_idx = topk_idx.view(-1)
690
  for i in range(self.num_experts):
691
  expert = self.experts[i]
692
- y[flat_topk_idx == i] = expert(hidden_states[flat_topk_idx == i])
 
693
 
694
  y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1)
695
 
 
689
  flat_topk_idx = topk_idx.view(-1)
690
  for i in range(self.num_experts):
691
  expert = self.experts[i]
692
+ expert_output = expert(hidden_states[flat_topk_idx == i])
693
+ y[flat_topk_idx == i] = expert_output.to(y.dtype) # Cast expert_output to the same dtype as y
694
 
695
  y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1)
696