boris commited on
Commit
c91ceb7
2 Parent(s): 193c88c 88c8e06

Merge pull request #122 from borisdayma/feat-acccum

Browse files
Files changed (2) hide show
  1. src/dalle_mini/data.py +41 -6
  2. tools/train/train.py +93 -46
src/dalle_mini/data.py CHANGED
@@ -153,16 +153,24 @@ class Dataset:
153
  ),
154
  )
155
 
156
- def dataloader(self, split, batch_size, epoch=None):
 
 
 
 
157
  def _dataloader_datasets_non_streaming(
158
  dataset: Dataset,
159
- batch_size: int,
 
160
  rng: jax.random.PRNGKey = None,
161
  ):
162
  """
163
  Returns batches of size `batch_size` from truncated `dataset`, sharded over all local devices.
164
  Shuffle batches if rng is set.
165
  """
 
 
 
166
  steps_per_epoch = len(dataset) // batch_size
167
 
168
  if rng is not None:
@@ -178,11 +186,20 @@ class Dataset:
178
  for idx in batch_idx:
179
  batch = dataset[idx]
180
  batch = {k: jnp.array(v) for k, v in batch.items()}
 
 
 
 
 
181
  batch = shard(batch)
182
  yield batch
183
 
184
  def _dataloader_datasets_streaming(
185
- dataset: Dataset, split: str, batch_size: int, epoch: int
 
 
 
 
186
  ):
187
  keys = ["input_ids", "attention_mask", "labels", "decoder_input_ids"]
188
  batch = {k: [] for k in keys}
@@ -199,8 +216,22 @@ class Dataset:
199
  for item in dataset:
200
  for k, v in item.items():
201
  batch[k].append(v)
202
- if len(batch[keys[0]]) == batch_size:
 
 
 
 
 
 
 
203
  batch = {k: jnp.array(v) for k, v in batch.items()}
 
 
 
 
 
 
 
204
  batch = shard(batch)
205
  yield batch
206
  batch = {k: [] for k in keys}
@@ -214,11 +245,15 @@ class Dataset:
214
  raise ValueError(f'split must be "train" or "eval", got {split}')
215
 
216
  if self.streaming:
217
- return _dataloader_datasets_streaming(ds, split, batch_size, epoch)
 
 
218
  else:
219
  if split == "train":
220
  self.rng_dataset, input_rng = jax.random.split(self.rng_dataset)
221
- return _dataloader_datasets_non_streaming(ds, batch_size, input_rng)
 
 
222
 
223
  @property
224
  def length(self):
 
153
  ),
154
  )
155
 
156
+ def dataloader(
157
+ self, split, per_device_batch_size, gradient_accumulation_steps=None, epoch=None
158
+ ):
159
+ num_devices = jax.local_device_count()
160
+
161
  def _dataloader_datasets_non_streaming(
162
  dataset: Dataset,
163
+ per_device_batch_size: int,
164
+ gradient_accumulation_steps: int,
165
  rng: jax.random.PRNGKey = None,
166
  ):
167
  """
168
  Returns batches of size `batch_size` from truncated `dataset`, sharded over all local devices.
169
  Shuffle batches if rng is set.
170
  """
171
+ batch_size = (
172
+ per_device_batch_size * num_devices * gradient_accumulation_steps
173
+ )
174
  steps_per_epoch = len(dataset) // batch_size
175
 
176
  if rng is not None:
 
186
  for idx in batch_idx:
187
  batch = dataset[idx]
188
  batch = {k: jnp.array(v) for k, v in batch.items()}
189
+ if gradient_accumulation_steps is not None:
190
+ batch = jax.tree_map(
191
+ lambda x: x.reshape((-1, per_device_batch_size) + x.shape[1:]),
192
+ batch,
193
+ )
194
  batch = shard(batch)
195
  yield batch
196
 
197
  def _dataloader_datasets_streaming(
198
+ dataset: Dataset,
199
+ split: str,
200
+ per_device_batch_size: int,
201
+ gradient_accumulation_steps: int,
202
+ epoch: int,
203
  ):
204
  keys = ["input_ids", "attention_mask", "labels", "decoder_input_ids"]
205
  batch = {k: [] for k in keys}
 
216
  for item in dataset:
217
  for k, v in item.items():
218
  batch[k].append(v)
219
+ # batch = 5, devices = 8, accumulation = 2 / batch_size = 5 x 8
220
+ # (40, 3, 3) -> shard 8 x (5, 3, 3)
221
+ # (16, 5, 3, 3) -> shard 8 x (2, 5, 3, 3)
222
+ if len(batch[keys[0]]) == per_device_batch_size * num_devices * (
223
+ gradient_accumulation_steps
224
+ if gradient_accumulation_steps is not None
225
+ else 1
226
+ ):
227
  batch = {k: jnp.array(v) for k, v in batch.items()}
228
+ if gradient_accumulation_steps is not None:
229
+ batch = jax.tree_map(
230
+ lambda x: x.reshape(
231
+ (-1, per_device_batch_size) + x.shape[1:]
232
+ ),
233
+ batch,
234
+ )
235
  batch = shard(batch)
236
  yield batch
237
  batch = {k: [] for k in keys}
 
245
  raise ValueError(f'split must be "train" or "eval", got {split}')
246
 
247
  if self.streaming:
248
+ return _dataloader_datasets_streaming(
249
+ ds, split, per_device_batch_size, gradient_accumulation_steps, epoch
250
+ )
251
  else:
252
  if split == "train":
253
  self.rng_dataset, input_rng = jax.random.split(self.rng_dataset)
254
+ return _dataloader_datasets_non_streaming(
255
+ ds, per_device_batch_size, gradient_accumulation_steps, input_rng
256
+ )
257
 
258
  @property
259
  def length(self):
tools/train/train.py CHANGED
@@ -277,8 +277,8 @@ class TrainingArguments:
277
  },
278
  )
279
 
280
- num_train_epochs: float = field(
281
- default=3.0, metadata={"help": "Total number of training epochs to perform."}
282
  )
283
  warmup_steps: int = field(
284
  default=0, metadata={"help": "Linear warmup over warmup_steps."}
@@ -310,12 +310,40 @@ class TrainingArguments:
310
  metadata={"help": "Reference to a wandb artifact for resuming training."},
311
  )
312
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
313
  def __post_init__(self):
314
  assert self.optim in [
315
  "distributed_shampoo",
316
  "adam",
317
  "adafactor",
318
  ], f"Selected optimizer not supported: {self.optim}"
 
 
 
 
 
 
 
 
 
 
319
 
320
 
321
  class TrainState(train_state.TrainState):
@@ -396,17 +424,6 @@ def main():
396
  else:
397
  model_args, data_args, training_args = parser.parse_args_into_dataclasses()
398
 
399
- if (
400
- os.path.exists(training_args.output_dir)
401
- and os.listdir(training_args.output_dir)
402
- and training_args.do_train
403
- and not training_args.overwrite_output_dir
404
- ):
405
- raise ValueError(
406
- f"Output directory ({training_args.output_dir}) already exists and is not empty."
407
- "Use --overwrite_output_dir to overcome."
408
- )
409
-
410
  # Make one log on every process with the configuration for debugging.
411
  logging.basicConfig(
412
  format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
@@ -433,14 +450,18 @@ def main():
433
  )
434
 
435
  logger.info(f"Local TPUs: {jax.local_device_count()}")
436
- assert jax.local_device_count() == 8, "TPUs in use, please check running processes"
 
 
 
 
437
 
438
  # Set up wandb run
439
  if jax.process_index() == 0:
440
  wandb.init(
441
- entity="dalle-mini",
442
- project="dalle-mini",
443
- job_type="Seq2Seq",
444
  config=parser.parse_args(),
445
  )
446
 
@@ -515,22 +536,19 @@ def main():
515
  rng, dropout_rng = jax.random.split(rng)
516
 
517
  # Store some constant
518
- num_epochs = int(training_args.num_train_epochs)
519
  # batch size per node
520
  train_batch_size = (
521
- int(training_args.per_device_train_batch_size) * jax.local_device_count()
522
- )
523
- batch_size_per_update = (
524
- train_batch_size
525
- * training_args.gradient_accumulation_steps
526
- * jax.process_count()
527
  )
 
 
528
  eval_batch_size = (
529
- int(training_args.per_device_eval_batch_size) * jax.local_device_count()
530
  )
531
  len_train_dataset, len_eval_dataset = dataset.length
532
  steps_per_epoch = (
533
- len_train_dataset // (train_batch_size * jax.process_count())
534
  if len_train_dataset is not None
535
  else None
536
  )
@@ -645,12 +663,6 @@ def main():
645
  clipping_threshold=training_args.max_grad_norm,
646
  )
647
 
648
- # add gradient accumulation
649
- if training_args.gradient_accumulation_steps > 1:
650
- optimizer = optax.chain(
651
- optax.apply_every(training_args.gradient_accumulation_steps), optimizer
652
- )
653
-
654
  # Setup train state
655
  state = TrainState.create(
656
  apply_fn=model.__call__,
@@ -673,22 +685,48 @@ def main():
673
  def train_step(state, batch, delta_time):
674
  dropout_rng, new_dropout_rng = jax.random.split(state.dropout_rng)
675
 
676
- def compute_loss(params, batch):
677
- labels = batch.pop("labels")
678
  logits = state.apply_fn(
679
- **batch, params=params, dropout_rng=dropout_rng, train=True
680
  )[0]
681
- loss = loss_fn(logits, labels)
682
- return loss
683
 
684
  grad_fn = jax.value_and_grad(compute_loss)
685
- loss, grads = grad_fn(state.params, batch)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
686
  grads = jax.lax.pmean(grads, "batch")
687
  state = state.apply_gradients(
688
  grads=grads,
689
  dropout_rng=new_dropout_rng,
690
  train_time=state.train_time + delta_time,
691
- train_samples=state.train_samples + train_batch_size * jax.process_count(),
692
  )
693
 
694
  metrics = {
@@ -711,19 +749,20 @@ def main():
711
  return metrics
712
 
713
  # Create parallel version of the train and eval step
714
- p_train_step = jax.pmap(train_step, "batch", donate_argnums=(0,))
715
- p_eval_step = jax.pmap(eval_step, "batch")
716
 
717
  logger.info("***** Running training *****")
718
  logger.info(f" Num examples = {len_train_dataset}")
719
  logger.info(f" Num Epochs = {num_epochs}")
720
  logger.info(
721
- f" Instantaneous batch size per device = {training_args.per_device_train_batch_size}"
722
  )
723
  logger.info(f" Number of devices = {jax.device_count()}")
724
  logger.info(
725
- f" Total train batch size (w. parallel, distributed & gradient accumulation) = {batch_size_per_update}"
726
  )
 
727
  logger.info(f" Model parameters = {num_params:,}")
728
  epochs = tqdm(
729
  range(state.epoch, num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0
@@ -740,8 +779,9 @@ def main():
740
  {
741
  "len_train_dataset": len_train_dataset,
742
  "len_eval_dataset": len_eval_dataset,
743
- "batch_size_per_update": batch_size_per_update,
744
  "num_params": num_params,
 
745
  }
746
  )
747
 
@@ -752,7 +792,9 @@ def main():
752
  # ======================== Evaluating ==============================
753
  eval_metrics = []
754
  if training_args.do_eval:
755
- eval_loader = dataset.dataloader("eval", eval_batch_size)
 
 
756
  eval_steps = (
757
  len_eval_dataset // eval_batch_size
758
  if len_eval_dataset is not None
@@ -869,7 +911,12 @@ def main():
869
  metrics_logger.log({"train/epoch": epoch}, step=unreplicate(state.step))
870
 
871
  # Generate an epoch by shuffling sampling indices from the train dataset
872
- train_loader = dataset.dataloader("train", train_batch_size, epoch)
 
 
 
 
 
873
  # train
874
  for batch in tqdm(
875
  train_loader,
 
277
  },
278
  )
279
 
280
+ num_train_epochs: int = field(
281
+ default=3, metadata={"help": "Total number of training epochs to perform."}
282
  )
283
  warmup_steps: int = field(
284
  default=0, metadata={"help": "Linear warmup over warmup_steps."}
 
310
  metadata={"help": "Reference to a wandb artifact for resuming training."},
311
  )
312
 
313
+ wandb_entity: Optional[str] = field(
314
+ default=None,
315
+ metadata={"help": "The wandb entity to use (for teams)."},
316
+ )
317
+ wandb_project: str = field(
318
+ default="dalle-mini",
319
+ metadata={"help": "The name of the wandb project."},
320
+ )
321
+ wandb_job_type: str = field(
322
+ default="Seq2Seq",
323
+ metadata={"help": "The name of the wandb job type."},
324
+ )
325
+
326
+ assert_TPU_available: bool = field(
327
+ default=False,
328
+ metadata={"help": "Verify that TPU is not in use."},
329
+ )
330
+
331
  def __post_init__(self):
332
  assert self.optim in [
333
  "distributed_shampoo",
334
  "adam",
335
  "adafactor",
336
  ], f"Selected optimizer not supported: {self.optim}"
337
+ if (
338
+ os.path.exists(self.output_dir)
339
+ and os.listdir(self.output_dir)
340
+ and self.do_train
341
+ and not self.overwrite_output_dir
342
+ ):
343
+ raise ValueError(
344
+ f"Output directory ({training_args.output_dir}) already exists and is not empty."
345
+ "Use --overwrite_output_dir to overcome."
346
+ )
347
 
348
 
349
  class TrainState(train_state.TrainState):
 
424
  else:
425
  model_args, data_args, training_args = parser.parse_args_into_dataclasses()
426
 
 
 
 
 
 
 
 
 
 
 
 
427
  # Make one log on every process with the configuration for debugging.
428
  logging.basicConfig(
429
  format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
 
450
  )
451
 
452
  logger.info(f"Local TPUs: {jax.local_device_count()}")
453
+ logger.info(f"Global TPUs: {jax.device_count()}")
454
+ if training_args.assert_TPU_available:
455
+ assert (
456
+ jax.local_device_count() == 8
457
+ ), "TPUs in use, please check running processes"
458
 
459
  # Set up wandb run
460
  if jax.process_index() == 0:
461
  wandb.init(
462
+ entity=training_args.wandb_entity,
463
+ project=training_args.wandb_project,
464
+ job_type=training_args.wandb_job_type,
465
  config=parser.parse_args(),
466
  )
467
 
 
536
  rng, dropout_rng = jax.random.split(rng)
537
 
538
  # Store some constant
539
+ num_epochs = training_args.num_train_epochs
540
  # batch size per node
541
  train_batch_size = (
542
+ training_args.per_device_train_batch_size * jax.local_device_count()
 
 
 
 
 
543
  )
544
+ batch_size_per_node = train_batch_size * training_args.gradient_accumulation_steps
545
+ batch_size_per_step = batch_size_per_node * jax.process_count()
546
  eval_batch_size = (
547
+ training_args.per_device_eval_batch_size * jax.local_device_count()
548
  )
549
  len_train_dataset, len_eval_dataset = dataset.length
550
  steps_per_epoch = (
551
+ len_train_dataset // batch_size_per_node
552
  if len_train_dataset is not None
553
  else None
554
  )
 
663
  clipping_threshold=training_args.max_grad_norm,
664
  )
665
 
 
 
 
 
 
 
666
  # Setup train state
667
  state = TrainState.create(
668
  apply_fn=model.__call__,
 
685
  def train_step(state, batch, delta_time):
686
  dropout_rng, new_dropout_rng = jax.random.split(state.dropout_rng)
687
 
688
+ def compute_loss(params, minibatch):
689
+ labels = minibatch.pop("labels")
690
  logits = state.apply_fn(
691
+ **minibatch, params=params, dropout_rng=dropout_rng, train=True
692
  )[0]
693
+ return loss_fn(logits, labels)
 
694
 
695
  grad_fn = jax.value_and_grad(compute_loss)
696
+
697
+ if training_args.gradient_accumulation_steps == 1:
698
+ minibatch = jax.tree_map(lambda x: x[0], batch)
699
+ loss, grads = grad_fn(state.params, minibatch)
700
+ else:
701
+
702
+ def _cumul_loss_grads(i, cumul_loss_grads):
703
+ minibatch = jax.tree_map(lambda x: x[i], batch)
704
+ return jax.tree_map(
705
+ lambda x, y: x + y,
706
+ cumul_loss_grads,
707
+ grad_fn(state.params, minibatch),
708
+ )
709
+
710
+ init_loss_grads = (
711
+ 0.0,
712
+ jax.tree_map(jnp.zeros_like, state.params),
713
+ )
714
+ loss, grads = jax.tree_map(
715
+ lambda x: x / training_args.gradient_accumulation_steps,
716
+ jax.lax.fori_loop(
717
+ 0,
718
+ training_args.gradient_accumulation_steps,
719
+ _cumul_loss_grads,
720
+ init_loss_grads,
721
+ ),
722
+ )
723
+
724
  grads = jax.lax.pmean(grads, "batch")
725
  state = state.apply_gradients(
726
  grads=grads,
727
  dropout_rng=new_dropout_rng,
728
  train_time=state.train_time + delta_time,
729
+ train_samples=state.train_samples + batch_size_per_step,
730
  )
731
 
732
  metrics = {
 
749
  return metrics
750
 
751
  # Create parallel version of the train and eval step
752
+ p_train_step = jax.pmap(train_step, "batch", donate_argnums=(0, 1))
753
+ p_eval_step = jax.pmap(eval_step, "batch", donate_argnums=(1,))
754
 
755
  logger.info("***** Running training *****")
756
  logger.info(f" Num examples = {len_train_dataset}")
757
  logger.info(f" Num Epochs = {num_epochs}")
758
  logger.info(
759
+ f" Batch size per device = {training_args.per_device_train_batch_size}"
760
  )
761
  logger.info(f" Number of devices = {jax.device_count()}")
762
  logger.info(
763
+ f" Gradient accumulation steps = {training_args.gradient_accumulation_steps}"
764
  )
765
+ logger.info(f" Batch size per update = {batch_size_per_step}")
766
  logger.info(f" Model parameters = {num_params:,}")
767
  epochs = tqdm(
768
  range(state.epoch, num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0
 
779
  {
780
  "len_train_dataset": len_train_dataset,
781
  "len_eval_dataset": len_eval_dataset,
782
+ "batch_size_per_step": batch_size_per_step,
783
  "num_params": num_params,
784
+ "num_devices": jax.device_count(),
785
  }
786
  )
787
 
 
792
  # ======================== Evaluating ==============================
793
  eval_metrics = []
794
  if training_args.do_eval:
795
+ eval_loader = dataset.dataloader(
796
+ "eval", training_args.per_device_eval_batch_size
797
+ )
798
  eval_steps = (
799
  len_eval_dataset // eval_batch_size
800
  if len_eval_dataset is not None
 
911
  metrics_logger.log({"train/epoch": epoch}, step=unreplicate(state.step))
912
 
913
  # Generate an epoch by shuffling sampling indices from the train dataset
914
+ train_loader = dataset.dataloader(
915
+ "train",
916
+ training_args.per_device_train_batch_size,
917
+ training_args.gradient_accumulation_steps,
918
+ epoch,
919
+ )
920
  # train
921
  for batch in tqdm(
922
  train_loader,