Spaces:
Running
Running
feat: add adafactor
Browse files- seq2seq/run_seq2seq_flax.py +16 -9
seq2seq/run_seq2seq_flax.py
CHANGED
@@ -623,17 +623,24 @@ def main():
|
|
623 |
return traverse_util.unflatten_dict(flat_mask)
|
624 |
|
625 |
# create adam optimizer
|
626 |
-
|
627 |
-
|
628 |
-
|
629 |
-
|
630 |
-
|
631 |
-
|
632 |
-
|
633 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
634 |
|
635 |
# Setup train state
|
636 |
-
state = TrainState.create(apply_fn=model.__call__, params=model.params, tx=
|
637 |
|
638 |
# label smoothed cross entropy
|
639 |
def loss_fn(logits, labels, padding_mask, label_smoothing_factor=0.0):
|
|
|
623 |
return traverse_util.unflatten_dict(flat_mask)
|
624 |
|
625 |
# create adam optimizer
|
626 |
+
if training_args.adafactor:
|
627 |
+
# We use the default parameters here to initialize adafactor,
|
628 |
+
# For more details about the parameters please check https://github.com/deepmind/optax/blob/ed02befef9bf81cbbf236be3d2b0e032e9ed4a40/optax/_src/alias.py#L74
|
629 |
+
optimizer = optax.adafactor(
|
630 |
+
learning_rate=linear_decay_lr_schedule_fn,
|
631 |
+
)
|
632 |
+
else:
|
633 |
+
optimizer = optax.adamw(
|
634 |
+
learning_rate=linear_decay_lr_schedule_fn,
|
635 |
+
b1=training_args.adam_beta1,
|
636 |
+
b2=training_args.adam_beta2,
|
637 |
+
eps=training_args.adam_epsilon,
|
638 |
+
weight_decay=training_args.weight_decay,
|
639 |
+
mask=decay_mask_fn,
|
640 |
+
)
|
641 |
|
642 |
# Setup train state
|
643 |
+
state = TrainState.create(apply_fn=model.__call__, params=model.params, tx=optimizer, dropout_rng=dropout_rng)
|
644 |
|
645 |
# label smoothed cross entropy
|
646 |
def loss_fn(logits, labels, padding_mask, label_smoothing_factor=0.0):
|