Spaces:
Running
Running
fix: remove breakpoint
Browse files
dev/seq2seq/run_seq2seq_flax.py
CHANGED
@@ -922,11 +922,8 @@ def main():
|
|
922 |
eval_metrics.append(metrics)
|
923 |
|
924 |
# normalize eval metrics
|
925 |
-
breakpoint()
|
926 |
eval_metrics = get_metrics(eval_metrics)
|
927 |
-
breakpoint()
|
928 |
eval_metrics = jax.tree_map(jnp.mean, eval_metrics)
|
929 |
-
breakpoint()
|
930 |
|
931 |
# log metrics
|
932 |
wandb_log(eval_metrics, step=global_step, prefix="eval")
|
|
|
922 |
eval_metrics.append(metrics)
|
923 |
|
924 |
# normalize eval metrics
|
|
|
925 |
eval_metrics = get_metrics(eval_metrics)
|
|
|
926 |
eval_metrics = jax.tree_map(jnp.mean, eval_metrics)
|
|
|
927 |
|
928 |
# log metrics
|
929 |
wandb_log(eval_metrics, step=global_step, prefix="eval")
|