Update modeling_codegen.py
Browse filesHello, this small change ensures that the labels are on the correct device.
- modeling_codegen.py +1 -0
modeling_codegen.py
CHANGED
@@ -713,6 +713,7 @@ class CodeGenForCausalLM(CodeGenPreTrainedModel):
|
|
713 |
|
714 |
loss = None
|
715 |
if labels is not None:
|
|
|
716 |
# Shift so that tokens < n predict n
|
717 |
shift_logits = lm_logits[..., :-1, :].contiguous()
|
718 |
shift_labels = labels[..., 1:].contiguous()
|
|
|
713 |
|
714 |
loss = None
|
715 |
if labels is not None:
|
716 |
+
labels = labels.to(lm_logits.device)
|
717 |
# Shift so that tokens < n predict n
|
718 |
shift_logits = lm_logits[..., :-1, :].contiguous()
|
719 |
shift_labels = labels[..., 1:].contiguous()
|