import torch def ddpm_sampler( net, batch, conditioning_keys=None, scheduler=None, uncond_tokens=None, num_steps=1000, cfg_rate=0, generator=None, use_confidence_sampling=False, use_uncond_token=True, confidence_value=1.0, unconfidence_value=0.0, ): if scheduler is None: raise ValueError("Scheduler must be provided") x_cur = batch["y"].to(torch.float32) latents = batch["previous_latents"] if use_confidence_sampling: batch["confidence"] = ( torch.ones(x_cur.shape[0], device=x_cur.device) * confidence_value ) step_indices = torch.arange(num_steps + 1, dtype=torch.float32, device=x_cur.device) steps = 1 - step_indices / num_steps gammas = scheduler(steps) latents_cond = latents_uncond = latents # dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 dtype = torch.float32 if cfg_rate > 0 and conditioning_keys is not None: stacked_batch = {} for key in conditioning_keys: if f"{key}_mask" in batch: if use_confidence_sampling and not use_uncond_token: stacked_batch[f"{key}_mask"] = torch.cat( [batch[f"{key}_mask"], batch[f"{key}_mask"]], dim=0 ) else: if ( batch[f"{key}_mask"].shape[1] > uncond_tokens[f"{key}_mask"].shape[1] ): uncond_mask = ( torch.zeros_like(batch[f"{key}_mask"]) if batch[f"{key}_mask"].dtype == torch.bool else torch.ones_like(batch[f"{key}_mask"]) * -torch.inf ) uncond_mask[:, : uncond_tokens[f"{key}_mask"].shape[1]] = ( uncond_tokens[f"{key}_mask"] ) else: uncond_mask = uncond_tokens[f"{key}_mask"] batch[f"{key}_mask"] = torch.cat( [ batch[f"{key}_mask"], torch.zeros( batch[f"{key}_mask"].shape[0], uncond_tokens[f"{key}_embeddings"].shape[1] - batch[f"{key}_mask"].shape[1], device=batch[f"{key}_mask"].device, dtype=batch[f"{key}_mask"].dtype, ), ], dim=1, ) stacked_batch[f"{key}_mask"] = torch.cat( [batch[f"{key}_mask"], uncond_mask], dim=0 ) if f"{key}_embeddings" in batch: if use_confidence_sampling and not use_uncond_token: stacked_batch[f"{key}_embeddings"] = torch.cat( [ batch[f"{key}_embeddings"], batch[f"{key}_embeddings"], ], dim=0, ) else: if ( batch[f"{key}_embeddings"].shape[1] > uncond_tokens[f"{key}_embeddings"].shape[1] ): uncond_tokens[f"{key}_embeddings"] = torch.cat( [ uncond_tokens[f"{key}_embeddings"], torch.zeros( uncond_tokens[f"{key}_embeddings"].shape[0], batch[f"{key}_embeddings"].shape[1] - uncond_tokens[f"{key}_embeddings"].shape[1], uncond_tokens[f"{key}_embeddings"].shape[2], device=uncond_tokens[f"{key}_embeddings"].device, ), ], dim=1, ) elif ( batch[f"{key}_embeddings"].shape[1] < uncond_tokens[f"{key}_embeddings"].shape[1] ): batch[f"{key}_embeddings"] = torch.cat( [ batch[f"{key}_embeddings"], torch.zeros( batch[f"{key}_embeddings"].shape[0], uncond_tokens[f"{key}_embeddings"].shape[1] - batch[f"{key}_embeddings"].shape[1], batch[f"{key}_embeddings"].shape[2], device=batch[f"{key}_embeddings"].device, ), ], dim=1, ) stacked_batch[f"{key}_embeddings"] = torch.cat( [ batch[f"{key}_embeddings"], uncond_tokens[f"{key}_embeddings"], ], dim=0, ) elif key not in batch: raise ValueError(f"Key {key} not in batch") else: if isinstance(batch[key], torch.Tensor): if use_confidence_sampling and not use_uncond_token: stacked_batch[key] = torch.cat([batch[key], batch[key]], dim=0) else: stacked_batch[key] = torch.cat( [batch[key], uncond_tokens], dim=0 ) elif isinstance(batch[key], list): if use_confidence_sampling and not use_uncond_token: stacked_batch[key] = [*batch[key], *batch[key]] else: stacked_batch[key] = [*batch[key], *uncond_tokens] else: raise ValueError( "Conditioning must be a tensor or a list of tensors" ) if use_confidence_sampling: stacked_batch["confidence"] = torch.cat( [ torch.ones(x_cur.shape[0], device=x_cur.device) * confidence_value, torch.ones(x_cur.shape[0], device=x_cur.device) * unconfidence_value, ], dim=0, ) for step, (gamma_now, gamma_next) in enumerate(zip(gammas[:-1], gammas[1:])): with torch.cuda.amp.autocast(dtype=dtype): if cfg_rate > 0 and conditioning_keys is not None: stacked_batch["y"] = torch.cat([x_cur, x_cur], dim=0) stacked_batch["gamma"] = gamma_now.expand(x_cur.shape[0] * 2) stacked_batch["previous_latents"] = ( torch.cat([latents_cond, latents_uncond], dim=0) if latents is not None else None ) denoised_all, latents_all = net(stacked_batch) denoised_cond, denoised_uncond = denoised_all.chunk(2, dim=0) latents_cond, latents_uncond = latents_all.chunk(2, dim=0) denoised = denoised_cond * (1 + cfg_rate) - denoised_uncond * cfg_rate else: batch["y"] = x_cur batch["gamma"] = gamma_now.expand(x_cur.shape[0]) batch["previous_latents"] = latents denoised, latents = net( batch, ) x_pred = (x_cur - torch.sqrt(1 - gamma_now) * denoised) / torch.sqrt(gamma_now) x_pred = torch.clamp(x_pred, -1, 1) noise_pred = (x_cur - torch.sqrt(gamma_now) * x_pred) / torch.sqrt( 1 - gamma_now ) log_alpha_t = torch.log(gamma_now) - torch.log(gamma_next) alpha_t = torch.clip(torch.exp(log_alpha_t), 0, 1) x_mean = torch.rsqrt(alpha_t) * ( x_cur - torch.rsqrt(1 - gamma_now) * (1 - alpha_t) * noise_pred ) var_t = 1 - alpha_t eps = torch.randn(x_cur.shape, device=x_cur.device, generator=generator) x_next = x_mean + torch.sqrt(var_t) * eps x_cur = x_next return x_cur.to(torch.float32)