boris commited on
Commit
1b757dc
·
1 Parent(s): 901ff72

feat: log more metrics

Browse files
Files changed (1) hide show
  1. tools/train/train.py +41 -20
tools/train/train.py CHANGED
@@ -331,14 +331,37 @@ def create_learning_rate_fn(
331
  return schedule_fn
332
 
333
 
334
- def wandb_log(metrics, step=None, prefix=None):
335
- if jax.process_index() == 0:
336
- log_metrics = {
337
- f"{prefix}/{k}" if prefix is not None else k: v for k, v in metrics.items()
 
 
 
 
 
 
 
 
338
  }
339
- if step is not None:
340
- log_metrics["train/step"] = step
341
- wandb.log(log_metrics)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
342
 
343
 
344
  def main():
@@ -628,9 +651,10 @@ def main():
628
  range(state.epoch, num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0
629
  )
630
 
 
631
  if jax.process_index() == 0:
632
  # set default x-axis as 'train/step'
633
- wandb_log({}, step=state.step)
634
  wandb.define_metric("*", step_metric="train/step")
635
 
636
  # add interesting config parameters
@@ -672,7 +696,9 @@ def main():
672
  eval_metrics = jax.tree_map(jnp.mean, eval_metrics)
673
 
674
  # log metrics
675
- wandb_log(eval_metrics, step=unreplicate(state.step), prefix="eval")
 
 
676
 
677
  # Print metrics and update progress bar
678
  desc = f"Epoch... ({epoch + 1}/{num_epochs} | Eval Loss: {eval_metrics['loss']})"
@@ -772,7 +798,7 @@ def main():
772
  for epoch in epochs:
773
  state.replace(epoch=jax_utils.replicate(epoch))
774
  # ======================== Training ================================
775
- wandb_log({"train/epoch": epoch}, step=unreplicate(state.step))
776
 
777
  # Generate an epoch by shuffling sampling indices from the train dataset
778
  train_loader = dataset.dataloader("train", train_batch_size)
@@ -797,17 +823,12 @@ def main():
797
  step = unreplicate(state.step)
798
 
799
  if step % training_args.logging_steps == 0 and jax.process_index() == 0:
800
- # log metrics
801
- metrics = unreplicate(train_metrics)
802
- # log state parameters
803
- state_dict = {
804
- k.split("_")[-1]: unreplicate(getattr(state, k))
805
- for k in ["epoch", "train_time", "train_samples"]
806
- }
807
- wandb_log({**metrics, **state_dict}, step=step, prefix="train")
808
 
809
  eval_metrics = None
810
  if training_args.eval_steps and step % training_args.eval_steps == 0:
 
811
  eval_metrics = run_evaluation()
812
 
813
  if step % training_args.save_steps == 0:
@@ -815,8 +836,8 @@ def main():
815
 
816
  # log final train metrics
817
  if train_metrics is not None:
818
- train_metrics = unreplicate(train_metrics)
819
- wandb_log(train_metrics, step=step, prefix="train")
820
 
821
  epochs.write(
822
  f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {train_metrics['loss']}, Learning Rate: {train_metrics['learning_rate']})"
 
331
  return schedule_fn
332
 
333
 
334
+ class MetricsLogger:
335
+ def __init__(self, state):
336
+ self.step = state.step
337
+ self.time = time.perf_counter()
338
+
339
+ def get_all_train_metrics(self, train_metrics, state):
340
+ """Make a dict of training metrics to be logged"""
341
+ metrics = unreplicate(train_metrics)
342
+ # get state parameters
343
+ state_dict = {
344
+ k.split("_")[-1]: unreplicate(getattr(state, k))
345
+ for k in ["epoch", "train_time", "train_samples"]
346
  }
347
+ # timing metrics
348
+ new_step = int(unreplicate(state.step))
349
+ new_time = time.perf_counter()
350
+ time_per_step = (new_time - self.time) / (new_step - self.step)
351
+ self.step = new_step
352
+ self.time = new_time
353
+ return {**metrics, **state_dict, "time_per_step": time_per_step}
354
+
355
+ @staticmethod
356
+ def log(metrics, step=None, prefix=None):
357
+ if jax.process_index() == 0:
358
+ log_metrics = {
359
+ f"{prefix}/{k}" if prefix is not None else k: v
360
+ for k, v in metrics.items()
361
+ }
362
+ if step is not None:
363
+ log_metrics["train/step"] = step
364
+ wandb.log(log_metrics)
365
 
366
 
367
  def main():
 
651
  range(state.epoch, num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0
652
  )
653
 
654
+ metrics_logger = MetricsLogger(state)
655
  if jax.process_index() == 0:
656
  # set default x-axis as 'train/step'
657
+ metrics_logger.log({}, step=state.step)
658
  wandb.define_metric("*", step_metric="train/step")
659
 
660
  # add interesting config parameters
 
696
  eval_metrics = jax.tree_map(jnp.mean, eval_metrics)
697
 
698
  # log metrics
699
+ metrics_logger.log(
700
+ eval_metrics, step=unreplicate(state.step), prefix="eval"
701
+ )
702
 
703
  # Print metrics and update progress bar
704
  desc = f"Epoch... ({epoch + 1}/{num_epochs} | Eval Loss: {eval_metrics['loss']})"
 
798
  for epoch in epochs:
799
  state.replace(epoch=jax_utils.replicate(epoch))
800
  # ======================== Training ================================
801
+ metrics_logger.log({"train/epoch": epoch}, step=unreplicate(state.step))
802
 
803
  # Generate an epoch by shuffling sampling indices from the train dataset
804
  train_loader = dataset.dataloader("train", train_batch_size)
 
823
  step = unreplicate(state.step)
824
 
825
  if step % training_args.logging_steps == 0 and jax.process_index() == 0:
826
+ all_metrics = metrics_logger.get_all_train_metrics(train_metrics, state)
827
+ metrics_logger.log(all_metrics, step=step, prefix="train")
 
 
 
 
 
 
828
 
829
  eval_metrics = None
830
  if training_args.eval_steps and step % training_args.eval_steps == 0:
831
+ return
832
  eval_metrics = run_evaluation()
833
 
834
  if step % training_args.save_steps == 0:
 
836
 
837
  # log final train metrics
838
  if train_metrics is not None:
839
+ all_metrics = metrics_logger.get_all_train_metrics(train_metrics, state)
840
+ metrics_logger.log(all_metrics, step=step, prefix="train")
841
 
842
  epochs.write(
843
  f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {train_metrics['loss']}, Learning Rate: {train_metrics['learning_rate']})"