boris commited on
Commit
4d55db6
1 Parent(s): dcbf091

fix: accumulation vs lr

Browse files
Files changed (1) hide show
  1. seq2seq/run_seq2seq_flax.py +2 -2
seq2seq/run_seq2seq_flax.py CHANGED
@@ -673,12 +673,12 @@ def main():
673
  grads = jax.tree_map(lambda x: x / training_args.gradient_accumulation_steps, grad_accum)
674
  grads = jax.lax.pmean(grads, "batch")
675
  new_state = state.apply_gradients(
676
- grads=grads, grad_accum=jax.tree_map(jnp.zeros_like, grads), optimizer_step=state.optimizer_step
677
  )
678
  return new_state
679
 
680
  new_state = jax.lax.cond(
681
- state.step % training_args.gradient_accumulation_steps == 0,
682
  lambda _: update_fn(),
683
  lambda _: state.replace(grad_accum=grad_accum, step=state.step + 1),
684
  None,
 
673
  grads = jax.tree_map(lambda x: x / training_args.gradient_accumulation_steps, grad_accum)
674
  grads = jax.lax.pmean(grads, "batch")
675
  new_state = state.apply_gradients(
676
+ grads=grads, grad_accum=jax.tree_map(jnp.zeros_like, grads), optimizer_step=state.optimizer_step + 1
677
  )
678
  return new_state
679
 
680
  new_state = jax.lax.cond(
681
+ (state.step + 1) % training_args.gradient_accumulation_steps == 0,
682
  lambda _: update_fn(),
683
  lambda _: state.replace(grad_accum=grad_accum, step=state.step + 1),
684
  None,