|
import inspect |
|
import os |
|
import time |
|
from typing import Any, Callable, Dict, List, Optional, Union, Tuple |
|
|
|
import gc |
|
import torch |
|
import numpy as np |
|
from glob import glob |
|
|
|
import PIL |
|
|
|
from diffusers import StableDiffusionXLPipeline, UNet2DConditionModel |
|
from diffusers.loaders import TextualInversionLoaderMixin |
|
from diffusers.image_processor import VaeImageProcessor |
|
from diffusers.models import AutoencoderKL |
|
from diffusers.schedulers import (DPMSolverMultistepScheduler, |
|
EulerAncestralDiscreteScheduler, |
|
EulerDiscreteScheduler, |
|
KarrasDiffusionSchedulers) |
|
from diffusers.models.embeddings import TimestepEmbedding, Timesteps |
|
from diffusers.utils.torch_utils import randn_tensor |
|
from diffusers.utils import logging |
|
from PIL import Image |
|
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPTextModelWithProjection |
|
from diffusers.utils import PIL_INTERPOLATION |
|
from .lyrasd_vae_model import LyraSdVaeModel |
|
|
|
from .lora_util import add_text_lora_layer, add_xltext_lora_layer, add_lora_to_opt_model, load_state_dict |
|
from safetensors.torch import load_file |
|
from .lyrasdxl_pipeline_base import LyraSDXLPipelineBase |
|
|
|
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): |
|
""" |
|
Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and |
|
Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 |
|
""" |
|
std_text = noise_pred_text.std( |
|
dim=list(range(1, noise_pred_text.ndim)), keepdim=True) |
|
std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) |
|
|
|
noise_pred_rescaled = noise_cfg * (std_text / std_cfg) |
|
|
|
noise_cfg = guidance_rescale * noise_pred_rescaled + \ |
|
(1 - guidance_rescale) * noise_cfg |
|
return noise_cfg |
|
|
|
|
|
class LyraSdXLControlnetTxt2ImgPipeline(LyraSDXLPipelineBase, StableDiffusionXLPipeline): |
|
device = torch.device("cpu") |
|
dtype = torch.float32 |
|
|
|
def __init__(self, device=torch.device("cuda"), dtype=torch.float16, vae_scale_factor=8, vae_scaling_factor=0.13025) -> None: |
|
self.register_to_config(force_zeros_for_empty_prompt=True) |
|
|
|
super().__init__(device, dtype, vae_scale_factor=vae_scale_factor, vae_scaling_factor=vae_scaling_factor) |
|
|
|
|
|
def prepare_image( |
|
self, |
|
image, |
|
width, |
|
height, |
|
batch_size, |
|
num_images_per_prompt, |
|
device, |
|
dtype, |
|
do_classifier_free_guidance=False, |
|
guess_mode=False, |
|
): |
|
image = self.control_image_processor.preprocess(image, height, width) |
|
image = image.permute(0, 2, 3, 1) |
|
|
|
image = image.to(device=device, dtype=dtype) |
|
|
|
|
|
|
|
return image |
|
|
|
@property |
|
def _execution_device(self): |
|
if not hasattr(self.unet, "_hf_hook"): |
|
return self.device |
|
for module in self.unet.modules(): |
|
if ( |
|
hasattr(module, "_hf_hook") |
|
and hasattr(module._hf_hook, "execution_device") |
|
and module._hf_hook.execution_device is not None |
|
): |
|
return torch.device(module._hf_hook.execution_device) |
|
return self.device |
|
|
|
def _get_aug_emb(self, add_embedding, time_ids, text_embeds, dtype): |
|
time_embeds = self.add_time_proj(time_ids.flatten()) |
|
time_embeds = time_embeds.reshape((text_embeds.shape[0], -1)) |
|
add_embeds = torch.concat([text_embeds, time_embeds], dim=-1) |
|
add_embeds = add_embeds.to(dtype) |
|
aug_emb = add_embedding(add_embeds) |
|
return aug_emb |
|
|
|
@torch.no_grad() |
|
def __call__( |
|
self, |
|
prompt: Union[str, List[str]] = None, |
|
prompt_2: Optional[Union[str, List[str]]] = None, |
|
height: Optional[int] = None, |
|
width: Optional[int] = None, |
|
num_inference_steps: int = 50, |
|
denoising_end: Optional[float] = None, |
|
guidance_scale: float = 5.0, |
|
negative_prompt: Optional[Union[str, List[str]]] = None, |
|
negative_prompt_2: Optional[Union[str, List[str]]] = None, |
|
num_images_per_prompt: Optional[int] = 1, |
|
controlnet_names: Optional[List[str]] = None, |
|
controlnet_images: Optional[List[PIL.Image.Image]] = None, |
|
controlnet_scale: Optional[List[float]] = None, |
|
guess_mode=False, |
|
eta: float = 0.0, |
|
generator: Optional[Union[torch.Generator, |
|
List[torch.Generator]]] = None, |
|
latents: Optional[torch.FloatTensor] = None, |
|
prompt_embeds: Optional[torch.FloatTensor] = None, |
|
negative_prompt_embeds: Optional[torch.FloatTensor] = None, |
|
pooled_prompt_embeds: Optional[torch.FloatTensor] = None, |
|
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, |
|
output_type: Optional[str] = "pil", |
|
return_dict: bool = True, |
|
callback: Optional[Callable[[ |
|
int, int, torch.FloatTensor], None]] = None, |
|
callback_steps: int = 1, |
|
cross_attention_kwargs: Optional[Dict[str, Any]] = None, |
|
guidance_rescale: float = 0.0, |
|
original_size: Optional[Tuple[int, int]] = None, |
|
crops_coords_top_left: Tuple[int, int] = (0, 0), |
|
target_size: Optional[Tuple[int, int]] = None, |
|
): |
|
|
|
|
|
height = height or self.default_sample_size * self.vae_scale_factor |
|
width = width or self.default_sample_size * self.vae_scale_factor |
|
|
|
original_size = original_size or (height, width) |
|
target_size = target_size or (height, width) |
|
|
|
|
|
self.check_inputs( |
|
prompt, |
|
prompt_2, |
|
height, |
|
width, |
|
callback_steps, |
|
negative_prompt, |
|
negative_prompt_2, |
|
prompt_embeds, |
|
negative_prompt_embeds, |
|
pooled_prompt_embeds, |
|
negative_pooled_prompt_embeds, |
|
) |
|
|
|
|
|
if prompt is not None and isinstance(prompt, str): |
|
batch_size = 1 |
|
elif prompt is not None and isinstance(prompt, list): |
|
batch_size = len(prompt) |
|
else: |
|
batch_size = prompt_embeds.shape[0] |
|
|
|
device = self._execution_device |
|
|
|
do_classifier_free_guidance = guidance_scale > 1.0 |
|
|
|
|
|
text_encoder_lora_scale = ( |
|
cross_attention_kwargs.get( |
|
"scale", None) if cross_attention_kwargs is not None else None |
|
) |
|
( |
|
prompt_embeds, |
|
negative_prompt_embeds, |
|
pooled_prompt_embeds, |
|
negative_pooled_prompt_embeds, |
|
) = self.encode_prompt( |
|
prompt=prompt, |
|
prompt_2=prompt_2, |
|
device=device, |
|
num_images_per_prompt=num_images_per_prompt, |
|
do_classifier_free_guidance=do_classifier_free_guidance, |
|
negative_prompt=negative_prompt, |
|
negative_prompt_2=negative_prompt_2, |
|
prompt_embeds=prompt_embeds, |
|
negative_prompt_embeds=negative_prompt_embeds, |
|
pooled_prompt_embeds=pooled_prompt_embeds, |
|
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, |
|
lora_scale=text_encoder_lora_scale, |
|
) |
|
|
|
control_images = [] |
|
|
|
for image_ in controlnet_images: |
|
image_ = self.prepare_image( |
|
image=image_, |
|
width=width, |
|
height=height, |
|
batch_size=batch_size * num_images_per_prompt, |
|
num_images_per_prompt=num_images_per_prompt, |
|
device=device, |
|
dtype=prompt_embeds.dtype, |
|
do_classifier_free_guidance=do_classifier_free_guidance |
|
) |
|
|
|
control_images.append(image_) |
|
|
|
control_scales = [] |
|
|
|
scales = [1.0, ] * 10 |
|
if guess_mode: |
|
scales = torch.logspace(-1, 0, 10).tolist() |
|
|
|
for scale in controlnet_scale: |
|
scales_ = [d * scale for d in scales] |
|
control_scales.append(scales_) |
|
|
|
|
|
self.scheduler.set_timesteps(num_inference_steps, device=device) |
|
|
|
timesteps = self.scheduler.timesteps |
|
|
|
|
|
num_channels_latents = self.unet_in_channels |
|
latents = self.prepare_latents( |
|
batch_size * num_images_per_prompt, |
|
num_channels_latents, |
|
height, |
|
width, |
|
prompt_embeds.dtype, |
|
device, |
|
generator, |
|
latents, |
|
) |
|
|
|
|
|
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) |
|
|
|
|
|
add_text_embeds = pooled_prompt_embeds |
|
add_time_ids = list( |
|
original_size + crops_coords_top_left + target_size) |
|
add_time_ids = torch.tensor([add_time_ids], dtype=prompt_embeds.dtype) |
|
|
|
if do_classifier_free_guidance: |
|
prompt_embeds = torch.cat( |
|
[negative_prompt_embeds, prompt_embeds], dim=0) |
|
add_text_embeds = torch.cat( |
|
[negative_pooled_prompt_embeds, add_text_embeds], dim=0) |
|
add_time_ids = torch.cat([add_time_ids, add_time_ids], dim=0) |
|
|
|
prompt_embeds = prompt_embeds.to(device) |
|
add_text_embeds = add_text_embeds.to(device) |
|
add_time_ids = add_time_ids.to(device).repeat( |
|
batch_size * num_images_per_prompt, 1) |
|
|
|
|
|
num_warmup_steps = max( |
|
len(timesteps) - num_inference_steps * self.scheduler.order, 0) |
|
|
|
|
|
if denoising_end is not None and type(denoising_end) == float and denoising_end > 0 and denoising_end < 1: |
|
discrete_timestep_cutoff = int( |
|
round( |
|
self.scheduler.config.num_train_timesteps |
|
- (denoising_end * self.scheduler.config.num_train_timesteps) |
|
) |
|
) |
|
num_inference_steps = len( |
|
list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps))) |
|
timesteps = timesteps[:num_inference_steps] |
|
|
|
aug_emb = self._get_aug_emb( |
|
self.add_embedding, add_time_ids, add_text_embeds, prompt_embeds.dtype) |
|
|
|
controlnet_aug_embs = [] |
|
for controlnet_name in controlnet_names: |
|
controlnet_aug_embs.append(self._get_aug_emb(self.controlnet_add_embedding[controlnet_name], |
|
add_time_ids, add_text_embeds, prompt_embeds.dtype)) |
|
|
|
with self.progress_bar(total=num_inference_steps) as progress_bar: |
|
for i, t in enumerate(timesteps): |
|
|
|
latent_model_input = torch.cat( |
|
[latents] * 2) if do_classifier_free_guidance else latents |
|
|
|
latent_model_input = self.scheduler.scale_model_input( |
|
latent_model_input, t) |
|
latent_model_input = latent_model_input.permute( |
|
0, 2, 3, 1).contiguous() |
|
|
|
noise_pred = self.unet.forward( |
|
latent_model_input, prompt_embeds, t, aug_emb, |
|
controlnet_names, control_images, controlnet_aug_embs, control_scales, guess_mode).permute(0, 3, 1, 2) |
|
|
|
|
|
|
|
|
|
if do_classifier_free_guidance: |
|
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) |
|
noise_pred = noise_pred_uncond + guidance_scale * \ |
|
(noise_pred_text - noise_pred_uncond) |
|
|
|
if do_classifier_free_guidance and guidance_rescale > 0.0: |
|
|
|
noise_pred = rescale_noise_cfg( |
|
noise_pred, noise_pred_text, guidance_rescale=guidance_rescale) |
|
|
|
|
|
latents = self.scheduler.step( |
|
noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] |
|
|
|
|
|
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): |
|
progress_bar.update() |
|
if callback is not None and i % callback_steps == 0: |
|
callback(i, t, latents) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
image = self.vae.decode(1 / self.vae.scaling_factor * latents) |
|
|
|
image = self.image_processor.postprocess( |
|
image, output_type=output_type) |
|
|
|
|
|
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: |
|
self.final_offload_hook.offload() |
|
|
|
return image |
|
|