File size: 2,990 Bytes
a03c9b4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
# Copyright 2024 The YourMT3 Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Please see the details in the LICENSE file.
"""lr_schedule.py"""
import torch
from typing import Dict, Optional


def get_lr_scheduler(optimizer: torch.optim.Optimizer, scheduler_name: str, base_lr: float, scheduler_cfg: Dict):

    if scheduler_name.lower() == 'cosine':
        from torch.optim.lr_scheduler import (
            SequentialLR,
            LinearLR,
            CosineAnnealingLR,
        )

        scheduler1 = LinearLR(
            optimizer,
            start_factor=0.5,
            end_factor=1,
            total_iters=scheduler_cfg["warmup_steps"],
            last_epoch=-1,
        )

        scheduler2 = CosineAnnealingLR(
            optimizer,
            T_max=scheduler_cfg["total_steps"] - scheduler_cfg["warmup_steps"],
            eta_min=scheduler_cfg["final_cosine"],
        )

        lr_scheduler = SequentialLR(optimizer,
                                    schedulers=[scheduler1, scheduler2],
                                    milestones=[scheduler_cfg["warmup_steps"]])
    elif scheduler_name.lower() == 'legacy':
        import math
        from torch.optim.lr_scheduler import (
            SequentialLR,
            LinearLR,
            LambdaLR,
        )

        msg = "You are using T5 legacy LR Schedule, it's independent from the optim.base_lr"
        print(msg)

        num_steps_optimizer1 = math.ceil(scheduler_cfg["total_steps"] * 0.9)
        iters_left_for_optimizer2 = scheduler_cfg["total_steps"] - num_steps_optimizer1

        scheduler1 = LambdaLR(optimizer, lambda step: min(base_lr, 1.0 / math.sqrt(step)) / base_lr
                              if step else base_lr / base_lr)

        scheduler2 = LinearLR(optimizer,
                              start_factor=(min(base_lr, 1.0 / math.sqrt(num_steps_optimizer1)) / base_lr),
                              end_factor=0,
                              total_iters=iters_left_for_optimizer2,
                              last_epoch=-1)

        lr_scheduler = SequentialLR(
            optimizer,
            schedulers=[scheduler1, scheduler2],
            milestones=[num_steps_optimizer1],
        )
    elif scheduler_name.lower() == 'constant':
        from transformers import get_scheduler
        lr_scheduler = get_scheduler(
            name=scheduler_name.lower(),
            optimizer=optimizer,
        )
    else:
        raise NotImplementedError

    return lr_scheduler


def extra_stats(args, model, optimizer):
    stats = {}

    if args.logging.weights_l2:
        weights_l2 = sum(p.detach().norm(2).item()**2 for p in model.parameters())**0.5
        stats['weights_l2'] = weights_l2

    cur_lr = optimizer.param_groups[0]['lr']
    stats['lr'] = cur_lr

    return stats