|
import torch |
|
import numpy as np |
|
|
|
def edm_sampler( |
|
net, latents, randn_like=torch.randn_like, |
|
num_steps=18, sigma_min=0.002, sigma_max=80, rho=7, |
|
|
|
S_churn=0, S_min=0, S_max=float('inf'), S_noise=1, ret_all=False |
|
): |
|
|
|
sigma_min = max(sigma_min, net.sigma_min) |
|
sigma_max = min(sigma_max, net.sigma_max) |
|
|
|
|
|
step_indices = torch.arange(num_steps, dtype=torch.float64, device=latents.device) |
|
t_steps = (sigma_max ** (1 / rho) + step_indices / (num_steps - 1) * (sigma_min ** (1 / rho) - sigma_max ** (1 / rho))) ** rho |
|
t_steps = torch.cat([net.round_sigma(t_steps), torch.zeros_like(t_steps[:1])]) |
|
|
|
|
|
x_next = latents.to(torch.float64) * t_steps[0] |
|
all_x=[] |
|
for i, (t_cur, t_next) in enumerate(zip(t_steps[:-1], t_steps[1:])): |
|
x_cur = x_next |
|
|
|
|
|
gamma = min(S_churn / num_steps, np.sqrt(2) - 1) if S_min <= t_cur <= S_max else 0 |
|
t_hat = net.round_sigma(t_cur + gamma * t_cur) |
|
x_hat = x_cur + (t_hat ** 2 - t_cur ** 2).sqrt() * S_noise * randn_like(x_cur) |
|
|
|
|
|
denoised = net(x_hat, t_hat).to(torch.float64) |
|
d_cur = (x_hat - denoised) / t_hat |
|
x_next = x_hat + (t_next - t_hat) * d_cur |
|
|
|
|
|
if i < num_steps - 1: |
|
denoised = net(x_next, t_next).to(torch.float64) |
|
d_prime = (x_next - denoised) / t_next |
|
x_next = x_hat + (t_next - t_hat) * (0.5 * d_cur + 0.5 * d_prime) |
|
all_x.append(x_next.clone()/(t_next**2+1).sqrt()) |
|
|
|
if ret_all: |
|
return x_next,all_x |
|
|
|
return x_next |
|
|
|
def edm_sampler_cond( |
|
net, latents,cond_points, randn_like=torch.randn_like, |
|
num_steps=18, sigma_min=0.002, sigma_max=80, rho=7, |
|
|
|
S_churn=0, S_min=0, S_max=float('inf'), S_noise=1, ret_all=False |
|
): |
|
|
|
sigma_min = max(sigma_min, net.sigma_min) |
|
sigma_max = min(sigma_max, net.sigma_max) |
|
|
|
|
|
step_indices = torch.arange(num_steps, dtype=torch.float64, device=latents.device) |
|
t_steps = (sigma_max ** (1 / rho) + step_indices / (num_steps - 1) * (sigma_min ** (1 / rho) - sigma_max ** (1 / rho))) ** rho |
|
t_steps = torch.cat([net.round_sigma(t_steps), torch.zeros_like(t_steps[:1])]) |
|
|
|
|
|
x_next = latents.to(torch.float64) * t_steps[0] |
|
all_x=[] |
|
for i, (t_cur, t_next) in enumerate(zip(t_steps[:-1], t_steps[1:])): |
|
x_cur = x_next |
|
|
|
|
|
gamma = min(S_churn / num_steps, np.sqrt(2) - 1) if S_min <= t_cur <= S_max else 0 |
|
t_hat = net.round_sigma(t_cur + gamma * t_cur) |
|
x_hat = x_cur + (t_hat ** 2 - t_cur ** 2).sqrt() * S_noise * randn_like(x_cur) |
|
|
|
|
|
denoised = net(x_hat, t_hat,cond_points).to(torch.float64) |
|
d_cur = (x_hat - denoised) / t_hat |
|
x_next = x_hat + (t_next - t_hat) * d_cur |
|
|
|
|
|
if i < num_steps - 1: |
|
denoised = net(x_next, t_next,cond_points).to(torch.float64) |
|
d_prime = (x_next - denoised) / t_next |
|
x_next = x_hat + (t_next - t_hat) * (0.5 * d_cur + 0.5 * d_prime) |
|
all_x.append(x_next.clone()/(t_next**2+1).sqrt()) |
|
|
|
if ret_all: |
|
return x_next,all_x |
|
|
|
return x_next |
|
|
|
|