# Copyright 2024 NVIDIA CORPORATION & AFFILIATES # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # # SPDX-License-Identifier: Apache-2.0 import math from diffusers import get_constant_schedule_with_warmup, get_cosine_schedule_with_warmup from torch.optim import Optimizer from torch.optim.lr_scheduler import LambdaLR from diffusion.utils.logger import get_root_logger def build_lr_scheduler(config, optimizer, train_dataloader, lr_scale_ratio): if not config.get("lr_schedule_args", None): config.lr_schedule_args = dict() if config.get("lr_warmup_steps", None): config["num_warmup_steps"] = config.get("lr_warmup_steps") # for compatibility with old version logger = get_root_logger() logger.info( f"Lr schedule: {config.lr_schedule}, " + ",".join([f"{key}:{value}" for key, value in config.lr_schedule_args.items()]) + "." ) if config.lr_schedule == "cosine": lr_scheduler = get_cosine_schedule_with_warmup( optimizer=optimizer, **config.lr_schedule_args, num_training_steps=(len(train_dataloader) * config.num_epochs), ) elif config.lr_schedule == "constant": lr_scheduler = get_constant_schedule_with_warmup( optimizer=optimizer, **config.lr_schedule_args, ) elif config.lr_schedule == "cosine_decay_to_constant": assert lr_scale_ratio >= 1 lr_scheduler = get_cosine_decay_to_constant_with_warmup( optimizer=optimizer, **config.lr_schedule_args, final_lr=1 / lr_scale_ratio, num_training_steps=(len(train_dataloader) * config.num_epochs), ) else: raise RuntimeError(f"Unrecognized lr schedule {config.lr_schedule}.") return lr_scheduler def get_cosine_decay_to_constant_with_warmup( optimizer: Optimizer, num_warmup_steps: int, num_training_steps: int, final_lr: float = 0.0, num_decay: float = 0.667, num_cycles: float = 0.5, last_epoch: int = -1, ): """ Create a schedule with a cosine annealing lr followed by a constant lr. Args: optimizer ([`~torch.optim.Optimizer`]): The optimizer for which to schedule the learning rate. num_warmup_steps (`int`): The number of steps for the warmup phase. num_training_steps (`int`): The number of total training steps. final_lr (`int`): The final constant lr after cosine decay. num_decay (`int`): The last_epoch (`int`, *optional*, defaults to -1): The index of the last epoch when resuming training. Return: `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. """ def lr_lambda(current_step): if current_step < num_warmup_steps: return float(current_step) / float(max(1, num_warmup_steps)) num_decay_steps = int(num_training_steps * num_decay) if current_step > num_decay_steps: return final_lr progress = float(current_step - num_warmup_steps) / float(max(1, num_decay_steps - num_warmup_steps)) return ( max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress))) * (1 - final_lr) + final_lr ) return LambdaLR(optimizer, lr_lambda, last_epoch)