Spaces:
Running
Running
feat(train): more custom x-axis
Browse files- tools/train/train.py +28 -35
tools/train/train.py
CHANGED
@@ -395,17 +395,16 @@ class TrainState(train_state.TrainState):
|
|
395 |
|
396 |
|
397 |
class MetricsLogger:
|
398 |
-
def __init__(self,
|
399 |
-
self.step =
|
400 |
self.time = time.perf_counter()
|
|
|
401 |
|
402 |
-
def
|
403 |
-
"""
|
404 |
-
|
405 |
-
|
406 |
-
|
407 |
-
k.split("_")[-1]: getattr(state, k)
|
408 |
-
for k in ["epoch", "train_time", "train_samples"]
|
409 |
}
|
410 |
# timing metrics
|
411 |
new_step = int(state.step)
|
@@ -414,19 +413,15 @@ class MetricsLogger:
|
|
414 |
time_per_step = (new_time - self.time) / (new_step - self.step)
|
415 |
self.step = new_step
|
416 |
self.time = new_time
|
417 |
-
state_dict["time_per_step"] = time_per_step
|
418 |
-
return {**metrics, **state_dict}
|
419 |
|
420 |
-
|
421 |
-
def log(metrics, step=None, prefix=None):
|
422 |
if jax.process_index() == 0:
|
423 |
log_metrics = {
|
424 |
f"{prefix}/{k}" if prefix is not None else k: v
|
425 |
for k, v in metrics.items()
|
426 |
}
|
427 |
-
|
428 |
-
log_metrics["train/step"] = step
|
429 |
-
wandb.log(log_metrics)
|
430 |
|
431 |
|
432 |
def main():
|
@@ -878,9 +873,9 @@ def main():
|
|
878 |
return state, metrics
|
879 |
|
880 |
# Define eval fn
|
881 |
-
def eval_step(
|
882 |
batch, labels = batch.pop("labels")
|
883 |
-
logits = model(**batch, params=params, train=False)[0]
|
884 |
loss = loss_fn(logits, labels)
|
885 |
return loss
|
886 |
|
@@ -893,7 +888,7 @@ def main():
|
|
893 |
)
|
894 |
p_eval_step = pjit(
|
895 |
eval_step,
|
896 |
-
in_axis_resources=(
|
897 |
out_axis_resources=None,
|
898 |
)
|
899 |
|
@@ -913,10 +908,14 @@ def main():
|
|
913 |
range(state.epoch, num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0
|
914 |
)
|
915 |
|
916 |
-
|
|
|
|
|
|
|
|
|
|
|
917 |
if jax.process_index() == 0:
|
918 |
# set default x-axis as 'train/step'
|
919 |
-
metrics_logger.log({}, step=state.step)
|
920 |
wandb.define_metric("*", step_metric="train/step")
|
921 |
|
922 |
# add interesting config parameters
|
@@ -950,7 +949,7 @@ def main():
|
|
950 |
# freeze batch to pass safely to JAX transforms
|
951 |
batch = freeze(batch)
|
952 |
# accumulate losses async
|
953 |
-
eval_loss.append(p_eval_step(state
|
954 |
|
955 |
# get the mean of the loss
|
956 |
eval_loss = jnp.stack(eval_loss)
|
@@ -958,7 +957,7 @@ def main():
|
|
958 |
eval_metrics = {"loss": eval_loss}
|
959 |
|
960 |
# log metrics
|
961 |
-
metrics_logger.log(eval_metrics,
|
962 |
|
963 |
# Print metrics and update progress bar
|
964 |
desc = f"Epoch... ({epoch + 1}/{num_epochs} | Eval Loss: {eval_metrics['loss']})"
|
@@ -1036,16 +1035,12 @@ def main():
|
|
1036 |
)
|
1037 |
wandb.run.log_artifact(artifact_state)
|
1038 |
|
1039 |
-
# init variables
|
1040 |
-
last_time = time.perf_counter()
|
1041 |
-
train_metrics = None
|
1042 |
-
step = int(state.step)
|
1043 |
-
|
1044 |
with maps.mesh(mesh.devices, mesh.axis_names):
|
1045 |
for epoch in epochs:
|
1046 |
state.replace(epoch=epoch)
|
1047 |
# ======================== Training ================================
|
1048 |
-
metrics_logger.
|
|
|
1049 |
|
1050 |
# Generate an epoch by shuffling sampling indices from the train dataset
|
1051 |
train_loader = dataset.dataloader(
|
@@ -1086,10 +1081,8 @@ def main():
|
|
1086 |
step += 1
|
1087 |
|
1088 |
if step % training_args.logging_steps == 0 and jax.process_index() == 0:
|
1089 |
-
|
1090 |
-
|
1091 |
-
)
|
1092 |
-
metrics_logger.log(all_metrics, step=step, prefix="train")
|
1093 |
|
1094 |
eval_metrics = None
|
1095 |
if step % training_args.eval_steps == 0:
|
@@ -1100,8 +1093,8 @@ def main():
|
|
1100 |
|
1101 |
# log final train metrics
|
1102 |
if train_metrics is not None:
|
1103 |
-
|
1104 |
-
metrics_logger.log(
|
1105 |
|
1106 |
epochs.write(
|
1107 |
f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {train_metrics['loss']}, Learning Rate: {train_metrics['learning_rate']})"
|
|
|
395 |
|
396 |
|
397 |
class MetricsLogger:
|
398 |
+
def __init__(self, step):
|
399 |
+
self.step = step
|
400 |
self.time = time.perf_counter()
|
401 |
+
self.state_dict = {}
|
402 |
|
403 |
+
def update_state_metrics(self, state):
|
404 |
+
"""Update internal state metrics (logged at each call to be used as x-axis)"""
|
405 |
+
self.state_dict = {
|
406 |
+
f'train/{k.split("_")[-1]}': getattr(state, k)
|
407 |
+
for k in ["step", "epoch", "train_time", "train_samples"]
|
|
|
|
|
408 |
}
|
409 |
# timing metrics
|
410 |
new_step = int(state.step)
|
|
|
413 |
time_per_step = (new_time - self.time) / (new_step - self.step)
|
414 |
self.step = new_step
|
415 |
self.time = new_time
|
416 |
+
self.state_dict["train/time_per_step"] = time_per_step
|
|
|
417 |
|
418 |
+
def log(self, metrics, prefix=None):
|
|
|
419 |
if jax.process_index() == 0:
|
420 |
log_metrics = {
|
421 |
f"{prefix}/{k}" if prefix is not None else k: v
|
422 |
for k, v in metrics.items()
|
423 |
}
|
424 |
+
wandb.log({**log_metrics, **self.state_dict})
|
|
|
|
|
425 |
|
426 |
|
427 |
def main():
|
|
|
873 |
return state, metrics
|
874 |
|
875 |
# Define eval fn
|
876 |
+
def eval_step(state, batch):
|
877 |
batch, labels = batch.pop("labels")
|
878 |
+
logits = model(**batch, params=state.params, train=False)[0]
|
879 |
loss = loss_fn(logits, labels)
|
880 |
return loss
|
881 |
|
|
|
888 |
)
|
889 |
p_eval_step = pjit(
|
890 |
eval_step,
|
891 |
+
in_axis_resources=(state_spec, batch_spec),
|
892 |
out_axis_resources=None,
|
893 |
)
|
894 |
|
|
|
908 |
range(state.epoch, num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0
|
909 |
)
|
910 |
|
911 |
+
# init variables
|
912 |
+
last_time = time.perf_counter()
|
913 |
+
train_metrics = None
|
914 |
+
step = int(state.step)
|
915 |
+
metrics_logger = MetricsLogger(step)
|
916 |
+
|
917 |
if jax.process_index() == 0:
|
918 |
# set default x-axis as 'train/step'
|
|
|
919 |
wandb.define_metric("*", step_metric="train/step")
|
920 |
|
921 |
# add interesting config parameters
|
|
|
949 |
# freeze batch to pass safely to JAX transforms
|
950 |
batch = freeze(batch)
|
951 |
# accumulate losses async
|
952 |
+
eval_loss.append(p_eval_step(state, batch))
|
953 |
|
954 |
# get the mean of the loss
|
955 |
eval_loss = jnp.stack(eval_loss)
|
|
|
957 |
eval_metrics = {"loss": eval_loss}
|
958 |
|
959 |
# log metrics
|
960 |
+
metrics_logger.log(eval_metrics, prefix="eval")
|
961 |
|
962 |
# Print metrics and update progress bar
|
963 |
desc = f"Epoch... ({epoch + 1}/{num_epochs} | Eval Loss: {eval_metrics['loss']})"
|
|
|
1035 |
)
|
1036 |
wandb.run.log_artifact(artifact_state)
|
1037 |
|
|
|
|
|
|
|
|
|
|
|
1038 |
with maps.mesh(mesh.devices, mesh.axis_names):
|
1039 |
for epoch in epochs:
|
1040 |
state.replace(epoch=epoch)
|
1041 |
# ======================== Training ================================
|
1042 |
+
metrics_logger.update_state_metrics(state)
|
1043 |
+
metrics_logger.log({})
|
1044 |
|
1045 |
# Generate an epoch by shuffling sampling indices from the train dataset
|
1046 |
train_loader = dataset.dataloader(
|
|
|
1081 |
step += 1
|
1082 |
|
1083 |
if step % training_args.logging_steps == 0 and jax.process_index() == 0:
|
1084 |
+
metrics_logger.update_state_metrics(state)
|
1085 |
+
metrics_logger.log(train_metrics, prefix="train")
|
|
|
|
|
1086 |
|
1087 |
eval_metrics = None
|
1088 |
if step % training_args.eval_steps == 0:
|
|
|
1093 |
|
1094 |
# log final train metrics
|
1095 |
if train_metrics is not None:
|
1096 |
+
metrics_logger.update_state_metrics(state)
|
1097 |
+
metrics_logger.log(train_metrics, prefix="train")
|
1098 |
|
1099 |
epochs.write(
|
1100 |
f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {train_metrics['loss']}, Learning Rate: {train_metrics['learning_rate']})"
|