Spaces:
Running
Running
feat(train): use MultiSteps for gradient accumulation
Browse files- tools/train/train.py +2 -4
tools/train/train.py
CHANGED
@@ -647,9 +647,7 @@ def main():
|
|
647 |
|
648 |
# add gradient accumulation
|
649 |
if training_args.gradient_accumulation_steps > 1:
|
650 |
-
optimizer = optax.
|
651 |
-
optax.apply_every(training_args.gradient_accumulation_steps), optimizer
|
652 |
-
)
|
653 |
|
654 |
# Setup train state
|
655 |
state = TrainState.create(
|
@@ -693,7 +691,7 @@ def main():
|
|
693 |
|
694 |
metrics = {
|
695 |
"loss": loss,
|
696 |
-
"learning_rate": learning_rate_fn(state.step),
|
697 |
}
|
698 |
metrics = jax.lax.pmean(metrics, axis_name="batch")
|
699 |
|
|
|
647 |
|
648 |
# add gradient accumulation
|
649 |
if training_args.gradient_accumulation_steps > 1:
|
650 |
+
optimizer = optax.MultiSteps(optimizer, training_args.gradient_accumulation_steps)
|
|
|
|
|
651 |
|
652 |
# Setup train state
|
653 |
state = TrainState.create(
|
|
|
691 |
|
692 |
metrics = {
|
693 |
"loss": loss,
|
694 |
+
"learning_rate": learning_rate_fn(state.step // training_args.gradient_accumulation_steps),
|
695 |
}
|
696 |
metrics = jax.lax.pmean(metrics, axis_name="batch")
|
697 |
|