|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""All functions and modules related to model definition.
|
|
"""
|
|
|
|
import torch
|
|
|
|
import numpy as np
|
|
from ...sdes import OUVESDE, OUVPSDE
|
|
|
|
|
|
_MODELS = {}
|
|
|
|
|
|
def register_model(cls=None, *, name=None):
|
|
"""A decorator for registering model classes."""
|
|
|
|
def _register(cls):
|
|
if name is None:
|
|
local_name = cls.__name__
|
|
else:
|
|
local_name = name
|
|
if local_name in _MODELS:
|
|
raise ValueError(f'Already registered model with name: {local_name}')
|
|
_MODELS[local_name] = cls
|
|
return cls
|
|
|
|
if cls is None:
|
|
return _register
|
|
else:
|
|
return _register(cls)
|
|
|
|
|
|
def get_model(name):
|
|
return _MODELS[name]
|
|
|
|
|
|
def get_sigmas(sigma_min, sigma_max, num_scales):
|
|
"""Get sigmas --- the set of noise levels for SMLD from config files.
|
|
Args:
|
|
config: A ConfigDict object parsed from the config file
|
|
Returns:
|
|
sigmas: a jax numpy arrary of noise levels
|
|
"""
|
|
sigmas = np.exp(
|
|
np.linspace(np.log(sigma_max), np.log(sigma_min), num_scales))
|
|
|
|
return sigmas
|
|
|
|
|
|
def get_ddpm_params(config):
|
|
"""Get betas and alphas --- parameters used in the original DDPM paper."""
|
|
num_diffusion_timesteps = 1000
|
|
|
|
beta_start = config.model.beta_min / config.model.num_scales
|
|
beta_end = config.model.beta_max / config.model.num_scales
|
|
betas = np.linspace(beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64)
|
|
|
|
alphas = 1. - betas
|
|
alphas_cumprod = np.cumprod(alphas, axis=0)
|
|
sqrt_alphas_cumprod = np.sqrt(alphas_cumprod)
|
|
sqrt_1m_alphas_cumprod = np.sqrt(1. - alphas_cumprod)
|
|
|
|
return {
|
|
'betas': betas,
|
|
'alphas': alphas,
|
|
'alphas_cumprod': alphas_cumprod,
|
|
'sqrt_alphas_cumprod': sqrt_alphas_cumprod,
|
|
'sqrt_1m_alphas_cumprod': sqrt_1m_alphas_cumprod,
|
|
'beta_min': beta_start * (num_diffusion_timesteps - 1),
|
|
'beta_max': beta_end * (num_diffusion_timesteps - 1),
|
|
'num_diffusion_timesteps': num_diffusion_timesteps
|
|
}
|
|
|
|
|
|
def create_model(config):
|
|
"""Create the score model."""
|
|
model_name = config.model.name
|
|
score_model = get_model(model_name)(config)
|
|
score_model = score_model.to(config.device)
|
|
score_model = torch.nn.DataParallel(score_model)
|
|
return score_model
|
|
|
|
|
|
def get_model_fn(model, train=False):
|
|
"""Create a function to give the output of the score-based model.
|
|
|
|
Args:
|
|
model: The score model.
|
|
train: `True` for training and `False` for evaluation.
|
|
|
|
Returns:
|
|
A model function.
|
|
"""
|
|
|
|
def model_fn(x, labels):
|
|
"""Compute the output of the score-based model.
|
|
|
|
Args:
|
|
x: A mini-batch of input data.
|
|
labels: A mini-batch of conditioning variables for time steps. Should be interpreted differently
|
|
for different models.
|
|
|
|
Returns:
|
|
A tuple of (model output, new mutable states)
|
|
"""
|
|
if not train:
|
|
model.eval()
|
|
return model(x, labels)
|
|
else:
|
|
model.train()
|
|
return model(x, labels)
|
|
|
|
return model_fn
|
|
|
|
|
|
def get_score_fn(sde, model, train=False, continuous=False):
|
|
"""Wraps `score_fn` so that the model output corresponds to a real time-dependent score function.
|
|
|
|
Args:
|
|
sde: An `sde_lib.SDE` object that represents the forward SDE.
|
|
model: A score model.
|
|
train: `True` for training and `False` for evaluation.
|
|
continuous: If `True`, the score-based model is expected to directly take continuous time steps.
|
|
|
|
Returns:
|
|
A score function.
|
|
"""
|
|
model_fn = get_model_fn(model, train=train)
|
|
|
|
if isinstance(sde, OUVPSDE):
|
|
def score_fn(x, t):
|
|
|
|
if continuous:
|
|
|
|
|
|
|
|
labels = t * 999
|
|
score = model_fn(x, labels)
|
|
std = sde.marginal_prob(torch.zeros_like(x), t)[1]
|
|
else:
|
|
|
|
labels = t * (sde.N - 1)
|
|
score = model_fn(x, labels)
|
|
std = sde.sqrt_1m_alphas_cumprod.to(labels.device)[labels.long()]
|
|
|
|
score = -score / std[:, None, None, None]
|
|
return score
|
|
|
|
elif isinstance(sde, OUVESDE):
|
|
def score_fn(x, t):
|
|
if continuous:
|
|
labels = sde.marginal_prob(torch.zeros_like(x), t)[1]
|
|
else:
|
|
|
|
labels = sde.T - t
|
|
labels *= sde.N - 1
|
|
labels = torch.round(labels).long()
|
|
|
|
score = model_fn(x, labels)
|
|
return score
|
|
|
|
else:
|
|
raise NotImplementedError(f"SDE class {sde.__class__.__name__} not yet supported.")
|
|
|
|
return score_fn
|
|
|
|
|
|
def to_flattened_numpy(x):
|
|
"""Flatten a torch tensor `x` and convert it to numpy."""
|
|
return x.detach().cpu().numpy().reshape((-1,))
|
|
|
|
|
|
def from_flattened_numpy(x, shape):
|
|
"""Form a torch tensor with the given `shape` from a flattened numpy array `x`."""
|
|
return torch.from_numpy(x.reshape(shape)) |