Spaces:
Running
on
Zero
Running
on
Zero
"""This file contains code to run different learning rate schedulers. | |
Copyright (2024) Bytedance Ltd. and/or its affiliates | |
Licensed under the Apache License, Version 2.0 (the "License"); | |
you may not use this file except in compliance with the License. | |
You may obtain a copy of the License at | |
http://www.apache.org/licenses/LICENSE-2.0 | |
Unless required by applicable law or agreed to in writing, software | |
distributed under the License is distributed on an "AS IS" BASIS, | |
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
See the License for the specific language governing permissions and | |
limitations under the License. | |
Reference: | |
https://raw.githubusercontent.com/huggingface/open-muse/vqgan-finetuning/muse/lr_schedulers.py | |
""" | |
import math | |
from enum import Enum | |
from typing import Optional, Union | |
import torch | |
class SchedulerType(Enum): | |
COSINE = "cosine" | |
def get_cosine_schedule_with_warmup( | |
optimizer: torch.optim.Optimizer, | |
num_warmup_steps: int, | |
num_training_steps: int, | |
num_cycles: float = 0.5, | |
last_epoch: int = -1, | |
base_lr: float = 1e-4, | |
end_lr: float = 0.0, | |
): | |
"""Creates a cosine learning rate schedule with warm-up and ending learning rate. | |
Args: | |
optimizer: A torch.optim.Optimizer, the optimizer for which to schedule the learning rate. | |
num_warmup_steps: An integer, the number of steps for the warmup phase. | |
num_training_steps: An integer, the total number of training steps. | |
num_cycles : A float, the number of periods of the cosine function in a schedule (the default is to | |
just decrease from the max value to 0 following a half-cosine). | |
last_epoch: An integer, the index of the last epoch when resuming training. | |
base_lr: A float, the base learning rate. | |
end_lr: A float, the final learning rate. | |
Return: | |
`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. | |
""" | |
def lr_lambda(current_step): | |
if current_step < num_warmup_steps: | |
return float(current_step) / float(max(1, num_warmup_steps)) | |
progress = float(current_step - num_warmup_steps) / \ | |
float(max(1, num_training_steps - num_warmup_steps)) | |
ratio = max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress))) | |
return (end_lr + (base_lr - end_lr) * ratio) / base_lr | |
return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda, last_epoch) | |
TYPE_TO_SCHEDULER_FUNCTION = { | |
SchedulerType.COSINE: get_cosine_schedule_with_warmup, | |
} | |
def get_scheduler( | |
name: Union[str, SchedulerType], | |
optimizer: torch.optim.Optimizer, | |
num_warmup_steps: Optional[int] = None, | |
num_training_steps: Optional[int] = None, | |
base_lr: float = 1e-4, | |
end_lr: float = 0.0, | |
): | |
"""Retrieves a learning rate scheduler from the given name and optimizer. | |
Args: | |
name: A string or SchedulerType, the name of the scheduler to retrieve. | |
optimizer: torch.optim.Optimizer. The optimizer to use with the scheduler. | |
num_warmup_steps: An integer, the number of warmup steps. | |
num_training_steps: An integer, the total number of training steps. | |
base_lr: A float, the base learning rate. | |
end_lr: A float, the final learning rate. | |
Returns: | |
A instance of torch.optim.lr_scheduler.LambdaLR | |
Raises: | |
ValueError: If num_warmup_steps or num_training_steps is not provided. | |
""" | |
name = SchedulerType(name) | |
schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name] | |
if num_warmup_steps is None: | |
raise ValueError(f"{name} requires `num_warmup_steps`, please provide that argument.") | |
if num_training_steps is None: | |
raise ValueError(f"{name} requires `num_training_steps`, please provide that argument.") | |
return schedule_func( | |
optimizer, | |
num_warmup_steps=num_warmup_steps, | |
num_training_steps=num_training_steps, | |
base_lr=base_lr, | |
end_lr=end_lr, | |
) |