import torch as th import numpy as np from functools import partial def expand_t_like_x(t, x): """Function to reshape time t to broadcastable dimension of x Args: t: [batch_dim,], time vector x: [batch_dim,...], data point """ dims = [1] * (len(x.size()) - 1) t = t.view(t.size(0), *dims) return t #################### Coupling Plans #################### class ICPlan: """Linear Coupling Plan""" def __init__(self, sigma=0.0): self.sigma = sigma def compute_alpha_t(self, t): """Compute the data coefficient along the path""" return t, 1 def compute_sigma_t(self, t): """Compute the noise coefficient along the path""" return 1 - t, -1 def compute_d_alpha_alpha_ratio_t(self, t): """Compute the ratio between d_alpha and alpha""" return 1 / t def compute_drift(self, x, t): """We always output sde according to score parametrization; """ t = expand_t_like_x(t, x) alpha_ratio = self.compute_d_alpha_alpha_ratio_t(t) sigma_t, d_sigma_t = self.compute_sigma_t(t) drift = alpha_ratio * x diffusion = alpha_ratio * (sigma_t ** 2) - sigma_t * d_sigma_t return -drift, diffusion def compute_diffusion(self, x, t, form="constant", norm=1.0): """Compute the diffusion term of the SDE Args: x: [batch_dim, ...], data point t: [batch_dim,], time vector form: str, form of the diffusion term norm: float, norm of the diffusion term """ t = expand_t_like_x(t, x) choices = { "constant": norm, "SBDM": norm * self.compute_drift(x, t)[1], "sigma": norm * self.compute_sigma_t(t)[0], "linear": norm * (1 - t), "decreasing": 0.25 * (norm * th.cos(np.pi * t) + 1) ** 2, "inccreasing-decreasing": norm * th.sin(np.pi * t) ** 2, } try: diffusion = choices[form] except KeyError: raise NotImplementedError(f"Diffusion form {form} not implemented") return diffusion def get_score_from_velocity(self, velocity, x, t): """Wrapper function: transfrom velocity prediction model to score Args: velocity: [batch_dim, ...] shaped tensor; velocity model output x: [batch_dim, ...] shaped tensor; x_t data point t: [batch_dim,] time tensor """ t = expand_t_like_x(t, x) alpha_t, d_alpha_t = self.compute_alpha_t(t) sigma_t, d_sigma_t = self.compute_sigma_t(t) mean = x reverse_alpha_ratio = alpha_t / d_alpha_t var = sigma_t**2 - reverse_alpha_ratio * d_sigma_t * sigma_t score = (reverse_alpha_ratio * velocity - mean) / var return score def get_noise_from_velocity(self, velocity, x, t): """Wrapper function: transfrom velocity prediction model to denoiser Args: velocity: [batch_dim, ...] shaped tensor; velocity model output x: [batch_dim, ...] shaped tensor; x_t data point t: [batch_dim,] time tensor """ t = expand_t_like_x(t, x) alpha_t, d_alpha_t = self.compute_alpha_t(t) sigma_t, d_sigma_t = self.compute_sigma_t(t) mean = x reverse_alpha_ratio = alpha_t / d_alpha_t var = reverse_alpha_ratio * d_sigma_t - sigma_t noise = (reverse_alpha_ratio * velocity - mean) / var return noise def get_velocity_from_score(self, score, x, t): """Wrapper function: transfrom score prediction model to velocity Args: score: [batch_dim, ...] shaped tensor; score model output x: [batch_dim, ...] shaped tensor; x_t data point t: [batch_dim,] time tensor """ t = expand_t_like_x(t, x) drift, var = self.compute_drift(x, t) velocity = var * score - drift return velocity def get_xstart_from_velocity(self, velocity, x_t, t, return_snr=False): """Wrapper function: transfrom velocity prediction model to score Args: velocity: [batch_dim, ...] shaped tensor; velocity model output x_t: [batch_dim, ...] shaped tensor; x_t data point t: [batch_dim,] time tensor """ t = expand_t_like_x(t, x_t) alpha_t, _ = self.compute_alpha_t(t) sigma_t, _ = self.compute_sigma_t(t) # mean = x_t x_start = (sigma_t * velocity + x_t) / (alpha_t + sigma_t) # reverse_alpha_ratio = alpha_t / d_alpha_t # var = sigma_t**2 - reverse_alpha_ratio * d_sigma_t * sigma_t # score = (reverse_alpha_ratio * velocity - mean) / var if return_snr: snr = (alpha_t / sigma_t) ** 2 return x_start, snr else: return x_start def compute_mu_t(self, t, x0, x1): """Compute the mean of time-dependent density p_t""" t = expand_t_like_x(t, x1) alpha_t, _ = self.compute_alpha_t(t) sigma_t, _ = self.compute_sigma_t(t) return alpha_t * x1 + sigma_t * x0 def compute_xt(self, t, x0, x1): """Sample xt from time-dependent density p_t; rng is required""" xt = self.compute_mu_t(t, x0, x1) return xt def compute_ut(self, t, x0, x1, xt): """Compute the vector field corresponding to p_t""" t = expand_t_like_x(t, x1) _, d_alpha_t = self.compute_alpha_t(t) _, d_sigma_t = self.compute_sigma_t(t) return d_alpha_t * x1 + d_sigma_t * x0 def plan(self, t, x0, x1): xt = self.compute_xt(t, x0, x1) ut = self.compute_ut(t, x0, x1, xt) return t, xt, ut class VPCPlan(ICPlan): """class for VP path flow matching""" def __init__(self, sigma_min=0.1, sigma_max=20.0): self.sigma_min = sigma_min self.sigma_max = sigma_max self.log_mean_coeff = lambda t: -0.25 * ((1 - t) ** 2) * (self.sigma_max - self.sigma_min) - 0.5 * (1 - t) * self.sigma_min self.d_log_mean_coeff = lambda t: 0.5 * (1 - t) * (self.sigma_max - self.sigma_min) + 0.5 * self.sigma_min def compute_alpha_t(self, t): """Compute coefficient of x1""" alpha_t = self.log_mean_coeff(t) alpha_t = th.exp(alpha_t) d_alpha_t = alpha_t * self.d_log_mean_coeff(t) return alpha_t, d_alpha_t def compute_sigma_t(self, t): """Compute coefficient of x0""" p_sigma_t = 2 * self.log_mean_coeff(t) sigma_t = th.sqrt(1 - th.exp(p_sigma_t)) d_sigma_t = th.exp(p_sigma_t) * (2 * self.d_log_mean_coeff(t)) / (-2 * sigma_t) return sigma_t, d_sigma_t def compute_d_alpha_alpha_ratio_t(self, t): """Special purposed function for computing numerical stabled d_alpha_t / alpha_t""" return self.d_log_mean_coeff(t) def compute_drift(self, x, t): """Compute the drift term of the SDE""" t = expand_t_like_x(t, x) beta_t = self.sigma_min + (1 - t) * (self.sigma_max - self.sigma_min) return -0.5 * beta_t * x, beta_t / 2 class GVPCPlan(ICPlan): def __init__(self, sigma=0.0): super().__init__(sigma) def compute_alpha_t(self, t): """Compute coefficient of x1""" alpha_t = th.sin(t * np.pi / 2) d_alpha_t = np.pi / 2 * th.cos(t * np.pi / 2) return alpha_t, d_alpha_t def compute_sigma_t(self, t): """Compute coefficient of x0""" sigma_t = th.cos(t * np.pi / 2) d_sigma_t = -np.pi / 2 * th.sin(t * np.pi / 2) return sigma_t, d_sigma_t def compute_d_alpha_alpha_ratio_t(self, t): """Special purposed function for computing numerical stabled d_alpha_t / alpha_t""" return np.pi / (2 * th.tan(t * np.pi / 2))