boris commited on
Commit
69cf636
1 Parent(s): 77657e6

feat: use optax for gradient accumulation

Browse files
Files changed (1) hide show
  1. dev/seq2seq/run_seq2seq_flax.py +25 -43
dev/seq2seq/run_seq2seq_flax.py CHANGED
@@ -280,9 +280,7 @@ class DataTrainingArguments:
280
 
281
 
282
  class TrainState(train_state.TrainState):
283
- dropout_rng: jnp.ndarray
284
- grad_accum: jnp.ndarray
285
- optimizer_step: int
286
 
287
  def replicate(self):
288
  return jax_utils.replicate(self).replace(
@@ -502,9 +500,8 @@ def main():
502
  with (Path(artifact_dir) / "training_state.json").open("r") as f:
503
  training_state = json.load(f)
504
  step = training_state["step"]
505
- optimizer_step = step // training_args.gradient_accumulation_steps
506
 
507
- return step, optimizer_step, opt_state
508
 
509
  # Set up wandb run
510
  wandb.init(
@@ -512,6 +509,7 @@ def main():
512
  project="dalle-mini",
513
  job_type="Seq2Seq",
514
  config=parser.parse_args(),
 
515
  )
516
 
517
  # set default x-axis as 'train/step'
@@ -722,7 +720,7 @@ def main():
722
  train_batch_size = (
723
  int(training_args.per_device_train_batch_size) * jax.device_count()
724
  )
725
- total_batch_size = int(train_batch_size) * training_args.gradient_accumulation_steps
726
  eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count()
727
  if data_args.streaming:
728
  len_train_dataset = data_args.len_train
@@ -743,12 +741,12 @@ def main():
743
  len_eval_dataset = len(eval_dataset)
744
  steps_per_epoch = len_train_dataset // train_batch_size
745
  total_steps = steps_per_epoch * num_epochs
746
- total_optimization_steps = (len_train_dataset // total_batch_size) * num_epochs
747
 
748
  # Create learning rate schedule
749
- linear_decay_lr_schedule_fn = create_learning_rate_fn(
750
  len_train_dataset,
751
- total_batch_size,
752
  training_args.num_train_epochs,
753
  training_args.warmup_steps,
754
  training_args.learning_rate,
@@ -783,11 +781,11 @@ def main():
783
  # We use the default parameters here to initialize adafactor,
784
  # For more details about the parameters please check https://github.com/deepmind/optax/blob/ed02befef9bf81cbbf236be3d2b0e032e9ed4a40/optax/_src/alias.py#L74
785
  optimizer = optax.adafactor(
786
- learning_rate=linear_decay_lr_schedule_fn,
787
  )
788
  else:
789
  optimizer = optax.adamw(
790
- learning_rate=linear_decay_lr_schedule_fn,
791
  b1=training_args.adam_beta1,
792
  b2=training_args.adam_beta2,
793
  eps=training_args.adam_epsilon,
@@ -795,21 +793,24 @@ def main():
795
  mask=decay_mask_fn,
796
  )
797
 
 
 
 
 
 
 
798
  # Setup train state
799
  state = TrainState.create(
800
  apply_fn=model.__call__,
801
  params=model.params,
802
  tx=optimizer,
803
  dropout_rng=dropout_rng,
804
- grad_accum=jax.tree_map(jnp.zeros_like, model.params),
805
- optimizer_step=0,
806
  )
807
  if model_args.from_checkpoint is not None:
808
- # restore optimizer state, step and optimizer_step
809
- step, optimizer_step, opt_state = restore_state(state, artifact_dir)
810
- state = state.replace(
811
- step=step, optimizer_step=optimizer_step, opt_state=opt_state
812
- )
813
 
814
  # label smoothed cross entropy
815
  def loss_fn(logits, labels):
@@ -821,7 +822,7 @@ def main():
821
  def train_step(state, batch):
822
  dropout_rng, new_dropout_rng = jax.random.split(state.dropout_rng)
823
 
824
- def compute_loss(params):
825
  labels = batch.pop("labels")
826
  logits = state.apply_fn(
827
  **batch, params=params, dropout_rng=dropout_rng, train=True
@@ -830,35 +831,16 @@ def main():
830
  return loss
831
 
832
  grad_fn = jax.value_and_grad(compute_loss)
833
- loss, grads = grad_fn(state.params)
834
- grad_accum = jax.tree_multimap(lambda x, y: x + y, grads, state.grad_accum)
835
-
836
- def update_fn():
837
- grads = jax.tree_map(
838
- lambda x: x / training_args.gradient_accumulation_steps, grad_accum
839
- )
840
- grads = jax.lax.pmean(grads, "batch")
841
- new_state = state.apply_gradients(
842
- grads=grads,
843
- grad_accum=jax.tree_map(jnp.zeros_like, grads),
844
- optimizer_step=state.optimizer_step + 1,
845
- )
846
- return new_state
847
-
848
- new_state = jax.lax.cond(
849
- (state.step + 1) % training_args.gradient_accumulation_steps == 0,
850
- lambda _: update_fn(),
851
- lambda _: state.replace(grad_accum=grad_accum, step=state.step + 1),
852
- None,
853
- )
854
 
855
  metrics = {
856
  "loss": loss,
857
- "learning_rate": linear_decay_lr_schedule_fn(state.optimizer_step),
858
  }
859
  metrics = jax.lax.pmean(metrics, axis_name="batch")
860
-
861
- return new_state.replace(dropout_rng=new_dropout_rng), metrics
862
 
863
  # Define eval fn
864
  def eval_step(params, batch):
 
280
 
281
 
282
  class TrainState(train_state.TrainState):
283
+ dropout_rng: jnp.ndarray = None
 
 
284
 
285
  def replicate(self):
286
  return jax_utils.replicate(self).replace(
 
500
  with (Path(artifact_dir) / "training_state.json").open("r") as f:
501
  training_state = json.load(f)
502
  step = training_state["step"]
 
503
 
504
+ return step, opt_state
505
 
506
  # Set up wandb run
507
  wandb.init(
 
509
  project="dalle-mini",
510
  job_type="Seq2Seq",
511
  config=parser.parse_args(),
512
+ save_code=True,
513
  )
514
 
515
  # set default x-axis as 'train/step'
 
720
  train_batch_size = (
721
  int(training_args.per_device_train_batch_size) * jax.device_count()
722
  )
723
+ batch_size_per_update = train_batch_size * training_args.gradient_accumulation_steps
724
  eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count()
725
  if data_args.streaming:
726
  len_train_dataset = data_args.len_train
 
741
  len_eval_dataset = len(eval_dataset)
742
  steps_per_epoch = len_train_dataset // train_batch_size
743
  total_steps = steps_per_epoch * num_epochs
744
+ total_optimization_steps = (len_train_dataset // batch_size_per_update) * num_epochs
745
 
746
  # Create learning rate schedule
747
+ learning_rate_fn = create_learning_rate_fn(
748
  len_train_dataset,
749
+ train_batch_size,
750
  training_args.num_train_epochs,
751
  training_args.warmup_steps,
752
  training_args.learning_rate,
 
781
  # We use the default parameters here to initialize adafactor,
782
  # For more details about the parameters please check https://github.com/deepmind/optax/blob/ed02befef9bf81cbbf236be3d2b0e032e9ed4a40/optax/_src/alias.py#L74
783
  optimizer = optax.adafactor(
784
+ learning_rate=learning_rate_fn,
785
  )
786
  else:
787
  optimizer = optax.adamw(
788
+ learning_rate=learning_rate_fn,
789
  b1=training_args.adam_beta1,
790
  b2=training_args.adam_beta2,
791
  eps=training_args.adam_epsilon,
 
793
  mask=decay_mask_fn,
794
  )
795
 
796
+ # add gradient accumulation
797
+ if training_args.gradient_accumulation_steps > 1:
798
+ optimizer = optax.chain(
799
+ optax.apply_every(training_args.gradient_accumulation_steps), optimizer
800
+ )
801
+
802
  # Setup train state
803
  state = TrainState.create(
804
  apply_fn=model.__call__,
805
  params=model.params,
806
  tx=optimizer,
807
  dropout_rng=dropout_rng,
 
 
808
  )
809
  if model_args.from_checkpoint is not None:
810
+ # restore optimizer state and step
811
+ step, opt_state = restore_state(state, artifact_dir)
812
+ state = state.replace(step=step, opt_state=opt_state)
813
+ # TODO: number of remaining training epochs/steps and dataloader state need to be adjusted
 
814
 
815
  # label smoothed cross entropy
816
  def loss_fn(logits, labels):
 
822
  def train_step(state, batch):
823
  dropout_rng, new_dropout_rng = jax.random.split(state.dropout_rng)
824
 
825
+ def compute_loss(params, batch):
826
  labels = batch.pop("labels")
827
  logits = state.apply_fn(
828
  **batch, params=params, dropout_rng=dropout_rng, train=True
 
831
  return loss
832
 
833
  grad_fn = jax.value_and_grad(compute_loss)
834
+ loss, grads = grad_fn(state.params, batch)
835
+ grads = jax.lax.pmean(grads, "batch")
836
+ state = state.apply_gradients(grads=grads)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
837
 
838
  metrics = {
839
  "loss": loss,
840
+ "learning_rate": learning_rate_fn(state.step),
841
  }
842
  metrics = jax.lax.pmean(metrics, axis_name="batch")
843
+ return state.replace(dropout_rng=new_dropout_rng), metrics
 
844
 
845
  # Define eval fn
846
  def eval_step(params, batch):