Spaces:
Running
Running
fix: shampoo -> distributed shampoo
Browse files- tools/train/train.py +4 -4
tools/train/train.py
CHANGED
@@ -214,11 +214,11 @@ class TrainingArguments:
|
|
214 |
)
|
215 |
adafactor: bool = field(
|
216 |
default=False,
|
217 |
-
metadata={"help": "
|
218 |
)
|
219 |
-
|
220 |
default=False,
|
221 |
-
metadata={"help": "
|
222 |
)
|
223 |
weight_decay: float = field(
|
224 |
default=None, metadata={"help": "Weight decay if we apply some."}
|
@@ -566,7 +566,7 @@ def main():
|
|
566 |
weight_decay_mask=decay_mask_fn,
|
567 |
clipping_threshold=training_args.max_grad_norm,
|
568 |
)
|
569 |
-
elif training_args.
|
570 |
# parameters from https://github.com/tensorflow/lingvo/blob/03ee9d7cd50764b0424c7c863733c91fc0b053ec/lingvo/jax/optimizers.py#L729
|
571 |
# Notes:
|
572 |
# - mask for weight decay is not implemented but we don't use it anyway
|
|
|
214 |
)
|
215 |
adafactor: bool = field(
|
216 |
default=False,
|
217 |
+
metadata={"help": "Use Adafactor instead of AdamW."},
|
218 |
)
|
219 |
+
distributed_shampoo: bool = field(
|
220 |
default=False,
|
221 |
+
metadata={"help": "Use Distributed Shampoo optimizer instead of AdamW."},
|
222 |
)
|
223 |
weight_decay: float = field(
|
224 |
default=None, metadata={"help": "Weight decay if we apply some."}
|
|
|
566 |
weight_decay_mask=decay_mask_fn,
|
567 |
clipping_threshold=training_args.max_grad_norm,
|
568 |
)
|
569 |
+
elif training_args.distributed_shampoo:
|
570 |
# parameters from https://github.com/tensorflow/lingvo/blob/03ee9d7cd50764b0424c7c863733c91fc0b053ec/lingvo/jax/optimizers.py#L729
|
571 |
# Notes:
|
572 |
# - mask for weight decay is not implemented but we don't use it anyway
|