|
""" |
|
Abstract SDE classes, Reverse SDE, and VE/VP SDEs. |
|
|
|
Taken and adapted from https://github.com/yang-song/score_sde_pytorch/blob/1618ddea340f3e4a2ed7852a0694a809775cf8d0/sde_lib.py |
|
""" |
|
import abc |
|
import warnings |
|
import math |
|
import scipy.special as sc |
|
import numpy as np |
|
from geco.util.tensors import batch_broadcast |
|
import torch |
|
|
|
from geco.util.registry import Registry |
|
|
|
|
|
SDERegistry = Registry("SDE") |
|
|
|
|
|
class SDE(abc.ABC): |
|
"""SDE abstract class. Functions are designed for a mini-batch of inputs.""" |
|
|
|
def __init__(self, N): |
|
"""Construct an SDE. |
|
|
|
Args: |
|
N: number of discretization time steps. |
|
""" |
|
super().__init__() |
|
self.N = N |
|
|
|
@property |
|
@abc.abstractmethod |
|
def T(self): |
|
"""End time of the SDE.""" |
|
pass |
|
|
|
@abc.abstractmethod |
|
def sde(self, x, t, *args): |
|
pass |
|
|
|
@abc.abstractmethod |
|
def marginal_prob(self, x, t, *args): |
|
"""Parameters to determine the marginal distribution of the SDE, $p_t(x|args)$.""" |
|
pass |
|
|
|
@abc.abstractmethod |
|
def prior_sampling(self, shape, *args): |
|
"""Generate one sample from the prior distribution, $p_T(x|args)$ with shape `shape`.""" |
|
pass |
|
|
|
@abc.abstractmethod |
|
def prior_logp(self, z): |
|
"""Compute log-density of the prior distribution. |
|
|
|
Useful for computing the log-likelihood via probability flow ODE. |
|
|
|
Args: |
|
z: latent code |
|
Returns: |
|
log probability density |
|
""" |
|
pass |
|
|
|
@staticmethod |
|
@abc.abstractmethod |
|
def add_argparse_args(parent_parser): |
|
""" |
|
Add the necessary arguments for instantiation of this SDE class to an argparse ArgumentParser. |
|
""" |
|
pass |
|
|
|
def discretize(self, x, t, y, stepsize): |
|
"""Discretize the SDE in the form: x_{i+1} = x_i + f_i(x_i) + G_i z_i. |
|
|
|
Useful for reverse diffusion sampling and probabiliy flow sampling. |
|
Defaults to Euler-Maruyama discretization. |
|
|
|
Args: |
|
x: a torch tensor |
|
t: a torch float representing the time step (from 0 to `self.T`) |
|
|
|
Returns: |
|
f, G |
|
""" |
|
dt = stepsize |
|
|
|
drift, diffusion = self.sde(x, t, y) |
|
f = drift * dt |
|
G = diffusion * torch.sqrt(torch.tensor(dt, device=t.device)) |
|
return f, G |
|
|
|
def reverse(oself, score_model, probability_flow=False): |
|
"""Create the reverse-time SDE/ODE. |
|
|
|
Args: |
|
score_model: A function that takes x, t and y and returns the score. |
|
probability_flow: If `True`, create the reverse-time ODE used for probability flow sampling. |
|
""" |
|
N = oself.N |
|
T = oself.T |
|
sde_fn = oself.sde |
|
discretize_fn = oself.discretize |
|
|
|
|
|
class RSDE(oself.__class__): |
|
def __init__(self): |
|
self.N = N |
|
self.probability_flow = probability_flow |
|
|
|
@property |
|
def T(self): |
|
return T |
|
|
|
def sde(self, x, t, *args): |
|
"""Create the drift and diffusion functions for the reverse SDE/ODE.""" |
|
rsde_parts = self.rsde_parts(x, t, *args) |
|
total_drift, diffusion = rsde_parts["total_drift"], rsde_parts["diffusion"] |
|
return total_drift, diffusion |
|
|
|
def discretize(self, x, t, y, m, stepsize): |
|
"""Create discretized iteration rules for the reverse diffusion sampler.""" |
|
f, G = discretize_fn(x, t, y, stepsize) |
|
if torch.is_complex(G): |
|
G = G.imag |
|
rev_f = f - G[:, None, None, None] ** 2 * score_model(x, t, y, m) * (0.5 if self.probability_flow else 1.) |
|
rev_G = torch.zeros_like(G) if self.probability_flow else G |
|
return rev_f, rev_G |
|
|
|
return RSDE() |
|
|
|
@abc.abstractmethod |
|
def copy(self): |
|
pass |
|
|
|
|
|
@SDERegistry.register("bbed") |
|
class BBED(SDE): |
|
@staticmethod |
|
def add_argparse_args(parser): |
|
parser.add_argument("--sde-n", type=int, default=30, help="The number of timesteps in the SDE discretization. 30 by default") |
|
parser.add_argument("--T_sampling", type=float, default=0.999, help="The T so that t < T during sampling in the train step.") |
|
parser.add_argument("--k", type=float, default = 2.6, help="base factor for diffusion term") |
|
parser.add_argument("--theta", type=float, default = 0.52, help="root scale factor for diffusion term.") |
|
return parser |
|
|
|
def __init__(self, T_sampling, k, theta, N=1000, **kwargs): |
|
"""Construct an Brownian Bridge with Exploding Diffusion Coefficient SDE with parameterization as in the paper. |
|
dx = (y-x)/(Tc-t) dt + sqrt(theta)*k^t dw |
|
""" |
|
super().__init__(N) |
|
self.k = k |
|
self.logk = np.log(self.k) |
|
self.theta = theta |
|
self.N = N |
|
self.Eilog = sc.expi(-2*self.logk) |
|
self.T = T_sampling |
|
self.Tc = 1 |
|
|
|
|
|
def copy(self): |
|
return BBED(self.T, self.k, self.theta, N=self.N) |
|
|
|
|
|
def T(self): |
|
return self.T |
|
|
|
def Tc(self): |
|
return self.Tc |
|
|
|
|
|
def sde(self, x, t, y): |
|
drift = (y - x)/(self.Tc - t) |
|
sigma = (self.k) ** t |
|
diffusion = sigma * np.sqrt(self.theta) |
|
return drift, diffusion |
|
|
|
|
|
def _mean(self, x0, t, y): |
|
time = (t/self.Tc)[:, None, None, None] |
|
mean = x0*(1-time) + y*time |
|
return mean |
|
|
|
def _std(self, t): |
|
t_np = t.cpu().detach().numpy() |
|
Eis = sc.expi(2*(t_np-1)*self.logk) - self.Eilog |
|
h = 2*self.k**2*self.logk |
|
var = (self.k**(2*t_np)-1+t_np) + h*(1-t_np)*Eis |
|
var = torch.tensor(var).to(device=t.device)*(1-t)*self.theta |
|
return torch.sqrt(var) |
|
|
|
def marginal_prob(self, x0, t, y): |
|
return self._mean(x0, t, y), self._std(t) |
|
|
|
def prior_sampling(self, shape, y): |
|
if shape != y.shape: |
|
warnings.warn(f"Target shape {shape} does not match shape of y {y.shape}! Ignoring target shape.") |
|
std = self._std(self.T*torch.ones((y.shape[0],), device=y.device)) |
|
z = torch.randn_like(y) |
|
x_T = y + z * std[:, None, None, None] |
|
return x_T, z |
|
|
|
def prior_logp(self, z): |
|
raise NotImplementedError("prior_logp for BBED not yet implemented!") |
|
|