Spaces:
Running
Running
import torch | |
class SigmoidScheduler: | |
def __init__(self, start=-3, end=3, tau=1, clip_min=1e-9): | |
self.start = start | |
self.end = end | |
self.tau = tau | |
self.clip_min = clip_min | |
self.v_start = torch.sigmoid(torch.tensor(self.start / self.tau)) | |
self.v_end = torch.sigmoid(torch.tensor(self.end / self.tau)) | |
def __call__(self, t): | |
output = ( | |
-torch.sigmoid((t * (self.end - self.start) + self.start) / self.tau) | |
+ self.v_end | |
) / (self.v_end - self.v_start) | |
return torch.clamp(output, min=self.clip_min, max=1.0) | |
def derivative(self, t): | |
x = (t * (self.end - self.start) + self.start) / self.tau | |
sigmoid_x = torch.sigmoid(x) | |
# Chain rule: d/dt of original function | |
return ( | |
-(self.end - self.start) | |
* sigmoid_x | |
* (1 - sigmoid_x) | |
/ (self.tau * (self.v_end - self.v_start)) | |
) | |
def alpha(self, t): | |
return -self.derivative(t) / (1e-6 + self.__call__(t)) | |
class LinearScheduler: | |
def __init__(self, start=1, end=0, clip_min=1e-9): | |
self.start = start | |
self.end = end | |
self.clip_min = clip_min | |
def __call__(self, t): | |
output = (self.end - self.start) * t + self.start | |
return torch.clamp(output, min=self.clip_min, max=1.0) | |
def derivative(self, t): | |
return torch.tensor(self.end - self.start).to(t.device) | |
def alpha(self, t): | |
return -self.derivative(t) / (1e-6 + self.__call__(t)) | |
class CosineScheduler: | |
def __init__( | |
self, | |
start: float = 1, | |
end: float = 0, | |
tau: float = 1.0, | |
clip_min: float = 1e-9, | |
): | |
self.start = start | |
self.end = end | |
self.tau = tau | |
self.clip_min = clip_min | |
self.v_start = torch.cos(torch.tensor(self.start) * torch.pi / 2) ** ( | |
2 * self.tau | |
) | |
self.v_end = torch.cos(torch.tensor(self.end) * torch.pi / 2) ** (2 * self.tau) | |
def __call__(self, t: float) -> float: | |
output = ( | |
torch.cos((t * (self.end - self.start) + self.start) * torch.pi / 2) | |
** (2 * self.tau) | |
- self.v_end | |
) / (self.v_start - self.v_end) | |
return torch.clamp(output, min=self.clip_min, max=1.0) | |
def derivative(self, t: float) -> float: | |
x = (t * (self.end - self.start) + self.start) * torch.pi / 2 | |
cos_x = torch.cos(x) | |
# Chain rule: d/dt of original function | |
return ( | |
-2 | |
* self.tau | |
* (self.end - self.start) | |
* torch.pi | |
/ 2 | |
* cos_x | |
* (cos_x ** (2 * self.tau - 1)) | |
* torch.sin(x) | |
/ (self.v_start - self.v_end) | |
) | |
class CosineSchedulerSimple: | |
def __init__(self, ns: float = 0.0002, ds: float = 0.00025): | |
self.ns = ns | |
self.ds = ds | |
def __call__(self, t: float) -> float: | |
return torch.cos(((t + self.ns) / (1 + self.ds)) * torch.pi / 2) ** 2 | |
def derivative(self, t: float) -> float: | |
x = ((t + self.ns) / (1 + self.ds)) * torch.pi / 2 | |
return -torch.pi * torch.cos(x) * torch.sin(x) / (1 + self.ds) | |