deepseek-moe / src /tests /test_moe.py
bird-of-paradise's picture
Initial commit
354a706
import unittest
import torch
from ..moe import MixtureOfExperts,Expert # Using relative import
import unittest
import torch
import torch.nn as nn
import torch.nn.functional as F
import sys
import os
# Add the parent directory to the path so we can import the module
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
from moe import Expert, MixtureOfExperts
class TestExpert(unittest.TestCase):
"""Test the Expert module of the DeepSeek MoE implementation."""
def setUp(self):
# Set random seed for reproducibility
torch.manual_seed(42)
# Common parameters for tests
self.batch_size = 8
self.seq_len = 16
self.d_model = 64
self.d_expert = 128
# Create sample input tensor
self.inputs = torch.randn(self.batch_size, self.seq_len, self.d_model)
# Create expert
self.expert = Expert(self.d_model, self.d_expert)
def test_expert_init(self):
"""Test expert initialization."""
# Check layer parameters
self.assertEqual(self.expert.fc1.in_features, self.d_model)
self.assertEqual(self.expert.fc1.out_features, self.d_expert)
self.assertEqual(self.expert.fc2.in_features, self.d_expert)
self.assertEqual(self.expert.fc2.out_features, self.d_model)
# Check if Xavier initialization was applied
# Just check if weights are within a reasonable range
self.assertTrue(torch.all(self.expert.fc1.weight < 1.0))
self.assertTrue(torch.all(self.expert.fc1.weight > -1.0))
def test_expert_forward(self):
"""Test the forward pass of the expert module."""
output = self.expert(self.inputs)
# Check output shape
self.assertEqual(output.shape, self.inputs.shape)
# Ensure output is different from input (transformation happened)
self.assertFalse(torch.allclose(output, self.inputs))
# Test the expert with a single example (easier to verify calculations)
single_input = torch.randn(1, 1, self.d_model)
# Step-by-step execution to verify correctness
fc1_output = self.expert.fc1(single_input)
relu_output = F.relu(fc1_output)
expected_output = self.expert.fc2(relu_output)
actual_output = self.expert(single_input)
# Verify that the output matches our manual calculation
self.assertTrue(torch.allclose(actual_output, expected_output))
class TestMixtureOfExperts(unittest.TestCase):
"""Test the MixtureOfExperts module."""
def setUp(self):
# Set random seed for reproducibility
torch.manual_seed(42)
# Common parameters for tests
self.batch_size = 8
self.seq_len = 16
self.d_model = 64
self.d_expert = 128
self.K = 2 # Top-K experts per token
self.N_s = 2 # Number of shared experts
self.N_r = 8 # Number of routed experts
self.alpha1 = 0.01 # Expert balance factor
self.alpha2 = 0.01 # Device balance factor
self.alpha3 = 0.01 # Communication balance factor
self.D = 4 # Number of devices
self.M = 3 # Device limit for routing
# Create sample input tensor
self.inputs = torch.randn(self.batch_size, self.seq_len, self.d_model)
# Create MoE layer
self.moe = MixtureOfExperts(
d_model=self.d_model,
d_expert=self.d_expert,
K=self.K,
N_s=self.N_s,
N_r=self.N_r,
alpha1=self.alpha1,
alpha2=self.alpha2,
alpha3=self.alpha3,
D=self.D,
M=self.M
)
def test_moe_init(self):
"""Test MoE initialization."""
# Check expert counts
self.assertEqual(len(self.moe.shared_experts), self.N_s)
self.assertEqual(len(self.moe.routed_experts), self.N_r)
# Check centroid initialization
self.assertEqual(self.moe.expert_centroids.shape, (self.N_r, self.d_model))
def test_moe_forward(self):
"""Test the forward pass of the MoE layer."""
output, expert_loss, device_loss, commu_loss = self.moe(self.inputs)
# Check output shape
self.assertEqual(output.shape, self.inputs.shape)
# Check that losses are scalars
self.assertEqual(expert_loss.dim(), 0)
self.assertEqual(device_loss.dim(), 0)
self.assertEqual(commu_loss.dim(), 0)
# Check that losses are non-negative
self.assertGreaterEqual(expert_loss.item(), 0.0)
self.assertGreaterEqual(device_loss.item(), 0.0)
self.assertGreaterEqual(commu_loss.item(), 0.0)
def test_topk_routing(self):
"""Test the top-K routing mechanism."""
# Forward pass to compute gate values
self.moe(self.inputs)
# Check gate shape
self.assertEqual(self.moe.last_gate.shape, (self.batch_size, self.seq_len, self.N_r))
# Check that exactly K experts are activated per token
for b in range(self.batch_size):
for s in range(self.seq_len):
# Count non-zero gate values for this token
active_experts = torch.count_nonzero(self.moe.last_gate[b, s])
self.assertEqual(active_experts, self.K)
# Check that gate values sum to approximately 1.0
gate_sum = self.moe.last_gate[b, s].sum().item()
self.assertAlmostEqual(gate_sum, 1.0, places=5)
def test_expert_contribution(self):
"""Test that both shared and routed experts contribute to the output."""
# Create an input where we can track contributions
special_input = torch.zeros_like(self.inputs)
special_input[:, 0, 0] = 1.0 # Set a specific element to 1.0
# Process with shared experts only (zero out routed expert centroids)
with torch.no_grad():
self.moe.expert_centroids.data.fill_(0.0)
shared_only_output, _, _, _ = self.moe(special_input)
# Process with both shared and routed experts
with torch.no_grad():
# Reset centroids
nn.init.xavier_uniform_(self.moe.expert_centroids)
full_output, _, _, _ = self.moe(special_input)
# Check that outputs are different, indicating routed experts contributed
self.assertFalse(torch.allclose(shared_only_output, full_output))
def test_residual_connection(self):
"""Test that the residual connection is properly implemented."""
# Zero out all expert weights to isolate residual behavior
with torch.no_grad():
for expert in self.moe.shared_experts:
expert.fc1.weight.fill_(0.0)
expert.fc1.bias.fill_(0.0)
expert.fc2.weight.fill_(0.0)
expert.fc2.bias.fill_(0.0)
for expert in self.moe.routed_experts:
expert.fc1.weight.fill_(0.0)
expert.fc1.bias.fill_(0.0)
expert.fc2.weight.fill_(0.0)
expert.fc2.bias.fill_(0.0)
# Reset centroids to ensure routing still happens
nn.init.xavier_uniform_(self.moe.expert_centroids)
# Process input
output, _, _, _ = self.moe(self.inputs)
# With zero weights, output should match input (residual connection)
self.assertTrue(torch.allclose(output, self.inputs))
class TestLoadBalancing(unittest.TestCase):
"""Test the load balancing mechanisms of the MixtureOfExperts."""
def setUp(self):
# Set random seed for reproducibility
torch.manual_seed(42)
# Common parameters for tests
self.batch_size = 16
self.seq_len = 32
self.d_model = 64
self.d_expert = 128
self.K = 2
self.N_s = 2
self.N_r = 8
# Create sample input tensor
self.inputs = torch.randn(self.batch_size, self.seq_len, self.d_model)
def test_expert_balance_loss(self):
"""Test that the expert balance loss penalizes imbalanced routing."""
# Create two MoE layers with different alpha1 values
moe_balanced = MixtureOfExperts(
d_model=self.d_model,
d_expert=self.d_expert,
K=self.K,
N_s=self.N_s,
N_r=self.N_r,
alpha1=1.0, # High expert balance factor
alpha2=0.0,
alpha3=0.0,
D=2,
M=2
)
moe_unbalanced = MixtureOfExperts(
d_model=self.d_model,
d_expert=self.d_expert,
K=self.K,
N_s=self.N_s,
N_r=self.N_r,
alpha1=0.0, # No expert balance factor
alpha2=0.0,
alpha3=0.0,
D=2,
M=2
)
# Create highly skewed inputs to test balancing
skewed_inputs = torch.randn(self.batch_size, self.seq_len, self.d_model)
# Force skewed routing by manipulating centroids
with torch.no_grad():
# Make first expert's centroid very similar to all inputs
prototype = skewed_inputs.mean(dim=(0, 1))
moe_unbalanced.expert_centroids[0] = prototype * 10
# Copy the same centroids to the balanced MoE
moe_balanced.expert_centroids.data.copy_(moe_unbalanced.expert_centroids.data)
# Process with both MoEs
_, unbalanced_loss, _, _ = moe_unbalanced(skewed_inputs)
_, balanced_loss, _, _ = moe_balanced(skewed_inputs)
# The balanced MoE should produce a higher loss to penalize imbalance
self.assertGreater(balanced_loss.item(), unbalanced_loss.item())
def test_device_balance_loss(self):
"""Test that the device balance loss works as expected."""
# Create MoE with high device balance factor
moe = MixtureOfExperts(
d_model=self.d_model,
d_expert=self.d_expert,
K=self.K,
N_s=self.N_s,
N_r=self.N_r,
alpha1=0.0,
alpha2=1.0, # High device balance factor
alpha3=0.0,
D=2, # Two devices
M=2
)
# Process input
_, _, device_loss, _ = moe(self.inputs)
# Check that device loss is calculated and non-zero
self.assertGreater(device_loss.item(), 0.0)
def test_communication_balance_loss(self):
"""Test that the communication balance loss works as expected."""
# Create MoE with high communication balance factor
moe = MixtureOfExperts(
d_model=self.d_model,
d_expert=self.d_expert,
K=self.K,
N_s=self.N_s,
N_r=self.N_r,
alpha1=0.0,
alpha2=0.0,
alpha3=1.0, # High communication balance factor
D=2, # Two devices
M=1 # Limited to one device
)
# Process input
_, _, _, commu_loss = moe(self.inputs)
# Check that communication loss is calculated and non-zero
self.assertGreater(commu_loss.item(), 0.0)
if __name__ == '__main__':
unittest.main()