|
|
|
"""Various sampling methods."""
|
|
from scipy import integrate
|
|
import torch
|
|
|
|
from .predictors import Predictor, PredictorRegistry, ReverseDiffusionPredictor
|
|
from .correctors import Corrector, CorrectorRegistry
|
|
|
|
|
|
__all__ = [
|
|
'PredictorRegistry', 'CorrectorRegistry', 'Predictor', 'Corrector',
|
|
'get_sampler'
|
|
]
|
|
|
|
|
|
def to_flattened_numpy(x):
|
|
"""Flatten a torch tensor `x` and convert it to numpy."""
|
|
return x.detach().cpu().numpy().reshape((-1,))
|
|
|
|
|
|
def from_flattened_numpy(x, shape):
|
|
"""Form a torch tensor with the given `shape` from a flattened numpy array `x`."""
|
|
return torch.from_numpy(x.reshape(shape))
|
|
|
|
|
|
def get_pc_sampler(
|
|
predictor_name, corrector_name, sde, score_fn, y,
|
|
denoise=True, eps=3e-2, snr=0.1, corrector_steps=1, probability_flow: bool = False,
|
|
intermediate=False, **kwargs
|
|
):
|
|
"""Create a Predictor-Corrector (PC) sampler.
|
|
|
|
Args:
|
|
predictor_name: The name of a registered `sampling.Predictor`.
|
|
corrector_name: The name of a registered `sampling.Corrector`.
|
|
sde: An `sdes.SDE` object representing the forward SDE.
|
|
score_fn: A function (typically learned model) that predicts the score.
|
|
y: A `torch.Tensor`, representing the (non-white-)noisy starting point(s) to condition the prior on.
|
|
denoise: If `True`, add one-step denoising to the final samples.
|
|
eps: A `float` number. The reverse-time SDE and ODE are integrated to `epsilon` to avoid numerical issues.
|
|
snr: The SNR to use for the corrector. 0.1 by default, and ignored for `NoneCorrector`.
|
|
N: The number of reverse sampling steps. If `None`, uses the SDE's `N` property by default.
|
|
|
|
Returns:
|
|
A sampling function that returns samples and the number of function evaluations during sampling.
|
|
"""
|
|
predictor_cls = PredictorRegistry.get_by_name(predictor_name)
|
|
corrector_cls = CorrectorRegistry.get_by_name(corrector_name)
|
|
predictor = predictor_cls(sde, score_fn, probability_flow=probability_flow)
|
|
corrector = corrector_cls(sde, score_fn, snr=snr, n_steps=corrector_steps)
|
|
|
|
def pc_sampler():
|
|
"""The PC sampler function."""
|
|
with torch.no_grad():
|
|
xt = sde.prior_sampling(y.shape, y).to(y.device)
|
|
timesteps = torch.linspace(sde.T, eps, sde.N, device=y.device)
|
|
for i in range(sde.N):
|
|
t = timesteps[i]
|
|
if i != len(timesteps) - 1:
|
|
stepsize = t - timesteps[i+1]
|
|
else:
|
|
stepsize = timesteps[-1]
|
|
vec_t = torch.ones(y.shape[0], device=y.device) * t
|
|
xt, xt_mean = corrector.update_fn(xt, y, vec_t)
|
|
xt, xt_mean = predictor.update_fn(xt, y, vec_t, stepsize)
|
|
x_result = xt_mean if denoise else xt
|
|
ns = sde.N * (corrector.n_steps + 1)
|
|
return x_result, ns
|
|
|
|
return pc_sampler
|
|
|
|
|
|
def get_ode_sampler(
|
|
sde, score_fn, y, inverse_scaler=None,
|
|
denoise=True, rtol=1e-5, atol=1e-5,
|
|
method='RK45', eps=3e-2, device='cuda', **kwargs
|
|
):
|
|
"""Probability flow ODE sampler with the black-box ODE solver.
|
|
|
|
Args:
|
|
sde: An `sdes.SDE` object representing the forward SDE.
|
|
score_fn: A function (typically learned model) that predicts the score.
|
|
y: A `torch.Tensor`, representing the (non-white-)noisy starting point(s) to condition the prior on.
|
|
inverse_scaler: The inverse data normalizer.
|
|
denoise: If `True`, add one-step denoising to final samples.
|
|
rtol: A `float` number. The relative tolerance level of the ODE solver.
|
|
atol: A `float` number. The absolute tolerance level of the ODE solver.
|
|
method: A `str`. The algorithm used for the black-box ODE solver.
|
|
See the documentation of `scipy.integrate.solve_ivp`.
|
|
eps: A `float` number. The reverse-time SDE/ODE will be integrated to `eps` for numerical stability.
|
|
device: PyTorch device.
|
|
|
|
Returns:
|
|
A sampling function that returns samples and the number of function evaluations during sampling.
|
|
"""
|
|
predictor = ReverseDiffusionPredictor(sde, score_fn, probability_flow=False)
|
|
rsde = sde.reverse(score_fn, probability_flow=True)
|
|
|
|
def denoise_update_fn(x):
|
|
vec_eps = torch.ones(x.shape[0], device=x.device) * eps
|
|
_, x = predictor.update_fn(x, y, vec_eps)
|
|
return x
|
|
|
|
def drift_fn(x, y, t):
|
|
"""Get the drift function of the reverse-time SDE."""
|
|
return rsde.sde(x, y, t)[0]
|
|
|
|
def ode_sampler(z=None, **kwargs):
|
|
"""The probability flow ODE sampler with black-box ODE solver.
|
|
|
|
Args:
|
|
model: A score model.
|
|
z: If present, generate samples from latent code `z`.
|
|
Returns:
|
|
samples, number of function evaluations.
|
|
"""
|
|
with torch.no_grad():
|
|
|
|
x = sde.prior_sampling(y.shape, y).to(device)
|
|
|
|
def ode_func(t, x):
|
|
x = from_flattened_numpy(x, y.shape).to(device).type(torch.complex64)
|
|
vec_t = torch.ones(y.shape[0], device=x.device) * t
|
|
drift = drift_fn(x, y, vec_t)
|
|
return to_flattened_numpy(drift)
|
|
|
|
|
|
solution = integrate.solve_ivp(
|
|
ode_func, (sde.T, eps), to_flattened_numpy(x),
|
|
rtol=rtol, atol=atol, method=method, **kwargs
|
|
)
|
|
nfe = solution.nfev
|
|
x = torch.tensor(solution.y[:, -1]).reshape(y.shape).to(device).type(torch.complex64)
|
|
|
|
|
|
if denoise:
|
|
x = denoise_update_fn(x)
|
|
|
|
if inverse_scaler is not None:
|
|
x = inverse_scaler(x)
|
|
return x, nfe
|
|
|
|
return ode_sampler
|
|
|
|
def get_sb_sampler(sde, model, y, eps=1e-4, n_steps=50, sampler_type="ode", **kwargs):
|
|
|
|
def sde_sampler():
|
|
"""The SB-SDE sampler function."""
|
|
with torch.no_grad():
|
|
xt = y[:, [0], :, :]
|
|
time_steps = torch.linspace(sde.T, eps, sde.N + 1, device=y.device)
|
|
|
|
|
|
time_prev = time_steps[0] * torch.ones(xt.shape[0], device=xt.device)
|
|
sigma_prev, sigma_T, sigma_bar_prev, alpha_prev, alpha_T, alpha_bar_prev = sde._sigmas_alphas(time_prev)
|
|
|
|
for t in time_steps[1:]:
|
|
|
|
time = t * torch.ones(xt.shape[0], device=xt.device)
|
|
|
|
|
|
sigma_t, sigma_T, sigma_bart, alpha_t, alpha_T, alpha_bart = sde._sigmas_alphas(time)
|
|
|
|
|
|
current_estimate = model(xt, y, time)
|
|
|
|
|
|
weight_prev = alpha_t * sigma_t**2 / (alpha_prev * sigma_prev**2 + sde.eps)
|
|
tmp = 1 - sigma_t**2 / (sigma_prev**2 + sde.eps)
|
|
weight_estimate = alpha_t * tmp
|
|
weight_z = alpha_t * sigma_t * torch.sqrt(tmp)
|
|
|
|
|
|
weight_prev = weight_prev[:, None, None, None]
|
|
weight_estimate = weight_estimate[:, None, None, None]
|
|
weight_z = weight_z[:, None, None, None]
|
|
|
|
|
|
z_norm = torch.randn_like(xt)
|
|
|
|
if t == time_steps[-1]:
|
|
weight_z = 0.0
|
|
|
|
|
|
xt = weight_prev * xt + weight_estimate * current_estimate + weight_z * z_norm
|
|
|
|
|
|
time_prev = time
|
|
alpha_prev = alpha_t
|
|
sigma_prev = sigma_t
|
|
sigma_bar_prev = sigma_bart
|
|
|
|
return xt, n_steps
|
|
|
|
def ode_sampler():
|
|
"""The SB-ODE sampler function."""
|
|
with torch.no_grad():
|
|
xt = y
|
|
time_steps = torch.linspace(sde.T, eps, sde.N + 1, device=y.device)
|
|
|
|
|
|
time_prev = time_steps[0] * torch.ones(xt.shape[0], device=xt.device)
|
|
sigma_prev, sigma_T, sigma_bar_prev, alpha_prev, alpha_T, alpha_bar_prev = sde._sigmas_alphas(time_prev)
|
|
|
|
for t in time_steps[1:]:
|
|
|
|
time = t * torch.ones(xt.shape[0], device=xt.device)
|
|
|
|
|
|
sigma_t, sigma_T, sigma_bart, alpha_t, alpha_T, alpha_bart = sde._sigmas_alphas(time)
|
|
|
|
|
|
current_estimate = model(xt, y, time)
|
|
|
|
|
|
weight_prev = alpha_t * sigma_t * sigma_bart / (alpha_prev * sigma_prev * sigma_bar_prev + sde.eps)
|
|
weight_estimate = (
|
|
alpha_t
|
|
/ (sigma_T**2 + sde.eps)
|
|
* (sigma_bart**2 - sigma_bar_prev * sigma_t * sigma_bart / (sigma_prev + sde.eps))
|
|
)
|
|
weight_prior_mean = (
|
|
alpha_t
|
|
/ (alpha_T * sigma_T**2 + sde.eps)
|
|
* (sigma_t**2 - sigma_prev * sigma_t * sigma_bart / (sigma_bar_prev + sde.eps))
|
|
)
|
|
|
|
|
|
weight_prev = weight_prev[:, None, None, None]
|
|
weight_estimate = weight_estimate[:, None, None, None]
|
|
weight_prior_mean = weight_prior_mean[:, None, None, None]
|
|
|
|
|
|
xt = weight_prev * xt + weight_estimate * current_estimate + weight_prior_mean * y
|
|
|
|
|
|
time_prev = time
|
|
alpha_prev = alpha_t
|
|
sigma_prev = sigma_t
|
|
sigma_bar_prev = sigma_bart
|
|
|
|
return xt, n_steps
|
|
|
|
if sampler_type == "sde":
|
|
return sde_sampler
|
|
elif sampler_type == "ode":
|
|
return ode_sampler
|
|
else:
|
|
raise ValueError("Invalid type. Choose 'ode' or 'sde'.")
|
|
|