YourMT3 / amt /src /model /lr_scheduler.py
mimbres's picture
.
a03c9b4
raw
history blame
2.99 kB
# Copyright 2024 The YourMT3 Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Please see the details in the LICENSE file.
"""lr_schedule.py"""
import torch
from typing import Dict, Optional
def get_lr_scheduler(optimizer: torch.optim.Optimizer, scheduler_name: str, base_lr: float, scheduler_cfg: Dict):
if scheduler_name.lower() == 'cosine':
from torch.optim.lr_scheduler import (
SequentialLR,
LinearLR,
CosineAnnealingLR,
)
scheduler1 = LinearLR(
optimizer,
start_factor=0.5,
end_factor=1,
total_iters=scheduler_cfg["warmup_steps"],
last_epoch=-1,
)
scheduler2 = CosineAnnealingLR(
optimizer,
T_max=scheduler_cfg["total_steps"] - scheduler_cfg["warmup_steps"],
eta_min=scheduler_cfg["final_cosine"],
)
lr_scheduler = SequentialLR(optimizer,
schedulers=[scheduler1, scheduler2],
milestones=[scheduler_cfg["warmup_steps"]])
elif scheduler_name.lower() == 'legacy':
import math
from torch.optim.lr_scheduler import (
SequentialLR,
LinearLR,
LambdaLR,
)
msg = "You are using T5 legacy LR Schedule, it's independent from the optim.base_lr"
print(msg)
num_steps_optimizer1 = math.ceil(scheduler_cfg["total_steps"] * 0.9)
iters_left_for_optimizer2 = scheduler_cfg["total_steps"] - num_steps_optimizer1
scheduler1 = LambdaLR(optimizer, lambda step: min(base_lr, 1.0 / math.sqrt(step)) / base_lr
if step else base_lr / base_lr)
scheduler2 = LinearLR(optimizer,
start_factor=(min(base_lr, 1.0 / math.sqrt(num_steps_optimizer1)) / base_lr),
end_factor=0,
total_iters=iters_left_for_optimizer2,
last_epoch=-1)
lr_scheduler = SequentialLR(
optimizer,
schedulers=[scheduler1, scheduler2],
milestones=[num_steps_optimizer1],
)
elif scheduler_name.lower() == 'constant':
from transformers import get_scheduler
lr_scheduler = get_scheduler(
name=scheduler_name.lower(),
optimizer=optimizer,
)
else:
raise NotImplementedError
return lr_scheduler
def extra_stats(args, model, optimizer):
stats = {}
if args.logging.weights_l2:
weights_l2 = sum(p.detach().norm(2).item()**2 for p in model.parameters())**0.5
stats['weights_l2'] = weights_l2
cur_lr = optimizer.param_groups[0]['lr']
stats['lr'] = cur_lr
return stats