boris commited on
Commit
5f28cd2
·
1 Parent(s): 7c4c287

feat(train): more custom x-axis

Browse files
Files changed (1) hide show
  1. tools/train/train.py +28 -35
tools/train/train.py CHANGED
@@ -395,17 +395,16 @@ class TrainState(train_state.TrainState):
395
 
396
 
397
  class MetricsLogger:
398
- def __init__(self, state):
399
- self.step = int(state.step)
400
  self.time = time.perf_counter()
 
401
 
402
- def get_all_train_metrics(self, train_metrics, state):
403
- """Make a dict of training metrics to be logged"""
404
- metrics = train_metrics
405
- # get state parameters
406
- state_dict = {
407
- k.split("_")[-1]: getattr(state, k)
408
- for k in ["epoch", "train_time", "train_samples"]
409
  }
410
  # timing metrics
411
  new_step = int(state.step)
@@ -414,19 +413,15 @@ class MetricsLogger:
414
  time_per_step = (new_time - self.time) / (new_step - self.step)
415
  self.step = new_step
416
  self.time = new_time
417
- state_dict["time_per_step"] = time_per_step
418
- return {**metrics, **state_dict}
419
 
420
- @staticmethod
421
- def log(metrics, step=None, prefix=None):
422
  if jax.process_index() == 0:
423
  log_metrics = {
424
  f"{prefix}/{k}" if prefix is not None else k: v
425
  for k, v in metrics.items()
426
  }
427
- if step is not None:
428
- log_metrics["train/step"] = step
429
- wandb.log(log_metrics)
430
 
431
 
432
  def main():
@@ -878,9 +873,9 @@ def main():
878
  return state, metrics
879
 
880
  # Define eval fn
881
- def eval_step(params, batch):
882
  batch, labels = batch.pop("labels")
883
- logits = model(**batch, params=params, train=False)[0]
884
  loss = loss_fn(logits, labels)
885
  return loss
886
 
@@ -893,7 +888,7 @@ def main():
893
  )
894
  p_eval_step = pjit(
895
  eval_step,
896
- in_axis_resources=(param_spec, batch_spec),
897
  out_axis_resources=None,
898
  )
899
 
@@ -913,10 +908,14 @@ def main():
913
  range(state.epoch, num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0
914
  )
915
 
916
- metrics_logger = MetricsLogger(state)
 
 
 
 
 
917
  if jax.process_index() == 0:
918
  # set default x-axis as 'train/step'
919
- metrics_logger.log({}, step=state.step)
920
  wandb.define_metric("*", step_metric="train/step")
921
 
922
  # add interesting config parameters
@@ -950,7 +949,7 @@ def main():
950
  # freeze batch to pass safely to JAX transforms
951
  batch = freeze(batch)
952
  # accumulate losses async
953
- eval_loss.append(p_eval_step(state.params, batch))
954
 
955
  # get the mean of the loss
956
  eval_loss = jnp.stack(eval_loss)
@@ -958,7 +957,7 @@ def main():
958
  eval_metrics = {"loss": eval_loss}
959
 
960
  # log metrics
961
- metrics_logger.log(eval_metrics, step=state.step, prefix="eval")
962
 
963
  # Print metrics and update progress bar
964
  desc = f"Epoch... ({epoch + 1}/{num_epochs} | Eval Loss: {eval_metrics['loss']})"
@@ -1036,16 +1035,12 @@ def main():
1036
  )
1037
  wandb.run.log_artifact(artifact_state)
1038
 
1039
- # init variables
1040
- last_time = time.perf_counter()
1041
- train_metrics = None
1042
- step = int(state.step)
1043
-
1044
  with maps.mesh(mesh.devices, mesh.axis_names):
1045
  for epoch in epochs:
1046
  state.replace(epoch=epoch)
1047
  # ======================== Training ================================
1048
- metrics_logger.log({"train/epoch": epoch}, step=state.step)
 
1049
 
1050
  # Generate an epoch by shuffling sampling indices from the train dataset
1051
  train_loader = dataset.dataloader(
@@ -1086,10 +1081,8 @@ def main():
1086
  step += 1
1087
 
1088
  if step % training_args.logging_steps == 0 and jax.process_index() == 0:
1089
- all_metrics = metrics_logger.get_all_train_metrics(
1090
- train_metrics, state
1091
- )
1092
- metrics_logger.log(all_metrics, step=step, prefix="train")
1093
 
1094
  eval_metrics = None
1095
  if step % training_args.eval_steps == 0:
@@ -1100,8 +1093,8 @@ def main():
1100
 
1101
  # log final train metrics
1102
  if train_metrics is not None:
1103
- all_metrics = metrics_logger.get_all_train_metrics(train_metrics, state)
1104
- metrics_logger.log(all_metrics, step=step, prefix="train")
1105
 
1106
  epochs.write(
1107
  f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {train_metrics['loss']}, Learning Rate: {train_metrics['learning_rate']})"
 
395
 
396
 
397
  class MetricsLogger:
398
+ def __init__(self, step):
399
+ self.step = step
400
  self.time = time.perf_counter()
401
+ self.state_dict = {}
402
 
403
+ def update_state_metrics(self, state):
404
+ """Update internal state metrics (logged at each call to be used as x-axis)"""
405
+ self.state_dict = {
406
+ f'train/{k.split("_")[-1]}': getattr(state, k)
407
+ for k in ["step", "epoch", "train_time", "train_samples"]
 
 
408
  }
409
  # timing metrics
410
  new_step = int(state.step)
 
413
  time_per_step = (new_time - self.time) / (new_step - self.step)
414
  self.step = new_step
415
  self.time = new_time
416
+ self.state_dict["train/time_per_step"] = time_per_step
 
417
 
418
+ def log(self, metrics, prefix=None):
 
419
  if jax.process_index() == 0:
420
  log_metrics = {
421
  f"{prefix}/{k}" if prefix is not None else k: v
422
  for k, v in metrics.items()
423
  }
424
+ wandb.log({**log_metrics, **self.state_dict})
 
 
425
 
426
 
427
  def main():
 
873
  return state, metrics
874
 
875
  # Define eval fn
876
+ def eval_step(state, batch):
877
  batch, labels = batch.pop("labels")
878
+ logits = model(**batch, params=state.params, train=False)[0]
879
  loss = loss_fn(logits, labels)
880
  return loss
881
 
 
888
  )
889
  p_eval_step = pjit(
890
  eval_step,
891
+ in_axis_resources=(state_spec, batch_spec),
892
  out_axis_resources=None,
893
  )
894
 
 
908
  range(state.epoch, num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0
909
  )
910
 
911
+ # init variables
912
+ last_time = time.perf_counter()
913
+ train_metrics = None
914
+ step = int(state.step)
915
+ metrics_logger = MetricsLogger(step)
916
+
917
  if jax.process_index() == 0:
918
  # set default x-axis as 'train/step'
 
919
  wandb.define_metric("*", step_metric="train/step")
920
 
921
  # add interesting config parameters
 
949
  # freeze batch to pass safely to JAX transforms
950
  batch = freeze(batch)
951
  # accumulate losses async
952
+ eval_loss.append(p_eval_step(state, batch))
953
 
954
  # get the mean of the loss
955
  eval_loss = jnp.stack(eval_loss)
 
957
  eval_metrics = {"loss": eval_loss}
958
 
959
  # log metrics
960
+ metrics_logger.log(eval_metrics, prefix="eval")
961
 
962
  # Print metrics and update progress bar
963
  desc = f"Epoch... ({epoch + 1}/{num_epochs} | Eval Loss: {eval_metrics['loss']})"
 
1035
  )
1036
  wandb.run.log_artifact(artifact_state)
1037
 
 
 
 
 
 
1038
  with maps.mesh(mesh.devices, mesh.axis_names):
1039
  for epoch in epochs:
1040
  state.replace(epoch=epoch)
1041
  # ======================== Training ================================
1042
+ metrics_logger.update_state_metrics(state)
1043
+ metrics_logger.log({})
1044
 
1045
  # Generate an epoch by shuffling sampling indices from the train dataset
1046
  train_loader = dataset.dataloader(
 
1081
  step += 1
1082
 
1083
  if step % training_args.logging_steps == 0 and jax.process_index() == 0:
1084
+ metrics_logger.update_state_metrics(state)
1085
+ metrics_logger.log(train_metrics, prefix="train")
 
 
1086
 
1087
  eval_metrics = None
1088
  if step % training_args.eval_steps == 0:
 
1093
 
1094
  # log final train metrics
1095
  if train_metrics is not None:
1096
+ metrics_logger.update_state_metrics(state)
1097
+ metrics_logger.log(train_metrics, prefix="train")
1098
 
1099
  epochs.write(
1100
  f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {train_metrics['loss']}, Learning Rate: {train_metrics['learning_rate']})"