File size: 8,531 Bytes
a93e458 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 |
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
"""Learning rate decay and weight decay incr functions."""
import math
from megatron import print_rank_0
class OptimizerParamScheduler(object):
"""Anneals learning rate and weight decay"""
def __init__(self, optimizer, max_lr, min_lr,
lr_warmup_steps, lr_decay_steps, lr_decay_style,
start_wd, end_wd, wd_incr_steps, wd_incr_style,
use_checkpoint_opt_param_scheduler=True,
override_opt_param_scheduler=False):
# Class values.
self.optimizer = optimizer
self.max_lr = float(max_lr)
self.min_lr = min_lr
assert self.min_lr >= 0.0
assert self.max_lr >= self.min_lr
self.lr_warmup_steps = lr_warmup_steps
self.num_steps = 0
self.lr_decay_steps = lr_decay_steps
assert self.lr_decay_steps > 0
assert self.lr_warmup_steps < self.lr_decay_steps
self.lr_decay_style = lr_decay_style
self.start_wd = start_wd
self.end_wd = end_wd
assert self.start_wd >= 0.0
assert self.end_wd >= self.start_wd
self.wd_incr_steps = wd_incr_steps
self.wd_incr_style = wd_incr_style
self.override_opt_param_scheduler = override_opt_param_scheduler
self.use_checkpoint_opt_param_scheduler = use_checkpoint_opt_param_scheduler
if self.override_opt_param_scheduler:
assert not self.use_checkpoint_opt_param_scheduler, 'both override and '\
'use-checkpoint are set.'
# Set the learning rate
self.step(0)
print_rank_0('> learning rate decay style: {}'.format(self.lr_decay_style))
def get_wd(self):
""" Weight decay incr functions"""
if self.num_steps > self.wd_incr_steps:
return self.end_wd
if self.wd_incr_style == 'constant':
assert self.start_wd == self.end_wd
return self.end_wd
incr_ratio = float(self.num_steps) / float(self.wd_incr_steps)
assert incr_ratio >= 0.0
assert incr_ratio <= 1.0
delta_wd = self.end_wd - self.start_wd
if self.wd_incr_style == 'linear':
coeff = incr_ratio
elif self.wd_incr_style == 'cosine':
coeff = 0.5 * (math.cos(math.pi * (1 - incr_ratio)) + 1.0)
else:
raise Exception('{} weight decay increment style is not supported.'.format(
self.wd_incr_style))
return self.start_wd + coeff * delta_wd
def get_lr(self):
"""Learning rate decay functions from:
https://openreview.net/pdf?id=BJYwwY9ll pg. 4"""
# Use linear warmup for the initial part.
if self.lr_warmup_steps > 0 and self.num_steps <= self.lr_warmup_steps:
return self.max_lr * float(self.num_steps) / \
float(self.lr_warmup_steps)
# If the learning rate is constant, just return the initial value.
if self.lr_decay_style == 'constant':
return self.max_lr
# For any steps larger than `self.lr_decay_steps`, use `self.min_lr`.
if self.num_steps > self.lr_decay_steps:
return self.min_lr
# If we are done with the warmup period, use the decay style.
if self.lr_decay_style == 'inverse-square-root':
warmup_steps = max(self.lr_warmup_steps, 1)
num_steps = max(self.num_steps, 1)
lr = self.max_lr * warmup_steps ** 0.5 / (num_steps ** 0.5)
return max(self.min_lr, lr)
num_steps_ = self.num_steps - self.lr_warmup_steps
decay_steps_ = self.lr_decay_steps - self.lr_warmup_steps
decay_ratio = float(num_steps_) / float(decay_steps_)
assert decay_ratio >= 0.0
assert decay_ratio <= 1.0
delta_lr = self.max_lr - self.min_lr
if self.lr_decay_style == 'linear':
coeff = (1.0 - decay_ratio)
elif self.lr_decay_style == 'cosine':
coeff = 0.5 * (math.cos(math.pi * decay_ratio) + 1.0)
else:
raise Exception('{} decay style is not supported.'.format(
self.lr_decay_style))
return self.min_lr + coeff * delta_lr
def step(self, increment):
"""Set lr for all parameters groups."""
self.num_steps += increment
new_lr = self.get_lr()
new_wd = self.get_wd()
for group in self.optimizer.param_groups:
group['lr'] = new_lr * group.get('lr_mult', 1.0)
group['weight_decay'] = new_wd * group.get('wd_mult', 1.0)
def state_dict(self):
state_dict = {
'max_lr': self.max_lr,
'lr_warmup_steps': self.lr_warmup_steps,
'num_steps': self.num_steps,
'lr_decay_style': self.lr_decay_style,
'lr_decay_steps': self.lr_decay_steps,
'min_lr': self.min_lr,
'start_wd': self.start_wd,
'end_wd': self.end_wd,
'wd_incr_style': self.wd_incr_style,
'wd_incr_steps': self.wd_incr_steps
}
return state_dict
def _check_and_set(self, cls_value, sd_value, name):
"""Auxiliary function for checking the values in the checkpoint and
setting them."""
if self.override_opt_param_scheduler:
print_rank_0(' > overriding {} value to {}'.format(name, cls_value))
return cls_value
if not self.use_checkpoint_opt_param_scheduler:
assert cls_value == sd_value, \
f'OptimizerParamScheduler: class input value {cls_value} and checkpoint' \
f'value {sd_value} for {name} do not match'
print_rank_0(' > using checkpoint value {} for {}'.format(sd_value,
name))
return sd_value
def load_state_dict(self, sd):
if 'start_lr' in sd:
max_lr_ = sd['start_lr']
else:
max_lr_ = sd['max_lr']
self.max_lr = self._check_and_set(self.max_lr, max_lr_,
'learning rate')
self.min_lr = self._check_and_set(self.min_lr, sd['min_lr'],
'minimum learning rate')
if 'warmup_iter' in sd:
lr_warmup_steps_ = sd['warmup_iter']
elif 'warmup_steps' in sd:
lr_warmup_steps_ = sd['warmup_steps']
else:
lr_warmup_steps_ = sd['lr_warmup_steps']
self.lr_warmup_steps = self._check_and_set(self.lr_warmup_steps,
lr_warmup_steps_,
'warmup iterations')
if 'end_iter' in sd:
lr_decay_steps_ = sd['end_iter']
elif 'decay_steps' in sd:
lr_decay_steps_ = sd['decay_steps']
else:
lr_decay_steps_ = sd['lr_decay_steps']
self.lr_decay_steps = self._check_and_set(self.lr_decay_steps, lr_decay_steps_,
'total number of iterations')
if 'decay_style' in sd:
lr_decay_style_ = sd['decay_style']
else:
lr_decay_style_ = sd['lr_decay_style']
self.lr_decay_style = self._check_and_set(self.lr_decay_style,
lr_decay_style_,
'learning rate decay style')
if 'num_iters' in sd:
num_steps = sd['num_iters']
else:
num_steps = sd['num_steps']
self.step(increment=num_steps)
if 'start_wd' in sd:
self.start_wd = self._check_and_set(self.start_wd,
sd['start_wd'],
"start weight decay")
self.end_wd = self._check_and_set(self.end_wd,
sd['end_wd'],
"end weight decay")
self.wd_incr_steps = self._check_and_set(self.wd_incr_steps,
sd['wd_incr_steps'],
"total number of weight decay iterations")
self.wd_incr_style = self._check_and_set(self.wd_incr_style,
sd['wd_incr_style'],
"weight decay incr style")
|