Crystalcareai commited on
Commit
45f7601
1 Parent(s): 2de5917

Update modeling_gemmoe.py

Browse files
Files changed (1) hide show
  1. modeling_gemmoe.py +76 -80
modeling_gemmoe.py CHANGED
@@ -65,9 +65,82 @@ logger = logging.get_logger(__name__)
65
 
66
  _CONFIG_FOR_DOC = "GemmoeConfig"
67
 
68
- class GemmoeDistributedDataParallel(nn.parallel.DistributedDataParallel):
69
- def __init__(self, model, **kwargs):
70
- super().__init__(model, find_unused_parameters=True, **kwargs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
 
72
  def approx_gelu(x):
73
  return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * x**3)))
@@ -164,76 +237,6 @@ class GemmoeMLP(nn.Module):
164
  def forward(self, x):
165
  return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
166
 
167
- def load_balancing_loss_func(
168
- self,
169
- gate_logits: torch.Tensor, num_experts: torch.Tensor = None, top_k=2, attention_mask: Optional[torch.Tensor] = None
170
- ) -> float:
171
- r"""
172
- Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch.
173
-
174
- See Switch Transformer (https://arxiv.org/abs/2101.03961) for more details. This function implements the loss
175
- function presented in equations (4) - (6) of the paper. It aims at penalizing cases where the routing between
176
- experts is too unbalanced.
177
-
178
- Args:
179
- gate_logits (Union[`torch.Tensor`, Tuple[torch.Tensor]):
180
- Logits from the `gate`, should be a tuple of model.config.num_hidden_layers tensors of
181
- shape [batch_size X sequence_length, num_experts].
182
- attention_mask (`torch.Tensor`, None):
183
- The attention_mask used in forward function
184
- shape [batch_size X sequence_length] if not None.
185
- num_experts (`int`, *optional*):
186
- Number of experts
187
-
188
- Returns:
189
- The auxiliary loss.
190
- """
191
- if gate_logits is None or not isinstance(gate_logits, tuple):
192
- return 0
193
-
194
- if isinstance(gate_logits, tuple):
195
- compute_device = gate_logits[0].device
196
- concatenated_gate_logits = torch.cat([layer_gate.to(compute_device) for layer_gate in gate_logits], dim=0)
197
-
198
- routing_weights = torch.nn.functional.softmax(concatenated_gate_logits, dim=-1)
199
- _, selected_experts = torch.topk(routing_weights, top_k, dim=-1)
200
- expert_mask = torch.nn.functional.one_hot(selected_experts, num_experts)
201
-
202
- if attention_mask is None:
203
- # Compute the percentage of tokens routed to each experts
204
- tokens_per_expert = torch.mean(expert_mask.float(), dim=0)
205
- # Compute the average probability of routing to these experts
206
- router_prob_per_expert = torch.mean(routing_weights, dim=0)
207
- else:
208
- batch_size, sequence_length = attention_mask.shape
209
- num_hidden_layers = concatenated_gate_logits.shape[0] // (batch_size * sequence_length)
210
-
211
- # Compute the mask that masks all padding tokens as 0 with the same shape of expert_mask
212
- expert_attention_mask = (
213
- attention_mask[None, :, :, None, None]
214
- .expand((num_hidden_layers, batch_size, sequence_length, 2, num_experts))
215
- .reshape(-1, 2, num_experts)
216
- .to(compute_device)
217
- )
218
- # Compute the percentage of tokens routed to each experts
219
- tokens_per_expert = torch.sum(expert_mask.float() * expert_attention_mask, dim=0) / torch.sum(
220
- expert_attention_mask, dim=0
221
- )
222
-
223
- # Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert
224
- router_per_expert_attention_mask = (
225
- attention_mask[None, :, :, None]
226
- .expand((num_hidden_layers, batch_size, sequence_length, num_experts))
227
- .reshape(-1, num_experts)
228
- .to(compute_device)
229
- )
230
- # Compute the average probability of routing to these experts
231
- router_prob_per_expert = torch.sum(routing_weights * router_per_expert_attention_mask, dim=0) / torch.sum(
232
- router_per_expert_attention_mask, dim=0
233
- )
234
-
235
- overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0))
236
- return overall_loss * num_experts
237
 
238
  def repeat_kv(self, hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
239
  """
@@ -1153,13 +1156,6 @@ class GemmoeForCausalLM(GemmoePreTrainedModel):
1153
  # Initialize weights and apply final processing
1154
  self.post_init()
1155
 
1156
- def parallelize(self, device_map=None):
1157
- self.model = GemmoeDistributedDataParallel(
1158
- self.model,
1159
- device_ids=[torch.cuda.current_device()],
1160
- output_device=torch.cuda.current_device(),
1161
- )
1162
-
1163
  def get_input_embeddings(self):
1164
  return self.model.embed_tokens
1165
 
 
65
 
66
  _CONFIG_FOR_DOC = "GemmoeConfig"
67
 
68
+ def load_balancing_loss_func(
69
+ gate_logits: torch.Tensor, num_experts: torch.Tensor = None, top_k=2, attention_mask: Optional[torch.Tensor] = None
70
+ ) -> float:
71
+ r"""
72
+ Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch.
73
+
74
+ See Switch Transformer (https://arxiv.org/abs/2101.03961) for more details. This function implements the loss
75
+ function presented in equations (4) - (6) of the paper. It aims at penalizing cases where the routing between
76
+ experts is too unbalanced.
77
+
78
+ Args:
79
+ gate_logits (Union[`torch.Tensor`, Tuple[torch.Tensor]):
80
+ Logits from the `gate`, should be a tuple of model.config.num_hidden_layers tensors of
81
+ shape [batch_size X sequence_length, num_experts].
82
+ attention_mask (`torch.Tensor`, None):
83
+ The attention_mask used in forward function
84
+ shape [batch_size X sequence_length] if not None.
85
+ num_experts (`int`, *optional*):
86
+ Number of experts
87
+
88
+ Returns:
89
+ The auxiliary loss.
90
+ """
91
+ if gate_logits is None or not isinstance(gate_logits, tuple):
92
+ return 0
93
+
94
+ if isinstance(gate_logits, tuple):
95
+ compute_device = gate_logits[0].device
96
+ concatenated_gate_logits = torch.cat([layer_gate.to(compute_device) for layer_gate in gate_logits], dim=0)
97
+
98
+ routing_weights = torch.nn.functional.softmax(concatenated_gate_logits, dim=-1)
99
+
100
+ _, selected_experts = torch.topk(routing_weights, top_k, dim=-1)
101
+
102
+ expert_mask = torch.nn.functional.one_hot(selected_experts, num_experts)
103
+
104
+ if attention_mask is None:
105
+ # Compute the percentage of tokens routed to each experts
106
+ tokens_per_expert = torch.mean(expert_mask.float(), dim=0)
107
+
108
+ # Compute the average probability of routing to these experts
109
+ router_prob_per_expert = torch.mean(routing_weights, dim=0)
110
+ else:
111
+ batch_size, sequence_length = attention_mask.shape
112
+ num_hidden_layers = concatenated_gate_logits.shape[0] // (batch_size * sequence_length)
113
+
114
+ # Compute the mask that masks all padding tokens as 0 with the same shape of expert_mask
115
+ expert_attention_mask = (
116
+ attention_mask[None, :, :, None, None]
117
+ .expand((num_hidden_layers, batch_size, sequence_length, top_k, num_experts))
118
+ .reshape(-1, top_k, num_experts)
119
+ .to(compute_device)
120
+ )
121
+
122
+ # Compute the percentage of tokens routed to each experts
123
+ tokens_per_expert = torch.sum(expert_mask.float() * expert_attention_mask, dim=0) / torch.sum(
124
+ expert_attention_mask, dim=0
125
+ )
126
+
127
+ # Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert
128
+ router_per_expert_attention_mask = (
129
+ attention_mask[None, :, :, None]
130
+ .expand((num_hidden_layers, batch_size, sequence_length, num_experts))
131
+ .reshape(-1, num_experts)
132
+ .to(compute_device)
133
+ )
134
+
135
+ # Compute the average probability of routing to these experts
136
+ router_prob_per_expert = torch.sum(routing_weights * router_per_expert_attention_mask, dim=0) / torch.sum(
137
+ router_per_expert_attention_mask, dim=0
138
+ )
139
+
140
+ overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0))
141
+ return overall_loss * num_experts
142
+
143
+
144
 
145
  def approx_gelu(x):
146
  return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * x**3)))
 
237
  def forward(self, x):
238
  return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
239
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
240
 
241
  def repeat_kv(self, hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
242
  """
 
1156
  # Initialize weights and apply final processing
1157
  self.post_init()
1158
 
 
 
 
 
 
 
 
1159
  def get_input_embeddings(self):
1160
  return self.model.embed_tokens
1161