RAR / utils /lr_schedulers.py
yucornetto's picture
upload
51ce47d verified
"""This file contains code to run different learning rate schedulers.
Copyright (2024) Bytedance Ltd. and/or its affiliates
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
Reference:
https://raw.githubusercontent.com/huggingface/open-muse/vqgan-finetuning/muse/lr_schedulers.py
"""
import math
from enum import Enum
from typing import Optional, Union
import torch
class SchedulerType(Enum):
COSINE = "cosine"
def get_cosine_schedule_with_warmup(
optimizer: torch.optim.Optimizer,
num_warmup_steps: int,
num_training_steps: int,
num_cycles: float = 0.5,
last_epoch: int = -1,
base_lr: float = 1e-4,
end_lr: float = 0.0,
):
"""Creates a cosine learning rate schedule with warm-up and ending learning rate.
Args:
optimizer: A torch.optim.Optimizer, the optimizer for which to schedule the learning rate.
num_warmup_steps: An integer, the number of steps for the warmup phase.
num_training_steps: An integer, the total number of training steps.
num_cycles : A float, the number of periods of the cosine function in a schedule (the default is to
just decrease from the max value to 0 following a half-cosine).
last_epoch: An integer, the index of the last epoch when resuming training.
base_lr: A float, the base learning rate.
end_lr: A float, the final learning rate.
Return:
`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
"""
def lr_lambda(current_step):
if current_step < num_warmup_steps:
return float(current_step) / float(max(1, num_warmup_steps))
progress = float(current_step - num_warmup_steps) / \
float(max(1, num_training_steps - num_warmup_steps))
ratio = max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)))
return (end_lr + (base_lr - end_lr) * ratio) / base_lr
return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda, last_epoch)
TYPE_TO_SCHEDULER_FUNCTION = {
SchedulerType.COSINE: get_cosine_schedule_with_warmup,
}
def get_scheduler(
name: Union[str, SchedulerType],
optimizer: torch.optim.Optimizer,
num_warmup_steps: Optional[int] = None,
num_training_steps: Optional[int] = None,
base_lr: float = 1e-4,
end_lr: float = 0.0,
):
"""Retrieves a learning rate scheduler from the given name and optimizer.
Args:
name: A string or SchedulerType, the name of the scheduler to retrieve.
optimizer: torch.optim.Optimizer. The optimizer to use with the scheduler.
num_warmup_steps: An integer, the number of warmup steps.
num_training_steps: An integer, the total number of training steps.
base_lr: A float, the base learning rate.
end_lr: A float, the final learning rate.
Returns:
A instance of torch.optim.lr_scheduler.LambdaLR
Raises:
ValueError: If num_warmup_steps or num_training_steps is not provided.
"""
name = SchedulerType(name)
schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name]
if num_warmup_steps is None:
raise ValueError(f"{name} requires `num_warmup_steps`, please provide that argument.")
if num_training_steps is None:
raise ValueError(f"{name} requires `num_training_steps`, please provide that argument.")
return schedule_func(
optimizer,
num_warmup_steps=num_warmup_steps,
num_training_steps=num_training_steps,
base_lr=base_lr,
end_lr=end_lr,
)