Spaces:
Running
on
Zero
Running
on
Zero
import torch as th | |
import numpy as np | |
from functools import partial | |
def expand_t_like_x(t, x): | |
"""Function to reshape time t to broadcastable dimension of x | |
Args: | |
t: [batch_dim,], time vector | |
x: [batch_dim,...], data point | |
""" | |
dims = [1] * (len(x.size()) - 1) | |
t = t.view(t.size(0), *dims) | |
return t | |
#################### Coupling Plans #################### | |
class ICPlan: | |
"""Linear Coupling Plan""" | |
def __init__(self, sigma=0.0): | |
self.sigma = sigma | |
def compute_alpha_t(self, t): | |
"""Compute the data coefficient along the path""" | |
return t, 1 | |
def compute_sigma_t(self, t): | |
"""Compute the noise coefficient along the path""" | |
return 1 - t, -1 | |
def compute_d_alpha_alpha_ratio_t(self, t): | |
"""Compute the ratio between d_alpha and alpha""" | |
return 1 / t | |
def compute_drift(self, x, t): | |
"""We always output sde according to score parametrization; """ | |
t = expand_t_like_x(t, x) | |
alpha_ratio = self.compute_d_alpha_alpha_ratio_t(t) | |
sigma_t, d_sigma_t = self.compute_sigma_t(t) | |
drift = alpha_ratio * x | |
diffusion = alpha_ratio * (sigma_t ** 2) - sigma_t * d_sigma_t | |
return -drift, diffusion | |
def compute_diffusion(self, x, t, form="constant", norm=1.0): | |
"""Compute the diffusion term of the SDE | |
Args: | |
x: [batch_dim, ...], data point | |
t: [batch_dim,], time vector | |
form: str, form of the diffusion term | |
norm: float, norm of the diffusion term | |
""" | |
t = expand_t_like_x(t, x) | |
choices = { | |
"constant": norm, | |
"SBDM": norm * self.compute_drift(x, t)[1], | |
"sigma": norm * self.compute_sigma_t(t)[0], | |
"linear": norm * (1 - t), | |
"decreasing": 0.25 * (norm * th.cos(np.pi * t) + 1) ** 2, | |
"inccreasing-decreasing": norm * th.sin(np.pi * t) ** 2, | |
} | |
try: | |
diffusion = choices[form] | |
except KeyError: | |
raise NotImplementedError(f"Diffusion form {form} not implemented") | |
return diffusion | |
def get_score_from_velocity(self, velocity, x, t): | |
"""Wrapper function: transfrom velocity prediction model to score | |
Args: | |
velocity: [batch_dim, ...] shaped tensor; velocity model output | |
x: [batch_dim, ...] shaped tensor; x_t data point | |
t: [batch_dim,] time tensor | |
""" | |
t = expand_t_like_x(t, x) | |
alpha_t, d_alpha_t = self.compute_alpha_t(t) | |
sigma_t, d_sigma_t = self.compute_sigma_t(t) | |
mean = x | |
reverse_alpha_ratio = alpha_t / d_alpha_t | |
var = sigma_t**2 - reverse_alpha_ratio * d_sigma_t * sigma_t | |
score = (reverse_alpha_ratio * velocity - mean) / var | |
return score | |
def get_noise_from_velocity(self, velocity, x, t): | |
"""Wrapper function: transfrom velocity prediction model to denoiser | |
Args: | |
velocity: [batch_dim, ...] shaped tensor; velocity model output | |
x: [batch_dim, ...] shaped tensor; x_t data point | |
t: [batch_dim,] time tensor | |
""" | |
t = expand_t_like_x(t, x) | |
alpha_t, d_alpha_t = self.compute_alpha_t(t) | |
sigma_t, d_sigma_t = self.compute_sigma_t(t) | |
mean = x | |
reverse_alpha_ratio = alpha_t / d_alpha_t | |
var = reverse_alpha_ratio * d_sigma_t - sigma_t | |
noise = (reverse_alpha_ratio * velocity - mean) / var | |
return noise | |
def get_velocity_from_score(self, score, x, t): | |
"""Wrapper function: transfrom score prediction model to velocity | |
Args: | |
score: [batch_dim, ...] shaped tensor; score model output | |
x: [batch_dim, ...] shaped tensor; x_t data point | |
t: [batch_dim,] time tensor | |
""" | |
t = expand_t_like_x(t, x) | |
drift, var = self.compute_drift(x, t) | |
velocity = var * score - drift | |
return velocity | |
def compute_mu_t(self, t, x0, x1): | |
"""Compute the mean of time-dependent density p_t""" | |
t = expand_t_like_x(t, x1) | |
alpha_t, _ = self.compute_alpha_t(t) | |
sigma_t, _ = self.compute_sigma_t(t) | |
return alpha_t * x1 + sigma_t * x0 | |
def compute_xt(self, t, x0, x1): | |
"""Sample xt from time-dependent density p_t; rng is required""" | |
xt = self.compute_mu_t(t, x0, x1) | |
return xt | |
def compute_ut(self, t, x0, x1, xt): | |
"""Compute the vector field corresponding to p_t""" | |
t = expand_t_like_x(t, x1) | |
_, d_alpha_t = self.compute_alpha_t(t) | |
_, d_sigma_t = self.compute_sigma_t(t) | |
return d_alpha_t * x1 + d_sigma_t * x0 | |
def plan(self, t, x0, x1): | |
xt = self.compute_xt(t, x0, x1) | |
ut = self.compute_ut(t, x0, x1, xt) | |
return t, xt, ut | |
class VPCPlan(ICPlan): | |
"""class for VP path flow matching""" | |
def __init__(self, sigma_min=0.1, sigma_max=20.0): | |
self.sigma_min = sigma_min | |
self.sigma_max = sigma_max | |
self.log_mean_coeff = lambda t: -0.25 * ((1 - t) ** 2) * (self.sigma_max - self.sigma_min) - 0.5 * (1 - t) * self.sigma_min | |
self.d_log_mean_coeff = lambda t: 0.5 * (1 - t) * (self.sigma_max - self.sigma_min) + 0.5 * self.sigma_min | |
def compute_alpha_t(self, t): | |
"""Compute coefficient of x1""" | |
alpha_t = self.log_mean_coeff(t) | |
alpha_t = th.exp(alpha_t) | |
d_alpha_t = alpha_t * self.d_log_mean_coeff(t) | |
return alpha_t, d_alpha_t | |
def compute_sigma_t(self, t): | |
"""Compute coefficient of x0""" | |
p_sigma_t = 2 * self.log_mean_coeff(t) | |
sigma_t = th.sqrt(1 - th.exp(p_sigma_t)) | |
d_sigma_t = th.exp(p_sigma_t) * (2 * self.d_log_mean_coeff(t)) / (-2 * sigma_t) | |
return sigma_t, d_sigma_t | |
def compute_d_alpha_alpha_ratio_t(self, t): | |
"""Special purposed function for computing numerical stabled d_alpha_t / alpha_t""" | |
return self.d_log_mean_coeff(t) | |
def compute_drift(self, x, t): | |
"""Compute the drift term of the SDE""" | |
t = expand_t_like_x(t, x) | |
beta_t = self.sigma_min + (1 - t) * (self.sigma_max - self.sigma_min) | |
return -0.5 * beta_t * x, beta_t / 2 | |
class GVPCPlan(ICPlan): | |
def __init__(self, sigma=0.0): | |
super().__init__(sigma) | |
def compute_alpha_t(self, t): | |
"""Compute coefficient of x1""" | |
alpha_t = th.sin(t * np.pi / 2) | |
d_alpha_t = np.pi / 2 * th.cos(t * np.pi / 2) | |
return alpha_t, d_alpha_t | |
def compute_sigma_t(self, t): | |
"""Compute coefficient of x0""" | |
sigma_t = th.cos(t * np.pi / 2) | |
d_sigma_t = -np.pi / 2 * th.sin(t * np.pi / 2) | |
return sigma_t, d_sigma_t | |
def compute_d_alpha_alpha_ratio_t(self, t): | |
"""Special purposed function for computing numerical stabled d_alpha_t / alpha_t""" | |
return np.pi / (2 * th.tan(t * np.pi / 2)) |