Spaces:
Running
Running
fix: wandb logging with sync_tensorboard
Browse files
seq2seq/run_seq2seq_flax.py
CHANGED
@@ -755,7 +755,8 @@ def main():
|
|
755 |
|
756 |
if global_step % data_args.log_interval == 0 and jax.process_index() == 0:
|
757 |
for k, v in unreplicate(train_metric).items():
|
758 |
-
wandb.log({
|
|
|
759 |
|
760 |
train_time += time.time() - train_start
|
761 |
|
|
|
755 |
|
756 |
if global_step % data_args.log_interval == 0 and jax.process_index() == 0:
|
757 |
for k, v in unreplicate(train_metric).items():
|
758 |
+
wandb.log({"train/step": global_step})
|
759 |
+
wandb.log({f"train/{k}": jax.device_get(v)})
|
760 |
|
761 |
train_time += time.time() - train_start
|
762 |
|