File size: 4,284 Bytes
03f6091
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
# -*- coding: utf-8 -*-
r"""
Schedulers
==============
    Leraning Rate schedulers used to train Polos models.
"""
from argparse import Namespace

from torch.optim import Optimizer
from torch.optim.lr_scheduler import LambdaLR


class ConstantPolicy:
    """Policy for updating the LR of the ConstantLR scheduler.
    With this class LambdaLR objects became picklable.
    """

    def __call__(self, *args, **kwargs):
        return 1


class ConstantLR(LambdaLR):
    """
    Constant learning rate schedule

    Wrapper for the huggingface Constant LR Scheduler.
        https://huggingface.co/transformers/v2.1.1/main_classes/optimizer_schedules.html

    :param optimizer: torch.optim.Optimizer
    :param last_epoch:
    """

    def __init__(self, optimizer: Optimizer, last_epoch: int = -1) -> None:
        super(ConstantLR, self).__init__(optimizer, ConstantPolicy(), last_epoch)

    @classmethod
    def from_hparams(
        cls, optimizer: Optimizer, hparams: Namespace, **kwargs
    ) -> LambdaLR:
        """ Initializes a constant learning rate scheduler. """
        return ConstantLR(optimizer)


class WarmupPolicy:
    """Policy for updating the LR of the WarmupConstant scheduler.
    With this class LambdaLR objects became picklable.
    """

    def __init__(self, warmup_steps):
        self.warmup_steps = warmup_steps

    def __call__(self, current_step):
        if current_step < self.warmup_steps:
            return float(current_step) / float(max(1.0, self.warmup_steps))
        return 1.0


class WarmupConstant(LambdaLR):
    """
    Warmup Linear scheduler.
    1) Linearly increases learning rate from 0 to 1 over warmup_steps
        training steps.
    2) Keeps the learning rate constant afterwards.

    :param optimizer: torch.optim.Optimizer
    :param warmup_steps: Linearly increases learning rate from 0 to 1 over warmup_steps.
    :param last_epoch:
    """

    def __init__(
        self, optimizer: Optimizer, warmup_steps: int, last_epoch: int = -1
    ) -> None:
        super(WarmupConstant, self).__init__(
            optimizer, WarmupPolicy(warmup_steps), last_epoch
        )

    @classmethod
    def from_hparams(
        cls, optimizer: Optimizer, hparams: Namespace, **kwargs
    ) -> LambdaLR:
        """ Initializes a constant learning rate scheduler with warmup period. """
        return WarmupConstant(optimizer, hparams.warmup_steps)


class LinearWarmupPolicy:
    """Policy for updating the LR of the LinearWarmup scheduler.
    With this class LambdaLR objects became picklable.
    """

    def __init__(self, warmup_steps, num_training_steps):
        self.num_training_steps = num_training_steps
        self.warmup_steps = warmup_steps

    def __call__(self, current_step):
        if current_step < self.warmup_steps:
            return float(current_step) / float(max(1, self.warmup_steps))
        return max(
            0.0,
            float(self.num_training_steps - current_step)
            / float(max(1, self.num_training_steps - self.warmup_steps)),
        )


class LinearWarmup(LambdaLR):
    """
    Create a schedule with a learning rate that decreases linearly after
    linearly increasing during a warmup period.

    :param optimizer: torch.optim.Optimizer
    :param warmup_steps: Linearly increases learning rate from 0 to 1*learning_rate over warmup_steps.
    :param num_training_steps: Linearly decreases learning rate from 1*learning_rate to 0. over remaining
        t_total - warmup_steps steps.
    :param last_epoch:
    """

    def __init__(
        self,
        optimizer: Optimizer,
        warmup_steps: int,
        num_training_steps: int,
        last_epoch: int = -1,
    ) -> None:
        super(LinearWarmup, self).__init__(
            optimizer, LinearWarmupPolicy(warmup_steps, num_training_steps), last_epoch
        )

    @classmethod
    def from_hparams(
        cls, optimizer: Optimizer, hparams: Namespace, num_training_steps: int
    ) -> LambdaLR:
        """ Initializes a learning rate scheduler with warmup period and decreasing period. """
        return LinearWarmup(optimizer, hparams.warmup_steps, num_training_steps)


str2scheduler = {
    "linear_warmup": LinearWarmup,
    "constant": ConstantLR,
    "warmup_constant": WarmupConstant,
}