baiyanlali-zhao's picture
init
eaf2e33
raw
history blame
1.87 kB
"""
General utility functions for machine learning.
"""
import abc
import math
import numpy as np
class ScalarSchedule(object, metaclass=abc.ABCMeta):
@abc.abstractmethod
def get_value(self, t):
pass
class ConstantSchedule(ScalarSchedule):
def __init__(self, value):
self._value = value
def get_value(self, t):
return self._value
class LinearSchedule(ScalarSchedule):
"""
Linearly interpolate and then stop at a final value.
"""
def __init__(
self,
init_value,
final_value,
ramp_duration,
):
self._init_value = init_value
self._final_value = final_value
self._ramp_duration = ramp_duration
def get_value(self, t):
return (
self._init_value
+ (self._final_value - self._init_value)
* min(1.0, t * 1.0 / self._ramp_duration)
)
class IntLinearSchedule(LinearSchedule):
"""
Same as RampUpSchedule but round output to an int
"""
def get_value(self, t):
return int(super().get_value(t))
class PiecewiseLinearSchedule(ScalarSchedule):
"""
Given a list of (x, t) value-time pairs, return value x at time t,
and linearly interpolate between the two
"""
def __init__(
self,
x_values,
y_values,
):
self._x_values = x_values
self._y_values = y_values
def get_value(self, t):
return np.interp(t, self._x_values, self._y_values)
class IntPiecewiseLinearSchedule(PiecewiseLinearSchedule):
def get_value(self, t):
return int(super().get_value(t))
def none_to_infty(bounds):
if bounds is None:
bounds = -math.inf, math.inf
lb, ub = bounds
if lb is None:
lb = -math.inf
if ub is None:
ub = math.inf
return lb, ub