Spaces:
Running
Running
# Lint as: python3 | |
# Copyright 2020 The TensorFlow Authors. All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# ============================================================================== | |
"""Learning rate schedule classes.""" | |
from typing import Mapping, Any, Union, Optional | |
import tensorflow as tf | |
class LinearWarmup(tf.keras.optimizers.schedules.LearningRateSchedule): | |
"""Linear warmup schedule.""" | |
def __init__(self, after_warmup_lr_sched: Union[ | |
tf.keras.optimizers.schedules.LearningRateSchedule, float], | |
warmup_steps: int, warmup_learning_rate: float, | |
name: Optional[str] = None): | |
"""Add linear warmup schedule to a learning rate schedule. | |
warmup_lr is the initial learning rate, the final learning rate of the | |
init_warmup period is the initial learning rate of lr_schedule in use. | |
The learning rate at each step linearly increased according to the following | |
formula: | |
learning_rate = warmup_lr + step / warmup_steps | |
* (final_warmup_lr - warmup_lr). | |
Using warmup overrides the learning rate schedule by the number of warmup | |
steps. | |
Args: | |
after_warmup_lr_sched: tf.keras.optimizers.schedules | |
.LearningRateSchedule or a constant. | |
warmup_steps: int. number of the warmup steps. | |
warmup_learning_rate: floating point number. Initial learning rate for the | |
warmup. | |
name: Optional, name of warmup schedule. | |
""" | |
super(LinearWarmup, self).__init__() | |
self._name = name | |
self._after_warmup_lr_sched = after_warmup_lr_sched | |
self._warmup_steps = warmup_steps | |
self._init_warmup_lr = warmup_learning_rate | |
if isinstance(after_warmup_lr_sched, | |
tf.keras.optimizers.schedules.LearningRateSchedule): | |
self._final_warmup_lr = after_warmup_lr_sched(warmup_steps) | |
else: | |
self._final_warmup_lr = tf.cast( | |
after_warmup_lr_sched, dtype=tf.float32) | |
def __call__(self, step: int): | |
global_step = tf.cast(step, dtype=tf.float32) | |
linear_warmup_lr = ( | |
self._init_warmup_lr + global_step / self._warmup_steps * | |
(self._final_warmup_lr - self._init_warmup_lr)) | |
if isinstance(self._after_warmup_lr_sched, | |
tf.keras.optimizers.schedules.LearningRateSchedule): | |
after_warmup_lr = self._after_warmup_lr_sched(step) | |
else: | |
after_warmup_lr = tf.cast(self._after_warmup_lr_sched, dtype=tf.float32) | |
lr = tf.cond(global_step < self._warmup_steps, | |
lambda: linear_warmup_lr, | |
lambda: after_warmup_lr) | |
return lr | |
def get_config(self) -> Mapping[str, Any]: | |
if isinstance(self._after_warmup_lr_sched, | |
tf.keras.optimizers.schedules.LearningRateSchedule): | |
config = { | |
"after_warmup_lr_sched": self._after_warmup_lr_sched.get_config()} # pytype: disable=attribute-error | |
else: | |
config = {"after_warmup_lr_sched": self._after_warmup_lr_sched} # pytype: disable=attribute-error | |
config.update({ | |
"warmup_steps": self._warmup_steps, | |
"warmup_learning_rate": self._init_warmup_lr, | |
"name": self._name | |
}) | |
return config | |
class PolynomialWarmUp(tf.keras.optimizers.schedules.LearningRateSchedule): | |
"""Applies polynomial warmup schedule on a given learning rate decay schedule. | |
""" | |
def __init__(self, | |
after_warmup_lr_sched: Union[ | |
tf.keras.optimizers.schedules.LearningRateSchedule, float], | |
warmup_steps: int, | |
power: float = 1.0, | |
name: str = "PolynomialWarmup"): | |
super(PolynomialWarmUp, self).__init__() | |
if isinstance(after_warmup_lr_sched, | |
tf.keras.optimizers.schedules.LearningRateSchedule): | |
self._initial_learning_rate = after_warmup_lr_sched(warmup_steps) | |
else: | |
self._initial_learning_rate = tf.cast( | |
after_warmup_lr_sched, dtype=tf.float32) | |
self._warmup_steps = warmup_steps | |
self._power = power | |
self._after_warmup_lr_sched = after_warmup_lr_sched | |
self._name = name | |
def __call__(self, step): | |
with tf.name_scope(self._name or "PolynomialWarmUp") as name: | |
# Implements polynomial warmup. i.e., if global_step < warmup_steps, the | |
# learning rate will be `global_step/num_warmup_steps * init_lr`. | |
global_step_float = tf.cast(step, tf.float32) | |
warmup_steps_float = tf.cast(self._warmup_steps, tf.float32) | |
warmup_percent_done = global_step_float / warmup_steps_float | |
warmup_learning_rate = ( | |
self._initial_learning_rate * | |
tf.math.pow(warmup_percent_done, self._power)) | |
if isinstance(self._after_warmup_lr_sched, | |
tf.keras.optimizers.schedules.LearningRateSchedule): | |
after_warmup_lr = self._after_warmup_lr_sched(step) | |
else: | |
after_warmup_lr = tf.cast(self._after_warmup_lr_sched, dtype=tf.float32) | |
return tf.cond( | |
global_step_float < warmup_steps_float, | |
lambda: warmup_learning_rate, | |
lambda: after_warmup_lr, | |
name=name) | |
def get_config(self) -> Mapping[str, Any]: | |
if isinstance(self._after_warmup_lr_sched, | |
tf.keras.optimizers.schedules.LearningRateSchedule): | |
config = { | |
"after_warmup_lr_sched": self._after_warmup_lr_sched.get_config()} # pytype: disable=attribute-error | |
else: | |
config = {"after_warmup_lr_sched": self._after_warmup_lr_sched} # pytype: disable=attribute-error | |
config.update({ | |
"warmup_steps": self._warmup_setps, | |
"power": self._power, | |
"name": self._name | |
}) | |
return config | |