File size: 2,128 Bytes
f0e6b7a f280910 f0e6b7a f280910 f0e6b7a f280910 f0e6b7a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 |
import torch
default_num_train_timesteps = 1000
@torch.no_grad()
def make_sigmas(beta_start=0.00085, beta_end=0.012, num_train_timesteps=default_num_train_timesteps, device=None):
betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32, device=device) ** 2
alphas = 1.0 - betas
alphas_cumprod = torch.cumprod(alphas, dim=0)
# TODO - would be nice to use a direct expression for this
sigmas = ((1 - alphas_cumprod) / alphas_cumprod) ** 0.5
return sigmas
@torch.no_grad()
def rk_ode_solver_diffusion_loop(eps_theta, timesteps, sigmas, x_T, rk_steps_weights):
x_t = x_T
for i in range(len(timesteps) - 1, -1, -1):
t = timesteps[i].unsqueeze(0)
sigma = sigmas[t]
if i == 0:
eps_hat = eps_theta(x_t=x_t, t=t, sigma=sigma)
x_0_hat = x_t - sigma * eps_hat
else:
dt = sigmas[timesteps[i - 1]] - sigma
dx_by_dt = torch.zeros_like(x_t)
dx_by_dt_cur = torch.zeros_like(x_t)
for rk_step, rk_weight in rk_steps_weights:
dt_ = dt * rk_step
t_ = t + dt_
x_t_ = x_t + dx_by_dt_cur * dt_
eps_hat = eps_theta(x_t=x_t_, t=t_, sigma=sigma)
# TODO - note which specific ode this is the solution to and
# how input scaling does/doesn't effect the solution
# dx_by_dt_cur = (x_t_ - sigma * eps_hat) / sigma
dx_by_dt_cur = eps_hat
dx_by_dt += dx_by_dt_cur * rk_weight
x_t_minus_1 = x_t + dx_by_dt * dt
x_t = x_t_minus_1
return x_0_hat
euler_ode_solver_diffusion_loop = lambda *args, **kwargs: rk_ode_solver_diffusion_loop(*args, **kwargs, rk_steps_weights=[[0, 1]])
heun_ode_solver_diffusion_loop = lambda *args, **kwargs: rk_ode_solver_diffusion_loop(*args, **kwargs, rk_steps_weights=[[0, 0.5], [1, 0.5]])
rk4_ode_solver_diffusion_loop = lambda *args, **kwargs: rk_ode_solver_diffusion_loop(*args, **kwargs, rk_steps_weights=[[0, 1 / 6], [1 / 2, 1 / 3], [1 / 2, 1 / 3], [1, 1 / 6]])
|