NCTCMumbai's picture
Upload 2571 files
0b8359d
raw
history blame
6.19 kB
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Learning rate schedule classes."""
from typing import Mapping, Any, Union, Optional
import tensorflow as tf
class LinearWarmup(tf.keras.optimizers.schedules.LearningRateSchedule):
"""Linear warmup schedule."""
def __init__(self, after_warmup_lr_sched: Union[
tf.keras.optimizers.schedules.LearningRateSchedule, float],
warmup_steps: int, warmup_learning_rate: float,
name: Optional[str] = None):
"""Add linear warmup schedule to a learning rate schedule.
warmup_lr is the initial learning rate, the final learning rate of the
init_warmup period is the initial learning rate of lr_schedule in use.
The learning rate at each step linearly increased according to the following
formula:
learning_rate = warmup_lr + step / warmup_steps
* (final_warmup_lr - warmup_lr).
Using warmup overrides the learning rate schedule by the number of warmup
steps.
Args:
after_warmup_lr_sched: tf.keras.optimizers.schedules
.LearningRateSchedule or a constant.
warmup_steps: int. number of the warmup steps.
warmup_learning_rate: floating point number. Initial learning rate for the
warmup.
name: Optional, name of warmup schedule.
"""
super(LinearWarmup, self).__init__()
self._name = name
self._after_warmup_lr_sched = after_warmup_lr_sched
self._warmup_steps = warmup_steps
self._init_warmup_lr = warmup_learning_rate
if isinstance(after_warmup_lr_sched,
tf.keras.optimizers.schedules.LearningRateSchedule):
self._final_warmup_lr = after_warmup_lr_sched(warmup_steps)
else:
self._final_warmup_lr = tf.cast(
after_warmup_lr_sched, dtype=tf.float32)
def __call__(self, step: int):
global_step = tf.cast(step, dtype=tf.float32)
linear_warmup_lr = (
self._init_warmup_lr + global_step / self._warmup_steps *
(self._final_warmup_lr - self._init_warmup_lr))
if isinstance(self._after_warmup_lr_sched,
tf.keras.optimizers.schedules.LearningRateSchedule):
after_warmup_lr = self._after_warmup_lr_sched(step)
else:
after_warmup_lr = tf.cast(self._after_warmup_lr_sched, dtype=tf.float32)
lr = tf.cond(global_step < self._warmup_steps,
lambda: linear_warmup_lr,
lambda: after_warmup_lr)
return lr
def get_config(self) -> Mapping[str, Any]:
if isinstance(self._after_warmup_lr_sched,
tf.keras.optimizers.schedules.LearningRateSchedule):
config = {
"after_warmup_lr_sched": self._after_warmup_lr_sched.get_config()} # pytype: disable=attribute-error
else:
config = {"after_warmup_lr_sched": self._after_warmup_lr_sched} # pytype: disable=attribute-error
config.update({
"warmup_steps": self._warmup_steps,
"warmup_learning_rate": self._init_warmup_lr,
"name": self._name
})
return config
class PolynomialWarmUp(tf.keras.optimizers.schedules.LearningRateSchedule):
"""Applies polynomial warmup schedule on a given learning rate decay schedule.
"""
def __init__(self,
after_warmup_lr_sched: Union[
tf.keras.optimizers.schedules.LearningRateSchedule, float],
warmup_steps: int,
power: float = 1.0,
name: str = "PolynomialWarmup"):
super(PolynomialWarmUp, self).__init__()
if isinstance(after_warmup_lr_sched,
tf.keras.optimizers.schedules.LearningRateSchedule):
self._initial_learning_rate = after_warmup_lr_sched(warmup_steps)
else:
self._initial_learning_rate = tf.cast(
after_warmup_lr_sched, dtype=tf.float32)
self._warmup_steps = warmup_steps
self._power = power
self._after_warmup_lr_sched = after_warmup_lr_sched
self._name = name
def __call__(self, step):
with tf.name_scope(self._name or "PolynomialWarmUp") as name:
# Implements polynomial warmup. i.e., if global_step < warmup_steps, the
# learning rate will be `global_step/num_warmup_steps * init_lr`.
global_step_float = tf.cast(step, tf.float32)
warmup_steps_float = tf.cast(self._warmup_steps, tf.float32)
warmup_percent_done = global_step_float / warmup_steps_float
warmup_learning_rate = (
self._initial_learning_rate *
tf.math.pow(warmup_percent_done, self._power))
if isinstance(self._after_warmup_lr_sched,
tf.keras.optimizers.schedules.LearningRateSchedule):
after_warmup_lr = self._after_warmup_lr_sched(step)
else:
after_warmup_lr = tf.cast(self._after_warmup_lr_sched, dtype=tf.float32)
return tf.cond(
global_step_float < warmup_steps_float,
lambda: warmup_learning_rate,
lambda: after_warmup_lr,
name=name)
def get_config(self) -> Mapping[str, Any]:
if isinstance(self._after_warmup_lr_sched,
tf.keras.optimizers.schedules.LearningRateSchedule):
config = {
"after_warmup_lr_sched": self._after_warmup_lr_sched.get_config()} # pytype: disable=attribute-error
else:
config = {"after_warmup_lr_sched": self._after_warmup_lr_sched} # pytype: disable=attribute-error
config.update({
"warmup_steps": self._warmup_setps,
"power": self._power,
"name": self._name
})
return config