Spaces:
Running
Running
feat(train): refactor learning rate params
Browse files- tools/train/train.py +53 -35
tools/train/train.py
CHANGED
@@ -246,9 +246,29 @@ class TrainingArguments:
|
|
246 |
},
|
247 |
)
|
248 |
|
249 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
250 |
default=False,
|
251 |
-
metadata={
|
|
|
|
|
252 |
)
|
253 |
|
254 |
num_train_epochs: float = field(
|
@@ -321,33 +341,6 @@ class TrainState(train_state.TrainState):
|
|
321 |
)
|
322 |
|
323 |
|
324 |
-
def create_learning_rate_fn(
|
325 |
-
num_warmup_steps: int,
|
326 |
-
learning_rate: float,
|
327 |
-
use_decay: bool,
|
328 |
-
num_train_steps: int = None, # used only with `use_decay`, typically train_size // batch_size * num_epochs
|
329 |
-
) -> Callable[[int], jnp.array]:
|
330 |
-
"""Returns a linear warmup, linear_decay learning rate function."""
|
331 |
-
if use_decay:
|
332 |
-
assert (
|
333 |
-
num_train_steps is not None
|
334 |
-
), "Learning rate with decay requires number of training steps"
|
335 |
-
warmup_fn = optax.linear_schedule(
|
336 |
-
init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps
|
337 |
-
)
|
338 |
-
if not use_decay:
|
339 |
-
return warmup_fn
|
340 |
-
decay_fn = optax.linear_schedule(
|
341 |
-
init_value=learning_rate,
|
342 |
-
end_value=0,
|
343 |
-
transition_steps=num_train_steps - num_warmup_steps,
|
344 |
-
)
|
345 |
-
schedule_fn = optax.join_schedules(
|
346 |
-
schedules=[warmup_fn, decay_fn], boundaries=[num_warmup_steps]
|
347 |
-
)
|
348 |
-
return schedule_fn
|
349 |
-
|
350 |
-
|
351 |
class MetricsLogger:
|
352 |
def __init__(self, state):
|
353 |
self.step = state.step
|
@@ -541,12 +534,37 @@ def main():
|
|
541 |
num_params = model.num_params
|
542 |
|
543 |
# Create learning rate schedule
|
544 |
-
|
545 |
-
|
546 |
-
|
547 |
-
|
548 |
-
|
549 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
550 |
|
551 |
# We use Optax's "masking" functionality to not apply weight decay
|
552 |
# to bias and LayerNorm scale parameters. decay_mask_fn returns a
|
|
|
246 |
},
|
247 |
)
|
248 |
|
249 |
+
lr_decay: str = field(
|
250 |
+
default=None,
|
251 |
+
metadata={
|
252 |
+
"help": "Decay to be used in the learning rate scheduler. Can be None (default), linear or exponential."
|
253 |
+
},
|
254 |
+
)
|
255 |
+
lr_transition_steps: int = field(
|
256 |
+
default=None,
|
257 |
+
metadata={
|
258 |
+
"help": "Number of transition steps associated with learning rate decay when using exponential decay."
|
259 |
+
},
|
260 |
+
)
|
261 |
+
lr_decay_rate: float = field(
|
262 |
+
default=None,
|
263 |
+
metadata={
|
264 |
+
"help": "Decay rate associated with learning rate when using exponential decay."
|
265 |
+
},
|
266 |
+
)
|
267 |
+
lr_staircase: bool = field(
|
268 |
default=False,
|
269 |
+
metadata={
|
270 |
+
"help": "Whether to use staircase or continuous learning rate when using exponential decay."
|
271 |
+
},
|
272 |
)
|
273 |
|
274 |
num_train_epochs: float = field(
|
|
|
341 |
)
|
342 |
|
343 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
344 |
class MetricsLogger:
|
345 |
def __init__(self, state):
|
346 |
self.step = state.step
|
|
|
534 |
num_params = model.num_params
|
535 |
|
536 |
# Create learning rate schedule
|
537 |
+
def create_learning_rate_fn() -> Callable[[int], jnp.array]:
|
538 |
+
"""Create the learning rate function."""
|
539 |
+
warmup_fn = optax.linear_schedule(
|
540 |
+
init_value=0.0,
|
541 |
+
end_value=training_args.learning_rate,
|
542 |
+
transition_steps=training_args.warmup_steps,
|
543 |
+
)
|
544 |
+
if training_args.lr_decay is None:
|
545 |
+
return warmup_fn
|
546 |
+
elif training_args.lr_decay == "linear":
|
547 |
+
assert (
|
548 |
+
num_train_steps is not None
|
549 |
+
), "linear decay requires knowing the dataset length"
|
550 |
+
decay_fn = optax.linear_schedule(
|
551 |
+
init_value=training_args.learning_rate,
|
552 |
+
end_value=0,
|
553 |
+
transition_steps=num_train_steps - training_args.warmup_steps,
|
554 |
+
)
|
555 |
+
elif training_args.lr_decay == "exponential":
|
556 |
+
decay_fn = optax.exponential_decay(
|
557 |
+
init_value=training_args.learning_rate,
|
558 |
+
transition_steps=training_args.lr_transition_steps,
|
559 |
+
decay_rate=training_args.lr_decay_rate,
|
560 |
+
staircase=training_args.lr_staircase,
|
561 |
+
)
|
562 |
+
schedule_fn = optax.join_schedules(
|
563 |
+
schedules=[warmup_fn, decay_fn], boundaries=[training_args.warmup_steps]
|
564 |
+
)
|
565 |
+
return schedule_fn
|
566 |
+
|
567 |
+
learning_rate_fn = create_learning_rate_fn()
|
568 |
|
569 |
# We use Optax's "masking" functionality to not apply weight decay
|
570 |
# to bias and LayerNorm scale parameters. decay_mask_fn returns a
|