boris commited on
Commit
4fa53a5
1 Parent(s): 193c88c

feat(train): use MultiSteps for gradient accumulation

Browse files
Files changed (1) hide show
  1. tools/train/train.py +2 -4
tools/train/train.py CHANGED
@@ -647,9 +647,7 @@ def main():
647
 
648
  # add gradient accumulation
649
  if training_args.gradient_accumulation_steps > 1:
650
- optimizer = optax.chain(
651
- optax.apply_every(training_args.gradient_accumulation_steps), optimizer
652
- )
653
 
654
  # Setup train state
655
  state = TrainState.create(
@@ -693,7 +691,7 @@ def main():
693
 
694
  metrics = {
695
  "loss": loss,
696
- "learning_rate": learning_rate_fn(state.step),
697
  }
698
  metrics = jax.lax.pmean(metrics, axis_name="batch")
699
 
 
647
 
648
  # add gradient accumulation
649
  if training_args.gradient_accumulation_steps > 1:
650
+ optimizer = optax.MultiSteps(optimizer, training_args.gradient_accumulation_steps)
 
 
651
 
652
  # Setup train state
653
  state = TrainState.create(
 
691
 
692
  metrics = {
693
  "loss": loss,
694
+ "learning_rate": learning_rate_fn(state.step // training_args.gradient_accumulation_steps),
695
  }
696
  metrics = jax.lax.pmean(metrics, axis_name="batch")
697