from typing import List import torch import torch.nn as nn from .args import MoeArgs class MoeLayer(nn.Module): def __init__(self, experts: List[nn.Module], gate: nn.Module, moe_args: MoeArgs): super().__init__() assert len(experts) > 0 self.experts = nn.ModuleList(experts) self.gate = gate self.args = moe_args def forward(self, inputs: torch.Tensor): gate_logits = self.gate(inputs) weights, selected_experts = torch.topk( gate_logits, self.args.num_experts_per_tok ) weights = torch.nn.functional.softmax(weights, dim=1, dtype=torch.float).to( inputs.dtype ) results = torch.zeros_like(inputs) for i, expert in enumerate(self.experts): batch_idx, nth_expert = torch.where(selected_experts == i) results[batch_idx] += weights[batch_idx, nth_expert, None] * expert( inputs[batch_idx] ) return results