Fast-GeCo / geco /sdes.py
anonymous9a7b
1
d4c980e
"""
Abstract SDE classes, Reverse SDE, and VE/VP SDEs.
Taken and adapted from https://github.com/yang-song/score_sde_pytorch/blob/1618ddea340f3e4a2ed7852a0694a809775cf8d0/sde_lib.py
"""
import abc
import warnings
import math
import scipy.special as sc
import numpy as np
from geco.util.tensors import batch_broadcast
import torch
from geco.util.registry import Registry
SDERegistry = Registry("SDE")
class SDE(abc.ABC):
"""SDE abstract class. Functions are designed for a mini-batch of inputs."""
def __init__(self, N):
"""Construct an SDE.
Args:
N: number of discretization time steps.
"""
super().__init__()
self.N = N
@property
@abc.abstractmethod
def T(self):
"""End time of the SDE."""
pass
@abc.abstractmethod
def sde(self, x, t, *args):
pass
@abc.abstractmethod
def marginal_prob(self, x, t, *args):
"""Parameters to determine the marginal distribution of the SDE, $p_t(x|args)$."""
pass
@abc.abstractmethod
def prior_sampling(self, shape, *args):
"""Generate one sample from the prior distribution, $p_T(x|args)$ with shape `shape`."""
pass
@abc.abstractmethod
def prior_logp(self, z):
"""Compute log-density of the prior distribution.
Useful for computing the log-likelihood via probability flow ODE.
Args:
z: latent code
Returns:
log probability density
"""
pass
@staticmethod
@abc.abstractmethod
def add_argparse_args(parent_parser):
"""
Add the necessary arguments for instantiation of this SDE class to an argparse ArgumentParser.
"""
pass
def discretize(self, x, t, y, stepsize):
"""Discretize the SDE in the form: x_{i+1} = x_i + f_i(x_i) + G_i z_i.
Useful for reverse diffusion sampling and probabiliy flow sampling.
Defaults to Euler-Maruyama discretization.
Args:
x: a torch tensor
t: a torch float representing the time step (from 0 to `self.T`)
Returns:
f, G
"""
dt = stepsize
#dt = 1 /self.N
drift, diffusion = self.sde(x, t, y)
f = drift * dt
G = diffusion * torch.sqrt(torch.tensor(dt, device=t.device))
return f, G
def reverse(oself, score_model, probability_flow=False):
"""Create the reverse-time SDE/ODE.
Args:
score_model: A function that takes x, t and y and returns the score.
probability_flow: If `True`, create the reverse-time ODE used for probability flow sampling.
"""
N = oself.N
T = oself.T
sde_fn = oself.sde
discretize_fn = oself.discretize
# Build the class for reverse-time SDE.
class RSDE(oself.__class__):
def __init__(self):
self.N = N
self.probability_flow = probability_flow
@property
def T(self):
return T
def sde(self, x, t, *args):
"""Create the drift and diffusion functions for the reverse SDE/ODE."""
rsde_parts = self.rsde_parts(x, t, *args)
total_drift, diffusion = rsde_parts["total_drift"], rsde_parts["diffusion"]
return total_drift, diffusion
def discretize(self, x, t, y, m, stepsize):
"""Create discretized iteration rules for the reverse diffusion sampler."""
f, G = discretize_fn(x, t, y, stepsize)
if torch.is_complex(G):
G = G.imag
rev_f = f - G[:, None, None, None] ** 2 * score_model(x, t, y, m) * (0.5 if self.probability_flow else 1.)
rev_G = torch.zeros_like(G) if self.probability_flow else G
return rev_f, rev_G
return RSDE()
@abc.abstractmethod
def copy(self):
pass
@SDERegistry.register("bbed")
class BBED(SDE):
@staticmethod
def add_argparse_args(parser):
parser.add_argument("--sde-n", type=int, default=30, help="The number of timesteps in the SDE discretization. 30 by default")
parser.add_argument("--T_sampling", type=float, default=0.999, help="The T so that t < T during sampling in the train step.")
parser.add_argument("--k", type=float, default = 2.6, help="base factor for diffusion term")
parser.add_argument("--theta", type=float, default = 0.52, help="root scale factor for diffusion term.")
return parser
def __init__(self, T_sampling, k, theta, N=1000, **kwargs):
"""Construct an Brownian Bridge with Exploding Diffusion Coefficient SDE with parameterization as in the paper.
dx = (y-x)/(Tc-t) dt + sqrt(theta)*k^t dw
"""
super().__init__(N)
self.k = k
self.logk = np.log(self.k)
self.theta = theta
self.N = N
self.Eilog = sc.expi(-2*self.logk)
self.T = T_sampling #for sampling in train step and inference
self.Tc = 1 #for constructing the SDE, dont change this
def copy(self):
return BBED(self.T, self.k, self.theta, N=self.N)
def T(self):
return self.T
def Tc(self):
return self.Tc
def sde(self, x, t, y):
drift = (y - x)/(self.Tc - t)
sigma = (self.k) ** t
diffusion = sigma * np.sqrt(self.theta)
return drift, diffusion
def _mean(self, x0, t, y):
time = (t/self.Tc)[:, None, None, None]
mean = x0*(1-time) + y*time
return mean
def _std(self, t):
t_np = t.cpu().detach().numpy()
Eis = sc.expi(2*(t_np-1)*self.logk) - self.Eilog
h = 2*self.k**2*self.logk
var = (self.k**(2*t_np)-1+t_np) + h*(1-t_np)*Eis
var = torch.tensor(var).to(device=t.device)*(1-t)*self.theta
return torch.sqrt(var)
def marginal_prob(self, x0, t, y):
return self._mean(x0, t, y), self._std(t)
def prior_sampling(self, shape, y):
if shape != y.shape:
warnings.warn(f"Target shape {shape} does not match shape of y {y.shape}! Ignoring target shape.")
std = self._std(self.T*torch.ones((y.shape[0],), device=y.device))
z = torch.randn_like(y)
x_T = y + z * std[:, None, None, None]
return x_T, z
def prior_logp(self, z):
raise NotImplementedError("prior_logp for BBED not yet implemented!")