File size: 5,610 Bytes
354a706
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
import torch
import torch.nn as nn
import torch.nn.functional as F

# Note: This is a simplified version of communication balance loss
# For the complete implementation with proper token-device mapping
# the device-limited routing implementation
# and more efficient calculations, please contact the author

class Expert(nn.Module):
    """
    Position-wise Feed-Forward Networks
    This consists of two linear transformations with a ReLU activation in between.
    
    FFN(x) = max(0, xW1 + b1 )W2 + b2
    d_model: embedding dimension (e.g., 512)
    d_expert: expert dimension (e.g., 256)
    
    """
    def __init__(self, d_model, d_expert):
        super().__init__()
        self.d_model=d_model
        self.d_expert= d_expert
        
        # Linear transformation y = xW+b
        self.fc1 = nn.Linear(self.d_model, self.d_expert, bias = True)
        self.fc2 = nn.Linear(self.d_expert, self.d_model, bias = True)
        
        # for potential speed up
        # Pre-normalize the weights (can help with training stability)
        nn.init.xavier_uniform_(self.fc1.weight)
        nn.init.xavier_uniform_(self.fc2.weight)

    def forward(self, input):
        # check input and first FF layer dimension matching
        batch_size, seq_length, d_input = input.size()
        assert self.d_model == d_input, "d_model must be the same dimension as the input"

        # max(0, xW_1 + b_1)W_2 + b_2 
        return self.fc2(F.relu(self.fc1(input)))

class MixtureOfExperts(nn.Module):
    """
    Mixture of Expert as in DeepSeek
    
    MoE(x) = x + \sum Expert^s_i(x) + \sum gate(x;K)*Expert^r_i(x)
    d_model: embedding dimension (e.g., 512)
    d_expert: expert dimension (e.g., 216)
    K : top K gate
    N_s: number of shared experts
    N_r: number of routed experts
    alpha1: hyper-parameter; expert-level balance factor
    alpha2: hyper-parameter; edevice-level balance factor
    alpha3: hyper-parameter; communication balance factor

    D: number of device for distributed system
    M: number of device for Device-Limited Routing
    """
    def __init__(self, d_model, d_expert, K, N_s, N_r, alpha1, alpha2, alpha3, D=4, M=3):
        super().__init__()

        assert D < N_r, "Number of partitions needs to be less than number of routed experts"
        assert M <= D, "Number of deviced for Device-Limited Routing needs to be less than number of total device"
        
        self.d_model=d_model
        self.d_expert= d_expert
        
        self.K = K
        self.N_s = N_s
        self.N_r = N_r
        self.alpha1 = alpha1
        self.alpha2 = alpha2
        self.alpha3 = alpha3
        
        self.D = D # number of device available
        self.M = M # for Device-Limited Routing

        # initialize shared experts and routed experts
        self.shared_experts = nn.ModuleList([
            Expert(self.d_model, self.d_expert)
            for _ in range(N_s)
        ])
        
        self.routed_experts = nn.ModuleList([
            Expert(self.d_model, self.d_expert)
            for _ in range(N_r)
        ])
        
        # Initiate centroids: learnable parameters, one vector per routed expert
        self.expert_centroids = nn.Parameter(
            torch.randn(N_r, d_model)  # [num_routed_experts, d_model]
        )
        nn.init.xavier_uniform_(self.expert_centroids)

    
    def forward(self, input):
        # check input and first FF layer dimension matching
        batch_size, seq_length, d_input = input.size()
        assert self.d_model == d_input, "d_model must be the same dimension as the input"

        
        shared_output = torch.zeros_like(input)
        for expert in self.shared_experts:
            shared_output += expert(input) #[batch, seq, d_model]

        
        # Calculate similarity between input tokens and expert centroids
        self.similarities = torch.matmul(input, self.expert_centroids.transpose(0, 1)) #[batch, seq, N_r]
        assert self.similarities.size(dim=-1) == self.N_r, \
        "last dimension of similarities must be the same as the number of routed expert"
        affinity = F.softmax(self.similarities, dim = -1)  #[batch, seq, N_r]
        

        ## Apply topK to calculate the gate 
        values, indexes = torch.topk(affinity, self.K)
        values = F.softmax(values, dim=-1) # Renormalize the top-K values
        gate = torch.zeros_like(affinity).scatter_(2, indexes, values)  #[batch, seq, N_r]
        """for testing"""
        self.last_gate = gate

        routed_output = torch.zeros_like(input)
        for i in range(self.N_r):
            routed_output += gate[:,:,i].unsqueeze(-1) * self.routed_experts[i](input)

        ## Auxiliary Loss for Load Balance 
        # Expert-Level Balance Loss.
        T = batch_size+seq_length
        f = self.N_r/(self.K*T) * torch.count_nonzero(gate,(0,1))
        P = 1/T * affinity.sum((0,1))
        expert_loss = self.alpha1 * torch.matmul(f,P)

        # Device-evel Balance Loss
        f1= torch.tensor([partition.to(f.dtype).mean() for partition in torch.tensor_split(f, self.D)])
        P1 = torch.tensor([partition.to(P.dtype).sum() for partition in torch.tensor_split(P, self.D)])
        device_loss = self.alpha2 * torch.matmul(f1,P1)

        # Communication Balance Loss
        f2 = self.D/(self.M*T)*torch.tensor( [ torch.count_nonzero(partition,(0,1)).sum() for partition in  torch.tensor_split(gate, self.D, dim=-1)] )
        P2 = P1 
        commu_loss = self.alpha3 * torch.matmul(f2,P2)
        
        return input + shared_output + routed_output, expert_loss, device_loss, commu_loss