Spaces:
Running
Running
feat: log everything through wandb
Browse files- seq2seq/run_seq2seq_flax.py +17 -53
seq2seq/run_seq2seq_flax.py
CHANGED
@@ -57,7 +57,6 @@ from transformers import (
|
|
57 |
FlaxBartForConditionalGeneration,
|
58 |
HfArgumentParser,
|
59 |
TrainingArguments,
|
60 |
-
is_tensorboard_available,
|
61 |
)
|
62 |
from transformers.models.bart.modeling_flax_bart import *
|
63 |
from transformers.file_utils import is_offline_mode
|
@@ -226,10 +225,10 @@ class DataTrainingArguments:
|
|
226 |
"value if set."
|
227 |
},
|
228 |
)
|
229 |
-
|
230 |
default=400,
|
231 |
metadata={
|
232 |
-
"help": "Evaluation will be performed every
|
233 |
},
|
234 |
)
|
235 |
log_model: bool = field(
|
@@ -324,19 +323,6 @@ def data_loader(rng: jax.random.PRNGKey, dataset: Dataset, batch_size: int, shuf
|
|
324 |
yield batch
|
325 |
|
326 |
|
327 |
-
def write_metric(summary_writer, train_metrics, eval_metrics, train_time, step):
|
328 |
-
summary_writer.scalar("train_time", train_time, step)
|
329 |
-
|
330 |
-
train_metrics = get_metrics(train_metrics)
|
331 |
-
for key, vals in train_metrics.items():
|
332 |
-
tag = f"train_epoch/{key}"
|
333 |
-
for i, val in enumerate(vals):
|
334 |
-
summary_writer.scalar(tag, val, step - len(vals) + i + 1)
|
335 |
-
|
336 |
-
for metric_name, value in eval_metrics.items():
|
337 |
-
summary_writer.scalar(f"eval/{metric_name}", value, step)
|
338 |
-
|
339 |
-
|
340 |
def create_learning_rate_fn(
|
341 |
train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float
|
342 |
) -> Callable[[int], jnp.array]:
|
@@ -351,6 +337,14 @@ def create_learning_rate_fn(
|
|
351 |
return schedule_fn
|
352 |
|
353 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
354 |
def main():
|
355 |
# See all possible arguments in src/transformers/training_args.py
|
356 |
# or by passing the --help flag to this script.
|
@@ -377,7 +371,6 @@ def main():
|
|
377 |
|
378 |
# Set up wandb run
|
379 |
wandb.init(
|
380 |
-
sync_tensorboard=True,
|
381 |
entity='wandb',
|
382 |
project='hf-flax-dalle-mini',
|
383 |
job_type='Seq2SeqVQGAN',
|
@@ -578,24 +571,6 @@ def main():
|
|
578 |
result = {k: round(v, 4) for k, v in result.items()}
|
579 |
return result
|
580 |
|
581 |
-
# Enable tensorboard only on the master node
|
582 |
-
has_tensorboard = is_tensorboard_available()
|
583 |
-
if has_tensorboard and jax.process_index() == 0:
|
584 |
-
try:
|
585 |
-
from flax.metrics.tensorboard import SummaryWriter
|
586 |
-
|
587 |
-
summary_writer = SummaryWriter(log_dir=Path(training_args.output_dir))
|
588 |
-
except ImportError as ie:
|
589 |
-
has_tensorboard = False
|
590 |
-
logger.warning(
|
591 |
-
f"Unable to display metrics through TensorBoard because some package are not installed: {ie}"
|
592 |
-
)
|
593 |
-
else:
|
594 |
-
logger.warning(
|
595 |
-
"Unable to display metrics through TensorBoard because the package is not installed: "
|
596 |
-
"Please run pip install tensorboard to enable."
|
597 |
-
)
|
598 |
-
|
599 |
# Initialize our training
|
600 |
rng = jax.random.PRNGKey(training_args.seed)
|
601 |
rng, dropout_rng = jax.random.split(rng)
|
@@ -774,10 +749,8 @@ def main():
|
|
774 |
eval_metrics = get_metrics(eval_metrics)
|
775 |
eval_metrics = jax.tree_map(jnp.mean, eval_metrics)
|
776 |
|
777 |
-
|
778 |
-
|
779 |
-
wandb.log({"eval/step": global_step})
|
780 |
-
wandb.log({f"eval/{k}": jax.device_get(v)})
|
781 |
|
782 |
# compute ROUGE metrics
|
783 |
rouge_desc = ""
|
@@ -790,6 +763,7 @@ def main():
|
|
790 |
desc = f"Epoch... ({epoch + 1}/{num_epochs} | Eval Loss: {eval_metrics['loss']} | {rouge_desc})"
|
791 |
epochs.write(desc)
|
792 |
epochs.desc = desc
|
|
|
793 |
return eval_metrics
|
794 |
|
795 |
for epoch in epochs:
|
@@ -798,7 +772,6 @@ def main():
|
|
798 |
|
799 |
# Create sampling rng
|
800 |
rng, input_rng = jax.random.split(rng)
|
801 |
-
train_metrics = []
|
802 |
|
803 |
# Generate an epoch by shuffling sampling indices from the train dataset
|
804 |
train_loader = data_loader(input_rng, train_dataset, train_batch_size, shuffle=True)
|
@@ -808,32 +781,23 @@ def main():
|
|
808 |
global_step +=1
|
809 |
batch = next(train_loader)
|
810 |
state, train_metric = p_train_step(state, batch)
|
811 |
-
train_metrics.append(train_metric)
|
812 |
|
813 |
if global_step % data_args.log_interval == 0 and jax.process_index() == 0:
|
814 |
-
|
815 |
-
|
816 |
-
wandb.log({"train/step": global_step})
|
817 |
-
wandb.log({f"train/{k}": jax.device_get(v)})
|
818 |
|
819 |
-
if global_step % data_args.
|
820 |
run_evaluation()
|
821 |
|
822 |
train_time += time.time() - train_start
|
823 |
-
|
824 |
train_metric = unreplicate(train_metric)
|
825 |
-
|
826 |
epochs.write(
|
827 |
f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {train_metric['loss']}, Learning Rate: {train_metric['learning_rate']})"
|
828 |
)
|
829 |
|
|
|
830 |
eval_metrics = run_evaluation()
|
831 |
|
832 |
-
# Save metrics
|
833 |
-
if has_tensorboard and jax.process_index() == 0:
|
834 |
-
cur_step = epoch * (len(train_dataset) // train_batch_size)
|
835 |
-
write_metric(summary_writer, train_metrics, eval_metrics, train_time, cur_step)
|
836 |
-
|
837 |
# save checkpoint after each epoch and push checkpoint to the hub
|
838 |
if jax.process_index() == 0:
|
839 |
params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
|
|
|
57 |
FlaxBartForConditionalGeneration,
|
58 |
HfArgumentParser,
|
59 |
TrainingArguments,
|
|
|
60 |
)
|
61 |
from transformers.models.bart.modeling_flax_bart import *
|
62 |
from transformers.file_utils import is_offline_mode
|
|
|
225 |
"value if set."
|
226 |
},
|
227 |
)
|
228 |
+
eval_steps: Optional[int] = field(
|
229 |
default=400,
|
230 |
metadata={
|
231 |
+
"help": "Evaluation will be performed every eval_steps"
|
232 |
},
|
233 |
)
|
234 |
log_model: bool = field(
|
|
|
323 |
yield batch
|
324 |
|
325 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
326 |
def create_learning_rate_fn(
|
327 |
train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float
|
328 |
) -> Callable[[int], jnp.array]:
|
|
|
337 |
return schedule_fn
|
338 |
|
339 |
|
340 |
+
def wandb_log(metrics, step=None, prefix=None):
|
341 |
+
if jax.process_index() == 0:
|
342 |
+
log_metrics = {f'{prefix}/k' if prefix is not None else k: jax.device_get(v) for k,v in metrics.items()}
|
343 |
+
if step is not None:
|
344 |
+
log_metrics = {**metrics, 'train/step': step}
|
345 |
+
wandb.log(log_metrics)
|
346 |
+
|
347 |
+
|
348 |
def main():
|
349 |
# See all possible arguments in src/transformers/training_args.py
|
350 |
# or by passing the --help flag to this script.
|
|
|
371 |
|
372 |
# Set up wandb run
|
373 |
wandb.init(
|
|
|
374 |
entity='wandb',
|
375 |
project='hf-flax-dalle-mini',
|
376 |
job_type='Seq2SeqVQGAN',
|
|
|
571 |
result = {k: round(v, 4) for k, v in result.items()}
|
572 |
return result
|
573 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
574 |
# Initialize our training
|
575 |
rng = jax.random.PRNGKey(training_args.seed)
|
576 |
rng, dropout_rng = jax.random.split(rng)
|
|
|
749 |
eval_metrics = get_metrics(eval_metrics)
|
750 |
eval_metrics = jax.tree_map(jnp.mean, eval_metrics)
|
751 |
|
752 |
+
# log metrics
|
753 |
+
wandb_log(eval_metrics, step=global_step, prefix='eval')
|
|
|
|
|
754 |
|
755 |
# compute ROUGE metrics
|
756 |
rouge_desc = ""
|
|
|
763 |
desc = f"Epoch... ({epoch + 1}/{num_epochs} | Eval Loss: {eval_metrics['loss']} | {rouge_desc})"
|
764 |
epochs.write(desc)
|
765 |
epochs.desc = desc
|
766 |
+
|
767 |
return eval_metrics
|
768 |
|
769 |
for epoch in epochs:
|
|
|
772 |
|
773 |
# Create sampling rng
|
774 |
rng, input_rng = jax.random.split(rng)
|
|
|
775 |
|
776 |
# Generate an epoch by shuffling sampling indices from the train dataset
|
777 |
train_loader = data_loader(input_rng, train_dataset, train_batch_size, shuffle=True)
|
|
|
781 |
global_step +=1
|
782 |
batch = next(train_loader)
|
783 |
state, train_metric = p_train_step(state, batch)
|
|
|
784 |
|
785 |
if global_step % data_args.log_interval == 0 and jax.process_index() == 0:
|
786 |
+
# log metrics
|
787 |
+
wandb_log(unreplicate(train_metric), step=global_step, prefix='tran')
|
|
|
|
|
788 |
|
789 |
+
if global_step % data_args.eval_steps == 0:
|
790 |
run_evaluation()
|
791 |
|
792 |
train_time += time.time() - train_start
|
|
|
793 |
train_metric = unreplicate(train_metric)
|
|
|
794 |
epochs.write(
|
795 |
f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {train_metric['loss']}, Learning Rate: {train_metric['learning_rate']})"
|
796 |
)
|
797 |
|
798 |
+
# Final evaluation
|
799 |
eval_metrics = run_evaluation()
|
800 |
|
|
|
|
|
|
|
|
|
|
|
801 |
# save checkpoint after each epoch and push checkpoint to the hub
|
802 |
if jax.process_index() == 0:
|
803 |
params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
|