boris commited on
Commit
b75e0e9
1 Parent(s): a96f44d

fix: remove breakpoint

Browse files
Files changed (1) hide show
  1. dev/seq2seq/run_seq2seq_flax.py +0 -3
dev/seq2seq/run_seq2seq_flax.py CHANGED
@@ -922,11 +922,8 @@ def main():
922
  eval_metrics.append(metrics)
923
 
924
  # normalize eval metrics
925
- breakpoint()
926
  eval_metrics = get_metrics(eval_metrics)
927
- breakpoint()
928
  eval_metrics = jax.tree_map(jnp.mean, eval_metrics)
929
- breakpoint()
930
 
931
  # log metrics
932
  wandb_log(eval_metrics, step=global_step, prefix="eval")
 
922
  eval_metrics.append(metrics)
923
 
924
  # normalize eval metrics
 
925
  eval_metrics = get_metrics(eval_metrics)
 
926
  eval_metrics = jax.tree_map(jnp.mean, eval_metrics)
 
927
 
928
  # log metrics
929
  wandb_log(eval_metrics, step=global_step, prefix="eval")