File size: 5,610 Bytes
354a706 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 |
import torch
import torch.nn as nn
import torch.nn.functional as F
# Note: This is a simplified version of communication balance loss
# For the complete implementation with proper token-device mapping
# the device-limited routing implementation
# and more efficient calculations, please contact the author
class Expert(nn.Module):
"""
Position-wise Feed-Forward Networks
This consists of two linear transformations with a ReLU activation in between.
FFN(x) = max(0, xW1 + b1 )W2 + b2
d_model: embedding dimension (e.g., 512)
d_expert: expert dimension (e.g., 256)
"""
def __init__(self, d_model, d_expert):
super().__init__()
self.d_model=d_model
self.d_expert= d_expert
# Linear transformation y = xW+b
self.fc1 = nn.Linear(self.d_model, self.d_expert, bias = True)
self.fc2 = nn.Linear(self.d_expert, self.d_model, bias = True)
# for potential speed up
# Pre-normalize the weights (can help with training stability)
nn.init.xavier_uniform_(self.fc1.weight)
nn.init.xavier_uniform_(self.fc2.weight)
def forward(self, input):
# check input and first FF layer dimension matching
batch_size, seq_length, d_input = input.size()
assert self.d_model == d_input, "d_model must be the same dimension as the input"
# max(0, xW_1 + b_1)W_2 + b_2
return self.fc2(F.relu(self.fc1(input)))
class MixtureOfExperts(nn.Module):
"""
Mixture of Expert as in DeepSeek
MoE(x) = x + \sum Expert^s_i(x) + \sum gate(x;K)*Expert^r_i(x)
d_model: embedding dimension (e.g., 512)
d_expert: expert dimension (e.g., 216)
K : top K gate
N_s: number of shared experts
N_r: number of routed experts
alpha1: hyper-parameter; expert-level balance factor
alpha2: hyper-parameter; edevice-level balance factor
alpha3: hyper-parameter; communication balance factor
D: number of device for distributed system
M: number of device for Device-Limited Routing
"""
def __init__(self, d_model, d_expert, K, N_s, N_r, alpha1, alpha2, alpha3, D=4, M=3):
super().__init__()
assert D < N_r, "Number of partitions needs to be less than number of routed experts"
assert M <= D, "Number of deviced for Device-Limited Routing needs to be less than number of total device"
self.d_model=d_model
self.d_expert= d_expert
self.K = K
self.N_s = N_s
self.N_r = N_r
self.alpha1 = alpha1
self.alpha2 = alpha2
self.alpha3 = alpha3
self.D = D # number of device available
self.M = M # for Device-Limited Routing
# initialize shared experts and routed experts
self.shared_experts = nn.ModuleList([
Expert(self.d_model, self.d_expert)
for _ in range(N_s)
])
self.routed_experts = nn.ModuleList([
Expert(self.d_model, self.d_expert)
for _ in range(N_r)
])
# Initiate centroids: learnable parameters, one vector per routed expert
self.expert_centroids = nn.Parameter(
torch.randn(N_r, d_model) # [num_routed_experts, d_model]
)
nn.init.xavier_uniform_(self.expert_centroids)
def forward(self, input):
# check input and first FF layer dimension matching
batch_size, seq_length, d_input = input.size()
assert self.d_model == d_input, "d_model must be the same dimension as the input"
shared_output = torch.zeros_like(input)
for expert in self.shared_experts:
shared_output += expert(input) #[batch, seq, d_model]
# Calculate similarity between input tokens and expert centroids
self.similarities = torch.matmul(input, self.expert_centroids.transpose(0, 1)) #[batch, seq, N_r]
assert self.similarities.size(dim=-1) == self.N_r, \
"last dimension of similarities must be the same as the number of routed expert"
affinity = F.softmax(self.similarities, dim = -1) #[batch, seq, N_r]
## Apply topK to calculate the gate
values, indexes = torch.topk(affinity, self.K)
values = F.softmax(values, dim=-1) # Renormalize the top-K values
gate = torch.zeros_like(affinity).scatter_(2, indexes, values) #[batch, seq, N_r]
"""for testing"""
self.last_gate = gate
routed_output = torch.zeros_like(input)
for i in range(self.N_r):
routed_output += gate[:,:,i].unsqueeze(-1) * self.routed_experts[i](input)
## Auxiliary Loss for Load Balance
# Expert-Level Balance Loss.
T = batch_size+seq_length
f = self.N_r/(self.K*T) * torch.count_nonzero(gate,(0,1))
P = 1/T * affinity.sum((0,1))
expert_loss = self.alpha1 * torch.matmul(f,P)
# Device-evel Balance Loss
f1= torch.tensor([partition.to(f.dtype).mean() for partition in torch.tensor_split(f, self.D)])
P1 = torch.tensor([partition.to(P.dtype).sum() for partition in torch.tensor_split(P, self.D)])
device_loss = self.alpha2 * torch.matmul(f1,P1)
# Communication Balance Loss
f2 = self.D/(self.M*T)*torch.tensor( [ torch.count_nonzero(partition,(0,1)).sum() for partition in torch.tensor_split(gate, self.D, dim=-1)] )
P2 = P1
commu_loss = self.alpha3 * torch.matmul(f2,P2)
return input + shared_output + routed_output, expert_loss, device_loss, commu_loss |