File size: 2,470 Bytes
5381499 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 |
import math
from torch.optim.lr_scheduler import _LRScheduler
class WarmUpScheduler(_LRScheduler):
def __init__(
self,
optimizer,
warmup_steps: int,
feature_size: int,
factor: float = 1.0,
last_epoch=-1,
):
self.warmup_steps = warmup_steps
self.feature_size = feature_size
self.factor = factor
super().__init__(optimizer, last_epoch)
def get_lr(self):
lr = self._compute_lr()
return [lr] * len(self.base_lrs)
def _compute_lr(self):
if self.last_epoch == 0:
return 0.0
lr = (self.feature_size ** (-0.5)) * min(
self.last_epoch ** (-0.5), self.last_epoch * self.warmup_steps ** (-1.5)
)
return lr * self.factor
class TriStateScheduler(_LRScheduler):
def __init__(
self,
optimizer,
total_steps: int,
warmup_steps: int,
constant_steps: int,
factor: float = 0.3,
last_epoch: int = -1,
):
self.warmup_steps = warmup_steps
self.constant_steps = constant_steps
self.total_steps = total_steps
self.factor = factor
super().__init__(optimizer, last_epoch)
def get_lr(self):
if not hasattr(self, "eta_min"):
self.eta_max = self.base_lrs.copy()
self.eta_min = [eta_max * self.factor for eta_max in self.eta_max]
return [
self._compute_lr(group["lr"], eta_min, eta_max)
for group, eta_min, eta_max in zip(
self.optimizer.param_groups, self.eta_min, self.eta_max
)
]
def _compute_lr(self, prev_lr: float, eta_min: float, eta_max: float):
# first stage
if self.last_epoch <= self.warmup_steps:
lr = eta_max - 0.5 * (eta_max - eta_min) * (
1 + math.cos(math.pi * self.last_epoch / self.warmup_steps)
)
# second stage
elif self.last_epoch <= self.warmup_steps + self.constant_steps:
lr = prev_lr
else:
# third stage
decay_steps = self.total_steps - self.warmup_steps - self.constant_steps
k = self.last_epoch - self.warmup_steps - self.constant_steps
lr = eta_min + 0.5 * (eta_max - eta_min) * (
1 + math.cos(math.pi * k / decay_steps)
)
return lr
def state_dict(self) -> dict:
return super().state_dict()
|