Spaces:
Running
on
Zero
Running
on
Zero
import math | |
import torch | |
from einops import rearrange | |
from tqdm import tqdm | |
from .utils import get_tensor_items | |
def get_named_beta_schedule(schedule_name, timesteps): | |
if schedule_name == "linear": | |
scale = 1000 / timesteps | |
beta_start = scale * 0.0001 | |
beta_end = scale * 0.02 | |
return torch.linspace( | |
beta_start, beta_end, timesteps, dtype=torch.float32 | |
) | |
elif schedule_name == "cosine": | |
alpha_bar = lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2 | |
betas = [] | |
for i in range(timesteps): | |
t1 = i / timesteps | |
t2 = (i + 1) / timesteps | |
betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), 0.999)) | |
return torch.tensor(betas, dtype=torch.float32) | |
class BaseDiffusion: | |
def __init__(self, betas, percentile=None, gen_noise=torch.randn_like): | |
self.betas = betas | |
self.num_timesteps = betas.shape[0] | |
alphas = 1. - betas | |
self.alphas_cumprod = torch.cumprod(alphas, dim=0) | |
self.alphas_cumprod_prev = torch.cat([torch.ones(1, dtype=betas.dtype), self.alphas_cumprod[:-1]]) | |
# calculate q(x_t | x_{t-1}) | |
self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod) | |
self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - self.alphas_cumprod) | |
# calculate q(x_{t-1} | x_t, x_0) | |
self.posterior_mean_coef_1 = torch.sqrt(self.alphas_cumprod_prev) * betas / (1. - self.alphas_cumprod) | |
self.posterior_mean_coef_2 = torch.sqrt(alphas) * (1. - self.alphas_cumprod_prev) / (1. - self.alphas_cumprod) | |
self.posterior_variance = betas * (1. - self.alphas_cumprod_prev) / (1. - self.alphas_cumprod) | |
self.posterior_log_variance = torch.log( | |
torch.cat([self.posterior_variance[1].unsqueeze(0), self.posterior_variance[1:]]) | |
) | |
self.percentile = percentile | |
self.time_scale = 1000 // self.num_timesteps | |
self.gen_noise = gen_noise | |
self.jump_length = 3 | |
def process_x_start(self, x_start): | |
bs, ndims = x_start.shape[0], len(x_start.shape[1:]) | |
if self.percentile is not None: | |
quantile = torch.quantile( | |
rearrange(x_start, 'b ... -> b (...)').abs(), | |
self.percentile, | |
dim=-1 | |
) | |
quantile = torch.clip(quantile, min=1.) | |
quantile = quantile.reshape(bs, *((1,) * ndims)) | |
return torch.clip(x_start, -quantile, quantile) / quantile | |
else: | |
return torch.clip(x_start, -1., 1.) | |
def get_x_start(self, x, t, noise): | |
sqrt_one_minus_alphas_cumprod = get_tensor_items(self.sqrt_one_minus_alphas_cumprod, t, noise.shape) | |
sqrt_alphas_cumprod = get_tensor_items(self.sqrt_alphas_cumprod, t, noise.shape) | |
pred_x_start = (x - sqrt_one_minus_alphas_cumprod * noise) / sqrt_alphas_cumprod | |
return pred_x_start | |
def get_noise(self, x, t, x_start): | |
sqrt_one_minus_alphas_cumprod = get_tensor_items(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) | |
sqrt_alphas_cumprod = get_tensor_items(self.sqrt_alphas_cumprod, t, x_start.shape) | |
pred_noise = (x - sqrt_alphas_cumprod * x_start) / sqrt_one_minus_alphas_cumprod | |
return pred_noise | |
def q_sample(self, x_start, t, noise=None): | |
if noise is None: | |
noise = self.gen_noise(x_start) | |
sqrt_alphas_cumprod = get_tensor_items(self.sqrt_alphas_cumprod, t, x_start.shape) | |
sqrt_one_minus_alphas_cumprod = get_tensor_items(self.sqrt_one_minus_alphas_cumprod, t, noise.shape) | |
x_t = sqrt_alphas_cumprod * x_start + sqrt_one_minus_alphas_cumprod * noise | |
return x_t | |
def q_posterior_mean_variance(self, x_start, x_t, t): | |
posterior_mean_coef_1 = get_tensor_items(self.posterior_mean_coef_1, t, x_start.shape) | |
posterior_mean_coef_2 = get_tensor_items(self.posterior_mean_coef_2, t, x_t.shape) | |
posterior_mean = posterior_mean_coef_1 * x_start + posterior_mean_coef_2 * x_t | |
posterior_variance = get_tensor_items(self.posterior_variance, t, x_start.shape) | |
posterior_log_variance = get_tensor_items(self.posterior_log_variance, t, x_start.shape) | |
return posterior_mean, posterior_variance, posterior_log_variance | |
def q_posterior_variance(self, t, prev_t, shape, eta=1., ): | |
alphas_cumprod = get_tensor_items(self.alphas_cumprod, t, shape) | |
prev_alphas_cumprod = get_tensor_items(self.alphas_cumprod, prev_t, shape) | |
posterior_variance = torch.sqrt( | |
eta * (1. - alphas_cumprod / prev_alphas_cumprod) * (1. - prev_alphas_cumprod) / (1. - alphas_cumprod) | |
) | |
return posterior_variance | |
def text_guidance( | |
self, model, x, t, context, context_mask, null_embedding, guidance_weight_text, | |
uncondition_context=None, uncondition_context_mask=None, mask=None, masked_latent=None | |
): | |
large_x = x.repeat(2, 1, 1, 1) | |
large_t = t.repeat(2).to(x.dtype) | |
if uncondition_context is None: | |
uncondition_context = torch.zeros_like(context) | |
uncondition_context_mask = torch.zeros_like(context_mask) | |
uncondition_context[:, 0] = null_embedding | |
uncondition_context_mask[:, 0] = 1 | |
large_context = torch.cat([context, uncondition_context]) | |
large_context_mask = torch.cat([context_mask, uncondition_context_mask]) | |
if mask is not None: | |
mask = mask.repeat(2, 1, 1, 1) | |
if masked_latent is not None: | |
masked_latent = masked_latent.repeat(2, 1, 1, 1) | |
if model.in_layer.in_channels == 9: | |
large_x = torch.cat([large_x, mask, masked_latent], dim=1) | |
pred_large_noise = model(large_x, large_t * self.time_scale, large_context, large_context_mask.bool()) | |
pred_noise, uncond_pred_noise = torch.chunk(pred_large_noise, 2) | |
pred_noise = (guidance_weight_text + 1.) * pred_noise - guidance_weight_text * uncond_pred_noise | |
return pred_noise | |
def p_mean_variance( | |
self, model, x, t, prev_t, context, context_mask, null_embedding, guidance_weight_text, eta=1., | |
negative_context=None, negative_context_mask=None, mask=None, masked_latent=None | |
): | |
pred_noise = self.text_guidance( | |
model, x, t, context, context_mask, null_embedding, guidance_weight_text, | |
negative_context, negative_context_mask, mask, masked_latent | |
) | |
pred_x_start = self.get_x_start(x, t, pred_noise) | |
pred_x_start = self.process_x_start(pred_x_start) | |
pred_noise = self.get_noise(x, t, pred_x_start) | |
pred_var = self.q_posterior_variance(t, prev_t, x.shape, eta) | |
prev_alphas_cumprod = get_tensor_items(self.alphas_cumprod, prev_t, x.shape) | |
pred_mean = torch.sqrt(prev_alphas_cumprod) * pred_x_start | |
pred_mean += torch.sqrt(1. - prev_alphas_cumprod - pred_var ** 2) * pred_noise | |
return pred_mean, pred_var | |
def p_sample( | |
self, model, x, t, prev_t, context, context_mask, null_embedding, guidance_weight_text, eta=1., | |
negative_context=None, negative_context_mask=None, mask=None, masked_latent=None | |
): | |
bs = x.shape[0] | |
ndims = len(x.shape[1:]) | |
pred_mean, pred_var = self.p_mean_variance( | |
model, x, t, prev_t, context, context_mask, null_embedding, guidance_weight_text, eta, | |
negative_context=negative_context, negative_context_mask=negative_context_mask, | |
mask=mask, masked_latent=masked_latent | |
) | |
noise = torch.randn_like(x) | |
mask = (prev_t != 0).reshape(bs, *((1,) * ndims)) | |
sample = pred_mean + mask * pred_var * noise | |
return sample | |
def p_sample_loop( | |
self, model, shape, times, device, context, context_mask, null_embedding, guidance_weight_text, eta=1., | |
negative_context=None, negative_context_mask=None, mask=None, masked_latent=None, gan=False, | |
): | |
img = torch.randn(*shape, device=device) | |
times = times + [0, ] | |
times = list(zip(times[:-1], times[1:])) | |
for time, prev_time in tqdm(times): | |
time = torch.tensor([time] * shape[0], device=device) | |
if gan: | |
x_t = self.q_sample(img, time) | |
pred_noise = model(x_t, time.type(x_t.dtype), context, context_mask.bool()) | |
img = self.get_x_start(x_t, time, pred_noise) | |
else: | |
prev_time = torch.tensor([prev_time] * shape[0], device=device) | |
img = self.p_sample( | |
model, img, time, prev_time, context, context_mask, null_embedding, guidance_weight_text, eta, | |
negative_context=negative_context, negative_context_mask=negative_context_mask, | |
mask=mask, masked_latent=masked_latent | |
) | |
return img | |
def get_diffusion(conf): | |
betas = get_named_beta_schedule(**conf.schedule_params) | |
base_diffusion = BaseDiffusion(betas, **conf.diffusion_params) | |
return base_diffusion | |