|
import math |
|
from torch.optim.lr_scheduler import _LRScheduler |
|
|
|
|
|
class WarmUpScheduler(_LRScheduler): |
|
def __init__( |
|
self, |
|
optimizer, |
|
warmup_steps: int, |
|
feature_size: int, |
|
factor: float = 1.0, |
|
last_epoch=-1, |
|
): |
|
self.warmup_steps = warmup_steps |
|
self.feature_size = feature_size |
|
self.factor = factor |
|
super().__init__(optimizer, last_epoch) |
|
|
|
def get_lr(self): |
|
lr = self._compute_lr() |
|
return [lr] * len(self.base_lrs) |
|
|
|
def _compute_lr(self): |
|
if self.last_epoch == 0: |
|
return 0.0 |
|
|
|
lr = (self.feature_size ** (-0.5)) * min( |
|
self.last_epoch ** (-0.5), self.last_epoch * self.warmup_steps ** (-1.5) |
|
) |
|
|
|
return lr * self.factor |
|
|
|
|
|
class TriStateScheduler(_LRScheduler): |
|
def __init__( |
|
self, |
|
optimizer, |
|
total_steps: int, |
|
warmup_steps: int, |
|
constant_steps: int, |
|
factor: float = 0.3, |
|
last_epoch: int = -1, |
|
): |
|
self.warmup_steps = warmup_steps |
|
self.constant_steps = constant_steps |
|
self.total_steps = total_steps |
|
self.factor = factor |
|
|
|
super().__init__(optimizer, last_epoch) |
|
|
|
def get_lr(self): |
|
if not hasattr(self, "eta_min"): |
|
self.eta_max = self.base_lrs.copy() |
|
self.eta_min = [eta_max * self.factor for eta_max in self.eta_max] |
|
|
|
return [ |
|
self._compute_lr(group["lr"], eta_min, eta_max) |
|
for group, eta_min, eta_max in zip( |
|
self.optimizer.param_groups, self.eta_min, self.eta_max |
|
) |
|
] |
|
|
|
def _compute_lr(self, prev_lr: float, eta_min: float, eta_max: float): |
|
|
|
if self.last_epoch <= self.warmup_steps: |
|
lr = eta_max - 0.5 * (eta_max - eta_min) * ( |
|
1 + math.cos(math.pi * self.last_epoch / self.warmup_steps) |
|
) |
|
|
|
elif self.last_epoch <= self.warmup_steps + self.constant_steps: |
|
lr = prev_lr |
|
else: |
|
|
|
decay_steps = self.total_steps - self.warmup_steps - self.constant_steps |
|
k = self.last_epoch - self.warmup_steps - self.constant_steps |
|
lr = eta_min + 0.5 * (eta_max - eta_min) * ( |
|
1 + math.cos(math.pi * k / decay_steps) |
|
) |
|
|
|
return lr |
|
|
|
def state_dict(self) -> dict: |
|
return super().state_dict() |
|
|