|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import inspect |
|
from collections.abc import Callable |
|
from typing import Any, Dict, List, Optional, Tuple, Union |
|
|
|
import numpy as np |
|
import PIL |
|
import torch |
|
import torch.nn.functional as F |
|
from transformers import ( |
|
CLIPTextModel, |
|
CLIPTextModelWithProjection, |
|
CLIPTokenizer, |
|
) |
|
|
|
from diffusers import DiffusionPipeline |
|
from diffusers.image_processor import PipelineImageInput, VaeImageProcessor |
|
from diffusers.loaders import ( |
|
FromSingleFileMixin, |
|
LoraLoaderMixin, |
|
StableDiffusionXLLoraLoaderMixin, |
|
TextualInversionLoaderMixin, |
|
) |
|
from diffusers.models import ( |
|
AutoencoderKL, |
|
ControlNetModel, |
|
MultiAdapter, |
|
T2IAdapter, |
|
UNet2DConditionModel, |
|
) |
|
from diffusers.models.attention_processor import ( |
|
AttnProcessor2_0, |
|
LoRAAttnProcessor2_0, |
|
LoRAXFormersAttnProcessor, |
|
XFormersAttnProcessor, |
|
) |
|
from diffusers.models.lora import adjust_lora_scale_text_encoder |
|
from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel |
|
from diffusers.pipelines.pipeline_utils import StableDiffusionMixin |
|
from diffusers.pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput |
|
from diffusers.schedulers import KarrasDiffusionSchedulers |
|
from diffusers.utils import ( |
|
PIL_INTERPOLATION, |
|
USE_PEFT_BACKEND, |
|
logging, |
|
replace_example_docstring, |
|
scale_lora_layers, |
|
unscale_lora_layers, |
|
) |
|
from diffusers.utils.torch_utils import is_compiled_module, randn_tensor |
|
|
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
EXAMPLE_DOC_STRING = """ |
|
Examples: |
|
```py |
|
>>> import torch |
|
>>> from diffusers import DiffusionPipeline, T2IAdapter |
|
>>> from diffusers.utils import load_image |
|
>>> from PIL import Image |
|
>>> from controlnet_aux.midas import MidasDetector |
|
|
|
>>> adapter = T2IAdapter.from_pretrained( |
|
... "TencentARC/t2i-adapter-sketch-sdxl-1.0", torch_dtype=torch.float16, variant="fp16" |
|
... ).to("cuda") |
|
|
|
>>> controlnet = ControlNetModel.from_pretrained( |
|
... "diffusers/controlnet-depth-sdxl-1.0", |
|
... torch_dtype=torch.float16, |
|
... variant="fp16", |
|
... use_safetensors=True |
|
... ).to("cuda") |
|
|
|
>>> pipe = DiffusionPipeline.from_pretrained( |
|
... "diffusers/stable-diffusion-xl-1.0-inpainting-0.1", |
|
... torch_dtype=torch.float16, |
|
... variant="fp16", |
|
... use_safetensors=True, |
|
... custom_pipeline="stable_diffusion_xl_adapter_controlnet_inpaint", |
|
... adapter=adapter, |
|
... controlnet=controlnet, |
|
... ).to("cuda") |
|
|
|
>>> prompt = "a tiger sitting on a park bench" |
|
>>> img_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png" |
|
>>> mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png" |
|
|
|
>>> image = load_image(img_url).resize((1024, 1024)) |
|
>>> mask_image = load_image(mask_url).resize((1024, 1024)) |
|
|
|
>>> midas_depth = MidasDetector.from_pretrained( |
|
... "valhalla/t2iadapter-aux-models", filename="dpt_large_384.pt", model_type="dpt_large" |
|
... ).to("cuda") |
|
|
|
>>> depth_image = midas_depth( |
|
... image, detect_resolution=512, image_resolution=1024 |
|
... ) |
|
|
|
>>> strength = 0.4 |
|
|
|
>>> generator = torch.manual_seed(42) |
|
|
|
>>> result_image = pipe( |
|
... image=image, |
|
... mask_image=mask, |
|
... adapter_image=depth_image, |
|
... control_image=depth_image, |
|
... controlnet_conditioning_scale=strength, |
|
... adapter_conditioning_scale=strength, |
|
... strength=0.7, |
|
... generator=generator, |
|
... prompt=prompt, |
|
... negative_prompt="extra digit, fewer digits, cropped, worst quality, low quality", |
|
... num_inference_steps=50 |
|
... ).images[0] |
|
``` |
|
""" |
|
|
|
|
|
def _preprocess_adapter_image(image, height, width): |
|
if isinstance(image, torch.Tensor): |
|
return image |
|
elif isinstance(image, PIL.Image.Image): |
|
image = [image] |
|
|
|
if isinstance(image[0], PIL.Image.Image): |
|
image = [np.array(i.resize((width, height), resample=PIL_INTERPOLATION["lanczos"])) for i in image] |
|
image = [ |
|
i[None, ..., None] if i.ndim == 2 else i[None, ...] for i in image |
|
] |
|
image = np.concatenate(image, axis=0) |
|
image = np.array(image).astype(np.float32) / 255.0 |
|
image = image.transpose(0, 3, 1, 2) |
|
image = torch.from_numpy(image) |
|
elif isinstance(image[0], torch.Tensor): |
|
if image[0].ndim == 3: |
|
image = torch.stack(image, dim=0) |
|
elif image[0].ndim == 4: |
|
image = torch.cat(image, dim=0) |
|
else: |
|
raise ValueError( |
|
f"Invalid image tensor! Expecting image tensor with 3 or 4 dimension, but recive: {image[0].ndim}" |
|
) |
|
return image |
|
|
|
|
|
def mask_pil_to_torch(mask, height, width): |
|
|
|
if isinstance(mask, Union[PIL.Image.Image, np.ndarray]): |
|
mask = [mask] |
|
|
|
if isinstance(mask, list) and isinstance(mask[0], PIL.Image.Image): |
|
mask = [i.resize((width, height), resample=PIL.Image.LANCZOS) for i in mask] |
|
mask = np.concatenate([np.array(m.convert("L"))[None, None, :] for m in mask], axis=0) |
|
mask = mask.astype(np.float32) / 255.0 |
|
elif isinstance(mask, list) and isinstance(mask[0], np.ndarray): |
|
mask = np.concatenate([m[None, None, :] for m in mask], axis=0) |
|
|
|
mask = torch.from_numpy(mask) |
|
return mask |
|
|
|
|
|
def prepare_mask_and_masked_image(image, mask, height, width, return_image: bool = False): |
|
""" |
|
Prepares a pair (image, mask) to be consumed by the Stable Diffusion pipeline. This means that those inputs will be |
|
converted to ``torch.Tensor`` with shapes ``batch x channels x height x width`` where ``channels`` is ``3`` for the |
|
``image`` and ``1`` for the ``mask``. |
|
|
|
The ``image`` will be converted to ``torch.float32`` and normalized to be in ``[-1, 1]``. The ``mask`` will be |
|
binarized (``mask > 0.5``) and cast to ``torch.float32`` too. |
|
|
|
Args: |
|
image (Union[np.array, PIL.Image, torch.Tensor]): The image to inpaint. |
|
It can be a ``PIL.Image``, or a ``height x width x 3`` ``np.array`` or a ``channels x height x width`` |
|
``torch.Tensor`` or a ``batch x channels x height x width`` ``torch.Tensor``. |
|
mask (_type_): The mask to apply to the image, i.e. regions to inpaint. |
|
It can be a ``PIL.Image``, or a ``height x width`` ``np.array`` or a ``1 x height x width`` |
|
``torch.Tensor`` or a ``batch x 1 x height x width`` ``torch.Tensor``. |
|
|
|
|
|
Raises: |
|
ValueError: ``torch.Tensor`` images should be in the ``[-1, 1]`` range. ValueError: ``torch.Tensor`` mask |
|
should be in the ``[0, 1]`` range. ValueError: ``mask`` and ``image`` should have the same spatial dimensions. |
|
TypeError: ``mask`` is a ``torch.Tensor`` but ``image`` is not |
|
(ot the other way around). |
|
|
|
Returns: |
|
tuple[torch.Tensor]: The pair (mask, masked_image) as ``torch.Tensor`` with 4 |
|
dimensions: ``batch x channels x height x width``. |
|
""" |
|
|
|
|
|
if image is None: |
|
raise ValueError("`image` input cannot be undefined.") |
|
|
|
if mask is None: |
|
raise ValueError("`mask_image` input cannot be undefined.") |
|
|
|
if isinstance(image, torch.Tensor): |
|
if not isinstance(mask, torch.Tensor): |
|
mask = mask_pil_to_torch(mask, height, width) |
|
|
|
if image.ndim == 3: |
|
image = image.unsqueeze(0) |
|
|
|
|
|
if mask.ndim == 2: |
|
mask = mask.unsqueeze(0).unsqueeze(0) |
|
|
|
|
|
if mask.ndim == 3: |
|
|
|
if mask.shape[0] == 1: |
|
mask = mask.unsqueeze(0) |
|
|
|
|
|
else: |
|
mask = mask.unsqueeze(1) |
|
|
|
assert image.ndim == 4 and mask.ndim == 4, "Image and Mask must have 4 dimensions" |
|
|
|
assert image.shape[0] == mask.shape[0], "Image and Mask must have the same batch size" |
|
|
|
|
|
|
|
|
|
|
|
|
|
if mask.min() < 0 or mask.max() > 1: |
|
raise ValueError("Mask should be in [0, 1] range") |
|
|
|
|
|
mask[mask < 0.5] = 0 |
|
mask[mask >= 0.5] = 1 |
|
|
|
|
|
image = image.to(dtype=torch.float32) |
|
elif isinstance(mask, torch.Tensor): |
|
raise TypeError(f"`mask` is a torch.Tensor but `image` (type: {type(image)} is not") |
|
else: |
|
|
|
if isinstance(image, Union[PIL.Image.Image, np.ndarray]): |
|
image = [image] |
|
if isinstance(image, list) and isinstance(image[0], PIL.Image.Image): |
|
|
|
image = [i.resize((width, height), resample=PIL.Image.LANCZOS) for i in image] |
|
image = [np.array(i.convert("RGB"))[None, :] for i in image] |
|
image = np.concatenate(image, axis=0) |
|
elif isinstance(image, list) and isinstance(image[0], np.ndarray): |
|
image = np.concatenate([i[None, :] for i in image], axis=0) |
|
|
|
image = image.transpose(0, 3, 1, 2) |
|
image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0 |
|
|
|
mask = mask_pil_to_torch(mask, height, width) |
|
mask[mask < 0.5] = 0 |
|
mask[mask >= 0.5] = 1 |
|
|
|
if image.shape[1] == 4: |
|
|
|
|
|
|
|
|
|
masked_image = None |
|
else: |
|
masked_image = image * (mask < 0.5) |
|
|
|
|
|
if return_image: |
|
return mask, masked_image, image |
|
|
|
return mask, masked_image |
|
|
|
|
|
|
|
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): |
|
""" |
|
Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and |
|
Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 |
|
""" |
|
std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) |
|
std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) |
|
|
|
noise_pred_rescaled = noise_cfg * (std_text / std_cfg) |
|
|
|
noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg |
|
return noise_cfg |
|
|
|
|
|
class StableDiffusionXLControlNetAdapterInpaintPipeline( |
|
DiffusionPipeline, StableDiffusionMixin, FromSingleFileMixin, LoraLoaderMixin |
|
): |
|
r""" |
|
Pipeline for text-to-image generation using Stable Diffusion augmented with T2I-Adapter |
|
https://arxiv.org/abs/2302.08453 |
|
|
|
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the |
|
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) |
|
|
|
Args: |
|
adapter ([`T2IAdapter`] or [`MultiAdapter`] or `List[T2IAdapter]`): |
|
Provides additional conditioning to the unet during the denoising process. If you set multiple Adapter as a |
|
list, the outputs from each Adapter are added together to create one combined additional conditioning. |
|
adapter_weights (`List[float]`, *optional*, defaults to None): |
|
List of floats representing the weight which will be multiply to each adapter's output before adding them |
|
together. |
|
vae ([`AutoencoderKL`]): |
|
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. |
|
text_encoder ([`CLIPTextModel`]): |
|
Frozen text-encoder. Stable Diffusion uses the text portion of |
|
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically |
|
the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. |
|
tokenizer (`CLIPTokenizer`): |
|
Tokenizer of class |
|
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). |
|
unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. |
|
scheduler ([`SchedulerMixin`]): |
|
A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of |
|
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. |
|
safety_checker ([`StableDiffusionSafetyChecker`]): |
|
Classification module that estimates whether generated images could be considered offensive or harmful. |
|
Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. |
|
feature_extractor ([`CLIPFeatureExtractor`]): |
|
Model that extracts features from generated images to be used as inputs for the `safety_checker`. |
|
requires_aesthetics_score (`bool`, *optional*, defaults to `"False"`): |
|
Whether the `unet` requires a aesthetic_score condition to be passed during inference. Also see the config |
|
of `stabilityai/stable-diffusion-xl-refiner-1-0`. |
|
force_zeros_for_empty_prompt (`bool`, *optional*, defaults to `"True"`): |
|
Whether the negative prompt embeddings shall be forced to always be set to 0. Also see the config of |
|
`stabilityai/stable-diffusion-xl-base-1-0`. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
vae: AutoencoderKL, |
|
text_encoder: CLIPTextModel, |
|
text_encoder_2: CLIPTextModelWithProjection, |
|
tokenizer: CLIPTokenizer, |
|
tokenizer_2: CLIPTokenizer, |
|
unet: UNet2DConditionModel, |
|
adapter: Union[T2IAdapter, MultiAdapter], |
|
controlnet: Union[ControlNetModel, MultiControlNetModel], |
|
scheduler: KarrasDiffusionSchedulers, |
|
requires_aesthetics_score: bool = False, |
|
force_zeros_for_empty_prompt: bool = True, |
|
): |
|
super().__init__() |
|
|
|
if isinstance(controlnet, (list, tuple)): |
|
controlnet = MultiControlNetModel(controlnet) |
|
|
|
self.register_modules( |
|
vae=vae, |
|
text_encoder=text_encoder, |
|
text_encoder_2=text_encoder_2, |
|
tokenizer=tokenizer, |
|
tokenizer_2=tokenizer_2, |
|
unet=unet, |
|
adapter=adapter, |
|
controlnet=controlnet, |
|
scheduler=scheduler, |
|
) |
|
self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt) |
|
self.register_to_config(requires_aesthetics_score=requires_aesthetics_score) |
|
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) |
|
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) |
|
self.control_image_processor = VaeImageProcessor( |
|
vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False |
|
) |
|
self.default_sample_size = self.unet.config.sample_size |
|
|
|
|
|
def encode_prompt( |
|
self, |
|
prompt: str, |
|
prompt_2: Optional[str] = None, |
|
device: Optional[torch.device] = None, |
|
num_images_per_prompt: int = 1, |
|
do_classifier_free_guidance: bool = True, |
|
negative_prompt: Optional[str] = None, |
|
negative_prompt_2: Optional[str] = None, |
|
prompt_embeds: Optional[torch.Tensor] = None, |
|
negative_prompt_embeds: Optional[torch.Tensor] = None, |
|
pooled_prompt_embeds: Optional[torch.Tensor] = None, |
|
negative_pooled_prompt_embeds: Optional[torch.Tensor] = None, |
|
lora_scale: Optional[float] = None, |
|
clip_skip: Optional[int] = None, |
|
): |
|
r""" |
|
Encodes the prompt into text encoder hidden states. |
|
|
|
Args: |
|
prompt (`str` or `List[str]`, *optional*): |
|
prompt to be encoded |
|
prompt_2 (`str` or `List[str]`, *optional*): |
|
The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is |
|
used in both text-encoders |
|
device: (`torch.device`): |
|
torch device |
|
num_images_per_prompt (`int`): |
|
number of images that should be generated per prompt |
|
do_classifier_free_guidance (`bool`): |
|
whether to use classifier free guidance or not |
|
negative_prompt (`str` or `List[str]`, *optional*): |
|
The prompt or prompts not to guide the image generation. If not defined, one has to pass |
|
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is |
|
less than `1`). |
|
negative_prompt_2 (`str` or `List[str]`, *optional*): |
|
The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and |
|
`text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders |
|
prompt_embeds (`torch.Tensor`, *optional*): |
|
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not |
|
provided, text embeddings will be generated from `prompt` input argument. |
|
negative_prompt_embeds (`torch.Tensor`, *optional*): |
|
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt |
|
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input |
|
argument. |
|
pooled_prompt_embeds (`torch.Tensor`, *optional*): |
|
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. |
|
If not provided, pooled text embeddings will be generated from `prompt` input argument. |
|
negative_pooled_prompt_embeds (`torch.Tensor`, *optional*): |
|
Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt |
|
weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` |
|
input argument. |
|
lora_scale (`float`, *optional*): |
|
A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. |
|
clip_skip (`int`, *optional*): |
|
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that |
|
the output of the pre-final layer will be used for computing the prompt embeddings. |
|
""" |
|
device = device or self._execution_device |
|
|
|
|
|
|
|
if lora_scale is not None and isinstance(self, StableDiffusionXLLoraLoaderMixin): |
|
self._lora_scale = lora_scale |
|
|
|
|
|
if self.text_encoder is not None: |
|
if not USE_PEFT_BACKEND: |
|
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) |
|
else: |
|
scale_lora_layers(self.text_encoder, lora_scale) |
|
|
|
if self.text_encoder_2 is not None: |
|
if not USE_PEFT_BACKEND: |
|
adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale) |
|
else: |
|
scale_lora_layers(self.text_encoder_2, lora_scale) |
|
|
|
prompt = [prompt] if isinstance(prompt, str) else prompt |
|
|
|
if prompt is not None: |
|
batch_size = len(prompt) |
|
else: |
|
batch_size = prompt_embeds.shape[0] |
|
|
|
|
|
tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2] |
|
text_encoders = ( |
|
[self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2] |
|
) |
|
|
|
if prompt_embeds is None: |
|
prompt_2 = prompt_2 or prompt |
|
prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 |
|
|
|
|
|
prompt_embeds_list = [] |
|
prompts = [prompt, prompt_2] |
|
for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders): |
|
if isinstance(self, TextualInversionLoaderMixin): |
|
prompt = self.maybe_convert_prompt(prompt, tokenizer) |
|
|
|
text_inputs = tokenizer( |
|
prompt, |
|
padding="max_length", |
|
max_length=tokenizer.model_max_length, |
|
truncation=True, |
|
return_tensors="pt", |
|
) |
|
|
|
text_input_ids = text_inputs.input_ids |
|
untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids |
|
|
|
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( |
|
text_input_ids, untruncated_ids |
|
): |
|
removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1]) |
|
logger.warning( |
|
"The following part of your input was truncated because CLIP can only handle sequences up to" |
|
f" {tokenizer.model_max_length} tokens: {removed_text}" |
|
) |
|
|
|
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True) |
|
|
|
|
|
pooled_prompt_embeds = prompt_embeds[0] |
|
if clip_skip is None: |
|
prompt_embeds = prompt_embeds.hidden_states[-2] |
|
else: |
|
|
|
prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)] |
|
|
|
prompt_embeds_list.append(prompt_embeds) |
|
|
|
prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) |
|
|
|
|
|
zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt |
|
if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt: |
|
negative_prompt_embeds = torch.zeros_like(prompt_embeds) |
|
negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds) |
|
elif do_classifier_free_guidance and negative_prompt_embeds is None: |
|
negative_prompt = negative_prompt or "" |
|
negative_prompt_2 = negative_prompt_2 or negative_prompt |
|
|
|
|
|
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt |
|
negative_prompt_2 = ( |
|
batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2 |
|
) |
|
|
|
uncond_tokens: List[str] |
|
if prompt is not None and type(prompt) is not type(negative_prompt): |
|
raise TypeError( |
|
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" |
|
f" {type(prompt)}." |
|
) |
|
elif batch_size != len(negative_prompt): |
|
raise ValueError( |
|
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" |
|
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" |
|
" the batch size of `prompt`." |
|
) |
|
else: |
|
uncond_tokens = [negative_prompt, negative_prompt_2] |
|
|
|
negative_prompt_embeds_list = [] |
|
for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders): |
|
if isinstance(self, TextualInversionLoaderMixin): |
|
negative_prompt = self.maybe_convert_prompt(negative_prompt, tokenizer) |
|
|
|
max_length = prompt_embeds.shape[1] |
|
uncond_input = tokenizer( |
|
negative_prompt, |
|
padding="max_length", |
|
max_length=max_length, |
|
truncation=True, |
|
return_tensors="pt", |
|
) |
|
|
|
negative_prompt_embeds = text_encoder( |
|
uncond_input.input_ids.to(device), |
|
output_hidden_states=True, |
|
) |
|
|
|
negative_pooled_prompt_embeds = negative_prompt_embeds[0] |
|
negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2] |
|
|
|
negative_prompt_embeds_list.append(negative_prompt_embeds) |
|
|
|
negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1) |
|
|
|
if self.text_encoder_2 is not None: |
|
prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) |
|
else: |
|
prompt_embeds = prompt_embeds.to(dtype=self.unet.dtype, device=device) |
|
|
|
bs_embed, seq_len, _ = prompt_embeds.shape |
|
|
|
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) |
|
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) |
|
|
|
if do_classifier_free_guidance: |
|
|
|
seq_len = negative_prompt_embeds.shape[1] |
|
|
|
if self.text_encoder_2 is not None: |
|
negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) |
|
else: |
|
negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.unet.dtype, device=device) |
|
|
|
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) |
|
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) |
|
|
|
pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( |
|
bs_embed * num_images_per_prompt, -1 |
|
) |
|
if do_classifier_free_guidance: |
|
negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( |
|
bs_embed * num_images_per_prompt, -1 |
|
) |
|
|
|
if self.text_encoder is not None: |
|
if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: |
|
|
|
unscale_lora_layers(self.text_encoder, lora_scale) |
|
|
|
if self.text_encoder_2 is not None: |
|
if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: |
|
|
|
unscale_lora_layers(self.text_encoder_2, lora_scale) |
|
|
|
return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds |
|
|
|
|
|
def prepare_extra_step_kwargs(self, generator, eta): |
|
|
|
|
|
|
|
|
|
|
|
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) |
|
extra_step_kwargs = {} |
|
if accepts_eta: |
|
extra_step_kwargs["eta"] = eta |
|
|
|
|
|
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) |
|
if accepts_generator: |
|
extra_step_kwargs["generator"] = generator |
|
return extra_step_kwargs |
|
|
|
|
|
def check_image(self, image, prompt, prompt_embeds): |
|
image_is_pil = isinstance(image, PIL.Image.Image) |
|
image_is_tensor = isinstance(image, torch.Tensor) |
|
image_is_np = isinstance(image, np.ndarray) |
|
image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image) |
|
image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor) |
|
image_is_np_list = isinstance(image, list) and isinstance(image[0], np.ndarray) |
|
|
|
if ( |
|
not image_is_pil |
|
and not image_is_tensor |
|
and not image_is_np |
|
and not image_is_pil_list |
|
and not image_is_tensor_list |
|
and not image_is_np_list |
|
): |
|
raise TypeError( |
|
f"image must be passed and be one of PIL image, numpy array, torch tensor, list of PIL images, list of numpy arrays or list of torch tensors, but is {type(image)}" |
|
) |
|
|
|
if image_is_pil: |
|
image_batch_size = 1 |
|
else: |
|
image_batch_size = len(image) |
|
|
|
if prompt is not None and isinstance(prompt, str): |
|
prompt_batch_size = 1 |
|
elif prompt is not None and isinstance(prompt, list): |
|
prompt_batch_size = len(prompt) |
|
elif prompt_embeds is not None: |
|
prompt_batch_size = prompt_embeds.shape[0] |
|
|
|
if image_batch_size != 1 and image_batch_size != prompt_batch_size: |
|
raise ValueError( |
|
f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}" |
|
) |
|
|
|
|
|
def check_inputs( |
|
self, |
|
prompt, |
|
prompt_2, |
|
height, |
|
width, |
|
callback_steps, |
|
negative_prompt=None, |
|
negative_prompt_2=None, |
|
prompt_embeds=None, |
|
negative_prompt_embeds=None, |
|
pooled_prompt_embeds=None, |
|
negative_pooled_prompt_embeds=None, |
|
callback_on_step_end_tensor_inputs=None, |
|
): |
|
if height % 8 != 0 or width % 8 != 0: |
|
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") |
|
|
|
if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): |
|
raise ValueError( |
|
f"`callback_steps` has to be a positive integer but is {callback_steps} of type" |
|
f" {type(callback_steps)}." |
|
) |
|
|
|
if callback_on_step_end_tensor_inputs is not None and not all( |
|
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs |
|
): |
|
raise ValueError( |
|
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" |
|
) |
|
|
|
if prompt is not None and prompt_embeds is not None: |
|
raise ValueError( |
|
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" |
|
" only forward one of the two." |
|
) |
|
elif prompt_2 is not None and prompt_embeds is not None: |
|
raise ValueError( |
|
f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" |
|
" only forward one of the two." |
|
) |
|
elif prompt is None and prompt_embeds is None: |
|
raise ValueError( |
|
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." |
|
) |
|
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): |
|
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") |
|
elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): |
|
raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") |
|
|
|
if negative_prompt is not None and negative_prompt_embeds is not None: |
|
raise ValueError( |
|
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" |
|
f" {negative_prompt_embeds}. Please make sure to only forward one of the two." |
|
) |
|
elif negative_prompt_2 is not None and negative_prompt_embeds is not None: |
|
raise ValueError( |
|
f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:" |
|
f" {negative_prompt_embeds}. Please make sure to only forward one of the two." |
|
) |
|
|
|
if prompt_embeds is not None and negative_prompt_embeds is not None: |
|
if prompt_embeds.shape != negative_prompt_embeds.shape: |
|
raise ValueError( |
|
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" |
|
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" |
|
f" {negative_prompt_embeds.shape}." |
|
) |
|
|
|
if prompt_embeds is not None and pooled_prompt_embeds is None: |
|
raise ValueError( |
|
"If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." |
|
) |
|
|
|
if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None: |
|
raise ValueError( |
|
"If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`." |
|
) |
|
|
|
def check_conditions( |
|
self, |
|
prompt, |
|
prompt_embeds, |
|
adapter_image, |
|
control_image, |
|
adapter_conditioning_scale, |
|
controlnet_conditioning_scale, |
|
control_guidance_start, |
|
control_guidance_end, |
|
): |
|
|
|
if not isinstance(control_guidance_start, (tuple, list)): |
|
control_guidance_start = [control_guidance_start] |
|
|
|
if not isinstance(control_guidance_end, (tuple, list)): |
|
control_guidance_end = [control_guidance_end] |
|
|
|
if len(control_guidance_start) != len(control_guidance_end): |
|
raise ValueError( |
|
f"`control_guidance_start` has {len(control_guidance_start)} elements, but `control_guidance_end` has {len(control_guidance_end)} elements. Make sure to provide the same number of elements to each list." |
|
) |
|
|
|
if isinstance(self.controlnet, MultiControlNetModel): |
|
if len(control_guidance_start) != len(self.controlnet.nets): |
|
raise ValueError( |
|
f"`control_guidance_start`: {control_guidance_start} has {len(control_guidance_start)} elements but there are {len(self.controlnet.nets)} controlnets available. Make sure to provide {len(self.controlnet.nets)}." |
|
) |
|
|
|
for start, end in zip(control_guidance_start, control_guidance_end): |
|
if start >= end: |
|
raise ValueError( |
|
f"control guidance start: {start} cannot be larger or equal to control guidance end: {end}." |
|
) |
|
if start < 0.0: |
|
raise ValueError(f"control guidance start: {start} can't be smaller than 0.") |
|
if end > 1.0: |
|
raise ValueError(f"control guidance end: {end} can't be larger than 1.0.") |
|
|
|
|
|
is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance( |
|
self.controlnet, torch._dynamo.eval_frame.OptimizedModule |
|
) |
|
if ( |
|
isinstance(self.controlnet, ControlNetModel) |
|
or is_compiled |
|
and isinstance(self.controlnet._orig_mod, ControlNetModel) |
|
): |
|
self.check_image(control_image, prompt, prompt_embeds) |
|
elif ( |
|
isinstance(self.controlnet, MultiControlNetModel) |
|
or is_compiled |
|
and isinstance(self.controlnet._orig_mod, MultiControlNetModel) |
|
): |
|
if not isinstance(control_image, list): |
|
raise TypeError("For multiple controlnets: `control_image` must be type `list`") |
|
|
|
|
|
|
|
elif any(isinstance(i, list) for i in control_image): |
|
raise ValueError("A single batch of multiple conditionings are supported at the moment.") |
|
elif len(control_image) != len(self.controlnet.nets): |
|
raise ValueError( |
|
f"For multiple controlnets: `image` must have the same length as the number of controlnets, but got {len(control_image)} images and {len(self.controlnet.nets)} ControlNets." |
|
) |
|
|
|
for image_ in control_image: |
|
self.check_image(image_, prompt, prompt_embeds) |
|
else: |
|
assert False |
|
|
|
|
|
if ( |
|
isinstance(self.controlnet, ControlNetModel) |
|
or is_compiled |
|
and isinstance(self.controlnet._orig_mod, ControlNetModel) |
|
): |
|
if not isinstance(controlnet_conditioning_scale, float): |
|
raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.") |
|
elif ( |
|
isinstance(self.controlnet, MultiControlNetModel) |
|
or is_compiled |
|
and isinstance(self.controlnet._orig_mod, MultiControlNetModel) |
|
): |
|
if isinstance(controlnet_conditioning_scale, list): |
|
if any(isinstance(i, list) for i in controlnet_conditioning_scale): |
|
raise ValueError("A single batch of multiple conditionings are supported at the moment.") |
|
elif isinstance(controlnet_conditioning_scale, list) and len(controlnet_conditioning_scale) != len( |
|
self.controlnet.nets |
|
): |
|
raise ValueError( |
|
"For multiple controlnets: When `controlnet_conditioning_scale` is specified as `list`, it must have" |
|
" the same length as the number of controlnets" |
|
) |
|
else: |
|
assert False |
|
|
|
|
|
if isinstance(self.adapter, T2IAdapter) or is_compiled and isinstance(self.adapter._orig_mod, T2IAdapter): |
|
self.check_image(adapter_image, prompt, prompt_embeds) |
|
elif ( |
|
isinstance(self.adapter, MultiAdapter) or is_compiled and isinstance(self.adapter._orig_mod, MultiAdapter) |
|
): |
|
if not isinstance(adapter_image, list): |
|
raise TypeError("For multiple adapters: `adapter_image` must be type `list`") |
|
|
|
|
|
|
|
elif any(isinstance(i, list) for i in adapter_image): |
|
raise ValueError("A single batch of multiple conditionings are supported at the moment.") |
|
elif len(adapter_image) != len(self.adapter.adapters): |
|
raise ValueError( |
|
f"For multiple adapters: `image` must have the same length as the number of adapters, but got {len(adapter_image)} images and {len(self.adapters.nets)} Adapters." |
|
) |
|
|
|
for image_ in adapter_image: |
|
self.check_image(image_, prompt, prompt_embeds) |
|
else: |
|
assert False |
|
|
|
|
|
if isinstance(self.adapter, T2IAdapter) or is_compiled and isinstance(self.adapter._orig_mod, T2IAdapter): |
|
if not isinstance(adapter_conditioning_scale, float): |
|
raise TypeError("For single adapter: `adapter_conditioning_scale` must be type `float`.") |
|
elif ( |
|
isinstance(self.adapter, MultiAdapter) or is_compiled and isinstance(self.adapter._orig_mod, MultiAdapter) |
|
): |
|
if isinstance(adapter_conditioning_scale, list): |
|
if any(isinstance(i, list) for i in adapter_conditioning_scale): |
|
raise ValueError("A single batch of multiple conditionings are supported at the moment.") |
|
elif isinstance(adapter_conditioning_scale, list) and len(adapter_conditioning_scale) != len( |
|
self.adapter.adapters |
|
): |
|
raise ValueError( |
|
"For multiple adapters: When `adapter_conditioning_scale` is specified as `list`, it must have" |
|
" the same length as the number of adapters" |
|
) |
|
else: |
|
assert False |
|
|
|
def prepare_latents( |
|
self, |
|
batch_size, |
|
num_channels_latents, |
|
height, |
|
width, |
|
dtype, |
|
device, |
|
generator, |
|
latents=None, |
|
image=None, |
|
timestep=None, |
|
is_strength_max=True, |
|
add_noise=True, |
|
return_noise=False, |
|
return_image_latents=False, |
|
): |
|
shape = ( |
|
batch_size, |
|
num_channels_latents, |
|
height // self.vae_scale_factor, |
|
width // self.vae_scale_factor, |
|
) |
|
if isinstance(generator, list) and len(generator) != batch_size: |
|
raise ValueError( |
|
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" |
|
f" size of {batch_size}. Make sure the batch size matches the length of the generators." |
|
) |
|
|
|
if (image is None or timestep is None) and not is_strength_max: |
|
raise ValueError( |
|
"Since strength < 1. initial latents are to be initialised as a combination of Image + Noise." |
|
"However, either the image or the noise timestep has not been provided." |
|
) |
|
|
|
if image.shape[1] == 4: |
|
image_latents = image.to(device=device, dtype=dtype) |
|
elif return_image_latents or (latents is None and not is_strength_max): |
|
image = image.to(device=device, dtype=dtype) |
|
image_latents = self._encode_vae_image(image=image, generator=generator) |
|
|
|
image_latents = image_latents.repeat(batch_size // image_latents.shape[0], 1, 1, 1) |
|
|
|
if latents is None and add_noise: |
|
noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) |
|
|
|
latents = noise if is_strength_max else self.scheduler.add_noise(image_latents, noise, timestep) |
|
|
|
latents = latents * self.scheduler.init_noise_sigma if is_strength_max else latents |
|
elif add_noise: |
|
noise = latents.to(device) |
|
latents = noise * self.scheduler.init_noise_sigma |
|
else: |
|
noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) |
|
latents = image_latents.to(device) |
|
|
|
outputs = (latents,) |
|
|
|
if return_noise: |
|
outputs += (noise,) |
|
|
|
if return_image_latents: |
|
outputs += (image_latents,) |
|
|
|
return outputs |
|
|
|
def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator): |
|
dtype = image.dtype |
|
if self.vae.config.force_upcast: |
|
image = image.float() |
|
self.vae.to(dtype=torch.float32) |
|
|
|
if isinstance(generator, list): |
|
image_latents = [ |
|
self.vae.encode(image[i : i + 1]).latent_dist.sample(generator=generator[i]) |
|
for i in range(image.shape[0]) |
|
] |
|
image_latents = torch.cat(image_latents, dim=0) |
|
else: |
|
image_latents = self.vae.encode(image).latent_dist.sample(generator=generator) |
|
|
|
if self.vae.config.force_upcast: |
|
self.vae.to(dtype) |
|
|
|
image_latents = image_latents.to(dtype) |
|
image_latents = self.vae.config.scaling_factor * image_latents |
|
|
|
return image_latents |
|
|
|
def prepare_mask_latents( |
|
self, |
|
mask, |
|
masked_image, |
|
batch_size, |
|
height, |
|
width, |
|
dtype, |
|
device, |
|
generator, |
|
do_classifier_free_guidance, |
|
): |
|
|
|
|
|
|
|
mask = torch.nn.functional.interpolate( |
|
mask, |
|
size=( |
|
height // self.vae_scale_factor, |
|
width // self.vae_scale_factor, |
|
), |
|
) |
|
mask = mask.to(device=device, dtype=dtype) |
|
|
|
|
|
if mask.shape[0] < batch_size: |
|
if not batch_size % mask.shape[0] == 0: |
|
raise ValueError( |
|
"The passed mask and the required batch size don't match. Masks are supposed to be duplicated to" |
|
f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number" |
|
" of masks that you pass is divisible by the total requested batch size." |
|
) |
|
mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1) |
|
|
|
mask = torch.cat([mask] * 2) if do_classifier_free_guidance else mask |
|
|
|
masked_image_latents = None |
|
if masked_image is not None: |
|
masked_image = masked_image.to(device=device, dtype=dtype) |
|
masked_image_latents = self._encode_vae_image(masked_image, generator=generator) |
|
if masked_image_latents.shape[0] < batch_size: |
|
if not batch_size % masked_image_latents.shape[0] == 0: |
|
raise ValueError( |
|
"The passed images and the required batch size don't match. Images are supposed to be duplicated" |
|
f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed." |
|
" Make sure the number of images that you pass is divisible by the total requested batch size." |
|
) |
|
masked_image_latents = masked_image_latents.repeat( |
|
batch_size // masked_image_latents.shape[0], 1, 1, 1 |
|
) |
|
|
|
masked_image_latents = ( |
|
torch.cat([masked_image_latents] * 2) if do_classifier_free_guidance else masked_image_latents |
|
) |
|
|
|
|
|
masked_image_latents = masked_image_latents.to(device=device, dtype=dtype) |
|
|
|
return mask, masked_image_latents |
|
|
|
|
|
def get_timesteps(self, num_inference_steps, strength, device, denoising_start=None): |
|
|
|
if denoising_start is None: |
|
init_timestep = min(int(num_inference_steps * strength), num_inference_steps) |
|
t_start = max(num_inference_steps - init_timestep, 0) |
|
else: |
|
t_start = 0 |
|
|
|
timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] |
|
|
|
|
|
|
|
if denoising_start is not None: |
|
discrete_timestep_cutoff = int( |
|
round( |
|
self.scheduler.config.num_train_timesteps |
|
- (denoising_start * self.scheduler.config.num_train_timesteps) |
|
) |
|
) |
|
|
|
num_inference_steps = (timesteps < discrete_timestep_cutoff).sum().item() |
|
if self.scheduler.order == 2 and num_inference_steps % 2 == 0: |
|
|
|
|
|
|
|
|
|
|
|
|
|
num_inference_steps = num_inference_steps + 1 |
|
|
|
|
|
timesteps = timesteps[-num_inference_steps:] |
|
return timesteps, num_inference_steps |
|
|
|
return timesteps, num_inference_steps - t_start |
|
|
|
def _get_add_time_ids( |
|
self, |
|
original_size, |
|
crops_coords_top_left, |
|
target_size, |
|
aesthetic_score, |
|
negative_aesthetic_score, |
|
dtype, |
|
text_encoder_projection_dim=None, |
|
): |
|
if self.config.requires_aesthetics_score: |
|
add_time_ids = list(original_size + crops_coords_top_left + (aesthetic_score,)) |
|
add_neg_time_ids = list(original_size + crops_coords_top_left + (negative_aesthetic_score,)) |
|
else: |
|
add_time_ids = list(original_size + crops_coords_top_left + target_size) |
|
add_neg_time_ids = list(original_size + crops_coords_top_left + target_size) |
|
|
|
passed_add_embed_dim = ( |
|
self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim |
|
) |
|
expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features |
|
|
|
if ( |
|
expected_add_embed_dim > passed_add_embed_dim |
|
and (expected_add_embed_dim - passed_add_embed_dim) == self.unet.config.addition_time_embed_dim |
|
): |
|
raise ValueError( |
|
f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. Please make sure to enable `requires_aesthetics_score` with `pipe.register_to_config(requires_aesthetics_score=True)` to make sure `aesthetic_score` {aesthetic_score} and `negative_aesthetic_score` {negative_aesthetic_score} is correctly used by the model." |
|
) |
|
elif ( |
|
expected_add_embed_dim < passed_add_embed_dim |
|
and (passed_add_embed_dim - expected_add_embed_dim) == self.unet.config.addition_time_embed_dim |
|
): |
|
raise ValueError( |
|
f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. Please make sure to disable `requires_aesthetics_score` with `pipe.register_to_config(requires_aesthetics_score=False)` to make sure `target_size` {target_size} is correctly used by the model." |
|
) |
|
elif expected_add_embed_dim != passed_add_embed_dim: |
|
raise ValueError( |
|
f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`." |
|
) |
|
|
|
add_time_ids = torch.tensor([add_time_ids], dtype=dtype) |
|
add_neg_time_ids = torch.tensor([add_neg_time_ids], dtype=dtype) |
|
|
|
return add_time_ids, add_neg_time_ids |
|
|
|
|
|
def upcast_vae(self): |
|
dtype = self.vae.dtype |
|
self.vae.to(dtype=torch.float32) |
|
use_torch_2_0_or_xformers = isinstance( |
|
self.vae.decoder.mid_block.attentions[0].processor, |
|
( |
|
AttnProcessor2_0, |
|
XFormersAttnProcessor, |
|
LoRAXFormersAttnProcessor, |
|
LoRAAttnProcessor2_0, |
|
), |
|
) |
|
|
|
|
|
if use_torch_2_0_or_xformers: |
|
self.vae.post_quant_conv.to(dtype) |
|
self.vae.decoder.conv_in.to(dtype) |
|
self.vae.decoder.mid_block.to(dtype) |
|
|
|
|
|
def _default_height_width(self, height, width, image): |
|
|
|
|
|
|
|
while isinstance(image, list): |
|
image = image[0] |
|
|
|
if height is None: |
|
if isinstance(image, PIL.Image.Image): |
|
height = image.height |
|
elif isinstance(image, torch.Tensor): |
|
height = image.shape[-2] |
|
|
|
|
|
height = (height // self.adapter.downscale_factor) * self.adapter.downscale_factor |
|
|
|
if width is None: |
|
if isinstance(image, PIL.Image.Image): |
|
width = image.width |
|
elif isinstance(image, torch.Tensor): |
|
width = image.shape[-1] |
|
|
|
|
|
width = (width // self.adapter.downscale_factor) * self.adapter.downscale_factor |
|
|
|
return height, width |
|
|
|
def prepare_control_image( |
|
self, |
|
image, |
|
width, |
|
height, |
|
batch_size, |
|
num_images_per_prompt, |
|
device, |
|
dtype, |
|
do_classifier_free_guidance=False, |
|
guess_mode=False, |
|
): |
|
image = self.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32) |
|
image_batch_size = image.shape[0] |
|
|
|
if image_batch_size == 1: |
|
repeat_by = batch_size |
|
else: |
|
|
|
repeat_by = num_images_per_prompt |
|
|
|
image = image.repeat_interleave(repeat_by, dim=0) |
|
|
|
image = image.to(device=device, dtype=dtype) |
|
|
|
if do_classifier_free_guidance and not guess_mode: |
|
image = torch.cat([image] * 2) |
|
|
|
return image |
|
|
|
@torch.no_grad() |
|
@replace_example_docstring(EXAMPLE_DOC_STRING) |
|
def __call__( |
|
self, |
|
prompt: Optional[Union[str, List[str]]] = None, |
|
prompt_2: Optional[Union[str, List[str]]] = None, |
|
image: Optional[Union[torch.Tensor, PIL.Image.Image]] = None, |
|
mask_image: Optional[Union[torch.Tensor, PIL.Image.Image]] = None, |
|
adapter_image: PipelineImageInput = None, |
|
control_image: PipelineImageInput = None, |
|
height: Optional[int] = None, |
|
width: Optional[int] = None, |
|
strength: float = 0.9999, |
|
num_inference_steps: int = 50, |
|
denoising_start: Optional[float] = None, |
|
denoising_end: Optional[float] = None, |
|
guidance_scale: float = 5.0, |
|
negative_prompt: Optional[Union[str, List[str]]] = None, |
|
negative_prompt_2: Optional[Union[str, List[str]]] = None, |
|
num_images_per_prompt: Optional[int] = 1, |
|
eta: float = 0.0, |
|
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, |
|
latents: Optional[Union[torch.Tensor]] = None, |
|
prompt_embeds: Optional[torch.Tensor] = None, |
|
negative_prompt_embeds: Optional[torch.Tensor] = None, |
|
pooled_prompt_embeds: Optional[torch.Tensor] = None, |
|
negative_pooled_prompt_embeds: Optional[torch.Tensor] = None, |
|
output_type: Optional[str] = "pil", |
|
return_dict: bool = True, |
|
callback: Optional[Callable[[int, int, torch.Tensor], None]] = None, |
|
callback_steps: int = 1, |
|
cross_attention_kwargs: Optional[Dict[str, Any]] = None, |
|
guidance_rescale: float = 0.0, |
|
original_size: Optional[Tuple[int, int]] = None, |
|
crops_coords_top_left: Optional[Tuple[int, int]] = (0, 0), |
|
target_size: Optional[Tuple[int, int]] = None, |
|
adapter_conditioning_scale: Optional[Union[float, List[float]]] = 1.0, |
|
cond_tau: float = 1.0, |
|
aesthetic_score: float = 6.0, |
|
negative_aesthetic_score: float = 2.5, |
|
controlnet_conditioning_scale=1.0, |
|
guess_mode: bool = False, |
|
control_guidance_start=0.0, |
|
control_guidance_end=1.0, |
|
): |
|
r""" |
|
Function invoked when calling the pipeline for generation. |
|
|
|
Args: |
|
prompt (`str` or `List[str]`, *optional*): |
|
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. |
|
instead. |
|
prompt_2 (`str` or `List[str]`, *optional*): |
|
The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is |
|
used in both text-encoders |
|
image (`PIL.Image.Image`): |
|
`Image`, or tensor representing an image batch which will be inpainted, *i.e.* parts of the image will |
|
be masked out with `mask_image` and repainted according to `prompt`. |
|
mask_image (`PIL.Image.Image`): |
|
`Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be |
|
repainted, while black pixels will be preserved. If `mask_image` is a PIL image, it will be converted |
|
to a single channel (luminance) before use. If it's a tensor, it should contain one color channel (L) |
|
instead of 3, so the expected shape would be `(B, H, W, 1)`. |
|
adapter_image (`torch.Tensor`, `PIL.Image.Image`, `List[torch.Tensor]` or `List[PIL.Image.Image]` or `List[List[PIL.Image.Image]]`): |
|
The Adapter input condition. Adapter uses this input condition to generate guidance to Unet. If the |
|
type is specified as `torch.Tensor`, it is passed to Adapter as is. PIL.Image.Image` can also be |
|
accepted as an image. The control image is automatically resized to fit the output image. |
|
control_image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,: |
|
`List[List[torch.Tensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`): |
|
The ControlNet input condition to provide guidance to the `unet` for generation. If the type is |
|
specified as `torch.Tensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be |
|
accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If height |
|
and/or width are passed, `image` is resized accordingly. If multiple ControlNets are specified in |
|
`init`, images must be passed as a list such that each element of the list can be correctly batched for |
|
input to a single ControlNet. |
|
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): |
|
The height in pixels of the generated image. |
|
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): |
|
The width in pixels of the generated image. |
|
strength (`float`, *optional*, defaults to 1.0): |
|
Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a |
|
starting point and more noise is added the higher the `strength`. The number of denoising steps depends |
|
on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising |
|
process runs for the full number of iterations specified in `num_inference_steps`. A value of 1 |
|
essentially ignores `image`. |
|
num_inference_steps (`int`, *optional*, defaults to 50): |
|
The number of denoising steps. More denoising steps usually lead to a higher quality image at the |
|
expense of slower inference. |
|
denoising_start (`float`, *optional*): |
|
When specified, indicates the fraction (between 0.0 and 1.0) of the total denoising process to be |
|
bypassed before it is initiated. Consequently, the initial part of the denoising process is skipped and |
|
it is assumed that the passed `image` is a partly denoised image. Note that when this is specified, |
|
strength will be ignored. The `denoising_start` parameter is particularly beneficial when this pipeline |
|
is integrated into a "Mixture of Denoisers" multi-pipeline setup, as detailed in [**Refining the Image |
|
Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output). |
|
denoising_end (`float`, *optional*): |
|
When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be |
|
completed before it is intentionally prematurely terminated. As a result, the returned sample will |
|
still retain a substantial amount of noise as determined by the discrete timesteps selected by the |
|
scheduler. The denoising_end parameter should ideally be utilized when this pipeline forms a part of a |
|
"Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image |
|
Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output) |
|
guidance_scale (`float`, *optional*, defaults to 5.0): |
|
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). |
|
`guidance_scale` is defined as `w` of equation 2. of [Imagen |
|
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > |
|
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, |
|
usually at the expense of lower image quality. |
|
negative_prompt (`str` or `List[str]`, *optional*): |
|
The prompt or prompts not to guide the image generation. If not defined, one has to pass |
|
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is |
|
less than `1`). |
|
negative_prompt_2 (`str` or `List[str]`, *optional*): |
|
The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and |
|
`text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders |
|
num_images_per_prompt (`int`, *optional*, defaults to 1): |
|
The number of images to generate per prompt. |
|
eta (`float`, *optional*, defaults to 0.0): |
|
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to |
|
[`schedulers.DDIMScheduler`], will be ignored for others. |
|
generator (`torch.Generator` or `List[torch.Generator]`, *optional*): |
|
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) |
|
to make generation deterministic. |
|
latents (`torch.Tensor`, *optional*): |
|
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image |
|
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents |
|
tensor will ge generated by sampling using the supplied random `generator`. |
|
prompt_embeds (`torch.Tensor`, *optional*): |
|
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not |
|
provided, text embeddings will be generated from `prompt` input argument. |
|
negative_prompt_embeds (`torch.Tensor`, *optional*): |
|
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt |
|
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input |
|
argument. |
|
pooled_prompt_embeds (`torch.Tensor`, *optional*): |
|
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. |
|
If not provided, pooled text embeddings will be generated from `prompt` input argument. |
|
negative_pooled_prompt_embeds (`torch.Tensor`, *optional*): |
|
Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt |
|
weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` |
|
input argument. |
|
output_type (`str`, *optional*, defaults to `"pil"`): |
|
The output format of the generate image. Choose between |
|
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. |
|
return_dict (`bool`, *optional*, defaults to `True`): |
|
Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionAdapterPipelineOutput`] |
|
instead of a plain tuple. |
|
callback (`Callable`, *optional*): |
|
A function that will be called every `callback_steps` steps during inference. The function will be |
|
called with the following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. |
|
callback_steps (`int`, *optional*, defaults to 1): |
|
The frequency at which the `callback` function will be called. If not specified, the callback will be |
|
called at every step. |
|
cross_attention_kwargs (`dict`, *optional*): |
|
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under |
|
`self.processor` in |
|
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). |
|
guidance_rescale (`float`, *optional*, defaults to 0.7): |
|
Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are |
|
Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of |
|
[Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). |
|
Guidance rescale factor should fix overexposure when using zero terminal SNR. |
|
original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): |
|
If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled. |
|
`original_size` defaults to `(width, height)` if not specified. Part of SDXL's micro-conditioning as |
|
explained in section 2.2 of |
|
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). |
|
crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)): |
|
`crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position |
|
`crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting |
|
`crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of |
|
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). |
|
target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): |
|
For most cases, `target_size` should be set to the desired height and width of the generated image. If |
|
not specified it will default to `(width, height)`. Part of SDXL's micro-conditioning as explained in |
|
section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). |
|
controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0): |
|
The outputs of the controlnet are multiplied by `controlnet_conditioning_scale` before they are added to the |
|
residual in the original unet. If multiple adapters are specified in init, you can set the |
|
corresponding scale as a list. |
|
adapter_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0): |
|
The outputs of the adapter are multiplied by `adapter_conditioning_scale` before they are added to the |
|
residual in the original unet. If multiple adapters are specified in init, you can set the |
|
corresponding scale as a list. |
|
aesthetic_score (`float`, *optional*, defaults to 6.0): |
|
Used to simulate an aesthetic score of the generated image by influencing the positive text condition. |
|
Part of SDXL's micro-conditioning as explained in section 2.2 of |
|
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). |
|
negative_aesthetic_score (`float`, *optional*, defaults to 2.5): |
|
Part of SDXL's micro-conditioning as explained in section 2.2 of |
|
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). Can be used to |
|
simulate an aesthetic score of the generated image by influencing the negative text condition. |
|
Examples: |
|
|
|
Returns: |
|
[`~pipelines.stable_diffusion.StableDiffusionAdapterPipelineOutput`] or `tuple`: |
|
[`~pipelines.stable_diffusion.StableDiffusionAdapterPipelineOutput`] if `return_dict` is True, otherwise a |
|
`tuple`. When returning a tuple, the first element is a list with the generated images. |
|
""" |
|
|
|
controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet |
|
adapter = self.adapter._orig_mod if is_compiled_module(self.adapter) else self.adapter |
|
height, width = self._default_height_width(height, width, adapter_image) |
|
device = self._execution_device |
|
|
|
if isinstance(adapter, MultiAdapter): |
|
adapter_input = [] |
|
for one_image in adapter_image: |
|
one_image = _preprocess_adapter_image(one_image, height, width) |
|
one_image = one_image.to(device=device, dtype=adapter.dtype) |
|
adapter_input.append(one_image) |
|
else: |
|
adapter_input = _preprocess_adapter_image(adapter_image, height, width) |
|
adapter_input = adapter_input.to(device=device, dtype=adapter.dtype) |
|
|
|
original_size = original_size or (height, width) |
|
target_size = target_size or (height, width) |
|
|
|
|
|
if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list): |
|
control_guidance_start = len(control_guidance_end) * [control_guidance_start] |
|
elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list): |
|
control_guidance_end = len(control_guidance_start) * [control_guidance_end] |
|
elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list): |
|
mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1 |
|
control_guidance_start, control_guidance_end = ( |
|
mult * [control_guidance_start], |
|
mult * [control_guidance_end], |
|
) |
|
|
|
if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float): |
|
controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets) |
|
if isinstance(adapter, MultiAdapter) and isinstance(adapter_conditioning_scale, float): |
|
adapter_conditioning_scale = [adapter_conditioning_scale] * len(adapter.nets) |
|
|
|
|
|
self.check_inputs( |
|
prompt, |
|
prompt_2, |
|
height, |
|
width, |
|
callback_steps, |
|
negative_prompt=negative_prompt, |
|
negative_prompt_2=negative_prompt_2, |
|
prompt_embeds=prompt_embeds, |
|
negative_prompt_embeds=negative_prompt_embeds, |
|
pooled_prompt_embeds=pooled_prompt_embeds, |
|
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, |
|
) |
|
|
|
self.check_conditions( |
|
prompt, |
|
prompt_embeds, |
|
adapter_image, |
|
control_image, |
|
adapter_conditioning_scale, |
|
controlnet_conditioning_scale, |
|
control_guidance_start, |
|
control_guidance_end, |
|
) |
|
|
|
|
|
if prompt is not None and isinstance(prompt, str): |
|
batch_size = 1 |
|
elif prompt is not None and isinstance(prompt, list): |
|
batch_size = len(prompt) |
|
else: |
|
batch_size = prompt_embeds.shape[0] |
|
|
|
device = self._execution_device |
|
|
|
|
|
|
|
|
|
do_classifier_free_guidance = guidance_scale > 1.0 |
|
|
|
|
|
( |
|
prompt_embeds, |
|
negative_prompt_embeds, |
|
pooled_prompt_embeds, |
|
negative_pooled_prompt_embeds, |
|
) = self.encode_prompt( |
|
prompt=prompt, |
|
prompt_2=prompt_2, |
|
device=device, |
|
num_images_per_prompt=num_images_per_prompt, |
|
do_classifier_free_guidance=do_classifier_free_guidance, |
|
negative_prompt=negative_prompt, |
|
negative_prompt_2=negative_prompt_2, |
|
prompt_embeds=prompt_embeds, |
|
negative_prompt_embeds=negative_prompt_embeds, |
|
pooled_prompt_embeds=pooled_prompt_embeds, |
|
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, |
|
) |
|
|
|
|
|
def denoising_value_valid(dnv): |
|
return isinstance(dnv, float) and 0 < dnv < 1 |
|
|
|
self.scheduler.set_timesteps(num_inference_steps, device=device) |
|
timesteps, num_inference_steps = self.get_timesteps( |
|
num_inference_steps, |
|
strength, |
|
device, |
|
denoising_start=denoising_start if denoising_value_valid(denoising_start) else None, |
|
) |
|
|
|
if num_inference_steps < 1: |
|
raise ValueError( |
|
f"After adjusting the num_inference_steps by strength parameter: {strength}, the number of pipeline" |
|
f"steps is {num_inference_steps} which is < 1 and not appropriate for this pipeline." |
|
) |
|
|
|
latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) |
|
|
|
is_strength_max = strength == 1.0 |
|
|
|
|
|
mask, masked_image, init_image = prepare_mask_and_masked_image( |
|
image, mask_image, height, width, return_image=True |
|
) |
|
|
|
|
|
num_channels_latents = self.vae.config.latent_channels |
|
num_channels_unet = self.unet.config.in_channels |
|
return_image_latents = num_channels_unet == 4 |
|
|
|
add_noise = denoising_start is None |
|
latents_outputs = self.prepare_latents( |
|
batch_size * num_images_per_prompt, |
|
num_channels_latents, |
|
height, |
|
width, |
|
prompt_embeds.dtype, |
|
device, |
|
generator, |
|
latents, |
|
image=init_image, |
|
timestep=latent_timestep, |
|
is_strength_max=is_strength_max, |
|
add_noise=add_noise, |
|
return_noise=True, |
|
return_image_latents=return_image_latents, |
|
) |
|
|
|
if return_image_latents: |
|
latents, noise, image_latents = latents_outputs |
|
else: |
|
latents, noise = latents_outputs |
|
|
|
|
|
mask, masked_image_latents = self.prepare_mask_latents( |
|
mask, |
|
masked_image, |
|
batch_size * num_images_per_prompt, |
|
height, |
|
width, |
|
prompt_embeds.dtype, |
|
device, |
|
generator, |
|
do_classifier_free_guidance, |
|
) |
|
|
|
|
|
if num_channels_unet == 9: |
|
|
|
num_channels_mask = mask.shape[1] |
|
num_channels_masked_image = masked_image_latents.shape[1] |
|
if num_channels_latents + num_channels_mask + num_channels_masked_image != self.unet.config.in_channels: |
|
raise ValueError( |
|
f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects" |
|
f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" |
|
f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}" |
|
f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of" |
|
" `pipeline.unet` or your `mask_image` or `image` input." |
|
) |
|
elif num_channels_unet != 4: |
|
raise ValueError( |
|
f"The unet {self.unet.__class__} should have either 4 or 9 input channels, not {self.unet.config.in_channels}." |
|
) |
|
|
|
|
|
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) |
|
|
|
|
|
if isinstance(adapter, MultiAdapter): |
|
adapter_state = adapter(adapter_input, adapter_conditioning_scale) |
|
for k, v in enumerate(adapter_state): |
|
adapter_state[k] = v |
|
else: |
|
adapter_state = adapter(adapter_input) |
|
for k, v in enumerate(adapter_state): |
|
adapter_state[k] = v * adapter_conditioning_scale |
|
if num_images_per_prompt > 1: |
|
for k, v in enumerate(adapter_state): |
|
adapter_state[k] = v.repeat(num_images_per_prompt, 1, 1, 1) |
|
if do_classifier_free_guidance: |
|
for k, v in enumerate(adapter_state): |
|
adapter_state[k] = torch.cat([v] * 2, dim=0) |
|
|
|
|
|
if isinstance(controlnet, ControlNetModel): |
|
control_image = self.prepare_control_image( |
|
image=control_image, |
|
width=width, |
|
height=height, |
|
batch_size=batch_size * num_images_per_prompt, |
|
num_images_per_prompt=num_images_per_prompt, |
|
device=device, |
|
dtype=controlnet.dtype, |
|
do_classifier_free_guidance=do_classifier_free_guidance, |
|
guess_mode=guess_mode, |
|
) |
|
elif isinstance(controlnet, MultiControlNetModel): |
|
control_images = [] |
|
|
|
for control_image_ in control_image: |
|
control_image_ = self.prepare_control_image( |
|
image=control_image_, |
|
width=width, |
|
height=height, |
|
batch_size=batch_size * num_images_per_prompt, |
|
num_images_per_prompt=num_images_per_prompt, |
|
device=device, |
|
dtype=controlnet.dtype, |
|
do_classifier_free_guidance=do_classifier_free_guidance, |
|
guess_mode=guess_mode, |
|
) |
|
|
|
control_images.append(control_image_) |
|
|
|
control_image = control_images |
|
else: |
|
raise ValueError(f"{controlnet.__class__} is not supported.") |
|
|
|
|
|
controlnet_keep = [] |
|
for i in range(len(timesteps)): |
|
keeps = [ |
|
1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e) |
|
for s, e in zip(control_guidance_start, control_guidance_end) |
|
] |
|
if isinstance(self.controlnet, MultiControlNetModel): |
|
controlnet_keep.append(keeps) |
|
else: |
|
controlnet_keep.append(keeps[0]) |
|
|
|
|
|
add_text_embeds = pooled_prompt_embeds |
|
if self.text_encoder_2 is None: |
|
text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1]) |
|
else: |
|
text_encoder_projection_dim = self.text_encoder_2.config.projection_dim |
|
|
|
add_time_ids, add_neg_time_ids = self._get_add_time_ids( |
|
original_size, |
|
crops_coords_top_left, |
|
target_size, |
|
aesthetic_score, |
|
negative_aesthetic_score, |
|
dtype=prompt_embeds.dtype, |
|
text_encoder_projection_dim=text_encoder_projection_dim, |
|
) |
|
add_time_ids = add_time_ids.repeat(batch_size * num_images_per_prompt, 1) |
|
|
|
if do_classifier_free_guidance: |
|
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) |
|
add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) |
|
add_neg_time_ids = add_neg_time_ids.repeat(batch_size * num_images_per_prompt, 1) |
|
add_time_ids = torch.cat([add_neg_time_ids, add_time_ids], dim=0) |
|
|
|
prompt_embeds = prompt_embeds.to(device) |
|
add_text_embeds = add_text_embeds.to(device) |
|
add_time_ids = add_time_ids.to(device) |
|
|
|
|
|
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) |
|
|
|
|
|
if ( |
|
denoising_end is not None |
|
and denoising_start is not None |
|
and denoising_value_valid(denoising_end) |
|
and denoising_value_valid(denoising_start) |
|
and denoising_start >= denoising_end |
|
): |
|
raise ValueError( |
|
f"`denoising_start`: {denoising_start} cannot be larger than or equal to `denoising_end`: " |
|
+ f" {denoising_end} when using type float." |
|
) |
|
elif denoising_end is not None and denoising_value_valid(denoising_end): |
|
discrete_timestep_cutoff = int( |
|
round( |
|
self.scheduler.config.num_train_timesteps |
|
- (denoising_end * self.scheduler.config.num_train_timesteps) |
|
) |
|
) |
|
num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps))) |
|
timesteps = timesteps[:num_inference_steps] |
|
|
|
with self.progress_bar(total=num_inference_steps) as progress_bar: |
|
for i, t in enumerate(timesteps): |
|
|
|
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents |
|
|
|
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) |
|
|
|
if num_channels_unet == 9: |
|
latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1) |
|
|
|
|
|
added_cond_kwargs = { |
|
"text_embeds": add_text_embeds, |
|
"time_ids": add_time_ids, |
|
} |
|
|
|
if i < int(num_inference_steps * cond_tau): |
|
down_block_additional_residuals = [state.clone() for state in adapter_state] |
|
else: |
|
down_block_additional_residuals = None |
|
|
|
|
|
|
|
|
|
latent_model_input_controlnet = torch.cat([latents] * 2) if do_classifier_free_guidance else latents |
|
|
|
|
|
latent_model_input_controlnet = self.scheduler.scale_model_input(latent_model_input_controlnet, t) |
|
|
|
|
|
if guess_mode and do_classifier_free_guidance: |
|
|
|
control_model_input = latents |
|
control_model_input = self.scheduler.scale_model_input(control_model_input, t) |
|
controlnet_prompt_embeds = prompt_embeds.chunk(2)[1] |
|
controlnet_added_cond_kwargs = { |
|
"text_embeds": add_text_embeds.chunk(2)[1], |
|
"time_ids": add_time_ids.chunk(2)[1], |
|
} |
|
else: |
|
control_model_input = latent_model_input_controlnet |
|
controlnet_prompt_embeds = prompt_embeds |
|
controlnet_added_cond_kwargs = added_cond_kwargs |
|
|
|
if isinstance(controlnet_keep[i], list): |
|
cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])] |
|
else: |
|
controlnet_cond_scale = controlnet_conditioning_scale |
|
if isinstance(controlnet_cond_scale, list): |
|
controlnet_cond_scale = controlnet_cond_scale[0] |
|
cond_scale = controlnet_cond_scale * controlnet_keep[i] |
|
down_block_res_samples, mid_block_res_sample = self.controlnet( |
|
control_model_input, |
|
t, |
|
encoder_hidden_states=controlnet_prompt_embeds, |
|
controlnet_cond=control_image, |
|
conditioning_scale=cond_scale, |
|
guess_mode=guess_mode, |
|
added_cond_kwargs=controlnet_added_cond_kwargs, |
|
return_dict=False, |
|
) |
|
|
|
noise_pred = self.unet( |
|
latent_model_input, |
|
t, |
|
encoder_hidden_states=prompt_embeds, |
|
cross_attention_kwargs=cross_attention_kwargs, |
|
added_cond_kwargs=added_cond_kwargs, |
|
return_dict=False, |
|
down_intrablock_additional_residuals=down_block_additional_residuals, |
|
down_block_additional_residuals=down_block_res_samples, |
|
mid_block_additional_residual=mid_block_res_sample, |
|
)[0] |
|
|
|
|
|
if do_classifier_free_guidance: |
|
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) |
|
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) |
|
|
|
if do_classifier_free_guidance and guidance_rescale > 0.0: |
|
|
|
noise_pred = rescale_noise_cfg( |
|
noise_pred, |
|
noise_pred_text, |
|
guidance_rescale=guidance_rescale, |
|
) |
|
|
|
|
|
latents = self.scheduler.step( |
|
noise_pred, |
|
t, |
|
latents, |
|
**extra_step_kwargs, |
|
return_dict=False, |
|
)[0] |
|
|
|
if num_channels_unet == 4: |
|
init_latents_proper = image_latents |
|
if do_classifier_free_guidance: |
|
init_mask, _ = mask.chunk(2) |
|
else: |
|
init_mask = mask |
|
|
|
if i < len(timesteps) - 1: |
|
noise_timestep = timesteps[i + 1] |
|
init_latents_proper = self.scheduler.add_noise( |
|
init_latents_proper, |
|
noise, |
|
torch.tensor([noise_timestep]), |
|
) |
|
|
|
latents = (1 - init_mask) * init_latents_proper + init_mask * latents |
|
|
|
|
|
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): |
|
progress_bar.update() |
|
if callback is not None and i % callback_steps == 0: |
|
callback(i, t, latents) |
|
|
|
|
|
if self.vae.dtype == torch.float16 and self.vae.config.force_upcast: |
|
self.upcast_vae() |
|
latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) |
|
|
|
if output_type != "latent": |
|
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] |
|
else: |
|
image = latents |
|
return StableDiffusionXLPipelineOutput(images=image) |
|
|
|
image = self.image_processor.postprocess(image, output_type=output_type) |
|
|
|
|
|
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: |
|
self.final_offload_hook.offload() |
|
|
|
if not return_dict: |
|
return (image,) |
|
|
|
return StableDiffusionXLPipelineOutput(images=image) |
|
|