shunxing1234 commited on
Commit
8008ee0
1 Parent(s): d022753

Upload ZEN/optimization.py

Browse files
Files changed (1) hide show
  1. ZEN/optimization.py +315 -0
ZEN/optimization.py ADDED
@@ -0,0 +1,315 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # This file is derived from the code at
3
+ # https://github.com/huggingface/transformers/blob/master/transformers/optimization.py
4
+ #
5
+ # Original copyright notice:
6
+ #
7
+ # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
8
+ #
9
+ # Licensed under the Apache License, Version 2.0 (the "License");
10
+ # you may not use this file except in compliance with the License.
11
+ # You may obtain a copy of the License at
12
+ #
13
+ # http://www.apache.org/licenses/LICENSE-2.0
14
+ #
15
+ # Unless required by applicable law or agreed to in writing, software
16
+ # distributed under the License is distributed on an "AS IS" BASIS,
17
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
+ # See the License for the specific language governing permissions and
19
+ # limitations under the License.
20
+ """PyTorch optimization for BERT model."""
21
+
22
+ import math
23
+ import torch
24
+ from torch.optim import Optimizer
25
+ from torch.optim.optimizer import required
26
+ from torch.nn.utils import clip_grad_norm_
27
+ import logging
28
+ import abc
29
+ import sys
30
+
31
+ logger = logging.getLogger(__name__)
32
+
33
+ if sys.version_info >= (3, 4):
34
+ ABC = abc.ABC
35
+ else:
36
+ ABC = abc.ABCMeta('ABC', (), {})
37
+
38
+
39
+ class _LRSchedule(ABC):
40
+ """ Parent of all LRSchedules here. """
41
+ warn_t_total = False # is set to True for schedules where progressing beyond t_total steps doesn't make sense
42
+
43
+ def __init__(self, warmup=0.002, t_total=-1, **kw):
44
+ """
45
+ :param warmup: what fraction of t_total steps will be used for linear warmup
46
+ :param t_total: how many training steps (updates) are planned
47
+ :param kw:
48
+ """
49
+ super(_LRSchedule, self).__init__(**kw)
50
+ if t_total < 0:
51
+ logger.warning("t_total value of {} results in schedule not being applied".format(t_total))
52
+ if not 0.0 <= warmup < 1.0 and not warmup == -1:
53
+ raise ValueError("Invalid warmup: {} - should be in [0.0, 1.0[ or -1".format(warmup))
54
+ warmup = max(warmup, 0.)
55
+ self.warmup, self.t_total = float(warmup), float(t_total)
56
+ self.warned_for_t_total_at_progress = -1
57
+
58
+ def get_lr(self, step, nowarn=False):
59
+ """
60
+ :param step: which of t_total steps we're on
61
+ :param nowarn: set to True to suppress warning regarding training beyond specified 't_total' steps
62
+ :return: learning rate multiplier for current update
63
+ """
64
+ if self.t_total < 0:
65
+ return 1.
66
+ progress = float(step) / self.t_total
67
+ ret = self.get_lr_(progress)
68
+ # warning for exceeding t_total (only active with warmup_linear
69
+ if not nowarn and self.warn_t_total and progress > 1. and progress > self.warned_for_t_total_at_progress:
70
+ logger.warning(
71
+ "Training beyond specified 't_total'. Learning rate multiplier set to {}. Please set 't_total' of {} correctly."
72
+ .format(ret, self.__class__.__name__))
73
+ self.warned_for_t_total_at_progress = progress
74
+ # end warning
75
+ return ret
76
+
77
+ @abc.abstractmethod
78
+ def get_lr_(self, progress):
79
+ """
80
+ :param progress: value between 0 and 1 (unless going beyond t_total steps) specifying training progress
81
+ :return: learning rate multiplier for current update
82
+ """
83
+ return 1.
84
+
85
+
86
+ class ConstantLR(_LRSchedule):
87
+ def get_lr_(self, progress):
88
+ return 1.
89
+
90
+
91
+ class WarmupCosineSchedule(_LRSchedule):
92
+ """
93
+ Linearly increases learning rate from 0 to 1 over `warmup` fraction of training steps.
94
+ Decreases learning rate from 1. to 0. over remaining `1 - warmup` steps following a cosine curve.
95
+ If `cycles` (default=0.5) is different from default, learning rate follows cosine function after warmup.
96
+ """
97
+ warn_t_total = True
98
+
99
+ def __init__(self, warmup=0.002, t_total=-1, cycles=.5, **kw):
100
+ """
101
+ :param warmup: see LRSchedule
102
+ :param t_total: see LRSchedule
103
+ :param cycles: number of cycles. Default: 0.5, corresponding to cosine decay from 1. at progress==warmup and 0 at progress==1.
104
+ :param kw:
105
+ """
106
+ super(WarmupCosineSchedule, self).__init__(warmup=warmup, t_total=t_total, **kw)
107
+ self.cycles = cycles
108
+
109
+ def get_lr_(self, progress):
110
+ if progress < self.warmup:
111
+ return progress / self.warmup
112
+ else:
113
+ progress = (progress - self.warmup) / (1 - self.warmup) # progress after warmup
114
+ return 0.5 * (1. + math.cos(math.pi * self.cycles * 2 * progress))
115
+
116
+
117
+ class WarmupCosineWithHardRestartsSchedule(WarmupCosineSchedule):
118
+ """
119
+ Linearly increases learning rate from 0 to 1 over `warmup` fraction of training steps.
120
+ If `cycles` (default=1.) is different from default, learning rate follows `cycles` times a cosine decaying
121
+ learning rate (with hard restarts).
122
+ """
123
+
124
+ def __init__(self, warmup=0.002, t_total=-1, cycles=1., **kw):
125
+ super(WarmupCosineWithHardRestartsSchedule, self).__init__(warmup=warmup, t_total=t_total, cycles=cycles, **kw)
126
+ assert (cycles >= 1.)
127
+
128
+ def get_lr_(self, progress):
129
+ if progress < self.warmup:
130
+ return progress / self.warmup
131
+ else:
132
+ progress = (progress - self.warmup) / (1 - self.warmup) # progress after warmup
133
+ ret = 0.5 * (1. + math.cos(math.pi * ((self.cycles * progress) % 1)))
134
+ return ret
135
+
136
+
137
+ class WarmupCosineWithWarmupRestartsSchedule(WarmupCosineWithHardRestartsSchedule):
138
+ """
139
+ All training progress is divided in `cycles` (default=1.) parts of equal length.
140
+ Every part follows a schedule with the first `warmup` fraction of the training steps linearly increasing from 0. to 1.,
141
+ followed by a learning rate decreasing from 1. to 0. following a cosine curve.
142
+ """
143
+
144
+ def __init__(self, warmup=0.002, t_total=-1, cycles=1., **kw):
145
+ assert (warmup * cycles < 1.)
146
+ warmup = warmup * cycles if warmup >= 0 else warmup
147
+ super(WarmupCosineWithWarmupRestartsSchedule, self).__init__(warmup=warmup, t_total=t_total, cycles=cycles,
148
+ **kw)
149
+
150
+ def get_lr_(self, progress):
151
+ progress = progress * self.cycles % 1.
152
+ if progress < self.warmup:
153
+ return progress / self.warmup
154
+ else:
155
+ progress = (progress - self.warmup) / (1 - self.warmup) # progress after warmup
156
+ ret = 0.5 * (1. + math.cos(math.pi * progress))
157
+ return ret
158
+
159
+
160
+ class WarmupConstantSchedule(_LRSchedule):
161
+ """
162
+ Linearly increases learning rate from 0 to 1 over `warmup` fraction of training steps.
163
+ Keeps learning rate equal to 1. after warmup.
164
+ """
165
+
166
+ def get_lr_(self, progress):
167
+ if progress < self.warmup:
168
+ return progress / self.warmup
169
+ return 1.
170
+
171
+
172
+ class WarmupLinearSchedule(_LRSchedule):
173
+ """
174
+ Linearly increases learning rate from 0 to 1 over `warmup` fraction of training steps.
175
+ Linearly decreases learning rate from 1. to 0. over remaining `1 - warmup` steps.
176
+ """
177
+ warn_t_total = True
178
+
179
+ def get_lr_(self, progress):
180
+ if progress < self.warmup:
181
+ return progress / self.warmup
182
+ return max((progress - 1.) / (self.warmup - 1.), 0.)
183
+
184
+
185
+ SCHEDULES = {
186
+ None: ConstantLR,
187
+ "none": ConstantLR,
188
+ "warmup_cosine": WarmupCosineSchedule,
189
+ "warmup_constant": WarmupConstantSchedule,
190
+ "warmup_linear": WarmupLinearSchedule
191
+ }
192
+
193
+
194
+ class BertAdam(Optimizer):
195
+ """Implements BERT version of Adam algorithm with weight decay fix.
196
+ Params:
197
+ lr: learning rate
198
+ warmup: portion of t_total for the warmup, -1 means no warmup. Default: -1
199
+ t_total: total number of training steps for the learning
200
+ rate schedule, -1 means constant learning rate of 1. (no warmup regardless of warmup setting). Default: -1
201
+ schedule: schedule to use for the warmup (see above).
202
+ Can be `'warmup_linear'`, `'warmup_constant'`, `'warmup_cosine'`, `'none'`, `None` or a `_LRSchedule` object (see below).
203
+ If `None` or `'none'`, learning rate is always kept constant.
204
+ Default : `'warmup_linear'`
205
+ b1: Adams b1. Default: 0.9
206
+ b2: Adams b2. Default: 0.999
207
+ e: Adams epsilon. Default: 1e-6
208
+ weight_decay: Weight decay. Default: 0.01
209
+ max_grad_norm: Maximum norm for the gradients (-1 means no clipping). Default: 1.0
210
+ """
211
+
212
+ def __init__(self, params, lr=required, warmup=-1, t_total=-1, schedule='warmup_linear',
213
+ b1=0.9, b2=0.999, e=1e-6, weight_decay=0.01, max_grad_norm=1.0, **kwargs):
214
+ if lr is not required and lr < 0.0:
215
+ raise ValueError("Invalid learning rate: {} - should be >= 0.0".format(lr))
216
+ if not isinstance(schedule, _LRSchedule) and schedule not in SCHEDULES:
217
+ raise ValueError("Invalid schedule parameter: {}".format(schedule))
218
+ if not 0.0 <= b1 < 1.0:
219
+ raise ValueError("Invalid b1 parameter: {} - should be in [0.0, 1.0[".format(b1))
220
+ if not 0.0 <= b2 < 1.0:
221
+ raise ValueError("Invalid b2 parameter: {} - should be in [0.0, 1.0[".format(b2))
222
+ if not e >= 0.0:
223
+ raise ValueError("Invalid epsilon value: {} - should be >= 0.0".format(e))
224
+ # initialize schedule object
225
+ if not isinstance(schedule, _LRSchedule):
226
+ schedule_type = SCHEDULES[schedule]
227
+ schedule = schedule_type(warmup=warmup, t_total=t_total)
228
+ else:
229
+ if warmup != -1 or t_total != -1:
230
+ logger.warning(
231
+ "warmup and t_total on the optimizer are ineffective when _LRSchedule object is provided as schedule. "
232
+ "Please specify custom warmup and t_total in _LRSchedule object.")
233
+ defaults = dict(lr=lr, schedule=schedule,
234
+ b1=b1, b2=b2, e=e, weight_decay=weight_decay,
235
+ max_grad_norm=max_grad_norm)
236
+ super(BertAdam, self).__init__(params, defaults)
237
+
238
+ def get_lr(self):
239
+ lr = []
240
+ for group in self.param_groups:
241
+ for p in group['params']:
242
+ state = self.state[p]
243
+ if len(state) == 0:
244
+ return [0]
245
+ lr_scheduled = group['lr']
246
+ lr_scheduled *= group['schedule'].get_lr(state['step'])
247
+ lr.append(lr_scheduled)
248
+ return lr
249
+
250
+ def step(self, closure=None):
251
+ """Performs a single optimization step.
252
+
253
+ Arguments:
254
+ closure (callable, optional): A closure that reevaluates the model
255
+ and returns the loss.
256
+ """
257
+ loss = None
258
+ if closure is not None:
259
+ loss = closure()
260
+
261
+ for group in self.param_groups:
262
+ for p in group['params']:
263
+ if p.grad is None:
264
+ continue
265
+ grad = p.grad.data
266
+ if grad.is_sparse:
267
+ raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead')
268
+
269
+ state = self.state[p]
270
+
271
+ # State initialization
272
+ if len(state) == 0:
273
+ state['step'] = 0
274
+ # Exponential moving average of gradient values
275
+ state['next_m'] = torch.zeros_like(p.data)
276
+ # Exponential moving average of squared gradient values
277
+ state['next_v'] = torch.zeros_like(p.data)
278
+
279
+ next_m, next_v = state['next_m'], state['next_v']
280
+ beta1, beta2 = group['b1'], group['b2']
281
+
282
+ # Add grad clipping
283
+ if group['max_grad_norm'] > 0:
284
+ clip_grad_norm_(p, group['max_grad_norm'])
285
+
286
+ # Decay the first and second moment running average coefficient
287
+ # In-place operations to update the averages at the same time
288
+ next_m.mul_(beta1).add_(1 - beta1, grad)
289
+ next_v.mul_(beta2).addcmul_(1 - beta2, grad, grad)
290
+ update = next_m / (next_v.sqrt() + group['e'])
291
+
292
+ # Just adding the square of the weights to the loss function is *not*
293
+ # the correct way of using L2 regularization/weight decay with Adam,
294
+ # since that will interact with the m and v parameters in strange ways.
295
+ #
296
+ # Instead we want to decay the weights in a manner that doesn't interact
297
+ # with the m/v parameters. This is equivalent to adding the square
298
+ # of the weights to the loss with plain (non-momentum) SGD.
299
+ if group['weight_decay'] > 0.0:
300
+ update += group['weight_decay'] * p.data
301
+
302
+ lr_scheduled = group['lr']
303
+ lr_scheduled *= group['schedule'].get_lr(state['step'])
304
+
305
+ update_with_lr = lr_scheduled * update
306
+ p.data.add_(-update_with_lr)
307
+
308
+ state['step'] += 1
309
+
310
+ # step_size = lr_scheduled * math.sqrt(bias_correction2) / bias_correction1
311
+ # No bias correction
312
+ # bias_correction1 = 1 - beta1 ** state['step']
313
+ # bias_correction2 = 1 - beta2 ** state['step']
314
+
315
+ return loss