deepseek-moe / src /moe.py
bird-of-paradise's picture
Initial commit
354a706
import torch
import torch.nn as nn
import torch.nn.functional as F
# Note: This is a simplified version of communication balance loss
# For the complete implementation with proper token-device mapping
# the device-limited routing implementation
# and more efficient calculations, please contact the author
class Expert(nn.Module):
"""
Position-wise Feed-Forward Networks
This consists of two linear transformations with a ReLU activation in between.
FFN(x) = max(0, xW1 + b1 )W2 + b2
d_model: embedding dimension (e.g., 512)
d_expert: expert dimension (e.g., 256)
"""
def __init__(self, d_model, d_expert):
super().__init__()
self.d_model=d_model
self.d_expert= d_expert
# Linear transformation y = xW+b
self.fc1 = nn.Linear(self.d_model, self.d_expert, bias = True)
self.fc2 = nn.Linear(self.d_expert, self.d_model, bias = True)
# for potential speed up
# Pre-normalize the weights (can help with training stability)
nn.init.xavier_uniform_(self.fc1.weight)
nn.init.xavier_uniform_(self.fc2.weight)
def forward(self, input):
# check input and first FF layer dimension matching
batch_size, seq_length, d_input = input.size()
assert self.d_model == d_input, "d_model must be the same dimension as the input"
# max(0, xW_1 + b_1)W_2 + b_2
return self.fc2(F.relu(self.fc1(input)))
class MixtureOfExperts(nn.Module):
"""
Mixture of Expert as in DeepSeek
MoE(x) = x + \sum Expert^s_i(x) + \sum gate(x;K)*Expert^r_i(x)
d_model: embedding dimension (e.g., 512)
d_expert: expert dimension (e.g., 216)
K : top K gate
N_s: number of shared experts
N_r: number of routed experts
alpha1: hyper-parameter; expert-level balance factor
alpha2: hyper-parameter; edevice-level balance factor
alpha3: hyper-parameter; communication balance factor
D: number of device for distributed system
M: number of device for Device-Limited Routing
"""
def __init__(self, d_model, d_expert, K, N_s, N_r, alpha1, alpha2, alpha3, D=4, M=3):
super().__init__()
assert D < N_r, "Number of partitions needs to be less than number of routed experts"
assert M <= D, "Number of deviced for Device-Limited Routing needs to be less than number of total device"
self.d_model=d_model
self.d_expert= d_expert
self.K = K
self.N_s = N_s
self.N_r = N_r
self.alpha1 = alpha1
self.alpha2 = alpha2
self.alpha3 = alpha3
self.D = D # number of device available
self.M = M # for Device-Limited Routing
# initialize shared experts and routed experts
self.shared_experts = nn.ModuleList([
Expert(self.d_model, self.d_expert)
for _ in range(N_s)
])
self.routed_experts = nn.ModuleList([
Expert(self.d_model, self.d_expert)
for _ in range(N_r)
])
# Initiate centroids: learnable parameters, one vector per routed expert
self.expert_centroids = nn.Parameter(
torch.randn(N_r, d_model) # [num_routed_experts, d_model]
)
nn.init.xavier_uniform_(self.expert_centroids)
def forward(self, input):
# check input and first FF layer dimension matching
batch_size, seq_length, d_input = input.size()
assert self.d_model == d_input, "d_model must be the same dimension as the input"
shared_output = torch.zeros_like(input)
for expert in self.shared_experts:
shared_output += expert(input) #[batch, seq, d_model]
# Calculate similarity between input tokens and expert centroids
self.similarities = torch.matmul(input, self.expert_centroids.transpose(0, 1)) #[batch, seq, N_r]
assert self.similarities.size(dim=-1) == self.N_r, \
"last dimension of similarities must be the same as the number of routed expert"
affinity = F.softmax(self.similarities, dim = -1) #[batch, seq, N_r]
## Apply topK to calculate the gate
values, indexes = torch.topk(affinity, self.K)
values = F.softmax(values, dim=-1) # Renormalize the top-K values
gate = torch.zeros_like(affinity).scatter_(2, indexes, values) #[batch, seq, N_r]
"""for testing"""
self.last_gate = gate
routed_output = torch.zeros_like(input)
for i in range(self.N_r):
routed_output += gate[:,:,i].unsqueeze(-1) * self.routed_experts[i](input)
## Auxiliary Loss for Load Balance
# Expert-Level Balance Loss.
T = batch_size+seq_length
f = self.N_r/(self.K*T) * torch.count_nonzero(gate,(0,1))
P = 1/T * affinity.sum((0,1))
expert_loss = self.alpha1 * torch.matmul(f,P)
# Device-evel Balance Loss
f1= torch.tensor([partition.to(f.dtype).mean() for partition in torch.tensor_split(f, self.D)])
P1 = torch.tensor([partition.to(P.dtype).sum() for partition in torch.tensor_split(P, self.D)])
device_loss = self.alpha2 * torch.matmul(f1,P1)
# Communication Balance Loss
f2 = self.D/(self.M*T)*torch.tensor( [ torch.count_nonzero(partition,(0,1)).sum() for partition in torch.tensor_split(gate, self.D, dim=-1)] )
P2 = P1
commu_loss = self.alpha3 * torch.matmul(f2,P2)
return input + shared_output + routed_output, expert_loss, device_loss, commu_loss