|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""lr_schedule.py""" |
|
import torch |
|
from typing import Dict, Optional |
|
|
|
|
|
def get_lr_scheduler(optimizer: torch.optim.Optimizer, scheduler_name: str, base_lr: float, scheduler_cfg: Dict): |
|
|
|
if scheduler_name.lower() == 'cosine': |
|
from torch.optim.lr_scheduler import ( |
|
SequentialLR, |
|
LinearLR, |
|
CosineAnnealingLR, |
|
) |
|
|
|
scheduler1 = LinearLR( |
|
optimizer, |
|
start_factor=0.5, |
|
end_factor=1, |
|
total_iters=scheduler_cfg["warmup_steps"], |
|
last_epoch=-1, |
|
) |
|
|
|
scheduler2 = CosineAnnealingLR( |
|
optimizer, |
|
T_max=scheduler_cfg["total_steps"] - scheduler_cfg["warmup_steps"], |
|
eta_min=scheduler_cfg["final_cosine"], |
|
) |
|
|
|
lr_scheduler = SequentialLR(optimizer, |
|
schedulers=[scheduler1, scheduler2], |
|
milestones=[scheduler_cfg["warmup_steps"]]) |
|
elif scheduler_name.lower() == 'legacy': |
|
import math |
|
from torch.optim.lr_scheduler import ( |
|
SequentialLR, |
|
LinearLR, |
|
LambdaLR, |
|
) |
|
|
|
msg = "You are using T5 legacy LR Schedule, it's independent from the optim.base_lr" |
|
print(msg) |
|
|
|
num_steps_optimizer1 = math.ceil(scheduler_cfg["total_steps"] * 0.9) |
|
iters_left_for_optimizer2 = scheduler_cfg["total_steps"] - num_steps_optimizer1 |
|
|
|
scheduler1 = LambdaLR(optimizer, lambda step: min(base_lr, 1.0 / math.sqrt(step)) / base_lr |
|
if step else base_lr / base_lr) |
|
|
|
scheduler2 = LinearLR(optimizer, |
|
start_factor=(min(base_lr, 1.0 / math.sqrt(num_steps_optimizer1)) / base_lr), |
|
end_factor=0, |
|
total_iters=iters_left_for_optimizer2, |
|
last_epoch=-1) |
|
|
|
lr_scheduler = SequentialLR( |
|
optimizer, |
|
schedulers=[scheduler1, scheduler2], |
|
milestones=[num_steps_optimizer1], |
|
) |
|
elif scheduler_name.lower() == 'constant': |
|
from transformers import get_scheduler |
|
lr_scheduler = get_scheduler( |
|
name=scheduler_name.lower(), |
|
optimizer=optimizer, |
|
) |
|
else: |
|
raise NotImplementedError |
|
|
|
return lr_scheduler |
|
|
|
|
|
def extra_stats(args, model, optimizer): |
|
stats = {} |
|
|
|
if args.logging.weights_l2: |
|
weights_l2 = sum(p.detach().norm(2).item()**2 for p in model.parameters())**0.5 |
|
stats['weights_l2'] = weights_l2 |
|
|
|
cur_lr = optimizer.param_groups[0]['lr'] |
|
stats['lr'] = cur_lr |
|
|
|
return stats |
|
|