Spaces:
Running
Running
fix: log correct metrics
Browse files
seq2seq/run_seq2seq_flax.py
CHANGED
@@ -340,7 +340,7 @@ def wandb_log(metrics, step=None, prefix=None):
|
|
340 |
if jax.process_index() == 0:
|
341 |
log_metrics = {f'{prefix}/k' if prefix is not None else k: jax.device_get(v) for k,v in metrics.items()}
|
342 |
if step is not None:
|
343 |
-
log_metrics = {**
|
344 |
wandb.log(log_metrics)
|
345 |
|
346 |
|
@@ -791,10 +791,13 @@ def main():
|
|
791 |
|
792 |
if global_step % data_args.log_interval == 0 and jax.process_index() == 0:
|
793 |
# log metrics
|
794 |
-
wandb_log(unreplicate(train_metric), step=global_step, prefix='
|
795 |
|
796 |
if global_step % training_args.eval_steps == 0:
|
797 |
run_evaluation()
|
|
|
|
|
|
|
798 |
|
799 |
train_time += time.time() - train_start
|
800 |
train_metric = unreplicate(train_metric)
|
|
|
340 |
if jax.process_index() == 0:
|
341 |
log_metrics = {f'{prefix}/k' if prefix is not None else k: jax.device_get(v) for k,v in metrics.items()}
|
342 |
if step is not None:
|
343 |
+
log_metrics = {**log_metrics, 'train/step': step}
|
344 |
wandb.log(log_metrics)
|
345 |
|
346 |
|
|
|
791 |
|
792 |
if global_step % data_args.log_interval == 0 and jax.process_index() == 0:
|
793 |
# log metrics
|
794 |
+
wandb_log(unreplicate(train_metric), step=global_step, prefix='train')
|
795 |
|
796 |
if global_step % training_args.eval_steps == 0:
|
797 |
run_evaluation()
|
798 |
+
|
799 |
+
# log final train metrics
|
800 |
+
wandb_log(unreplicate(train_metric), step=global_step, prefix='train')
|
801 |
|
802 |
train_time += time.time() - train_start
|
803 |
train_metric = unreplicate(train_metric)
|