Spaces:
Running
Running
fix: state.step type
Browse files- dev/seq2seq/run_seq2seq_flax.py +14 -12
dev/seq2seq/run_seq2seq_flax.py
CHANGED
@@ -416,7 +416,7 @@ def wandb_log(metrics, step=None, prefix=None):
|
|
416 |
f"{prefix}/{k}" if prefix is not None else k: v for k, v in metrics.items()
|
417 |
}
|
418 |
if step is not None:
|
419 |
-
log_metrics["train/step"] =
|
420 |
wandb.log(log_metrics)
|
421 |
|
422 |
|
@@ -846,7 +846,7 @@ def main():
|
|
846 |
f" Instantaneous batch size per device = {training_args.per_device_train_batch_size}"
|
847 |
)
|
848 |
logger.info(
|
849 |
-
f" Total train batch size (w. parallel &
|
850 |
)
|
851 |
logger.info(f" Total global steps = {total_steps}")
|
852 |
logger.info(f" Total optimization steps = {total_optimization_steps}")
|
@@ -854,7 +854,7 @@ def main():
|
|
854 |
epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0)
|
855 |
|
856 |
# set default x-axis as 'train/step'
|
857 |
-
wandb_log({}, step=state.step)
|
858 |
wandb.define_metric("*", step_metric="train/step")
|
859 |
|
860 |
# add interesting config parameters
|
@@ -893,7 +893,7 @@ def main():
|
|
893 |
eval_metrics = jax.tree_map(jnp.mean, eval_metrics)
|
894 |
|
895 |
# log metrics
|
896 |
-
wandb_log(eval_metrics, step=state.step, prefix="eval")
|
897 |
|
898 |
# Print metrics and update progress bar
|
899 |
desc = f"Epoch... ({epoch + 1}/{num_epochs} | Eval Loss: {eval_metrics['loss']})"
|
@@ -956,7 +956,7 @@ def main():
|
|
956 |
)
|
957 |
# save some space
|
958 |
c = wandb.wandb_sdk.wandb_artifacts.get_artifacts_cache()
|
959 |
-
c.cleanup(
|
960 |
|
961 |
wandb.run.log_artifact(artifact)
|
962 |
|
@@ -972,7 +972,8 @@ def main():
|
|
972 |
|
973 |
for epoch in epochs:
|
974 |
# ======================== Training ================================
|
975 |
-
|
|
|
976 |
|
977 |
# Create sampling rng
|
978 |
rng, input_rng = jax.random.split(rng)
|
@@ -994,19 +995,20 @@ def main():
|
|
994 |
total=steps_per_epoch,
|
995 |
):
|
996 |
state, train_metric = p_train_step(state, batch)
|
|
|
997 |
|
998 |
-
if
|
999 |
# log metrics
|
1000 |
-
wandb_log(unreplicate(train_metric), step=
|
1001 |
|
1002 |
-
if training_args.eval_steps and
|
1003 |
run_evaluation()
|
1004 |
|
1005 |
-
if
|
1006 |
-
run_save_model(state,
|
1007 |
|
1008 |
# log final train metrics
|
1009 |
-
wandb_log(unreplicate(train_metric), step=
|
1010 |
|
1011 |
train_metric = unreplicate(train_metric)
|
1012 |
epochs.write(
|
|
|
416 |
f"{prefix}/{k}" if prefix is not None else k: v for k, v in metrics.items()
|
417 |
}
|
418 |
if step is not None:
|
419 |
+
log_metrics["train/step"] = step
|
420 |
wandb.log(log_metrics)
|
421 |
|
422 |
|
|
|
846 |
f" Instantaneous batch size per device = {training_args.per_device_train_batch_size}"
|
847 |
)
|
848 |
logger.info(
|
849 |
+
f" Total train batch size (w. parallel, distributed & gradient accumulation) = {batch_size_per_update}"
|
850 |
)
|
851 |
logger.info(f" Total global steps = {total_steps}")
|
852 |
logger.info(f" Total optimization steps = {total_optimization_steps}")
|
|
|
854 |
epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0)
|
855 |
|
856 |
# set default x-axis as 'train/step'
|
857 |
+
wandb_log({}, step=unreplicate(state.step))
|
858 |
wandb.define_metric("*", step_metric="train/step")
|
859 |
|
860 |
# add interesting config parameters
|
|
|
893 |
eval_metrics = jax.tree_map(jnp.mean, eval_metrics)
|
894 |
|
895 |
# log metrics
|
896 |
+
wandb_log(eval_metrics, step=unreplicate(state.step), prefix="eval")
|
897 |
|
898 |
# Print metrics and update progress bar
|
899 |
desc = f"Epoch... ({epoch + 1}/{num_epochs} | Eval Loss: {eval_metrics['loss']})"
|
|
|
956 |
)
|
957 |
# save some space
|
958 |
c = wandb.wandb_sdk.wandb_artifacts.get_artifacts_cache()
|
959 |
+
c.cleanup("5GB")
|
960 |
|
961 |
wandb.run.log_artifact(artifact)
|
962 |
|
|
|
972 |
|
973 |
for epoch in epochs:
|
974 |
# ======================== Training ================================
|
975 |
+
step = unreplicate(state.step)
|
976 |
+
wandb_log({"train/epoch": epoch}, step=step)
|
977 |
|
978 |
# Create sampling rng
|
979 |
rng, input_rng = jax.random.split(rng)
|
|
|
995 |
total=steps_per_epoch,
|
996 |
):
|
997 |
state, train_metric = p_train_step(state, batch)
|
998 |
+
step = unreplicate(state.step)
|
999 |
|
1000 |
+
if step % data_args.log_interval == 0 and jax.process_index() == 0:
|
1001 |
# log metrics
|
1002 |
+
wandb_log(unreplicate(train_metric), step=step, prefix="train")
|
1003 |
|
1004 |
+
if training_args.eval_steps and step % training_args.eval_steps == 0:
|
1005 |
run_evaluation()
|
1006 |
|
1007 |
+
if step % data_args.save_model_steps == 0:
|
1008 |
+
run_save_model(state, step, epoch)
|
1009 |
|
1010 |
# log final train metrics
|
1011 |
+
wandb_log(unreplicate(train_metric), step=step, prefix="train")
|
1012 |
|
1013 |
train_metric = unreplicate(train_metric)
|
1014 |
epochs.write(
|