|
from functools import partial |
|
from typing import Callable |
|
|
|
|
|
def linear_warm_up( |
|
step: int, |
|
warm_up_steps: int, |
|
reduce_lr_steps: int |
|
) -> float: |
|
r"""Get linear warm up scheduler for LambdaLR. |
|
|
|
Args: |
|
step (int): global step |
|
warm_up_steps (int): steps for warm up |
|
reduce_lr_steps (int): reduce learning rate by a factor of 0.9 #reduce_lr_steps step |
|
|
|
.. code-block: python |
|
>>> lr_lambda = partial(linear_warm_up, warm_up_steps=1000, reduce_lr_steps=10000) |
|
>>> from torch.optim.lr_scheduler import LambdaLR |
|
>>> LambdaLR(optimizer, lr_lambda) |
|
|
|
Returns: |
|
lr_scale (float): learning rate scaler |
|
""" |
|
|
|
if step <= warm_up_steps: |
|
lr_scale = step / warm_up_steps |
|
else: |
|
lr_scale = 0.9 ** (step // reduce_lr_steps) |
|
|
|
return lr_scale |
|
|
|
|
|
def constant_warm_up( |
|
step: int, |
|
warm_up_steps: int, |
|
reduce_lr_steps: int |
|
) -> float: |
|
r"""Get constant warm up scheduler for LambdaLR. |
|
|
|
Args: |
|
step (int): global step |
|
warm_up_steps (int): steps for warm up |
|
reduce_lr_steps (int): reduce learning rate by a factor of 0.9 #reduce_lr_steps step |
|
|
|
.. code-block: python |
|
>>> lr_lambda = partial(constant_warm_up, warm_up_steps=1000, reduce_lr_steps=10000) |
|
>>> from torch.optim.lr_scheduler import LambdaLR |
|
>>> LambdaLR(optimizer, lr_lambda) |
|
|
|
Returns: |
|
lr_scale (float): learning rate scaler |
|
""" |
|
|
|
if 0 <= step < warm_up_steps: |
|
lr_scale = 0.001 |
|
|
|
elif warm_up_steps <= step < 2 * warm_up_steps: |
|
lr_scale = 0.01 |
|
|
|
elif 2 * warm_up_steps <= step < 3 * warm_up_steps: |
|
lr_scale = 0.1 |
|
|
|
else: |
|
lr_scale = 1 |
|
|
|
return lr_scale |
|
|
|
|
|
def get_lr_lambda( |
|
lr_lambda_type: str, |
|
**kwargs |
|
) -> Callable: |
|
r"""Get learning scheduler. |
|
|
|
Args: |
|
lr_lambda_type (str), e.g., "constant_warm_up" | "linear_warm_up" |
|
|
|
Returns: |
|
lr_lambda_func (Callable) |
|
""" |
|
if lr_lambda_type == "constant_warm_up": |
|
|
|
lr_lambda_func = partial( |
|
constant_warm_up, |
|
warm_up_steps=kwargs["warm_up_steps"], |
|
reduce_lr_steps=kwargs["reduce_lr_steps"], |
|
) |
|
|
|
elif lr_lambda_type == "linear_warm_up": |
|
|
|
lr_lambda_func = partial( |
|
linear_warm_up, |
|
warm_up_steps=kwargs["warm_up_steps"], |
|
reduce_lr_steps=kwargs["reduce_lr_steps"], |
|
) |
|
|
|
else: |
|
raise NotImplementedError |
|
|
|
return lr_lambda_func |
|
|