import torch from torch.optim.lr_scheduler import LRScheduler class LinearSchedulerWithWarmup(LRScheduler): def __init__( self, optimizer: torch.optim.Optimizer, num_warmup_steps: int, num_training_steps: int, last_epoch: int = -1, verbose: bool = False, **kwargs, ): self.num_warmup_steps = num_warmup_steps self.num_training_steps = num_training_steps super().__init__(optimizer, last_epoch, verbose) def get_lr(self): def scheduler_fn(current_step): if current_step < self.num_warmup_steps: return current_step / max(1, self.num_warmup_steps) return max( 0.0, float(self.num_training_steps - current_step) / float(max(1, self.num_training_steps - self.num_warmup_steps)), ) return [base_lr * scheduler_fn(self.last_epoch) for base_lr in self.base_lrs] class LinearScheduler(LRScheduler): def __init__( self, optimizer: torch.optim.Optimizer, num_training_steps: int, last_epoch: int = -1, verbose: bool = False, **kwargs, ): self.num_training_steps = num_training_steps super().__init__(optimizer, last_epoch, verbose) def get_lr(self): def scheduler_fn(current_step): # if current_step < self.num_warmup_steps: # return current_step / max(1, self.num_warmup_steps) return max( 0.0, float(self.num_training_steps - current_step) / float(max(1, self.num_training_steps)), ) return [base_lr * scheduler_fn(self.last_epoch) for base_lr in self.base_lrs]