|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
|
|
|
|
|
|
|
|
|
|
class Expert(nn.Module): |
|
""" |
|
Position-wise Feed-Forward Networks |
|
This consists of two linear transformations with a ReLU activation in between. |
|
|
|
FFN(x) = max(0, xW1 + b1 )W2 + b2 |
|
d_model: embedding dimension (e.g., 512) |
|
d_expert: expert dimension (e.g., 256) |
|
|
|
""" |
|
def __init__(self, d_model, d_expert): |
|
super().__init__() |
|
self.d_model=d_model |
|
self.d_expert= d_expert |
|
|
|
|
|
self.fc1 = nn.Linear(self.d_model, self.d_expert, bias = True) |
|
self.fc2 = nn.Linear(self.d_expert, self.d_model, bias = True) |
|
|
|
|
|
|
|
nn.init.xavier_uniform_(self.fc1.weight) |
|
nn.init.xavier_uniform_(self.fc2.weight) |
|
|
|
def forward(self, input): |
|
|
|
batch_size, seq_length, d_input = input.size() |
|
assert self.d_model == d_input, "d_model must be the same dimension as the input" |
|
|
|
|
|
return self.fc2(F.relu(self.fc1(input))) |
|
|
|
class MixtureOfExperts(nn.Module): |
|
""" |
|
Mixture of Expert as in DeepSeek |
|
|
|
MoE(x) = x + \sum Expert^s_i(x) + \sum gate(x;K)*Expert^r_i(x) |
|
d_model: embedding dimension (e.g., 512) |
|
d_expert: expert dimension (e.g., 216) |
|
K : top K gate |
|
N_s: number of shared experts |
|
N_r: number of routed experts |
|
alpha1: hyper-parameter; expert-level balance factor |
|
alpha2: hyper-parameter; edevice-level balance factor |
|
alpha3: hyper-parameter; communication balance factor |
|
|
|
D: number of device for distributed system |
|
M: number of device for Device-Limited Routing |
|
""" |
|
def __init__(self, d_model, d_expert, K, N_s, N_r, alpha1, alpha2, alpha3, D=4, M=3): |
|
super().__init__() |
|
|
|
assert D < N_r, "Number of partitions needs to be less than number of routed experts" |
|
assert M <= D, "Number of deviced for Device-Limited Routing needs to be less than number of total device" |
|
|
|
self.d_model=d_model |
|
self.d_expert= d_expert |
|
|
|
self.K = K |
|
self.N_s = N_s |
|
self.N_r = N_r |
|
self.alpha1 = alpha1 |
|
self.alpha2 = alpha2 |
|
self.alpha3 = alpha3 |
|
|
|
self.D = D |
|
self.M = M |
|
|
|
|
|
self.shared_experts = nn.ModuleList([ |
|
Expert(self.d_model, self.d_expert) |
|
for _ in range(N_s) |
|
]) |
|
|
|
self.routed_experts = nn.ModuleList([ |
|
Expert(self.d_model, self.d_expert) |
|
for _ in range(N_r) |
|
]) |
|
|
|
|
|
self.expert_centroids = nn.Parameter( |
|
torch.randn(N_r, d_model) |
|
) |
|
nn.init.xavier_uniform_(self.expert_centroids) |
|
|
|
|
|
def forward(self, input): |
|
|
|
batch_size, seq_length, d_input = input.size() |
|
assert self.d_model == d_input, "d_model must be the same dimension as the input" |
|
|
|
|
|
shared_output = torch.zeros_like(input) |
|
for expert in self.shared_experts: |
|
shared_output += expert(input) |
|
|
|
|
|
|
|
self.similarities = torch.matmul(input, self.expert_centroids.transpose(0, 1)) |
|
assert self.similarities.size(dim=-1) == self.N_r, \ |
|
"last dimension of similarities must be the same as the number of routed expert" |
|
affinity = F.softmax(self.similarities, dim = -1) |
|
|
|
|
|
|
|
values, indexes = torch.topk(affinity, self.K) |
|
values = F.softmax(values, dim=-1) |
|
gate = torch.zeros_like(affinity).scatter_(2, indexes, values) |
|
"""for testing""" |
|
self.last_gate = gate |
|
|
|
routed_output = torch.zeros_like(input) |
|
for i in range(self.N_r): |
|
routed_output += gate[:,:,i].unsqueeze(-1) * self.routed_experts[i](input) |
|
|
|
|
|
|
|
T = batch_size+seq_length |
|
f = self.N_r/(self.K*T) * torch.count_nonzero(gate,(0,1)) |
|
P = 1/T * affinity.sum((0,1)) |
|
expert_loss = self.alpha1 * torch.matmul(f,P) |
|
|
|
|
|
f1= torch.tensor([partition.to(f.dtype).mean() for partition in torch.tensor_split(f, self.D)]) |
|
P1 = torch.tensor([partition.to(P.dtype).sum() for partition in torch.tensor_split(P, self.D)]) |
|
device_loss = self.alpha2 * torch.matmul(f1,P1) |
|
|
|
|
|
f2 = self.D/(self.M*T)*torch.tensor( [ torch.count_nonzero(partition,(0,1)).sum() for partition in torch.tensor_split(gate, self.D, dim=-1)] ) |
|
P2 = P1 |
|
commu_loss = self.alpha3 * torch.matmul(f2,P2) |
|
|
|
return input + shared_output + routed_output, expert_loss, device_loss, commu_loss |