Clemspace's picture
Initial model upload
cb9e677
raw
history blame contribute delete
No virus
996 Bytes
from typing import List
import torch
import torch.nn as nn
from .args import MoeArgs
class MoeLayer(nn.Module):
def __init__(self, experts: List[nn.Module], gate: nn.Module, moe_args: MoeArgs):
super().__init__()
assert len(experts) > 0
self.experts = nn.ModuleList(experts)
self.gate = gate
self.args = moe_args
def forward(self, inputs: torch.Tensor):
gate_logits = self.gate(inputs)
weights, selected_experts = torch.topk(
gate_logits, self.args.num_experts_per_tok
)
weights = torch.nn.functional.softmax(weights, dim=1, dtype=torch.float).to(
inputs.dtype
)
results = torch.zeros_like(inputs)
for i, expert in enumerate(self.experts):
batch_idx, nth_expert = torch.where(selected_experts == i)
results[batch_idx] += weights[batch_idx, nth_expert, None] * expert(
inputs[batch_idx]
)
return results