|
import unittest |
|
import torch |
|
from ..moe import MixtureOfExperts,Expert |
|
|
|
import unittest |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
import sys |
|
import os |
|
|
|
|
|
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) |
|
|
|
from moe import Expert, MixtureOfExperts |
|
|
|
|
|
class TestExpert(unittest.TestCase): |
|
"""Test the Expert module of the DeepSeek MoE implementation.""" |
|
|
|
def setUp(self): |
|
|
|
torch.manual_seed(42) |
|
|
|
|
|
self.batch_size = 8 |
|
self.seq_len = 16 |
|
self.d_model = 64 |
|
self.d_expert = 128 |
|
|
|
|
|
self.inputs = torch.randn(self.batch_size, self.seq_len, self.d_model) |
|
|
|
|
|
self.expert = Expert(self.d_model, self.d_expert) |
|
|
|
def test_expert_init(self): |
|
"""Test expert initialization.""" |
|
|
|
self.assertEqual(self.expert.fc1.in_features, self.d_model) |
|
self.assertEqual(self.expert.fc1.out_features, self.d_expert) |
|
self.assertEqual(self.expert.fc2.in_features, self.d_expert) |
|
self.assertEqual(self.expert.fc2.out_features, self.d_model) |
|
|
|
|
|
|
|
self.assertTrue(torch.all(self.expert.fc1.weight < 1.0)) |
|
self.assertTrue(torch.all(self.expert.fc1.weight > -1.0)) |
|
|
|
def test_expert_forward(self): |
|
"""Test the forward pass of the expert module.""" |
|
output = self.expert(self.inputs) |
|
|
|
|
|
self.assertEqual(output.shape, self.inputs.shape) |
|
|
|
|
|
self.assertFalse(torch.allclose(output, self.inputs)) |
|
|
|
|
|
single_input = torch.randn(1, 1, self.d_model) |
|
|
|
|
|
fc1_output = self.expert.fc1(single_input) |
|
relu_output = F.relu(fc1_output) |
|
expected_output = self.expert.fc2(relu_output) |
|
|
|
actual_output = self.expert(single_input) |
|
|
|
|
|
self.assertTrue(torch.allclose(actual_output, expected_output)) |
|
|
|
|
|
class TestMixtureOfExperts(unittest.TestCase): |
|
"""Test the MixtureOfExperts module.""" |
|
|
|
def setUp(self): |
|
|
|
torch.manual_seed(42) |
|
|
|
|
|
self.batch_size = 8 |
|
self.seq_len = 16 |
|
self.d_model = 64 |
|
self.d_expert = 128 |
|
self.K = 2 |
|
self.N_s = 2 |
|
self.N_r = 8 |
|
self.alpha1 = 0.01 |
|
self.alpha2 = 0.01 |
|
self.alpha3 = 0.01 |
|
self.D = 4 |
|
self.M = 3 |
|
|
|
|
|
self.inputs = torch.randn(self.batch_size, self.seq_len, self.d_model) |
|
|
|
|
|
self.moe = MixtureOfExperts( |
|
d_model=self.d_model, |
|
d_expert=self.d_expert, |
|
K=self.K, |
|
N_s=self.N_s, |
|
N_r=self.N_r, |
|
alpha1=self.alpha1, |
|
alpha2=self.alpha2, |
|
alpha3=self.alpha3, |
|
D=self.D, |
|
M=self.M |
|
) |
|
|
|
def test_moe_init(self): |
|
"""Test MoE initialization.""" |
|
|
|
self.assertEqual(len(self.moe.shared_experts), self.N_s) |
|
self.assertEqual(len(self.moe.routed_experts), self.N_r) |
|
|
|
|
|
self.assertEqual(self.moe.expert_centroids.shape, (self.N_r, self.d_model)) |
|
|
|
def test_moe_forward(self): |
|
"""Test the forward pass of the MoE layer.""" |
|
output, expert_loss, device_loss, commu_loss = self.moe(self.inputs) |
|
|
|
|
|
self.assertEqual(output.shape, self.inputs.shape) |
|
|
|
|
|
self.assertEqual(expert_loss.dim(), 0) |
|
self.assertEqual(device_loss.dim(), 0) |
|
self.assertEqual(commu_loss.dim(), 0) |
|
|
|
|
|
self.assertGreaterEqual(expert_loss.item(), 0.0) |
|
self.assertGreaterEqual(device_loss.item(), 0.0) |
|
self.assertGreaterEqual(commu_loss.item(), 0.0) |
|
|
|
def test_topk_routing(self): |
|
"""Test the top-K routing mechanism.""" |
|
|
|
self.moe(self.inputs) |
|
|
|
|
|
self.assertEqual(self.moe.last_gate.shape, (self.batch_size, self.seq_len, self.N_r)) |
|
|
|
|
|
for b in range(self.batch_size): |
|
for s in range(self.seq_len): |
|
|
|
active_experts = torch.count_nonzero(self.moe.last_gate[b, s]) |
|
self.assertEqual(active_experts, self.K) |
|
|
|
|
|
gate_sum = self.moe.last_gate[b, s].sum().item() |
|
self.assertAlmostEqual(gate_sum, 1.0, places=5) |
|
|
|
def test_expert_contribution(self): |
|
"""Test that both shared and routed experts contribute to the output.""" |
|
|
|
special_input = torch.zeros_like(self.inputs) |
|
special_input[:, 0, 0] = 1.0 |
|
|
|
|
|
with torch.no_grad(): |
|
self.moe.expert_centroids.data.fill_(0.0) |
|
shared_only_output, _, _, _ = self.moe(special_input) |
|
|
|
|
|
with torch.no_grad(): |
|
|
|
nn.init.xavier_uniform_(self.moe.expert_centroids) |
|
full_output, _, _, _ = self.moe(special_input) |
|
|
|
|
|
self.assertFalse(torch.allclose(shared_only_output, full_output)) |
|
|
|
def test_residual_connection(self): |
|
"""Test that the residual connection is properly implemented.""" |
|
|
|
with torch.no_grad(): |
|
for expert in self.moe.shared_experts: |
|
expert.fc1.weight.fill_(0.0) |
|
expert.fc1.bias.fill_(0.0) |
|
expert.fc2.weight.fill_(0.0) |
|
expert.fc2.bias.fill_(0.0) |
|
|
|
for expert in self.moe.routed_experts: |
|
expert.fc1.weight.fill_(0.0) |
|
expert.fc1.bias.fill_(0.0) |
|
expert.fc2.weight.fill_(0.0) |
|
expert.fc2.bias.fill_(0.0) |
|
|
|
|
|
nn.init.xavier_uniform_(self.moe.expert_centroids) |
|
|
|
|
|
output, _, _, _ = self.moe(self.inputs) |
|
|
|
|
|
self.assertTrue(torch.allclose(output, self.inputs)) |
|
|
|
|
|
class TestLoadBalancing(unittest.TestCase): |
|
"""Test the load balancing mechanisms of the MixtureOfExperts.""" |
|
|
|
def setUp(self): |
|
|
|
torch.manual_seed(42) |
|
|
|
|
|
self.batch_size = 16 |
|
self.seq_len = 32 |
|
self.d_model = 64 |
|
self.d_expert = 128 |
|
self.K = 2 |
|
self.N_s = 2 |
|
self.N_r = 8 |
|
|
|
|
|
self.inputs = torch.randn(self.batch_size, self.seq_len, self.d_model) |
|
|
|
def test_expert_balance_loss(self): |
|
"""Test that the expert balance loss penalizes imbalanced routing.""" |
|
|
|
moe_balanced = MixtureOfExperts( |
|
d_model=self.d_model, |
|
d_expert=self.d_expert, |
|
K=self.K, |
|
N_s=self.N_s, |
|
N_r=self.N_r, |
|
alpha1=1.0, |
|
alpha2=0.0, |
|
alpha3=0.0, |
|
D=2, |
|
M=2 |
|
) |
|
|
|
moe_unbalanced = MixtureOfExperts( |
|
d_model=self.d_model, |
|
d_expert=self.d_expert, |
|
K=self.K, |
|
N_s=self.N_s, |
|
N_r=self.N_r, |
|
alpha1=0.0, |
|
alpha2=0.0, |
|
alpha3=0.0, |
|
D=2, |
|
M=2 |
|
) |
|
|
|
|
|
skewed_inputs = torch.randn(self.batch_size, self.seq_len, self.d_model) |
|
|
|
|
|
with torch.no_grad(): |
|
|
|
prototype = skewed_inputs.mean(dim=(0, 1)) |
|
moe_unbalanced.expert_centroids[0] = prototype * 10 |
|
|
|
|
|
moe_balanced.expert_centroids.data.copy_(moe_unbalanced.expert_centroids.data) |
|
|
|
|
|
_, unbalanced_loss, _, _ = moe_unbalanced(skewed_inputs) |
|
_, balanced_loss, _, _ = moe_balanced(skewed_inputs) |
|
|
|
|
|
self.assertGreater(balanced_loss.item(), unbalanced_loss.item()) |
|
|
|
def test_device_balance_loss(self): |
|
"""Test that the device balance loss works as expected.""" |
|
|
|
moe = MixtureOfExperts( |
|
d_model=self.d_model, |
|
d_expert=self.d_expert, |
|
K=self.K, |
|
N_s=self.N_s, |
|
N_r=self.N_r, |
|
alpha1=0.0, |
|
alpha2=1.0, |
|
alpha3=0.0, |
|
D=2, |
|
M=2 |
|
) |
|
|
|
|
|
_, _, device_loss, _ = moe(self.inputs) |
|
|
|
|
|
self.assertGreater(device_loss.item(), 0.0) |
|
|
|
def test_communication_balance_loss(self): |
|
"""Test that the communication balance loss works as expected.""" |
|
|
|
moe = MixtureOfExperts( |
|
d_model=self.d_model, |
|
d_expert=self.d_expert, |
|
K=self.K, |
|
N_s=self.N_s, |
|
N_r=self.N_r, |
|
alpha1=0.0, |
|
alpha2=0.0, |
|
alpha3=1.0, |
|
D=2, |
|
M=1 |
|
) |
|
|
|
|
|
_, _, _, commu_loss = moe(self.inputs) |
|
|
|
|
|
self.assertGreater(commu_loss.item(), 0.0) |
|
|
|
|
|
if __name__ == '__main__': |
|
unittest.main() |