boris commited on
Commit
61f888f
2 Parent(s): ecafe5e 833a2d5

Merge pull request #9 from borisdayma/feat--wandb-search

Browse files
Files changed (2) hide show
  1. seq2seq/do_run.sh +1 -0
  2. seq2seq/run_seq2seq_flax.py +32 -12
seq2seq/do_run.sh CHANGED
@@ -6,5 +6,6 @@ python run_seq2seq_flax.py \
6
  --per_device_train_batch_size 24 \
7
  --per_device_eval_batch_size 24 \
8
  --preprocessing_num_workers 48 \
 
9
  --do_train \
10
  --do_eval \
 
6
  --per_device_train_batch_size 24 \
7
  --per_device_eval_batch_size 24 \
8
  --preprocessing_num_workers 48 \
9
+ --warmup_steps 1000 \
10
  --do_train \
11
  --do_eval \
seq2seq/run_seq2seq_flax.py CHANGED
@@ -215,6 +215,13 @@ class DataTrainingArguments:
215
  overwrite_cache: bool = field(
216
  default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
217
  )
 
 
 
 
 
 
 
218
 
219
  def __post_init__(self):
220
  if self.dataset_name is None and self.train_file is None and self.validation_file is None:
@@ -307,12 +314,12 @@ def write_metric(summary_writer, train_metrics, eval_metrics, train_time, step):
307
 
308
  train_metrics = get_metrics(train_metrics)
309
  for key, vals in train_metrics.items():
310
- tag = f"train_{key}"
311
  for i, val in enumerate(vals):
312
  summary_writer.scalar(tag, val, step - len(vals) + i + 1)
313
 
314
  for metric_name, value in eval_metrics.items():
315
- summary_writer.scalar(f"eval_{metric_name}", value, step)
316
 
317
 
318
  def create_learning_rate_fn(
@@ -616,17 +623,24 @@ def main():
616
  return traverse_util.unflatten_dict(flat_mask)
617
 
618
  # create adam optimizer
619
- adamw = optax.adamw(
620
- learning_rate=linear_decay_lr_schedule_fn,
621
- b1=training_args.adam_beta1,
622
- b2=training_args.adam_beta2,
623
- eps=training_args.adam_epsilon,
624
- weight_decay=training_args.weight_decay,
625
- mask=decay_mask_fn,
626
- )
 
 
 
 
 
 
 
627
 
628
  # Setup train state
629
- state = TrainState.create(apply_fn=model.__call__, params=model.params, tx=adamw, dropout_rng=dropout_rng)
630
 
631
  # label smoothed cross entropy
632
  def loss_fn(logits, labels, padding_mask, label_smoothing_factor=0.0):
@@ -718,6 +732,7 @@ def main():
718
 
719
  train_time = 0
720
  epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0)
 
721
  for epoch in epochs:
722
  # ======================== Training ================================
723
  train_start = time.time()
@@ -730,11 +745,16 @@ def main():
730
  train_loader = data_loader(input_rng, train_dataset, train_batch_size, shuffle=True)
731
  steps_per_epoch = len(train_dataset) // train_batch_size
732
  # train
733
- for _ in tqdm(range(steps_per_epoch), desc="Training...", position=1, leave=False):
 
734
  batch = next(train_loader)
735
  state, train_metric = p_train_step(state, batch)
736
  train_metrics.append(train_metric)
737
 
 
 
 
 
738
  train_time += time.time() - train_start
739
 
740
  train_metric = unreplicate(train_metric)
 
215
  overwrite_cache: bool = field(
216
  default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
217
  )
218
+ log_interval: Optional[int] = field(
219
+ default=5,
220
+ metadata={
221
+ "help": "For debugging purposes or quicker training, truncate the number of training examples to this "
222
+ "value if set."
223
+ },
224
+ )
225
 
226
  def __post_init__(self):
227
  if self.dataset_name is None and self.train_file is None and self.validation_file is None:
 
314
 
315
  train_metrics = get_metrics(train_metrics)
316
  for key, vals in train_metrics.items():
317
+ tag = f"train_epoch/{key}"
318
  for i, val in enumerate(vals):
319
  summary_writer.scalar(tag, val, step - len(vals) + i + 1)
320
 
321
  for metric_name, value in eval_metrics.items():
322
+ summary_writer.scalar(f"eval/{metric_name}", value, step)
323
 
324
 
325
  def create_learning_rate_fn(
 
623
  return traverse_util.unflatten_dict(flat_mask)
624
 
625
  # create adam optimizer
626
+ if training_args.adafactor:
627
+ # We use the default parameters here to initialize adafactor,
628
+ # For more details about the parameters please check https://github.com/deepmind/optax/blob/ed02befef9bf81cbbf236be3d2b0e032e9ed4a40/optax/_src/alias.py#L74
629
+ optimizer = optax.adafactor(
630
+ learning_rate=linear_decay_lr_schedule_fn,
631
+ )
632
+ else:
633
+ optimizer = optax.adamw(
634
+ learning_rate=linear_decay_lr_schedule_fn,
635
+ b1=training_args.adam_beta1,
636
+ b2=training_args.adam_beta2,
637
+ eps=training_args.adam_epsilon,
638
+ weight_decay=training_args.weight_decay,
639
+ mask=decay_mask_fn,
640
+ )
641
 
642
  # Setup train state
643
+ state = TrainState.create(apply_fn=model.__call__, params=model.params, tx=optimizer, dropout_rng=dropout_rng)
644
 
645
  # label smoothed cross entropy
646
  def loss_fn(logits, labels, padding_mask, label_smoothing_factor=0.0):
 
732
 
733
  train_time = 0
734
  epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0)
735
+ global_step = 0
736
  for epoch in epochs:
737
  # ======================== Training ================================
738
  train_start = time.time()
 
745
  train_loader = data_loader(input_rng, train_dataset, train_batch_size, shuffle=True)
746
  steps_per_epoch = len(train_dataset) // train_batch_size
747
  # train
748
+ for step in tqdm(range(steps_per_epoch), desc="Training...", position=1, leave=False):
749
+ global_step +=1
750
  batch = next(train_loader)
751
  state, train_metric = p_train_step(state, batch)
752
  train_metrics.append(train_metric)
753
 
754
+ if global_step % data_args.log_interval == 0 and jax.process_index() == 0:
755
+ for k, v in unreplicate(train_metric).items():
756
+ wandb.log(f{'train/{k}': jax.device_get(v)}, step=global_step)
757
+
758
  train_time += time.time() - train_start
759
 
760
  train_metric = unreplicate(train_metric)