Spaces:
Running
Running
fix: log train_metric only if defined
Browse files
dev/seq2seq/run_seq2seq_flax.py
CHANGED
@@ -336,10 +336,7 @@ def main():
|
|
336 |
# Set the verbosity to info of the Transformers logger (on main process only):
|
337 |
logger.info(f"Training/evaluation parameters {training_args}")
|
338 |
|
339 |
-
#
|
340 |
-
# or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
|
341 |
-
# (the dataset will be downloaded automatically from the datasets Hub).
|
342 |
-
#
|
343 |
if data_args.train_file is not None or data_args.validation_file is not None:
|
344 |
data_files = {
|
345 |
"train": data_args.train_file,
|
@@ -826,7 +823,10 @@ def main():
|
|
826 |
temp_dir=True, # avoid issues with being in a repository
|
827 |
)
|
828 |
|
|
|
829 |
last_time = time.perf_counter()
|
|
|
|
|
830 |
for epoch in epochs:
|
831 |
state.replace(epoch=jax_utils.replicate(epoch))
|
832 |
# ======================== Training ================================
|
@@ -871,12 +871,13 @@ def main():
|
|
871 |
run_save_model(state, eval_metrics)
|
872 |
|
873 |
# log final train metrics
|
874 |
-
train_metric
|
875 |
-
|
|
|
876 |
|
877 |
-
|
878 |
-
|
879 |
-
|
880 |
|
881 |
# Final evaluation
|
882 |
eval_metrics = run_evaluation()
|
|
|
336 |
# Set the verbosity to info of the Transformers logger (on main process only):
|
337 |
logger.info(f"Training/evaluation parameters {training_args}")
|
338 |
|
339 |
+
# Load dataset
|
|
|
|
|
|
|
340 |
if data_args.train_file is not None or data_args.validation_file is not None:
|
341 |
data_files = {
|
342 |
"train": data_args.train_file,
|
|
|
823 |
temp_dir=True, # avoid issues with being in a repository
|
824 |
)
|
825 |
|
826 |
+
# init variables
|
827 |
last_time = time.perf_counter()
|
828 |
+
train_metric = None
|
829 |
+
|
830 |
for epoch in epochs:
|
831 |
state.replace(epoch=jax_utils.replicate(epoch))
|
832 |
# ======================== Training ================================
|
|
|
871 |
run_save_model(state, eval_metrics)
|
872 |
|
873 |
# log final train metrics
|
874 |
+
if train_metric is not None:
|
875 |
+
train_metric = get_metrics(train_metric)
|
876 |
+
wandb_log(train_metric, step=step, prefix="train")
|
877 |
|
878 |
+
epochs.write(
|
879 |
+
f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {train_metric['loss']}, Learning Rate: {train_metric['learning_rate']})"
|
880 |
+
)
|
881 |
|
882 |
# Final evaluation
|
883 |
eval_metrics = run_evaluation()
|