Spaces:
Running
Running
# Copyright 2017 The TensorFlow Authors. All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# ============================================================================== | |
"""Library of common learning rate schedules.""" | |
from __future__ import absolute_import | |
from __future__ import division | |
from __future__ import print_function | |
import numpy as np | |
from six.moves import range | |
from six.moves import zip | |
import tensorflow.compat.v1 as tf | |
def exponential_decay_with_burnin(global_step, | |
learning_rate_base, | |
learning_rate_decay_steps, | |
learning_rate_decay_factor, | |
burnin_learning_rate=0.0, | |
burnin_steps=0, | |
min_learning_rate=0.0, | |
staircase=True): | |
"""Exponential decay schedule with burn-in period. | |
In this schedule, learning rate is fixed at burnin_learning_rate | |
for a fixed period, before transitioning to a regular exponential | |
decay schedule. | |
Args: | |
global_step: int tensor representing global step. | |
learning_rate_base: base learning rate. | |
learning_rate_decay_steps: steps to take between decaying the learning rate. | |
Note that this includes the number of burn-in steps. | |
learning_rate_decay_factor: multiplicative factor by which to decay | |
learning rate. | |
burnin_learning_rate: initial learning rate during burn-in period. If | |
0.0 (which is the default), then the burn-in learning rate is simply | |
set to learning_rate_base. | |
burnin_steps: number of steps to use burnin learning rate. | |
min_learning_rate: the minimum learning rate. | |
staircase: whether use staircase decay. | |
Returns: | |
If executing eagerly: | |
returns a no-arg callable that outputs the (scalar) | |
float tensor learning rate given the current value of global_step. | |
If in a graph: | |
immediately returns a (scalar) float tensor representing learning rate. | |
""" | |
if burnin_learning_rate == 0: | |
burnin_learning_rate = learning_rate_base | |
def eager_decay_rate(): | |
"""Callable to compute the learning rate.""" | |
post_burnin_learning_rate = tf.train.exponential_decay( | |
learning_rate_base, | |
global_step - burnin_steps, | |
learning_rate_decay_steps, | |
learning_rate_decay_factor, | |
staircase=staircase) | |
if callable(post_burnin_learning_rate): | |
post_burnin_learning_rate = post_burnin_learning_rate() | |
return tf.maximum(tf.where( | |
tf.less(tf.cast(global_step, tf.int32), tf.constant(burnin_steps)), | |
tf.constant(burnin_learning_rate), | |
post_burnin_learning_rate), min_learning_rate, name='learning_rate') | |
if tf.executing_eagerly(): | |
return eager_decay_rate | |
else: | |
return eager_decay_rate() | |
def cosine_decay_with_warmup(global_step, | |
learning_rate_base, | |
total_steps, | |
warmup_learning_rate=0.0, | |
warmup_steps=0, | |
hold_base_rate_steps=0): | |
"""Cosine decay schedule with warm up period. | |
Cosine annealing learning rate as described in: | |
Loshchilov and Hutter, SGDR: Stochastic Gradient Descent with Warm Restarts. | |
ICLR 2017. https://arxiv.org/abs/1608.03983 | |
In this schedule, the learning rate grows linearly from warmup_learning_rate | |
to learning_rate_base for warmup_steps, then transitions to a cosine decay | |
schedule. | |
Args: | |
global_step: int64 (scalar) tensor representing global step. | |
learning_rate_base: base learning rate. | |
total_steps: total number of training steps. | |
warmup_learning_rate: initial learning rate for warm up. | |
warmup_steps: number of warmup steps. | |
hold_base_rate_steps: Optional number of steps to hold base learning rate | |
before decaying. | |
Returns: | |
If executing eagerly: | |
returns a no-arg callable that outputs the (scalar) | |
float tensor learning rate given the current value of global_step. | |
If in a graph: | |
immediately returns a (scalar) float tensor representing learning rate. | |
Raises: | |
ValueError: if warmup_learning_rate is larger than learning_rate_base, | |
or if warmup_steps is larger than total_steps. | |
""" | |
if total_steps < warmup_steps: | |
raise ValueError('total_steps must be larger or equal to ' | |
'warmup_steps.') | |
def eager_decay_rate(): | |
"""Callable to compute the learning rate.""" | |
learning_rate = 0.5 * learning_rate_base * (1 + tf.cos( | |
np.pi * | |
(tf.cast(global_step, tf.float32) - warmup_steps - hold_base_rate_steps | |
) / float(total_steps - warmup_steps - hold_base_rate_steps))) | |
if hold_base_rate_steps > 0: | |
learning_rate = tf.where( | |
global_step > warmup_steps + hold_base_rate_steps, | |
learning_rate, learning_rate_base) | |
if warmup_steps > 0: | |
if learning_rate_base < warmup_learning_rate: | |
raise ValueError('learning_rate_base must be larger or equal to ' | |
'warmup_learning_rate.') | |
slope = (learning_rate_base - warmup_learning_rate) / warmup_steps | |
warmup_rate = slope * tf.cast(global_step, | |
tf.float32) + warmup_learning_rate | |
learning_rate = tf.where(global_step < warmup_steps, warmup_rate, | |
learning_rate) | |
return tf.where(global_step > total_steps, 0.0, learning_rate, | |
name='learning_rate') | |
if tf.executing_eagerly(): | |
return eager_decay_rate | |
else: | |
return eager_decay_rate() | |
def manual_stepping(global_step, boundaries, rates, warmup=False): | |
"""Manually stepped learning rate schedule. | |
This function provides fine grained control over learning rates. One must | |
specify a sequence of learning rates as well as a set of integer steps | |
at which the current learning rate must transition to the next. For example, | |
if boundaries = [5, 10] and rates = [.1, .01, .001], then the learning | |
rate returned by this function is .1 for global_step=0,...,4, .01 for | |
global_step=5...9, and .001 for global_step=10 and onward. | |
Args: | |
global_step: int64 (scalar) tensor representing global step. | |
boundaries: a list of global steps at which to switch learning | |
rates. This list is assumed to consist of increasing positive integers. | |
rates: a list of (float) learning rates corresponding to intervals between | |
the boundaries. The length of this list must be exactly | |
len(boundaries) + 1. | |
warmup: Whether to linearly interpolate learning rate for steps in | |
[0, boundaries[0]]. | |
Returns: | |
If executing eagerly: | |
returns a no-arg callable that outputs the (scalar) | |
float tensor learning rate given the current value of global_step. | |
If in a graph: | |
immediately returns a (scalar) float tensor representing learning rate. | |
Raises: | |
ValueError: if one of the following checks fails: | |
1. boundaries is a strictly increasing list of positive integers | |
2. len(rates) == len(boundaries) + 1 | |
3. boundaries[0] != 0 | |
""" | |
if any([b < 0 for b in boundaries]) or any( | |
[not isinstance(b, int) for b in boundaries]): | |
raise ValueError('boundaries must be a list of positive integers') | |
if any([bnext <= b for bnext, b in zip(boundaries[1:], boundaries[:-1])]): | |
raise ValueError('Entries in boundaries must be strictly increasing.') | |
if any([not isinstance(r, float) for r in rates]): | |
raise ValueError('Learning rates must be floats') | |
if len(rates) != len(boundaries) + 1: | |
raise ValueError('Number of provided learning rates must exceed ' | |
'number of boundary points by exactly 1.') | |
if boundaries and boundaries[0] == 0: | |
raise ValueError('First step cannot be zero.') | |
if warmup and boundaries: | |
slope = (rates[1] - rates[0]) * 1.0 / boundaries[0] | |
warmup_steps = list(range(boundaries[0])) | |
warmup_rates = [rates[0] + slope * step for step in warmup_steps] | |
boundaries = warmup_steps + boundaries | |
rates = warmup_rates + rates[1:] | |
else: | |
boundaries = [0] + boundaries | |
num_boundaries = len(boundaries) | |
def eager_decay_rate(): | |
"""Callable to compute the learning rate.""" | |
rate_index = tf.reduce_max(tf.where( | |
tf.greater_equal(global_step, boundaries), | |
list(range(num_boundaries)), | |
[0] * num_boundaries)) | |
return tf.reduce_sum(rates * tf.one_hot(rate_index, depth=num_boundaries), | |
name='learning_rate') | |
if tf.executing_eagerly(): | |
return eager_decay_rate | |
else: | |
return eager_decay_rate() | |