Spaces:
Running
Running
File size: 10,261 Bytes
0b8359d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 |
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
"""Schedule functions for controlling hparams over time."""
from abc import ABCMeta
from abc import abstractmethod
import math
from common import config_lib # brain coder
class Schedule(object):
"""Schedule is a function which sets a hyperparameter's value over time.
For example, a schedule can be used to decay an hparams, or oscillate it over
time.
This object is constructed with an instance of config_lib.Config (will be
specific to each class implementation). For example if this is a decay
schedule, the config may specify the rate of decay and decay start time. Then
the object instance is called like a function, mapping global step (an integer
counting how many calls to the train op have been made) to the hparam value.
Properties of a schedule function f(t):
0) Domain of t is the non-negative integers (t may be 0).
1) Range of f is the reals.
2) Schedule functions can assume that they will be called in time order. This
allows schedules to be stateful.
3) Schedule functions should be deterministic. Two schedule instances with the
same config must always give the same value for each t, and regardless of
what t's it was previously called on. Users may call f(t) on arbitrary
(positive) time jumps. Essentially, multiple schedule instances used in
replica training will behave the same.
4) Duplicate successive calls on the same time are allowed.
"""
__metaclass__ = ABCMeta
@abstractmethod
def __init__(self, config):
"""Construct this schedule with a config specific to each class impl.
Args:
config: An instance of config_lib.Config.
"""
pass
@abstractmethod
def __call__(self, global_step):
"""Map `global_step` to a value.
`global_step` is an integer counting how many calls to the train op have
been made across all replicas (hence why it is global). Implementations
may assume calls to be made in time order, i.e. `global_step` now >=
previous `global_step` values.
Args:
global_step: Non-negative integer.
Returns:
Hparam value at this step. A number.
"""
pass
class ConstSchedule(Schedule):
"""Constant function.
config:
const: Constant value at every step.
f(t) = const.
"""
def __init__(self, config):
super(ConstSchedule, self).__init__(config)
self.const = config.const
def __call__(self, global_step):
return self.const
class LinearDecaySchedule(Schedule):
"""Linear decay function.
config:
initial: Decay starts from this value.
final: Decay ends at this value.
start_time: Step when decay starts. Constant before it.
end_time: When decay ends. Constant after it.
f(t) is a linear function when start_time <= t <= end_time, with slope of
(final - initial) / (end_time - start_time). f(t) = initial
when t <= start_time. f(t) = final when t >= end_time.
If start_time == end_time, this becomes a step function.
"""
def __init__(self, config):
super(LinearDecaySchedule, self).__init__(config)
self.initial = config.initial
self.final = config.final
self.start_time = config.start_time
self.end_time = config.end_time
if self.end_time < self.start_time:
raise ValueError('start_time must be before end_time.')
# Linear interpolation.
self._time_diff = float(self.end_time - self.start_time)
self._diff = float(self.final - self.initial)
self._slope = (
self._diff / self._time_diff if self._time_diff > 0 else float('inf'))
def __call__(self, global_step):
if global_step <= self.start_time:
return self.initial
if global_step > self.end_time:
return self.final
return self.initial + (global_step - self.start_time) * self._slope
class ExponentialDecaySchedule(Schedule):
"""Exponential decay function.
See https://en.wikipedia.org/wiki/Exponential_decay.
Use this decay function to decay over orders of magnitude. For example, to
decay learning rate from 1e-2 to 1e-6. Exponential decay will decay the
exponent linearly.
config:
initial: Decay starts from this value.
final: Decay ends at this value.
start_time: Step when decay starts. Constant before it.
end_time: When decay ends. Constant after it.
f(t) is an exponential decay function when start_time <= t <= end_time. The
decay rate and amplitude are chosen so that f(t) = initial when
t = start_time, and f(t) = final when t = end_time. f(t) is constant for
t < start_time or t > end_time. initial and final must be positive values.
If start_time == end_time, this becomes a step function.
"""
def __init__(self, config):
super(ExponentialDecaySchedule, self).__init__(config)
self.initial = config.initial
self.final = config.final
self.start_time = config.start_time
self.end_time = config.end_time
if self.initial <= 0 or self.final <= 0:
raise ValueError('initial and final must be positive numbers.')
# Linear interpolation in log space.
self._linear_fn = LinearDecaySchedule(
config_lib.Config(
initial=math.log(self.initial),
final=math.log(self.final),
start_time=self.start_time,
end_time=self.end_time))
def __call__(self, global_step):
return math.exp(self._linear_fn(global_step))
class SmootherstepDecaySchedule(Schedule):
"""Smootherstep decay function.
A sigmoidal like transition from initial to final values. A smoother
transition than linear and exponential decays, hence the name.
See https://en.wikipedia.org/wiki/Smoothstep.
config:
initial: Decay starts from this value.
final: Decay ends at this value.
start_time: Step when decay starts. Constant before it.
end_time: When decay ends. Constant after it.
f(t) is fully defined here:
https://en.wikipedia.org/wiki/Smoothstep#Variations.
f(t) is smooth, as in its first-derivative exists everywhere.
"""
def __init__(self, config):
super(SmootherstepDecaySchedule, self).__init__(config)
self.initial = config.initial
self.final = config.final
self.start_time = config.start_time
self.end_time = config.end_time
if self.end_time < self.start_time:
raise ValueError('start_time must be before end_time.')
self._time_diff = float(self.end_time - self.start_time)
self._diff = float(self.final - self.initial)
def __call__(self, global_step):
if global_step <= self.start_time:
return self.initial
if global_step > self.end_time:
return self.final
x = (global_step - self.start_time) / self._time_diff
# Smootherstep
return self.initial + x * x * x * (x * (x * 6 - 15) + 10) * self._diff
class HardOscillatorSchedule(Schedule):
"""Hard oscillator function.
config:
high: Max value of the oscillator. Value at constant plateaus.
low: Min value of the oscillator. Value at constant valleys.
start_time: Global step when oscillation starts. Constant before this.
period: Width of one oscillation, i.e. number of steps over which the
oscillation takes place.
transition_fraction: Fraction of the period spent transitioning between high
and low values. 50% of this time is spent rising, and 50% of this time
is spent falling. 50% of the remaining time is spent constant at the
high value, and 50% of the remaining time is spent constant at the low
value. transition_fraction = 1.0 means the entire period is spent
rising and falling. transition_fraction = 0.0 means no time is spent
rising and falling, i.e. the function jumps instantaneously between
high and low.
f(t) = high when t < start_time.
f(t) is periodic when t >= start_time, with f(t + period) = f(t).
f(t) is linear with positive slope when rising, and negative slope when
falling. At the start of the period t0, f(t0) = high and begins to descend.
At the middle of the period f is low and is constant until the ascension
begins. f then rises from low to high and is constant again until the period
repeats.
Note: when transition_fraction is 0, f starts the period low and ends high.
"""
def __init__(self, config):
super(HardOscillatorSchedule, self).__init__(config)
self.high = config.high
self.low = config.low
self.start_time = config.start_time
self.period = float(config.period)
self.transition_fraction = config.transition_fraction
self.half_transition_fraction = config.transition_fraction / 2.0
if self.transition_fraction < 0 or self.transition_fraction > 1.0:
raise ValueError('transition_fraction must be between 0 and 1.0')
if self.period <= 0:
raise ValueError('period must be positive')
self._slope = (
float(self.high - self.low) / self.half_transition_fraction
if self.half_transition_fraction > 0 else float('inf'))
def __call__(self, global_step):
if global_step < self.start_time:
return self.high
period_pos = ((global_step - self.start_time) / self.period) % 1.0
if period_pos >= 0.5:
# ascending
period_pos -= 0.5
if period_pos < self.half_transition_fraction:
return self.low + period_pos * self._slope
else:
return self.high
else:
# descending
if period_pos < self.half_transition_fraction:
return self.high - period_pos * self._slope
else:
return self.low
_NAME_TO_CONFIG = {
'const': ConstSchedule,
'linear_decay': LinearDecaySchedule,
'exp_decay': ExponentialDecaySchedule,
'smooth_decay': SmootherstepDecaySchedule,
'hard_osc': HardOscillatorSchedule,
}
def make_schedule(config):
"""Schedule factory.
Given `config` containing a `fn` property, a Schedule implementation is
instantiated with `config`. See `_NAME_TO_CONFIG` for `fn` options.
Args:
config: Config with a `fn` option that specifies which Schedule
implementation to use. `config` is passed into the constructor.
Returns:
A Schedule impl instance.
"""
schedule_class = _NAME_TO_CONFIG[config.fn]
return schedule_class(config)
|