Spaces:
Running
Running
fix: typo
Browse files
seq2seq/run_seq2seq_flax.py
CHANGED
@@ -216,7 +216,7 @@ class DataTrainingArguments:
|
|
216 |
default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
|
217 |
)
|
218 |
log_interval: Optional[int] = field(
|
219 |
-
default=
|
220 |
metadata={
|
221 |
"help": "For debugging purposes or quicker training, truncate the number of training examples to this "
|
222 |
"value if set."
|
@@ -753,7 +753,7 @@ def main():
|
|
753 |
|
754 |
if global_step % data_args.log_interval == 0 and jax.process_index() == 0:
|
755 |
for k, v in unreplicate(train_metric).items():
|
756 |
-
wandb.log(f{'train/{k}': jax.device_get(v), step=global_step)
|
757 |
|
758 |
train_time += time.time() - train_start
|
759 |
|
|
|
216 |
default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
|
217 |
)
|
218 |
log_interval: Optional[int] = field(
|
219 |
+
default=5,
|
220 |
metadata={
|
221 |
"help": "For debugging purposes or quicker training, truncate the number of training examples to this "
|
222 |
"value if set."
|
|
|
753 |
|
754 |
if global_step % data_args.log_interval == 0 and jax.process_index() == 0:
|
755 |
for k, v in unreplicate(train_metric).items():
|
756 |
+
wandb.log(f{'train/{k}': jax.device_get(v)}, step=global_step)
|
757 |
|
758 |
train_time += time.time() - train_start
|
759 |
|