import torch import torch.nn as nn import numpy as np from tqdm import tqdm from ldm.util import default from ldm.modules.diffusionmodules.util import extract_into_tensor from .ddpm import DDPM class LatentDiffusion(DDPM): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) # hardcoded self.clip_denoised = False def q_sample(self, x_start, t, noise=None): noise = default(noise, lambda: torch.randn_like(x_start)) return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise) "Does not support DDPM sampling anymore. Only do DDIM or PLMS" # = = = = = = = = = = = = Below is for sampling = = = = = = = = = = = = # # def predict_start_from_noise(self, x_t, t, noise): # return ( extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - # extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise ) # def q_posterior(self, x_start, x_t, t): # posterior_mean = ( # extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start + # extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t # ) # posterior_variance = extract_into_tensor(self.posterior_variance, t, x_t.shape) # posterior_log_variance_clipped = extract_into_tensor(self.posterior_log_variance_clipped, t, x_t.shape) # return posterior_mean, posterior_variance, posterior_log_variance_clipped # def p_mean_variance(self, model, x, c, t): # model_out = model(x, t, c) # x_recon = self.predict_start_from_noise(x, t=t, noise=model_out) # if self.clip_denoised: # x_recon.clamp_(-1., 1.) # model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t) # return model_mean, posterior_variance, posterior_log_variance, x_recon # @torch.no_grad() # def p_sample(self, model, x, c, t): # b, *_, device = *x.shape, x.device # model_mean, _, model_log_variance, x0 = self.p_mean_variance(model, x=x, c=c, t=t, ) # noise = torch.randn_like(x) # # no noise when t == 0 # nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1))) # return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, x0 # @torch.no_grad() # def p_sample_loop(self, model, shape, c): # device = self.betas.device # b = shape[0] # img = torch.randn(shape, device=device) # iterator = tqdm(reversed(range(0, self.num_timesteps)), desc='Sampling t', total=self.num_timesteps) # for i in iterator: # ts = torch.full((b,), i, device=device, dtype=torch.long) # img, x0 = self.p_sample(model, img, c, ts) # return img # @torch.no_grad() # def sample(self, model, shape, c, uc=None, guidance_scale=None): # return self.p_sample_loop(model, shape, c)