Crystalcareai commited on
Commit
6f6cbec
·
verified ·
1 Parent(s): 53006e5

Update modeling_gemmoe.py

Browse files
Files changed (1) hide show
  1. modeling_gemmoe.py +3 -4
modeling_gemmoe.py CHANGED
@@ -667,8 +667,6 @@ class GemmoeSparseMoeBlock(nn.Module):
667
  topk_weight, topk_idx = torch.topk(routing_weights, self.top_k, dim=-1, sorted=False)
668
  topk_weight /= topk_weight.sum(dim=-1, keepdim=True)
669
 
670
- hidden_states = hidden_states.repeat_interleave(self.top_k, dim=0)
671
-
672
  expert_outputs = []
673
  for i in range(self.num_experts):
674
  expert_input = hidden_states[topk_idx[:, i]]
@@ -676,9 +674,10 @@ class GemmoeSparseMoeBlock(nn.Module):
676
  expert_outputs.append(expert_output)
677
 
678
  expert_outputs = torch.stack(expert_outputs, dim=1)
679
- expert_outputs = expert_outputs.view(batch_size * sequence_length, self.top_k, -1)
 
680
 
681
- final_hidden_states = torch.einsum("bke,bkd->bed", topk_weight, expert_outputs)
682
  final_hidden_states = final_hidden_states.view(batch_size, sequence_length, hidden_dim)
683
 
684
  return final_hidden_states.to(hidden_states.dtype), router_logits.to(hidden_states.dtype)
 
667
  topk_weight, topk_idx = torch.topk(routing_weights, self.top_k, dim=-1, sorted=False)
668
  topk_weight /= topk_weight.sum(dim=-1, keepdim=True)
669
 
 
 
670
  expert_outputs = []
671
  for i in range(self.num_experts):
672
  expert_input = hidden_states[topk_idx[:, i]]
 
674
  expert_outputs.append(expert_output)
675
 
676
  expert_outputs = torch.stack(expert_outputs, dim=1)
677
+ expert_outputs = expert_outputs.view(batch_size, sequence_length, self.top_k, -1)
678
+ topk_weight = topk_weight.view(batch_size, sequence_length, self.top_k, 1)
679
 
680
+ final_hidden_states = (expert_outputs * topk_weight).sum(dim=2)
681
  final_hidden_states = final_hidden_states.view(batch_size, sequence_length, hidden_dim)
682
 
683
  return final_hidden_states.to(hidden_states.dtype), router_logits.to(hidden_states.dtype)