Crystalcareai
commited on
Commit
•
5f967cc
1
Parent(s):
3ba04a1
Update modeling_gemmoe.py
Browse files- modeling_gemmoe.py +7 -10
modeling_gemmoe.py
CHANGED
@@ -1235,16 +1235,13 @@ class GemmoeForCausalLM(GemmoePreTrainedModel):
|
|
1235 |
logits = self.lm_head(hidden_states)
|
1236 |
logits = logits.float()
|
1237 |
|
1238 |
-
|
1239 |
-
|
1240 |
-
|
1241 |
-
|
1242 |
-
|
1243 |
-
|
1244 |
-
|
1245 |
-
loss = None
|
1246 |
-
if labels is not None:
|
1247 |
-
|
1248 |
loss = None
|
1249 |
if labels is not None:
|
1250 |
shift_logits = logits[..., :-1, :].contiguous()
|
|
|
1235 |
logits = self.lm_head(hidden_states)
|
1236 |
logits = logits.float()
|
1237 |
|
1238 |
+
# Handle unused parameters
|
1239 |
+
if self.training:
|
1240 |
+
for expert in self.model.layers[-1].block_sparse_moe.experts:
|
1241 |
+
for param in expert.parameters():
|
1242 |
+
if param.requires_grad and param.grad is None:
|
1243 |
+
param.grad = torch.zeros_like(param)
|
1244 |
+
|
|
|
|
|
|
|
1245 |
loss = None
|
1246 |
if labels is not None:
|
1247 |
shift_logits = logits[..., :-1, :].contiguous()
|