Ashaar / poetry_diacritizer /util /learning_rates.py
Ababababababbababa's picture
Duplicate from arbml/Ashaar
6faf7e7
raw
history blame
No virus
2.04 kB
import numpy as np
import math
class LearningRateDecay:
def __init__(self, lr=0.002, warmup_steps=4000.0) -> None:
self.lr = lr
self.warmup_steps = warmup_steps
def __call__(self, global_step) -> float:
step = global_step + 1.0
lr = (
self.lr
* self.warmup_steps ** 0.5
* np.minimum(step * self.warmup_steps ** -1.5, step ** -0.5)
)
return lr
class SquareRootScheduler:
def __init__(self, lr=0.1):
self.lr = lr
def __call__(self, global_step):
global_step = global_step // 1000
return self.lr * pow(global_step + 1.0, -0.5)
class CosineScheduler:
def __init__(
self, max_update, base_lr=0.02, final_lr=0, warmup_steps=0, warmup_begin_lr=0
):
self.base_lr_orig = base_lr
self.max_update = max_update
self.final_lr = final_lr
self.warmup_steps = warmup_steps
self.warmup_begin_lr = warmup_begin_lr
self.max_steps = self.max_update - self.warmup_steps
def get_warmup_lr(self, global_step):
increase = (
(self.base_lr_orig - self.warmup_begin_lr)
* float(global_step)
/ float(self.warmup_steps)
)
return self.warmup_begin_lr + increase
def __call__(self, global_step):
if global_step < self.warmup_steps:
return self.get_warmup_lr(global_step)
if global_step <= self.max_update:
self.base_lr = (
self.final_lr
+ (self.base_lr_orig - self.final_lr)
* (
1
+ math.cos(
math.pi * (global_step - self.warmup_steps) / self.max_steps
)
)
/ 2
)
return self.base_lr
def adjust_learning_rate(optimizer, global_step):
lr = LearningRateDecay()(global_step=global_step)
for param_group in optimizer.param_groups:
param_group["lr"] = lr
return lr