File size: 2,041 Bytes
6faf7e7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import numpy as np
import math


class LearningRateDecay:
    def __init__(self, lr=0.002, warmup_steps=4000.0) -> None:
        self.lr = lr
        self.warmup_steps = warmup_steps

    def __call__(self, global_step) -> float:
        step = global_step + 1.0
        lr = (
            self.lr
            * self.warmup_steps ** 0.5
            * np.minimum(step * self.warmup_steps ** -1.5, step ** -0.5)
        )

        return lr

class SquareRootScheduler:
    def __init__(self, lr=0.1):
        self.lr = lr

    def __call__(self, global_step):
        global_step = global_step // 1000
        return self.lr * pow(global_step + 1.0, -0.5)


class CosineScheduler:
    def __init__(
        self, max_update, base_lr=0.02, final_lr=0, warmup_steps=0, warmup_begin_lr=0
    ):
        self.base_lr_orig = base_lr
        self.max_update = max_update
        self.final_lr = final_lr
        self.warmup_steps = warmup_steps
        self.warmup_begin_lr = warmup_begin_lr
        self.max_steps = self.max_update - self.warmup_steps

    def get_warmup_lr(self, global_step):
        increase = (
            (self.base_lr_orig - self.warmup_begin_lr)
            * float(global_step)
            / float(self.warmup_steps)
        )
        return self.warmup_begin_lr + increase

    def __call__(self, global_step):
        if global_step < self.warmup_steps:
            return self.get_warmup_lr(global_step)
        if global_step <= self.max_update:
            self.base_lr = (
                self.final_lr
                + (self.base_lr_orig - self.final_lr)
                * (
                    1
                    + math.cos(
                        math.pi * (global_step - self.warmup_steps) / self.max_steps
                    )
                )
                / 2
            )
        return self.base_lr

def adjust_learning_rate(optimizer, global_step):
    lr = LearningRateDecay()(global_step=global_step)
    for param_group in optimizer.param_groups:
        param_group["lr"] = lr
    return lr