nesterus
moved contents of presentations repo
d90acf0
raw
history blame
9.15 kB
import math
import torch
from einops import rearrange
from tqdm import tqdm
from .utils import get_tensor_items
def get_named_beta_schedule(schedule_name, timesteps):
if schedule_name == "linear":
scale = 1000 / timesteps
beta_start = scale * 0.0001
beta_end = scale * 0.02
return torch.linspace(
beta_start, beta_end, timesteps, dtype=torch.float32
)
elif schedule_name == "cosine":
alpha_bar = lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
betas = []
for i in range(timesteps):
t1 = i / timesteps
t2 = (i + 1) / timesteps
betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), 0.999))
return torch.tensor(betas, dtype=torch.float32)
class BaseDiffusion:
def __init__(self, betas, percentile=None, gen_noise=torch.randn_like):
self.betas = betas
self.num_timesteps = betas.shape[0]
alphas = 1. - betas
self.alphas_cumprod = torch.cumprod(alphas, dim=0)
self.alphas_cumprod_prev = torch.cat([torch.ones(1, dtype=betas.dtype), self.alphas_cumprod[:-1]])
# calculate q(x_t | x_{t-1})
self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod)
self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - self.alphas_cumprod)
# calculate q(x_{t-1} | x_t, x_0)
self.posterior_mean_coef_1 = torch.sqrt(self.alphas_cumprod_prev) * betas / (1. - self.alphas_cumprod)
self.posterior_mean_coef_2 = torch.sqrt(alphas) * (1. - self.alphas_cumprod_prev) / (1. - self.alphas_cumprod)
self.posterior_variance = betas * (1. - self.alphas_cumprod_prev) / (1. - self.alphas_cumprod)
self.posterior_log_variance = torch.log(
torch.cat([self.posterior_variance[1].unsqueeze(0), self.posterior_variance[1:]])
)
self.percentile = percentile
self.time_scale = 1000 // self.num_timesteps
self.gen_noise = gen_noise
self.jump_length = 3
def process_x_start(self, x_start):
bs, ndims = x_start.shape[0], len(x_start.shape[1:])
if self.percentile is not None:
quantile = torch.quantile(
rearrange(x_start, 'b ... -> b (...)').abs(),
self.percentile,
dim=-1
)
quantile = torch.clip(quantile, min=1.)
quantile = quantile.reshape(bs, *((1,) * ndims))
return torch.clip(x_start, -quantile, quantile) / quantile
else:
return torch.clip(x_start, -1., 1.)
def get_x_start(self, x, t, noise):
sqrt_one_minus_alphas_cumprod = get_tensor_items(self.sqrt_one_minus_alphas_cumprod, t, noise.shape)
sqrt_alphas_cumprod = get_tensor_items(self.sqrt_alphas_cumprod, t, noise.shape)
pred_x_start = (x - sqrt_one_minus_alphas_cumprod * noise) / sqrt_alphas_cumprod
return pred_x_start
def get_noise(self, x, t, x_start):
sqrt_one_minus_alphas_cumprod = get_tensor_items(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape)
sqrt_alphas_cumprod = get_tensor_items(self.sqrt_alphas_cumprod, t, x_start.shape)
pred_noise = (x - sqrt_alphas_cumprod * x_start) / sqrt_one_minus_alphas_cumprod
return pred_noise
def q_sample(self, x_start, t, noise=None):
if noise is None:
noise = self.gen_noise(x_start)
sqrt_alphas_cumprod = get_tensor_items(self.sqrt_alphas_cumprod, t, x_start.shape)
sqrt_one_minus_alphas_cumprod = get_tensor_items(self.sqrt_one_minus_alphas_cumprod, t, noise.shape)
x_t = sqrt_alphas_cumprod * x_start + sqrt_one_minus_alphas_cumprod * noise
return x_t
def q_posterior_mean_variance(self, x_start, x_t, t):
posterior_mean_coef_1 = get_tensor_items(self.posterior_mean_coef_1, t, x_start.shape)
posterior_mean_coef_2 = get_tensor_items(self.posterior_mean_coef_2, t, x_t.shape)
posterior_mean = posterior_mean_coef_1 * x_start + posterior_mean_coef_2 * x_t
posterior_variance = get_tensor_items(self.posterior_variance, t, x_start.shape)
posterior_log_variance = get_tensor_items(self.posterior_log_variance, t, x_start.shape)
return posterior_mean, posterior_variance, posterior_log_variance
def q_posterior_variance(self, t, prev_t, shape, eta=1., ):
alphas_cumprod = get_tensor_items(self.alphas_cumprod, t, shape)
prev_alphas_cumprod = get_tensor_items(self.alphas_cumprod, prev_t, shape)
posterior_variance = torch.sqrt(
eta * (1. - alphas_cumprod / prev_alphas_cumprod) * (1. - prev_alphas_cumprod) / (1. - alphas_cumprod)
)
return posterior_variance
def text_guidance(
self, model, x, t, context, context_mask, null_embedding, guidance_weight_text,
uncondition_context=None, uncondition_context_mask=None, mask=None, masked_latent=None
):
large_x = x.repeat(2, 1, 1, 1)
large_t = t.repeat(2).to(x.dtype)
if uncondition_context is None:
uncondition_context = torch.zeros_like(context)
uncondition_context_mask = torch.zeros_like(context_mask)
uncondition_context[:, 0] = null_embedding
uncondition_context_mask[:, 0] = 1
large_context = torch.cat([context, uncondition_context])
large_context_mask = torch.cat([context_mask, uncondition_context_mask])
if mask is not None:
mask = mask.repeat(2, 1, 1, 1)
if masked_latent is not None:
masked_latent = masked_latent.repeat(2, 1, 1, 1)
if model.in_layer.in_channels == 9:
large_x = torch.cat([large_x, mask, masked_latent], dim=1)
pred_large_noise = model(large_x, large_t * self.time_scale, large_context, large_context_mask.bool())
pred_noise, uncond_pred_noise = torch.chunk(pred_large_noise, 2)
pred_noise = (guidance_weight_text + 1.) * pred_noise - guidance_weight_text * uncond_pred_noise
return pred_noise
def p_mean_variance(
self, model, x, t, prev_t, context, context_mask, null_embedding, guidance_weight_text, eta=1.,
negative_context=None, negative_context_mask=None, mask=None, masked_latent=None
):
pred_noise = self.text_guidance(
model, x, t, context, context_mask, null_embedding, guidance_weight_text,
negative_context, negative_context_mask, mask, masked_latent
)
pred_x_start = self.get_x_start(x, t, pred_noise)
pred_x_start = self.process_x_start(pred_x_start)
pred_noise = self.get_noise(x, t, pred_x_start)
pred_var = self.q_posterior_variance(t, prev_t, x.shape, eta)
prev_alphas_cumprod = get_tensor_items(self.alphas_cumprod, prev_t, x.shape)
pred_mean = torch.sqrt(prev_alphas_cumprod) * pred_x_start
pred_mean += torch.sqrt(1. - prev_alphas_cumprod - pred_var ** 2) * pred_noise
return pred_mean, pred_var
@torch.no_grad()
def p_sample(
self, model, x, t, prev_t, context, context_mask, null_embedding, guidance_weight_text, eta=1.,
negative_context=None, negative_context_mask=None, mask=None, masked_latent=None
):
bs = x.shape[0]
ndims = len(x.shape[1:])
pred_mean, pred_var = self.p_mean_variance(
model, x, t, prev_t, context, context_mask, null_embedding, guidance_weight_text, eta,
negative_context=negative_context, negative_context_mask=negative_context_mask,
mask=mask, masked_latent=masked_latent
)
noise = torch.randn_like(x)
mask = (prev_t != 0).reshape(bs, *((1,) * ndims))
sample = pred_mean + mask * pred_var * noise
return sample
@torch.no_grad()
def p_sample_loop(
self, model, shape, times, device, context, context_mask, null_embedding, guidance_weight_text, eta=1.,
negative_context=None, negative_context_mask=None, mask=None, masked_latent=None, gan=False,
):
img = torch.randn(*shape, device=device)
times = times + [0, ]
times = list(zip(times[:-1], times[1:]))
for time, prev_time in tqdm(times):
time = torch.tensor([time] * shape[0], device=device)
if gan:
x_t = self.q_sample(img, time)
pred_noise = model(x_t, time.type(x_t.dtype), context, context_mask.bool())
img = self.get_x_start(x_t, time, pred_noise)
else:
prev_time = torch.tensor([prev_time] * shape[0], device=device)
img = self.p_sample(
model, img, time, prev_time, context, context_mask, null_embedding, guidance_weight_text, eta,
negative_context=negative_context, negative_context_mask=negative_context_mask,
mask=mask, masked_latent=masked_latent
)
return img
def get_diffusion(conf):
betas = get_named_beta_schedule(**conf.schedule_params)
base_diffusion = BaseDiffusion(betas, **conf.diffusion_params)
return base_diffusion