boris commited on
Commit
19070ab
·
1 Parent(s): f0a53ac

feat: log everything through wandb

Browse files
Files changed (1) hide show
  1. seq2seq/run_seq2seq_flax.py +17 -53
seq2seq/run_seq2seq_flax.py CHANGED
@@ -57,7 +57,6 @@ from transformers import (
57
  FlaxBartForConditionalGeneration,
58
  HfArgumentParser,
59
  TrainingArguments,
60
- is_tensorboard_available,
61
  )
62
  from transformers.models.bart.modeling_flax_bart import *
63
  from transformers.file_utils import is_offline_mode
@@ -226,10 +225,10 @@ class DataTrainingArguments:
226
  "value if set."
227
  },
228
  )
229
- eval_interval: Optional[int] = field(
230
  default=400,
231
  metadata={
232
- "help": "Evaluation will be performed every eval_interval steps"
233
  },
234
  )
235
  log_model: bool = field(
@@ -324,19 +323,6 @@ def data_loader(rng: jax.random.PRNGKey, dataset: Dataset, batch_size: int, shuf
324
  yield batch
325
 
326
 
327
- def write_metric(summary_writer, train_metrics, eval_metrics, train_time, step):
328
- summary_writer.scalar("train_time", train_time, step)
329
-
330
- train_metrics = get_metrics(train_metrics)
331
- for key, vals in train_metrics.items():
332
- tag = f"train_epoch/{key}"
333
- for i, val in enumerate(vals):
334
- summary_writer.scalar(tag, val, step - len(vals) + i + 1)
335
-
336
- for metric_name, value in eval_metrics.items():
337
- summary_writer.scalar(f"eval/{metric_name}", value, step)
338
-
339
-
340
  def create_learning_rate_fn(
341
  train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float
342
  ) -> Callable[[int], jnp.array]:
@@ -351,6 +337,14 @@ def create_learning_rate_fn(
351
  return schedule_fn
352
 
353
 
 
 
 
 
 
 
 
 
354
  def main():
355
  # See all possible arguments in src/transformers/training_args.py
356
  # or by passing the --help flag to this script.
@@ -377,7 +371,6 @@ def main():
377
 
378
  # Set up wandb run
379
  wandb.init(
380
- sync_tensorboard=True,
381
  entity='wandb',
382
  project='hf-flax-dalle-mini',
383
  job_type='Seq2SeqVQGAN',
@@ -578,24 +571,6 @@ def main():
578
  result = {k: round(v, 4) for k, v in result.items()}
579
  return result
580
 
581
- # Enable tensorboard only on the master node
582
- has_tensorboard = is_tensorboard_available()
583
- if has_tensorboard and jax.process_index() == 0:
584
- try:
585
- from flax.metrics.tensorboard import SummaryWriter
586
-
587
- summary_writer = SummaryWriter(log_dir=Path(training_args.output_dir))
588
- except ImportError as ie:
589
- has_tensorboard = False
590
- logger.warning(
591
- f"Unable to display metrics through TensorBoard because some package are not installed: {ie}"
592
- )
593
- else:
594
- logger.warning(
595
- "Unable to display metrics through TensorBoard because the package is not installed: "
596
- "Please run pip install tensorboard to enable."
597
- )
598
-
599
  # Initialize our training
600
  rng = jax.random.PRNGKey(training_args.seed)
601
  rng, dropout_rng = jax.random.split(rng)
@@ -774,10 +749,8 @@ def main():
774
  eval_metrics = get_metrics(eval_metrics)
775
  eval_metrics = jax.tree_map(jnp.mean, eval_metrics)
776
 
777
- if jax.process_index() == 0:
778
- for k, v in eval_metrics.items():
779
- wandb.log({"eval/step": global_step})
780
- wandb.log({f"eval/{k}": jax.device_get(v)})
781
 
782
  # compute ROUGE metrics
783
  rouge_desc = ""
@@ -790,6 +763,7 @@ def main():
790
  desc = f"Epoch... ({epoch + 1}/{num_epochs} | Eval Loss: {eval_metrics['loss']} | {rouge_desc})"
791
  epochs.write(desc)
792
  epochs.desc = desc
 
793
  return eval_metrics
794
 
795
  for epoch in epochs:
@@ -798,7 +772,6 @@ def main():
798
 
799
  # Create sampling rng
800
  rng, input_rng = jax.random.split(rng)
801
- train_metrics = []
802
 
803
  # Generate an epoch by shuffling sampling indices from the train dataset
804
  train_loader = data_loader(input_rng, train_dataset, train_batch_size, shuffle=True)
@@ -808,32 +781,23 @@ def main():
808
  global_step +=1
809
  batch = next(train_loader)
810
  state, train_metric = p_train_step(state, batch)
811
- train_metrics.append(train_metric)
812
 
813
  if global_step % data_args.log_interval == 0 and jax.process_index() == 0:
814
- print("logging train loss")
815
- for k, v in unreplicate(train_metric).items():
816
- wandb.log({"train/step": global_step})
817
- wandb.log({f"train/{k}": jax.device_get(v)})
818
 
819
- if global_step % data_args.eval_interval == 0 and jax.process_index() == 0:
820
  run_evaluation()
821
 
822
  train_time += time.time() - train_start
823
-
824
  train_metric = unreplicate(train_metric)
825
-
826
  epochs.write(
827
  f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {train_metric['loss']}, Learning Rate: {train_metric['learning_rate']})"
828
  )
829
 
 
830
  eval_metrics = run_evaluation()
831
 
832
- # Save metrics
833
- if has_tensorboard and jax.process_index() == 0:
834
- cur_step = epoch * (len(train_dataset) // train_batch_size)
835
- write_metric(summary_writer, train_metrics, eval_metrics, train_time, cur_step)
836
-
837
  # save checkpoint after each epoch and push checkpoint to the hub
838
  if jax.process_index() == 0:
839
  params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
 
57
  FlaxBartForConditionalGeneration,
58
  HfArgumentParser,
59
  TrainingArguments,
 
60
  )
61
  from transformers.models.bart.modeling_flax_bart import *
62
  from transformers.file_utils import is_offline_mode
 
225
  "value if set."
226
  },
227
  )
228
+ eval_steps: Optional[int] = field(
229
  default=400,
230
  metadata={
231
+ "help": "Evaluation will be performed every eval_steps"
232
  },
233
  )
234
  log_model: bool = field(
 
323
  yield batch
324
 
325
 
 
 
 
 
 
 
 
 
 
 
 
 
 
326
  def create_learning_rate_fn(
327
  train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float
328
  ) -> Callable[[int], jnp.array]:
 
337
  return schedule_fn
338
 
339
 
340
+ def wandb_log(metrics, step=None, prefix=None):
341
+ if jax.process_index() == 0:
342
+ log_metrics = {f'{prefix}/k' if prefix is not None else k: jax.device_get(v) for k,v in metrics.items()}
343
+ if step is not None:
344
+ log_metrics = {**metrics, 'train/step': step}
345
+ wandb.log(log_metrics)
346
+
347
+
348
  def main():
349
  # See all possible arguments in src/transformers/training_args.py
350
  # or by passing the --help flag to this script.
 
371
 
372
  # Set up wandb run
373
  wandb.init(
 
374
  entity='wandb',
375
  project='hf-flax-dalle-mini',
376
  job_type='Seq2SeqVQGAN',
 
571
  result = {k: round(v, 4) for k, v in result.items()}
572
  return result
573
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
574
  # Initialize our training
575
  rng = jax.random.PRNGKey(training_args.seed)
576
  rng, dropout_rng = jax.random.split(rng)
 
749
  eval_metrics = get_metrics(eval_metrics)
750
  eval_metrics = jax.tree_map(jnp.mean, eval_metrics)
751
 
752
+ # log metrics
753
+ wandb_log(eval_metrics, step=global_step, prefix='eval')
 
 
754
 
755
  # compute ROUGE metrics
756
  rouge_desc = ""
 
763
  desc = f"Epoch... ({epoch + 1}/{num_epochs} | Eval Loss: {eval_metrics['loss']} | {rouge_desc})"
764
  epochs.write(desc)
765
  epochs.desc = desc
766
+
767
  return eval_metrics
768
 
769
  for epoch in epochs:
 
772
 
773
  # Create sampling rng
774
  rng, input_rng = jax.random.split(rng)
 
775
 
776
  # Generate an epoch by shuffling sampling indices from the train dataset
777
  train_loader = data_loader(input_rng, train_dataset, train_batch_size, shuffle=True)
 
781
  global_step +=1
782
  batch = next(train_loader)
783
  state, train_metric = p_train_step(state, batch)
 
784
 
785
  if global_step % data_args.log_interval == 0 and jax.process_index() == 0:
786
+ # log metrics
787
+ wandb_log(unreplicate(train_metric), step=global_step, prefix='tran')
 
 
788
 
789
+ if global_step % data_args.eval_steps == 0:
790
  run_evaluation()
791
 
792
  train_time += time.time() - train_start
 
793
  train_metric = unreplicate(train_metric)
 
794
  epochs.write(
795
  f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {train_metric['loss']}, Learning Rate: {train_metric['learning_rate']})"
796
  )
797
 
798
+ # Final evaluation
799
  eval_metrics = run_evaluation()
800
 
 
 
 
 
 
801
  # save checkpoint after each epoch and push checkpoint to the hub
802
  if jax.process_index() == 0:
803
  params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))