""" Abstract SDE classes, Reverse SDE, and VE/VP SDEs. Taken and adapted from https://github.com/yang-song/score_sde_pytorch/blob/1618ddea340f3e4a2ed7852a0694a809775cf8d0/sde_lib.py """ import abc import warnings import math import scipy.special as sc import numpy as np from geco.util.tensors import batch_broadcast import torch from geco.util.registry import Registry SDERegistry = Registry("SDE") class SDE(abc.ABC): """SDE abstract class. Functions are designed for a mini-batch of inputs.""" def __init__(self, N): """Construct an SDE. Args: N: number of discretization time steps. """ super().__init__() self.N = N @property @abc.abstractmethod def T(self): """End time of the SDE.""" pass @abc.abstractmethod def sde(self, x, t, *args): pass @abc.abstractmethod def marginal_prob(self, x, t, *args): """Parameters to determine the marginal distribution of the SDE, $p_t(x|args)$.""" pass @abc.abstractmethod def prior_sampling(self, shape, *args): """Generate one sample from the prior distribution, $p_T(x|args)$ with shape `shape`.""" pass @abc.abstractmethod def prior_logp(self, z): """Compute log-density of the prior distribution. Useful for computing the log-likelihood via probability flow ODE. Args: z: latent code Returns: log probability density """ pass @staticmethod @abc.abstractmethod def add_argparse_args(parent_parser): """ Add the necessary arguments for instantiation of this SDE class to an argparse ArgumentParser. """ pass def discretize(self, x, t, y, stepsize): """Discretize the SDE in the form: x_{i+1} = x_i + f_i(x_i) + G_i z_i. Useful for reverse diffusion sampling and probabiliy flow sampling. Defaults to Euler-Maruyama discretization. Args: x: a torch tensor t: a torch float representing the time step (from 0 to `self.T`) Returns: f, G """ dt = stepsize #dt = 1 /self.N drift, diffusion = self.sde(x, t, y) f = drift * dt G = diffusion * torch.sqrt(torch.tensor(dt, device=t.device)) return f, G def reverse(oself, score_model, probability_flow=False): """Create the reverse-time SDE/ODE. Args: score_model: A function that takes x, t and y and returns the score. probability_flow: If `True`, create the reverse-time ODE used for probability flow sampling. """ N = oself.N T = oself.T sde_fn = oself.sde discretize_fn = oself.discretize # Build the class for reverse-time SDE. class RSDE(oself.__class__): def __init__(self): self.N = N self.probability_flow = probability_flow @property def T(self): return T def sde(self, x, t, *args): """Create the drift and diffusion functions for the reverse SDE/ODE.""" rsde_parts = self.rsde_parts(x, t, *args) total_drift, diffusion = rsde_parts["total_drift"], rsde_parts["diffusion"] return total_drift, diffusion def discretize(self, x, t, y, m, stepsize): """Create discretized iteration rules for the reverse diffusion sampler.""" f, G = discretize_fn(x, t, y, stepsize) if torch.is_complex(G): G = G.imag rev_f = f - G[:, None, None, None] ** 2 * score_model(x, t, y, m) * (0.5 if self.probability_flow else 1.) rev_G = torch.zeros_like(G) if self.probability_flow else G return rev_f, rev_G return RSDE() @abc.abstractmethod def copy(self): pass @SDERegistry.register("bbed") class BBED(SDE): @staticmethod def add_argparse_args(parser): parser.add_argument("--sde-n", type=int, default=30, help="The number of timesteps in the SDE discretization. 30 by default") parser.add_argument("--T_sampling", type=float, default=0.999, help="The T so that t < T during sampling in the train step.") parser.add_argument("--k", type=float, default = 2.6, help="base factor for diffusion term") parser.add_argument("--theta", type=float, default = 0.52, help="root scale factor for diffusion term.") return parser def __init__(self, T_sampling, k, theta, N=1000, **kwargs): """Construct an Brownian Bridge with Exploding Diffusion Coefficient SDE with parameterization as in the paper. dx = (y-x)/(Tc-t) dt + sqrt(theta)*k^t dw """ super().__init__(N) self.k = k self.logk = np.log(self.k) self.theta = theta self.N = N self.Eilog = sc.expi(-2*self.logk) self.T = T_sampling #for sampling in train step and inference self.Tc = 1 #for constructing the SDE, dont change this def copy(self): return BBED(self.T, self.k, self.theta, N=self.N) def T(self): return self.T def Tc(self): return self.Tc def sde(self, x, t, y): drift = (y - x)/(self.Tc - t) sigma = (self.k) ** t diffusion = sigma * np.sqrt(self.theta) return drift, diffusion def _mean(self, x0, t, y): time = (t/self.Tc)[:, None, None, None] mean = x0*(1-time) + y*time return mean def _std(self, t): t_np = t.cpu().detach().numpy() Eis = sc.expi(2*(t_np-1)*self.logk) - self.Eilog h = 2*self.k**2*self.logk var = (self.k**(2*t_np)-1+t_np) + h*(1-t_np)*Eis var = torch.tensor(var).to(device=t.device)*(1-t)*self.theta return torch.sqrt(var) def marginal_prob(self, x0, t, y): return self._mean(x0, t, y), self._std(t) def prior_sampling(self, shape, y): if shape != y.shape: warnings.warn(f"Target shape {shape} does not match shape of y {y.shape}! Ignoring target shape.") std = self._std(self.T*torch.ones((y.shape[0],), device=y.device)) z = torch.randn_like(y) x_T = y + z * std[:, None, None, None] return x_T, z def prior_logp(self, z): raise NotImplementedError("prior_logp for BBED not yet implemented!")