Spaces:
Running
Running
from __future__ import absolute_import | |
from __future__ import division | |
from __future__ import print_function | |
"""Schedule functions for controlling hparams over time.""" | |
from abc import ABCMeta | |
from abc import abstractmethod | |
import math | |
from common import config_lib # brain coder | |
class Schedule(object): | |
"""Schedule is a function which sets a hyperparameter's value over time. | |
For example, a schedule can be used to decay an hparams, or oscillate it over | |
time. | |
This object is constructed with an instance of config_lib.Config (will be | |
specific to each class implementation). For example if this is a decay | |
schedule, the config may specify the rate of decay and decay start time. Then | |
the object instance is called like a function, mapping global step (an integer | |
counting how many calls to the train op have been made) to the hparam value. | |
Properties of a schedule function f(t): | |
0) Domain of t is the non-negative integers (t may be 0). | |
1) Range of f is the reals. | |
2) Schedule functions can assume that they will be called in time order. This | |
allows schedules to be stateful. | |
3) Schedule functions should be deterministic. Two schedule instances with the | |
same config must always give the same value for each t, and regardless of | |
what t's it was previously called on. Users may call f(t) on arbitrary | |
(positive) time jumps. Essentially, multiple schedule instances used in | |
replica training will behave the same. | |
4) Duplicate successive calls on the same time are allowed. | |
""" | |
__metaclass__ = ABCMeta | |
def __init__(self, config): | |
"""Construct this schedule with a config specific to each class impl. | |
Args: | |
config: An instance of config_lib.Config. | |
""" | |
pass | |
def __call__(self, global_step): | |
"""Map `global_step` to a value. | |
`global_step` is an integer counting how many calls to the train op have | |
been made across all replicas (hence why it is global). Implementations | |
may assume calls to be made in time order, i.e. `global_step` now >= | |
previous `global_step` values. | |
Args: | |
global_step: Non-negative integer. | |
Returns: | |
Hparam value at this step. A number. | |
""" | |
pass | |
class ConstSchedule(Schedule): | |
"""Constant function. | |
config: | |
const: Constant value at every step. | |
f(t) = const. | |
""" | |
def __init__(self, config): | |
super(ConstSchedule, self).__init__(config) | |
self.const = config.const | |
def __call__(self, global_step): | |
return self.const | |
class LinearDecaySchedule(Schedule): | |
"""Linear decay function. | |
config: | |
initial: Decay starts from this value. | |
final: Decay ends at this value. | |
start_time: Step when decay starts. Constant before it. | |
end_time: When decay ends. Constant after it. | |
f(t) is a linear function when start_time <= t <= end_time, with slope of | |
(final - initial) / (end_time - start_time). f(t) = initial | |
when t <= start_time. f(t) = final when t >= end_time. | |
If start_time == end_time, this becomes a step function. | |
""" | |
def __init__(self, config): | |
super(LinearDecaySchedule, self).__init__(config) | |
self.initial = config.initial | |
self.final = config.final | |
self.start_time = config.start_time | |
self.end_time = config.end_time | |
if self.end_time < self.start_time: | |
raise ValueError('start_time must be before end_time.') | |
# Linear interpolation. | |
self._time_diff = float(self.end_time - self.start_time) | |
self._diff = float(self.final - self.initial) | |
self._slope = ( | |
self._diff / self._time_diff if self._time_diff > 0 else float('inf')) | |
def __call__(self, global_step): | |
if global_step <= self.start_time: | |
return self.initial | |
if global_step > self.end_time: | |
return self.final | |
return self.initial + (global_step - self.start_time) * self._slope | |
class ExponentialDecaySchedule(Schedule): | |
"""Exponential decay function. | |
See https://en.wikipedia.org/wiki/Exponential_decay. | |
Use this decay function to decay over orders of magnitude. For example, to | |
decay learning rate from 1e-2 to 1e-6. Exponential decay will decay the | |
exponent linearly. | |
config: | |
initial: Decay starts from this value. | |
final: Decay ends at this value. | |
start_time: Step when decay starts. Constant before it. | |
end_time: When decay ends. Constant after it. | |
f(t) is an exponential decay function when start_time <= t <= end_time. The | |
decay rate and amplitude are chosen so that f(t) = initial when | |
t = start_time, and f(t) = final when t = end_time. f(t) is constant for | |
t < start_time or t > end_time. initial and final must be positive values. | |
If start_time == end_time, this becomes a step function. | |
""" | |
def __init__(self, config): | |
super(ExponentialDecaySchedule, self).__init__(config) | |
self.initial = config.initial | |
self.final = config.final | |
self.start_time = config.start_time | |
self.end_time = config.end_time | |
if self.initial <= 0 or self.final <= 0: | |
raise ValueError('initial and final must be positive numbers.') | |
# Linear interpolation in log space. | |
self._linear_fn = LinearDecaySchedule( | |
config_lib.Config( | |
initial=math.log(self.initial), | |
final=math.log(self.final), | |
start_time=self.start_time, | |
end_time=self.end_time)) | |
def __call__(self, global_step): | |
return math.exp(self._linear_fn(global_step)) | |
class SmootherstepDecaySchedule(Schedule): | |
"""Smootherstep decay function. | |
A sigmoidal like transition from initial to final values. A smoother | |
transition than linear and exponential decays, hence the name. | |
See https://en.wikipedia.org/wiki/Smoothstep. | |
config: | |
initial: Decay starts from this value. | |
final: Decay ends at this value. | |
start_time: Step when decay starts. Constant before it. | |
end_time: When decay ends. Constant after it. | |
f(t) is fully defined here: | |
https://en.wikipedia.org/wiki/Smoothstep#Variations. | |
f(t) is smooth, as in its first-derivative exists everywhere. | |
""" | |
def __init__(self, config): | |
super(SmootherstepDecaySchedule, self).__init__(config) | |
self.initial = config.initial | |
self.final = config.final | |
self.start_time = config.start_time | |
self.end_time = config.end_time | |
if self.end_time < self.start_time: | |
raise ValueError('start_time must be before end_time.') | |
self._time_diff = float(self.end_time - self.start_time) | |
self._diff = float(self.final - self.initial) | |
def __call__(self, global_step): | |
if global_step <= self.start_time: | |
return self.initial | |
if global_step > self.end_time: | |
return self.final | |
x = (global_step - self.start_time) / self._time_diff | |
# Smootherstep | |
return self.initial + x * x * x * (x * (x * 6 - 15) + 10) * self._diff | |
class HardOscillatorSchedule(Schedule): | |
"""Hard oscillator function. | |
config: | |
high: Max value of the oscillator. Value at constant plateaus. | |
low: Min value of the oscillator. Value at constant valleys. | |
start_time: Global step when oscillation starts. Constant before this. | |
period: Width of one oscillation, i.e. number of steps over which the | |
oscillation takes place. | |
transition_fraction: Fraction of the period spent transitioning between high | |
and low values. 50% of this time is spent rising, and 50% of this time | |
is spent falling. 50% of the remaining time is spent constant at the | |
high value, and 50% of the remaining time is spent constant at the low | |
value. transition_fraction = 1.0 means the entire period is spent | |
rising and falling. transition_fraction = 0.0 means no time is spent | |
rising and falling, i.e. the function jumps instantaneously between | |
high and low. | |
f(t) = high when t < start_time. | |
f(t) is periodic when t >= start_time, with f(t + period) = f(t). | |
f(t) is linear with positive slope when rising, and negative slope when | |
falling. At the start of the period t0, f(t0) = high and begins to descend. | |
At the middle of the period f is low and is constant until the ascension | |
begins. f then rises from low to high and is constant again until the period | |
repeats. | |
Note: when transition_fraction is 0, f starts the period low and ends high. | |
""" | |
def __init__(self, config): | |
super(HardOscillatorSchedule, self).__init__(config) | |
self.high = config.high | |
self.low = config.low | |
self.start_time = config.start_time | |
self.period = float(config.period) | |
self.transition_fraction = config.transition_fraction | |
self.half_transition_fraction = config.transition_fraction / 2.0 | |
if self.transition_fraction < 0 or self.transition_fraction > 1.0: | |
raise ValueError('transition_fraction must be between 0 and 1.0') | |
if self.period <= 0: | |
raise ValueError('period must be positive') | |
self._slope = ( | |
float(self.high - self.low) / self.half_transition_fraction | |
if self.half_transition_fraction > 0 else float('inf')) | |
def __call__(self, global_step): | |
if global_step < self.start_time: | |
return self.high | |
period_pos = ((global_step - self.start_time) / self.period) % 1.0 | |
if period_pos >= 0.5: | |
# ascending | |
period_pos -= 0.5 | |
if period_pos < self.half_transition_fraction: | |
return self.low + period_pos * self._slope | |
else: | |
return self.high | |
else: | |
# descending | |
if period_pos < self.half_transition_fraction: | |
return self.high - period_pos * self._slope | |
else: | |
return self.low | |
_NAME_TO_CONFIG = { | |
'const': ConstSchedule, | |
'linear_decay': LinearDecaySchedule, | |
'exp_decay': ExponentialDecaySchedule, | |
'smooth_decay': SmootherstepDecaySchedule, | |
'hard_osc': HardOscillatorSchedule, | |
} | |
def make_schedule(config): | |
"""Schedule factory. | |
Given `config` containing a `fn` property, a Schedule implementation is | |
instantiated with `config`. See `_NAME_TO_CONFIG` for `fn` options. | |
Args: | |
config: Config with a `fn` option that specifies which Schedule | |
implementation to use. `config` is passed into the constructor. | |
Returns: | |
A Schedule impl instance. | |
""" | |
schedule_class = _NAME_TO_CONFIG[config.fn] | |
return schedule_class(config) | |