# -*- coding: utf-8 -*- r""" Schedulers ============== Leraning Rate schedulers used to train Polos models. """ from argparse import Namespace from torch.optim import Optimizer from torch.optim.lr_scheduler import LambdaLR class ConstantPolicy: """Policy for updating the LR of the ConstantLR scheduler. With this class LambdaLR objects became picklable. """ def __call__(self, *args, **kwargs): return 1 class ConstantLR(LambdaLR): """ Constant learning rate schedule Wrapper for the huggingface Constant LR Scheduler. https://huggingface.co/transformers/v2.1.1/main_classes/optimizer_schedules.html :param optimizer: torch.optim.Optimizer :param last_epoch: """ def __init__(self, optimizer: Optimizer, last_epoch: int = -1) -> None: super(ConstantLR, self).__init__(optimizer, ConstantPolicy(), last_epoch) @classmethod def from_hparams( cls, optimizer: Optimizer, hparams: Namespace, **kwargs ) -> LambdaLR: """ Initializes a constant learning rate scheduler. """ return ConstantLR(optimizer) class WarmupPolicy: """Policy for updating the LR of the WarmupConstant scheduler. With this class LambdaLR objects became picklable. """ def __init__(self, warmup_steps): self.warmup_steps = warmup_steps def __call__(self, current_step): if current_step < self.warmup_steps: return float(current_step) / float(max(1.0, self.warmup_steps)) return 1.0 class WarmupConstant(LambdaLR): """ Warmup Linear scheduler. 1) Linearly increases learning rate from 0 to 1 over warmup_steps training steps. 2) Keeps the learning rate constant afterwards. :param optimizer: torch.optim.Optimizer :param warmup_steps: Linearly increases learning rate from 0 to 1 over warmup_steps. :param last_epoch: """ def __init__( self, optimizer: Optimizer, warmup_steps: int, last_epoch: int = -1 ) -> None: super(WarmupConstant, self).__init__( optimizer, WarmupPolicy(warmup_steps), last_epoch ) @classmethod def from_hparams( cls, optimizer: Optimizer, hparams: Namespace, **kwargs ) -> LambdaLR: """ Initializes a constant learning rate scheduler with warmup period. """ return WarmupConstant(optimizer, hparams.warmup_steps) class LinearWarmupPolicy: """Policy for updating the LR of the LinearWarmup scheduler. With this class LambdaLR objects became picklable. """ def __init__(self, warmup_steps, num_training_steps): self.num_training_steps = num_training_steps self.warmup_steps = warmup_steps def __call__(self, current_step): if current_step < self.warmup_steps: return float(current_step) / float(max(1, self.warmup_steps)) return max( 0.0, float(self.num_training_steps - current_step) / float(max(1, self.num_training_steps - self.warmup_steps)), ) class LinearWarmup(LambdaLR): """ Create a schedule with a learning rate that decreases linearly after linearly increasing during a warmup period. :param optimizer: torch.optim.Optimizer :param warmup_steps: Linearly increases learning rate from 0 to 1*learning_rate over warmup_steps. :param num_training_steps: Linearly decreases learning rate from 1*learning_rate to 0. over remaining t_total - warmup_steps steps. :param last_epoch: """ def __init__( self, optimizer: Optimizer, warmup_steps: int, num_training_steps: int, last_epoch: int = -1, ) -> None: super(LinearWarmup, self).__init__( optimizer, LinearWarmupPolicy(warmup_steps, num_training_steps), last_epoch ) @classmethod def from_hparams( cls, optimizer: Optimizer, hparams: Namespace, num_training_steps: int ) -> LambdaLR: """ Initializes a learning rate scheduler with warmup period and decreasing period. """ return LinearWarmup(optimizer, hparams.warmup_steps, num_training_steps) str2scheduler = { "linear_warmup": LinearWarmup, "constant": ConstantLR, "warmup_constant": WarmupConstant, }