boris commited on
Commit
9bf9397
1 Parent(s): 6523a6d

fix: log train_metric only if defined

Browse files
Files changed (1) hide show
  1. dev/seq2seq/run_seq2seq_flax.py +10 -9
dev/seq2seq/run_seq2seq_flax.py CHANGED
@@ -336,10 +336,7 @@ def main():
336
  # Set the verbosity to info of the Transformers logger (on main process only):
337
  logger.info(f"Training/evaluation parameters {training_args}")
338
 
339
- # Get the datasets: you can either provide your own CSV/JSON training and evaluation files (see below)
340
- # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
341
- # (the dataset will be downloaded automatically from the datasets Hub).
342
- #
343
  if data_args.train_file is not None or data_args.validation_file is not None:
344
  data_files = {
345
  "train": data_args.train_file,
@@ -826,7 +823,10 @@ def main():
826
  temp_dir=True, # avoid issues with being in a repository
827
  )
828
 
 
829
  last_time = time.perf_counter()
 
 
830
  for epoch in epochs:
831
  state.replace(epoch=jax_utils.replicate(epoch))
832
  # ======================== Training ================================
@@ -871,12 +871,13 @@ def main():
871
  run_save_model(state, eval_metrics)
872
 
873
  # log final train metrics
874
- train_metric = get_metrics(train_metric)
875
- wandb_log(train_metric, step=step, prefix="train")
 
876
 
877
- epochs.write(
878
- f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {train_metric['loss']}, Learning Rate: {train_metric['learning_rate']})"
879
- )
880
 
881
  # Final evaluation
882
  eval_metrics = run_evaluation()
 
336
  # Set the verbosity to info of the Transformers logger (on main process only):
337
  logger.info(f"Training/evaluation parameters {training_args}")
338
 
339
+ # Load dataset
 
 
 
340
  if data_args.train_file is not None or data_args.validation_file is not None:
341
  data_files = {
342
  "train": data_args.train_file,
 
823
  temp_dir=True, # avoid issues with being in a repository
824
  )
825
 
826
+ # init variables
827
  last_time = time.perf_counter()
828
+ train_metric = None
829
+
830
  for epoch in epochs:
831
  state.replace(epoch=jax_utils.replicate(epoch))
832
  # ======================== Training ================================
 
871
  run_save_model(state, eval_metrics)
872
 
873
  # log final train metrics
874
+ if train_metric is not None:
875
+ train_metric = get_metrics(train_metric)
876
+ wandb_log(train_metric, step=step, prefix="train")
877
 
878
+ epochs.write(
879
+ f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {train_metric['loss']}, Learning Rate: {train_metric['learning_rate']})"
880
+ )
881
 
882
  # Final evaluation
883
  eval_metrics = run_evaluation()