""" Copyright (c) 2022, salesforce.com, inc. All rights reserved. SPDX-License-Identifier: BSD-3-Clause For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause """ import math from medomni.common.registry import registry @registry.register_lr_scheduler("linear_warmup_step_lr") class LinearWarmupStepLRScheduler: def __init__( self, optimizer, max_epoch, min_lr, init_lr, decay_rate=1, warmup_start_lr=-1, warmup_steps=0, **kwargs ): self.optimizer = optimizer self.max_epoch = max_epoch self.min_lr = min_lr self.decay_rate = decay_rate self.init_lr = init_lr self.warmup_steps = warmup_steps self.warmup_start_lr = warmup_start_lr if warmup_start_lr >= 0 else init_lr def step(self, cur_epoch, cur_step): if cur_epoch == 0: warmup_lr_schedule( step=cur_step, optimizer=self.optimizer, max_step=self.warmup_steps, init_lr=self.warmup_start_lr, max_lr=self.init_lr, ) else: step_lr_schedule( epoch=cur_epoch, optimizer=self.optimizer, init_lr=self.init_lr, min_lr=self.min_lr, decay_rate=self.decay_rate, ) @registry.register_lr_scheduler("linear_warmup_cosine_lr") class LinearWarmupCosineLRScheduler: def __init__( self, optimizer, max_epoch, iters_per_epoch, min_lr, init_lr, warmup_steps=0, warmup_start_lr=-1, **kwargs ): self.optimizer = optimizer self.max_epoch = max_epoch self.iters_per_epoch = iters_per_epoch self.min_lr = min_lr self.init_lr = init_lr self.warmup_steps = warmup_steps self.warmup_start_lr = warmup_start_lr if warmup_start_lr >= 0 else init_lr def step(self, cur_epoch, cur_step): total_cur_step = cur_epoch * self.iters_per_epoch + cur_step if total_cur_step < self.warmup_steps: warmup_lr_schedule( step=cur_step, optimizer=self.optimizer, max_step=self.warmup_steps, init_lr=self.warmup_start_lr, max_lr=self.init_lr, ) else: cosine_lr_schedule( epoch=total_cur_step, optimizer=self.optimizer, max_epoch=self.max_epoch * self.iters_per_epoch, init_lr=self.init_lr, min_lr=self.min_lr, ) def cosine_lr_schedule(optimizer, epoch, max_epoch, init_lr, min_lr): """Decay the learning rate""" lr = (init_lr - min_lr) * 0.5 * ( 1.0 + math.cos(math.pi * epoch / max_epoch) ) + min_lr for param_group in optimizer.param_groups: param_group["lr"] = lr def warmup_lr_schedule(optimizer, step, max_step, init_lr, max_lr): """Warmup the learning rate""" lr = min(max_lr, init_lr + (max_lr - init_lr) * step / max(max_step, 1)) for param_group in optimizer.param_groups: param_group["lr"] = lr def step_lr_schedule(optimizer, epoch, init_lr, min_lr, decay_rate): """Decay the learning rate""" lr = max(min_lr, init_lr * (decay_rate**epoch)) for param_group in optimizer.param_groups: param_group["lr"] = lr