Crystalcareai commited on
Commit
5f967cc
1 Parent(s): 3ba04a1

Update modeling_gemmoe.py

Browse files
Files changed (1) hide show
  1. modeling_gemmoe.py +7 -10
modeling_gemmoe.py CHANGED
@@ -1235,16 +1235,13 @@ class GemmoeForCausalLM(GemmoePreTrainedModel):
1235
  logits = self.lm_head(hidden_states)
1236
  logits = logits.float()
1237
 
1238
- # Handle unused parameters
1239
- if self.training:
1240
- for expert in self.model.block_sparse_moe.experts:
1241
- for param in expert.parameters():
1242
- if param.requires_grad and param.grad is None:
1243
- param.grad = torch.zeros_like(param)
1244
-
1245
- loss = None
1246
- if labels is not None:
1247
-
1248
  loss = None
1249
  if labels is not None:
1250
  shift_logits = logits[..., :-1, :].contiguous()
 
1235
  logits = self.lm_head(hidden_states)
1236
  logits = logits.float()
1237
 
1238
+ # Handle unused parameters
1239
+ if self.training:
1240
+ for expert in self.model.layers[-1].block_sparse_moe.experts:
1241
+ for param in expert.parameters():
1242
+ if param.requires_grad and param.grad is None:
1243
+ param.grad = torch.zeros_like(param)
1244
+
 
 
 
1245
  loss = None
1246
  if labels is not None:
1247
  shift_logits = logits[..., :-1, :].contiguous()