File size: 2,511 Bytes
ae29df4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 |
from functools import partial
from typing import Callable
def linear_warm_up(
step: int,
warm_up_steps: int,
reduce_lr_steps: int
) -> float:
r"""Get linear warm up scheduler for LambdaLR.
Args:
step (int): global step
warm_up_steps (int): steps for warm up
reduce_lr_steps (int): reduce learning rate by a factor of 0.9 #reduce_lr_steps step
.. code-block: python
>>> lr_lambda = partial(linear_warm_up, warm_up_steps=1000, reduce_lr_steps=10000)
>>> from torch.optim.lr_scheduler import LambdaLR
>>> LambdaLR(optimizer, lr_lambda)
Returns:
lr_scale (float): learning rate scaler
"""
if step <= warm_up_steps:
lr_scale = step / warm_up_steps
else:
lr_scale = 0.9 ** (step // reduce_lr_steps)
return lr_scale
def constant_warm_up(
step: int,
warm_up_steps: int,
reduce_lr_steps: int
) -> float:
r"""Get constant warm up scheduler for LambdaLR.
Args:
step (int): global step
warm_up_steps (int): steps for warm up
reduce_lr_steps (int): reduce learning rate by a factor of 0.9 #reduce_lr_steps step
.. code-block: python
>>> lr_lambda = partial(constant_warm_up, warm_up_steps=1000, reduce_lr_steps=10000)
>>> from torch.optim.lr_scheduler import LambdaLR
>>> LambdaLR(optimizer, lr_lambda)
Returns:
lr_scale (float): learning rate scaler
"""
if 0 <= step < warm_up_steps:
lr_scale = 0.001
elif warm_up_steps <= step < 2 * warm_up_steps:
lr_scale = 0.01
elif 2 * warm_up_steps <= step < 3 * warm_up_steps:
lr_scale = 0.1
else:
lr_scale = 1
return lr_scale
def get_lr_lambda(
lr_lambda_type: str,
**kwargs
) -> Callable:
r"""Get learning scheduler.
Args:
lr_lambda_type (str), e.g., "constant_warm_up" | "linear_warm_up"
Returns:
lr_lambda_func (Callable)
"""
if lr_lambda_type == "constant_warm_up":
lr_lambda_func = partial(
constant_warm_up,
warm_up_steps=kwargs["warm_up_steps"],
reduce_lr_steps=kwargs["reduce_lr_steps"],
)
elif lr_lambda_type == "linear_warm_up":
lr_lambda_func = partial(
linear_warm_up,
warm_up_steps=kwargs["warm_up_steps"],
reduce_lr_steps=kwargs["reduce_lr_steps"],
)
else:
raise NotImplementedError
return lr_lambda_func
|