multimodalart HF staff hyoungwoncho commited on
Commit
994e46a
·
verified ·
0 Parent(s):

Duplicate from hyoungwoncho/sd_perturbed_attention_guidance

Browse files

Co-authored-by: Hyoungwon Cho <hyoungwoncho@users.noreply.huggingface.co>

Files changed (4) hide show
  1. .gitattributes +35 -0
  2. README.md +75 -0
  3. pipeline.py +1485 -0
  4. sd_pag_demo.ipynb +0 -0
.gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ language:
3
+ - en
4
+ pipeline_tag: unconditional-image-generation
5
+ tags:
6
+ - Diffusion Models
7
+ - Stable Diffusion
8
+ - Perturbed-Attention Guidance
9
+ - PAG
10
+ ---
11
+
12
+ # Perturbed-Attention Guidance
13
+
14
+ ![image/jpeg](https://cdn-uploads.huggingface.co/production/uploads/6601282b569b30694e67b886/27Lmuol8anwd6L6BLzyWf.jpeg)
15
+
16
+ [Project](https://ku-cvlab.github.io/Perturbed-Attention-Guidance/) / [arXiv](https://arxiv.org/abs/2403.17377) / [GitHub](https://github.com/KU-CVLAB/Perturbed-Attention-Guidance)
17
+
18
+ This repository is based on [Diffusers](https://huggingface.co/docs/diffusers/index). The pipeline is a modification of StableDiffusionPipeline to support Perturbed-Attention Guidance (PAG).
19
+
20
+ ## Quickstart
21
+
22
+ Loading Custom Piepline:
23
+
24
+ ```
25
+ from diffusers import StableDiffusionPipeline
26
+
27
+ pipe = StableDiffusionPipeline.from_pretrained(
28
+ "runwayml/stable-diffusion-v1-5",
29
+ custom_pipeline="hyoungwoncho/sd_perturbed_attention_guidance",
30
+ torch_dtype=torch.float16
31
+ )
32
+
33
+ device="cuda"
34
+ pipe = pipe.to(device)
35
+ ```
36
+
37
+ Sampling with PAG:
38
+
39
+ ```
40
+ output = pipe(
41
+ prompts,
42
+ width=512,
43
+ height=512,
44
+ num_inference_steps=50,
45
+ guidance_scale=0.0,
46
+ pag_scale=5.0,
47
+ pag_applied_layers_index=['m0']
48
+ ).images
49
+ ```
50
+
51
+ Sampling with PAG and CFG:
52
+
53
+ ```
54
+ output = pipe(
55
+ prompts,
56
+ width=512,
57
+ height=512,
58
+ num_inference_steps=50,
59
+ guidance_scale=4.0,
60
+ pag_scale=3.0,
61
+ pag_applied_layers_index=['m0']
62
+ ).images
63
+ ```
64
+
65
+ ## Parameters
66
+
67
+ guidance_scale : gudiance scale of CFG (ex: 7.5)
68
+
69
+ pag_scale : gudiance scale of PAG (ex: 5.0)
70
+
71
+ pag_applied_layers_index : index of the layer to apply perturbation (ex: ['m0'])
72
+
73
+ ## Stable Diffusion Demo
74
+
75
+ To join a demo of PAG on Stable Diffusion, run [sd_pag_demo.ipynb](https://huggingface.co/hyoungwoncho/sd_perturbed_attention_guidance/blob/main/sd_pag_demo.ipynb).
pipeline.py ADDED
@@ -0,0 +1,1485 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Implementation of StableDiffusionPAGPipeline
2
+
3
+ import inspect
4
+ from typing import Any, Callable, Dict, List, Optional, Union
5
+
6
+ import torch
7
+ import torch.nn.functional as F
8
+ from packaging import version
9
+ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
10
+
11
+ from diffusers.configuration_utils import FrozenDict
12
+ from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
13
+ from diffusers.loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin
14
+ from diffusers.models import AutoencoderKL, ImageProjection, UNet2DConditionModel
15
+ from diffusers.models.attention_processor import FusedAttnProcessor2_0
16
+ from diffusers.models.lora import adjust_lora_scale_text_encoder
17
+ from diffusers.schedulers import KarrasDiffusionSchedulers
18
+ from diffusers.utils import (
19
+ USE_PEFT_BACKEND,
20
+ deprecate,
21
+ logging,
22
+ replace_example_docstring,
23
+ scale_lora_layers,
24
+ unscale_lora_layers,
25
+ )
26
+ from diffusers.utils.torch_utils import randn_tensor
27
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
28
+ from diffusers.pipelines.stable_diffusion.pipeline_output import StableDiffusionPipelineOutput
29
+ from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
30
+
31
+ from diffusers.models.attention_processor import Attention, AttnProcessor2_0
32
+
33
+
34
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
35
+
36
+ EXAMPLE_DOC_STRING = """
37
+ Examples:
38
+ ```py
39
+ >>> import torch
40
+ >>> from diffusers import StableDiffusionPipeline
41
+ >>> pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16)
42
+ >>> pipe = pipe.to("cuda")
43
+ >>> prompt = "a photo of an astronaut riding a horse on mars"
44
+ >>> image = pipe(prompt).images[0]
45
+ ```
46
+ """
47
+
48
+
49
+ class PAGIdentitySelfAttnProcessor:
50
+ r"""
51
+ Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
52
+ """
53
+
54
+ def __init__(self):
55
+ if not hasattr(F, "scaled_dot_product_attention"):
56
+ raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
57
+
58
+ def __call__(
59
+ self,
60
+ attn: Attention,
61
+ hidden_states: torch.FloatTensor,
62
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
63
+ attention_mask: Optional[torch.FloatTensor] = None,
64
+ temb: Optional[torch.FloatTensor] = None,
65
+ *args,
66
+ **kwargs,
67
+ ) -> torch.FloatTensor:
68
+ if len(args) > 0 or kwargs.get("scale", None) is not None:
69
+ deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
70
+ deprecate("scale", "1.0.0", deprecation_message)
71
+
72
+ residual = hidden_states
73
+ if attn.spatial_norm is not None:
74
+ hidden_states = attn.spatial_norm(hidden_states, temb)
75
+
76
+ input_ndim = hidden_states.ndim
77
+ if input_ndim == 4:
78
+ batch_size, channel, height, width = hidden_states.shape
79
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
80
+
81
+ # chunk
82
+ hidden_states_org, hidden_states_ptb = hidden_states.chunk(2)
83
+
84
+ # original path
85
+ batch_size, sequence_length, _ = hidden_states_org.shape
86
+
87
+ if attention_mask is not None:
88
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
89
+ # scaled_dot_product_attention expects attention_mask shape to be
90
+ # (batch, heads, source_length, target_length)
91
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
92
+
93
+ if attn.group_norm is not None:
94
+ hidden_states_org = attn.group_norm(hidden_states_org.transpose(1, 2)).transpose(1, 2)
95
+
96
+ query = attn.to_q(hidden_states_org)
97
+ key = attn.to_k(hidden_states_org)
98
+ value = attn.to_v(hidden_states_org)
99
+
100
+ inner_dim = key.shape[-1]
101
+ head_dim = inner_dim // attn.heads
102
+
103
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
104
+
105
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
106
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
107
+
108
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
109
+ # TODO: add support for attn.scale when we move to Torch 2.1
110
+ hidden_states_org = F.scaled_dot_product_attention(
111
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
112
+ )
113
+
114
+ hidden_states_org = hidden_states_org.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
115
+ hidden_states_org = hidden_states_org.to(query.dtype)
116
+
117
+ # linear proj
118
+ hidden_states_org = attn.to_out[0](hidden_states_org)
119
+ # dropout
120
+ hidden_states_org = attn.to_out[1](hidden_states_org)
121
+
122
+ if input_ndim == 4:
123
+ hidden_states_org = hidden_states_org.transpose(-1, -2).reshape(batch_size, channel, height, width)
124
+
125
+ # perturbed path (identity attention)
126
+ batch_size, sequence_length, _ = hidden_states_ptb.shape
127
+
128
+ if attention_mask is not None:
129
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
130
+ # scaled_dot_product_attention expects attention_mask shape to be
131
+ # (batch, heads, source_length, target_length)
132
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
133
+
134
+ if attn.group_norm is not None:
135
+ hidden_states_ptb = attn.group_norm(hidden_states_ptb.transpose(1, 2)).transpose(1, 2)
136
+
137
+ value = attn.to_v(hidden_states_ptb)
138
+
139
+ hidden_states_ptb = torch.zeros(value.shape).to(value.get_device())
140
+ #hidden_states_ptb = value
141
+
142
+ hidden_states_ptb = hidden_states_ptb.to(query.dtype)
143
+
144
+ # linear proj
145
+ hidden_states_ptb = attn.to_out[0](hidden_states_ptb)
146
+ # dropout
147
+ hidden_states_ptb = attn.to_out[1](hidden_states_ptb)
148
+
149
+ if input_ndim == 4:
150
+ hidden_states_ptb = hidden_states_ptb.transpose(-1, -2).reshape(batch_size, channel, height, width)
151
+
152
+ # cat
153
+ hidden_states = torch.cat([hidden_states_org, hidden_states_ptb])
154
+
155
+ if attn.residual_connection:
156
+ hidden_states = hidden_states + residual
157
+
158
+ hidden_states = hidden_states / attn.rescale_output_factor
159
+
160
+ return hidden_states
161
+
162
+
163
+ class PAGCFGIdentitySelfAttnProcessor:
164
+ r"""
165
+ Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
166
+ """
167
+
168
+ def __init__(self):
169
+ if not hasattr(F, "scaled_dot_product_attention"):
170
+ raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
171
+
172
+ def __call__(
173
+ self,
174
+ attn: Attention,
175
+ hidden_states: torch.FloatTensor,
176
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
177
+ attention_mask: Optional[torch.FloatTensor] = None,
178
+ temb: Optional[torch.FloatTensor] = None,
179
+ *args,
180
+ **kwargs,
181
+ ) -> torch.FloatTensor:
182
+ if len(args) > 0 or kwargs.get("scale", None) is not None:
183
+ deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
184
+ deprecate("scale", "1.0.0", deprecation_message)
185
+
186
+ residual = hidden_states
187
+ if attn.spatial_norm is not None:
188
+ hidden_states = attn.spatial_norm(hidden_states, temb)
189
+
190
+ input_ndim = hidden_states.ndim
191
+ if input_ndim == 4:
192
+ batch_size, channel, height, width = hidden_states.shape
193
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
194
+
195
+ # chunk
196
+ hidden_states_uncond, hidden_states_org, hidden_states_ptb = hidden_states.chunk(3)
197
+ hidden_states_org = torch.cat([hidden_states_uncond, hidden_states_org])
198
+
199
+ # original path
200
+ batch_size, sequence_length, _ = hidden_states_org.shape
201
+
202
+ if attention_mask is not None:
203
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
204
+ # scaled_dot_product_attention expects attention_mask shape to be
205
+ # (batch, heads, source_length, target_length)
206
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
207
+
208
+ if attn.group_norm is not None:
209
+ hidden_states_org = attn.group_norm(hidden_states_org.transpose(1, 2)).transpose(1, 2)
210
+
211
+ query = attn.to_q(hidden_states_org)
212
+ key = attn.to_k(hidden_states_org)
213
+ value = attn.to_v(hidden_states_org)
214
+
215
+ inner_dim = key.shape[-1]
216
+ head_dim = inner_dim // attn.heads
217
+
218
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
219
+
220
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
221
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
222
+
223
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
224
+ # TODO: add support for attn.scale when we move to Torch 2.1
225
+ hidden_states_org = F.scaled_dot_product_attention(
226
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
227
+ )
228
+
229
+ hidden_states_org = hidden_states_org.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
230
+ hidden_states_org = hidden_states_org.to(query.dtype)
231
+
232
+ # linear proj
233
+ hidden_states_org = attn.to_out[0](hidden_states_org)
234
+ # dropout
235
+ hidden_states_org = attn.to_out[1](hidden_states_org)
236
+
237
+ if input_ndim == 4:
238
+ hidden_states_org = hidden_states_org.transpose(-1, -2).reshape(batch_size, channel, height, width)
239
+
240
+ # perturbed path (identity attention)
241
+ batch_size, sequence_length, _ = hidden_states_ptb.shape
242
+
243
+ if attention_mask is not None:
244
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
245
+ # scaled_dot_product_attention expects attention_mask shape to be
246
+ # (batch, heads, source_length, target_length)
247
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
248
+
249
+ if attn.group_norm is not None:
250
+ hidden_states_ptb = attn.group_norm(hidden_states_ptb.transpose(1, 2)).transpose(1, 2)
251
+
252
+ value = attn.to_v(hidden_states_ptb)
253
+ hidden_states_ptb = value
254
+ hidden_states_ptb = hidden_states_ptb.to(query.dtype)
255
+
256
+ # linear proj
257
+ hidden_states_ptb = attn.to_out[0](hidden_states_ptb)
258
+ # dropout
259
+ hidden_states_ptb = attn.to_out[1](hidden_states_ptb)
260
+
261
+ if input_ndim == 4:
262
+ hidden_states_ptb = hidden_states_ptb.transpose(-1, -2).reshape(batch_size, channel, height, width)
263
+
264
+ # cat
265
+ hidden_states = torch.cat([hidden_states_org, hidden_states_ptb])
266
+
267
+ if attn.residual_connection:
268
+ hidden_states = hidden_states + residual
269
+
270
+ hidden_states = hidden_states / attn.rescale_output_factor
271
+
272
+ return hidden_states
273
+
274
+
275
+ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
276
+ """
277
+ Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
278
+ Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
279
+ """
280
+ std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
281
+ std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
282
+ # rescale the results from guidance (fixes overexposure)
283
+ noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
284
+ # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
285
+ noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
286
+ return noise_cfg
287
+
288
+
289
+ def retrieve_timesteps(
290
+ scheduler,
291
+ num_inference_steps: Optional[int] = None,
292
+ device: Optional[Union[str, torch.device]] = None,
293
+ timesteps: Optional[List[int]] = None,
294
+ **kwargs,
295
+ ):
296
+ """
297
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
298
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
299
+ Args:
300
+ scheduler (`SchedulerMixin`):
301
+ The scheduler to get timesteps from.
302
+ num_inference_steps (`int`):
303
+ The number of diffusion steps used when generating samples with a pre-trained model. If used,
304
+ `timesteps` must be `None`.
305
+ device (`str` or `torch.device`, *optional*):
306
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
307
+ timesteps (`List[int]`, *optional*):
308
+ Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default
309
+ timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps`
310
+ must be `None`.
311
+ Returns:
312
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
313
+ second element is the number of inference steps.
314
+ """
315
+ if timesteps is not None:
316
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
317
+ if not accepts_timesteps:
318
+ raise ValueError(
319
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
320
+ f" timestep schedules. Please check whether you are using the correct scheduler."
321
+ )
322
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
323
+ timesteps = scheduler.timesteps
324
+ num_inference_steps = len(timesteps)
325
+ else:
326
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
327
+ timesteps = scheduler.timesteps
328
+ return timesteps, num_inference_steps
329
+
330
+
331
+ class StableDiffusionPipeline(
332
+ DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, IPAdapterMixin, FromSingleFileMixin
333
+ ):
334
+ r"""
335
+ Pipeline for text-to-image generation using Stable Diffusion.
336
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
337
+ implemented for all pipelines (downloading, saving, running on a particular device, etc.).
338
+ The pipeline also inherits the following loading methods:
339
+ - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
340
+ - [`~loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights
341
+ - [`~loaders.LoraLoaderMixin.save_lora_weights`] for saving LoRA weights
342
+ - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
343
+ - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
344
+ Args:
345
+ vae ([`AutoencoderKL`]):
346
+ Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.
347
+ text_encoder ([`~transformers.CLIPTextModel`]):
348
+ Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).
349
+ tokenizer ([`~transformers.CLIPTokenizer`]):
350
+ A `CLIPTokenizer` to tokenize text.
351
+ unet ([`UNet2DConditionModel`]):
352
+ A `UNet2DConditionModel` to denoise the encoded image latents.
353
+ scheduler ([`SchedulerMixin`]):
354
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
355
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
356
+ safety_checker ([`StableDiffusionSafetyChecker`]):
357
+ Classification module that estimates whether generated images could be considered offensive or harmful.
358
+ Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details
359
+ about a model's potential harms.
360
+ feature_extractor ([`~transformers.CLIPImageProcessor`]):
361
+ A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
362
+ """
363
+
364
+ model_cpu_offload_seq = "text_encoder->image_encoder->unet->vae"
365
+ _optional_components = ["safety_checker", "feature_extractor", "image_encoder"]
366
+ _exclude_from_cpu_offload = ["safety_checker"]
367
+ _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
368
+
369
+ def __init__(
370
+ self,
371
+ vae: AutoencoderKL,
372
+ text_encoder: CLIPTextModel,
373
+ tokenizer: CLIPTokenizer,
374
+ unet: UNet2DConditionModel,
375
+ scheduler: KarrasDiffusionSchedulers,
376
+ safety_checker: StableDiffusionSafetyChecker,
377
+ feature_extractor: CLIPImageProcessor,
378
+ image_encoder: CLIPVisionModelWithProjection = None,
379
+ requires_safety_checker: bool = True,
380
+ ):
381
+ super().__init__()
382
+
383
+ if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
384
+ deprecation_message = (
385
+ f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
386
+ f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
387
+ "to update the config accordingly as leaving `steps_offset` might led to incorrect results"
388
+ " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
389
+ " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
390
+ " file"
391
+ )
392
+ deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
393
+ new_config = dict(scheduler.config)
394
+ new_config["steps_offset"] = 1
395
+ scheduler._internal_dict = FrozenDict(new_config)
396
+
397
+ if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
398
+ deprecation_message = (
399
+ f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
400
+ " `clip_sample` should be set to False in the configuration file. Please make sure to update the"
401
+ " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
402
+ " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
403
+ " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
404
+ )
405
+ deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False)
406
+ new_config = dict(scheduler.config)
407
+ new_config["clip_sample"] = False
408
+ scheduler._internal_dict = FrozenDict(new_config)
409
+
410
+ if safety_checker is None and requires_safety_checker:
411
+ logger.warning(
412
+ f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
413
+ " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
414
+ " results in services or applications open to the public. Both the diffusers team and Hugging Face"
415
+ " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
416
+ " it only for use-cases that involve analyzing network behavior or auditing its results. For more"
417
+ " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
418
+ )
419
+
420
+ if safety_checker is not None and feature_extractor is None:
421
+ raise ValueError(
422
+ "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
423
+ " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
424
+ )
425
+
426
+ is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
427
+ version.parse(unet.config._diffusers_version).base_version
428
+ ) < version.parse("0.9.0.dev0")
429
+ is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
430
+ if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
431
+ deprecation_message = (
432
+ "The configuration file of the unet has set the default `sample_size` to smaller than"
433
+ " 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the"
434
+ " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
435
+ " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
436
+ " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
437
+ " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
438
+ " in the config might lead to incorrect results in future versions. If you have downloaded this"
439
+ " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
440
+ " the `unet/config.json` file"
441
+ )
442
+ deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False)
443
+ new_config = dict(unet.config)
444
+ new_config["sample_size"] = 64
445
+ unet._internal_dict = FrozenDict(new_config)
446
+
447
+ self.register_modules(
448
+ vae=vae,
449
+ text_encoder=text_encoder,
450
+ tokenizer=tokenizer,
451
+ unet=unet,
452
+ scheduler=scheduler,
453
+ safety_checker=safety_checker,
454
+ feature_extractor=feature_extractor,
455
+ image_encoder=image_encoder,
456
+ )
457
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
458
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
459
+ self.register_to_config(requires_safety_checker=requires_safety_checker)
460
+
461
+ def enable_vae_slicing(self):
462
+ r"""
463
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
464
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
465
+ """
466
+ self.vae.enable_slicing()
467
+
468
+ def disable_vae_slicing(self):
469
+ r"""
470
+ Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
471
+ computing decoding in one step.
472
+ """
473
+ self.vae.disable_slicing()
474
+
475
+ def enable_vae_tiling(self):
476
+ r"""
477
+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
478
+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
479
+ processing larger images.
480
+ """
481
+ self.vae.enable_tiling()
482
+
483
+ def disable_vae_tiling(self):
484
+ r"""
485
+ Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
486
+ computing decoding in one step.
487
+ """
488
+ self.vae.disable_tiling()
489
+
490
+ def _encode_prompt(
491
+ self,
492
+ prompt,
493
+ device,
494
+ num_images_per_prompt,
495
+ do_classifier_free_guidance,
496
+ negative_prompt=None,
497
+ prompt_embeds: Optional[torch.FloatTensor] = None,
498
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
499
+ lora_scale: Optional[float] = None,
500
+ **kwargs,
501
+ ):
502
+ deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple."
503
+ deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False)
504
+
505
+ prompt_embeds_tuple = self.encode_prompt(
506
+ prompt=prompt,
507
+ device=device,
508
+ num_images_per_prompt=num_images_per_prompt,
509
+ do_classifier_free_guidance=do_classifier_free_guidance,
510
+ negative_prompt=negative_prompt,
511
+ prompt_embeds=prompt_embeds,
512
+ negative_prompt_embeds=negative_prompt_embeds,
513
+ lora_scale=lora_scale,
514
+ **kwargs,
515
+ )
516
+
517
+ # concatenate for backwards comp
518
+ prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]])
519
+
520
+ return prompt_embeds
521
+
522
+ def encode_prompt(
523
+ self,
524
+ prompt,
525
+ device,
526
+ num_images_per_prompt,
527
+ do_classifier_free_guidance,
528
+ negative_prompt=None,
529
+ prompt_embeds: Optional[torch.FloatTensor] = None,
530
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
531
+ lora_scale: Optional[float] = None,
532
+ clip_skip: Optional[int] = None,
533
+ ):
534
+ r"""
535
+ Encodes the prompt into text encoder hidden states.
536
+ Args:
537
+ prompt (`str` or `List[str]`, *optional*):
538
+ prompt to be encoded
539
+ device: (`torch.device`):
540
+ torch device
541
+ num_images_per_prompt (`int`):
542
+ number of images that should be generated per prompt
543
+ do_classifier_free_guidance (`bool`):
544
+ whether to use classifier free guidance or not
545
+ negative_prompt (`str` or `List[str]`, *optional*):
546
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
547
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
548
+ less than `1`).
549
+ prompt_embeds (`torch.FloatTensor`, *optional*):
550
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
551
+ provided, text embeddings will be generated from `prompt` input argument.
552
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
553
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
554
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
555
+ argument.
556
+ lora_scale (`float`, *optional*):
557
+ A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
558
+ clip_skip (`int`, *optional*):
559
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
560
+ the output of the pre-final layer will be used for computing the prompt embeddings.
561
+ """
562
+ # set lora scale so that monkey patched LoRA
563
+ # function of text encoder can correctly access it
564
+ if lora_scale is not None and isinstance(self, LoraLoaderMixin):
565
+ self._lora_scale = lora_scale
566
+
567
+ # dynamically adjust the LoRA scale
568
+ if not USE_PEFT_BACKEND:
569
+ adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
570
+ else:
571
+ scale_lora_layers(self.text_encoder, lora_scale)
572
+
573
+ if prompt is not None and isinstance(prompt, str):
574
+ batch_size = 1
575
+ elif prompt is not None and isinstance(prompt, list):
576
+ batch_size = len(prompt)
577
+ else:
578
+ batch_size = prompt_embeds.shape[0]
579
+
580
+ if prompt_embeds is None:
581
+ # textual inversion: process multi-vector tokens if necessary
582
+ if isinstance(self, TextualInversionLoaderMixin):
583
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
584
+
585
+ text_inputs = self.tokenizer(
586
+ prompt,
587
+ padding="max_length",
588
+ max_length=self.tokenizer.model_max_length,
589
+ truncation=True,
590
+ return_tensors="pt",
591
+ )
592
+ text_input_ids = text_inputs.input_ids
593
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
594
+
595
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
596
+ text_input_ids, untruncated_ids
597
+ ):
598
+ removed_text = self.tokenizer.batch_decode(
599
+ untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
600
+ )
601
+ logger.warning(
602
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
603
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
604
+ )
605
+
606
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
607
+ attention_mask = text_inputs.attention_mask.to(device)
608
+ else:
609
+ attention_mask = None
610
+
611
+ if clip_skip is None:
612
+ prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask)
613
+ prompt_embeds = prompt_embeds[0]
614
+ else:
615
+ prompt_embeds = self.text_encoder(
616
+ text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True
617
+ )
618
+ # Access the `hidden_states` first, that contains a tuple of
619
+ # all the hidden states from the encoder layers. Then index into
620
+ # the tuple to access the hidden states from the desired layer.
621
+ prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)]
622
+ # We also need to apply the final LayerNorm here to not mess with the
623
+ # representations. The `last_hidden_states` that we typically use for
624
+ # obtaining the final prompt representations passes through the LayerNorm
625
+ # layer.
626
+ prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds)
627
+
628
+ if self.text_encoder is not None:
629
+ prompt_embeds_dtype = self.text_encoder.dtype
630
+ elif self.unet is not None:
631
+ prompt_embeds_dtype = self.unet.dtype
632
+ else:
633
+ prompt_embeds_dtype = prompt_embeds.dtype
634
+
635
+ prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
636
+
637
+ bs_embed, seq_len, _ = prompt_embeds.shape
638
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
639
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
640
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
641
+
642
+ # get unconditional embeddings for classifier free guidance
643
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
644
+ uncond_tokens: List[str]
645
+ if negative_prompt is None:
646
+ uncond_tokens = [""] * batch_size
647
+ elif prompt is not None and type(prompt) is not type(negative_prompt):
648
+ raise TypeError(
649
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
650
+ f" {type(prompt)}."
651
+ )
652
+ elif isinstance(negative_prompt, str):
653
+ uncond_tokens = [negative_prompt]
654
+ elif batch_size != len(negative_prompt):
655
+ raise ValueError(
656
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
657
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
658
+ " the batch size of `prompt`."
659
+ )
660
+ else:
661
+ uncond_tokens = negative_prompt
662
+
663
+ # textual inversion: process multi-vector tokens if necessary
664
+ if isinstance(self, TextualInversionLoaderMixin):
665
+ uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
666
+
667
+ max_length = prompt_embeds.shape[1]
668
+ uncond_input = self.tokenizer(
669
+ uncond_tokens,
670
+ padding="max_length",
671
+ max_length=max_length,
672
+ truncation=True,
673
+ return_tensors="pt",
674
+ )
675
+
676
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
677
+ attention_mask = uncond_input.attention_mask.to(device)
678
+ else:
679
+ attention_mask = None
680
+
681
+ negative_prompt_embeds = self.text_encoder(
682
+ uncond_input.input_ids.to(device),
683
+ attention_mask=attention_mask,
684
+ )
685
+ negative_prompt_embeds = negative_prompt_embeds[0]
686
+
687
+ if do_classifier_free_guidance:
688
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
689
+ seq_len = negative_prompt_embeds.shape[1]
690
+
691
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
692
+
693
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
694
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
695
+
696
+ if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
697
+ # Retrieve the original scale by scaling back the LoRA layers
698
+ unscale_lora_layers(self.text_encoder, lora_scale)
699
+
700
+ return prompt_embeds, negative_prompt_embeds
701
+
702
+ def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None):
703
+ dtype = next(self.image_encoder.parameters()).dtype
704
+
705
+ if not isinstance(image, torch.Tensor):
706
+ image = self.feature_extractor(image, return_tensors="pt").pixel_values
707
+
708
+ image = image.to(device=device, dtype=dtype)
709
+ if output_hidden_states:
710
+ image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
711
+ image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
712
+ uncond_image_enc_hidden_states = self.image_encoder(
713
+ torch.zeros_like(image), output_hidden_states=True
714
+ ).hidden_states[-2]
715
+ uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(
716
+ num_images_per_prompt, dim=0
717
+ )
718
+ return image_enc_hidden_states, uncond_image_enc_hidden_states
719
+ else:
720
+ image_embeds = self.image_encoder(image).image_embeds
721
+ image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
722
+ uncond_image_embeds = torch.zeros_like(image_embeds)
723
+
724
+ return image_embeds, uncond_image_embeds
725
+
726
+ def prepare_ip_adapter_image_embeds(
727
+ self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt
728
+ ):
729
+ if ip_adapter_image_embeds is None:
730
+ if not isinstance(ip_adapter_image, list):
731
+ ip_adapter_image = [ip_adapter_image]
732
+
733
+ if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers):
734
+ raise ValueError(
735
+ f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
736
+ )
737
+
738
+ image_embeds = []
739
+ for single_ip_adapter_image, image_proj_layer in zip(
740
+ ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers
741
+ ):
742
+ output_hidden_state = not isinstance(image_proj_layer, ImageProjection)
743
+ single_image_embeds, single_negative_image_embeds = self.encode_image(
744
+ single_ip_adapter_image, device, 1, output_hidden_state
745
+ )
746
+ single_image_embeds = torch.stack([single_image_embeds] * num_images_per_prompt, dim=0)
747
+ single_negative_image_embeds = torch.stack(
748
+ [single_negative_image_embeds] * num_images_per_prompt, dim=0
749
+ )
750
+
751
+ if self.do_classifier_free_guidance:
752
+ single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
753
+ single_image_embeds = single_image_embeds.to(device)
754
+
755
+ image_embeds.append(single_image_embeds)
756
+ else:
757
+ image_embeds = ip_adapter_image_embeds
758
+ return image_embeds
759
+
760
+ def run_safety_checker(self, image, device, dtype):
761
+ if self.safety_checker is None:
762
+ has_nsfw_concept = None
763
+ else:
764
+ if torch.is_tensor(image):
765
+ feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
766
+ else:
767
+ feature_extractor_input = self.image_processor.numpy_to_pil(image)
768
+ safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
769
+ image, has_nsfw_concept = self.safety_checker(
770
+ images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
771
+ )
772
+ return image, has_nsfw_concept
773
+
774
+ def decode_latents(self, latents):
775
+ deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead"
776
+ deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False)
777
+
778
+ latents = 1 / self.vae.config.scaling_factor * latents
779
+ image = self.vae.decode(latents, return_dict=False)[0]
780
+ image = (image / 2 + 0.5).clamp(0, 1)
781
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
782
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
783
+ return image
784
+
785
+ def prepare_extra_step_kwargs(self, generator, eta):
786
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
787
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
788
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
789
+ # and should be between [0, 1]
790
+
791
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
792
+ extra_step_kwargs = {}
793
+ if accepts_eta:
794
+ extra_step_kwargs["eta"] = eta
795
+
796
+ # check if the scheduler accepts generator
797
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
798
+ if accepts_generator:
799
+ extra_step_kwargs["generator"] = generator
800
+ return extra_step_kwargs
801
+
802
+ def check_inputs(
803
+ self,
804
+ prompt,
805
+ height,
806
+ width,
807
+ callback_steps,
808
+ negative_prompt=None,
809
+ prompt_embeds=None,
810
+ negative_prompt_embeds=None,
811
+ ip_adapter_image=None,
812
+ ip_adapter_image_embeds=None,
813
+ callback_on_step_end_tensor_inputs=None,
814
+ ):
815
+ if height % 8 != 0 or width % 8 != 0:
816
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
817
+
818
+ if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):
819
+ raise ValueError(
820
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
821
+ f" {type(callback_steps)}."
822
+ )
823
+ if callback_on_step_end_tensor_inputs is not None and not all(
824
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
825
+ ):
826
+ raise ValueError(
827
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
828
+ )
829
+
830
+ if prompt is not None and prompt_embeds is not None:
831
+ raise ValueError(
832
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
833
+ " only forward one of the two."
834
+ )
835
+ elif prompt is None and prompt_embeds is None:
836
+ raise ValueError(
837
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
838
+ )
839
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
840
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
841
+
842
+ if negative_prompt is not None and negative_prompt_embeds is not None:
843
+ raise ValueError(
844
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
845
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
846
+ )
847
+
848
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
849
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
850
+ raise ValueError(
851
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
852
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
853
+ f" {negative_prompt_embeds.shape}."
854
+ )
855
+
856
+ if ip_adapter_image is not None and ip_adapter_image_embeds is not None:
857
+ raise ValueError(
858
+ "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined."
859
+ )
860
+
861
+ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
862
+ shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
863
+ if isinstance(generator, list) and len(generator) != batch_size:
864
+ raise ValueError(
865
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
866
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
867
+ )
868
+
869
+ if latents is None:
870
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
871
+ else:
872
+ latents = latents.to(device)
873
+
874
+ # scale the initial noise by the standard deviation required by the scheduler
875
+ latents = latents * self.scheduler.init_noise_sigma
876
+ return latents
877
+
878
+ def enable_freeu(self, s1: float, s2: float, b1: float, b2: float):
879
+ r"""Enables the FreeU mechanism as in https://arxiv.org/abs/2309.11497.
880
+ The suffixes after the scaling factors represent the stages where they are being applied.
881
+ Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of the values
882
+ that are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL.
883
+ Args:
884
+ s1 (`float`):
885
+ Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to
886
+ mitigate "oversmoothing effect" in the enhanced denoising process.
887
+ s2 (`float`):
888
+ Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to
889
+ mitigate "oversmoothing effect" in the enhanced denoising process.
890
+ b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features.
891
+ b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features.
892
+ """
893
+ if not hasattr(self, "unet"):
894
+ raise ValueError("The pipeline must have `unet` for using FreeU.")
895
+ self.unet.enable_freeu(s1=s1, s2=s2, b1=b1, b2=b2)
896
+
897
+ def disable_freeu(self):
898
+ """Disables the FreeU mechanism if enabled."""
899
+ self.unet.disable_freeu()
900
+
901
+ # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.fuse_qkv_projections
902
+ def fuse_qkv_projections(self, unet: bool = True, vae: bool = True):
903
+ """
904
+ Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query,
905
+ key, value) are fused. For cross-attention modules, key and value projection matrices are fused.
906
+ <Tip warning={true}>
907
+ This API is 🧪 experimental.
908
+ </Tip>
909
+ Args:
910
+ unet (`bool`, defaults to `True`): To apply fusion on the UNet.
911
+ vae (`bool`, defaults to `True`): To apply fusion on the VAE.
912
+ """
913
+ self.fusing_unet = False
914
+ self.fusing_vae = False
915
+
916
+ if unet:
917
+ self.fusing_unet = True
918
+ self.unet.fuse_qkv_projections()
919
+ self.unet.set_attn_processor(FusedAttnProcessor2_0())
920
+
921
+ if vae:
922
+ if not isinstance(self.vae, AutoencoderKL):
923
+ raise ValueError("`fuse_qkv_projections()` is only supported for the VAE of type `AutoencoderKL`.")
924
+
925
+ self.fusing_vae = True
926
+ self.vae.fuse_qkv_projections()
927
+ self.vae.set_attn_processor(FusedAttnProcessor2_0())
928
+
929
+ # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.unfuse_qkv_projections
930
+ def unfuse_qkv_projections(self, unet: bool = True, vae: bool = True):
931
+ """Disable QKV projection fusion if enabled.
932
+ <Tip warning={true}>
933
+ This API is 🧪 experimental.
934
+ </Tip>
935
+ Args:
936
+ unet (`bool`, defaults to `True`): To apply fusion on the UNet.
937
+ vae (`bool`, defaults to `True`): To apply fusion on the VAE.
938
+ """
939
+ if unet:
940
+ if not self.fusing_unet:
941
+ logger.warning("The UNet was not initially fused for QKV projections. Doing nothing.")
942
+ else:
943
+ self.unet.unfuse_qkv_projections()
944
+ self.fusing_unet = False
945
+
946
+ if vae:
947
+ if not self.fusing_vae:
948
+ logger.warning("The VAE was not initially fused for QKV projections. Doing nothing.")
949
+ else:
950
+ self.vae.unfuse_qkv_projections()
951
+ self.fusing_vae = False
952
+
953
+ # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding
954
+ def get_guidance_scale_embedding(self, w, embedding_dim=512, dtype=torch.float32):
955
+ """
956
+ See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298
957
+ Args:
958
+ timesteps (`torch.Tensor`):
959
+ generate embedding vectors at these timesteps
960
+ embedding_dim (`int`, *optional*, defaults to 512):
961
+ dimension of the embeddings to generate
962
+ dtype:
963
+ data type of the generated embeddings
964
+ Returns:
965
+ `torch.FloatTensor`: Embedding vectors with shape `(len(timesteps), embedding_dim)`
966
+ """
967
+ assert len(w.shape) == 1
968
+ w = w * 1000.0
969
+
970
+ half_dim = embedding_dim // 2
971
+ emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
972
+ emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)
973
+ emb = w.to(dtype)[:, None] * emb[None, :]
974
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
975
+ if embedding_dim % 2 == 1: # zero pad
976
+ emb = torch.nn.functional.pad(emb, (0, 1))
977
+ assert emb.shape == (w.shape[0], embedding_dim)
978
+ return emb
979
+
980
+ def pred_z0(self, sample, model_output, timestep):
981
+ alpha_prod_t = self.scheduler.alphas_cumprod[timestep].to(sample.device)
982
+
983
+ beta_prod_t = 1 - alpha_prod_t
984
+ if self.scheduler.config.prediction_type == "epsilon":
985
+ pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
986
+ elif self.scheduler.config.prediction_type == "sample":
987
+ pred_original_sample = model_output
988
+ elif self.scheduler.config.prediction_type == "v_prediction":
989
+ pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output
990
+ # predict V
991
+ model_output = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample
992
+ else:
993
+ raise ValueError(
994
+ f"prediction_type given as {self.scheduler.config.prediction_type} must be one of `epsilon`, `sample`,"
995
+ " or `v_prediction`"
996
+ )
997
+
998
+ return pred_original_sample
999
+
1000
+ def pred_x0(self, latents, noise_pred, t, generator, device, prompt_embeds, output_type):
1001
+
1002
+ pred_z0 = self.pred_z0(latents, noise_pred, t)
1003
+ pred_x0 = self.vae.decode(
1004
+ pred_z0 / self.vae.config.scaling_factor,
1005
+ return_dict=False,
1006
+ generator=generator
1007
+ )[0]
1008
+ pred_x0, ____ = self.run_safety_checker(pred_x0, device, prompt_embeds.dtype)
1009
+ do_denormalize = [True] * pred_x0.shape[0]
1010
+ pred_x0 = self.image_processor.postprocess(pred_x0, output_type=output_type, do_denormalize=do_denormalize)
1011
+
1012
+ return pred_x0
1013
+
1014
+ @property
1015
+ def guidance_scale(self):
1016
+ return self._guidance_scale
1017
+
1018
+ @property
1019
+ def guidance_rescale(self):
1020
+ return self._guidance_rescale
1021
+
1022
+ @property
1023
+ def clip_skip(self):
1024
+ return self._clip_skip
1025
+
1026
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
1027
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
1028
+ # corresponds to doing no classifier free guidance.
1029
+ @property
1030
+ def do_classifier_free_guidance(self):
1031
+ return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None
1032
+
1033
+ @property
1034
+ def cross_attention_kwargs(self):
1035
+ return self._cross_attention_kwargs
1036
+
1037
+ @property
1038
+ def num_timesteps(self):
1039
+ return self._num_timesteps
1040
+
1041
+ @property
1042
+ def interrupt(self):
1043
+ return self._interrupt
1044
+
1045
+ @property
1046
+ def pag_scale(self):
1047
+ return self._pag_scale
1048
+
1049
+ @property
1050
+ def do_adversarial_guidance(self):
1051
+ return self._pag_scale > 0
1052
+
1053
+ @property
1054
+ def pag_adaptive_scaling(self):
1055
+ return self._pag_adaptive_scaling
1056
+
1057
+ @property
1058
+ def do_pag_adaptive_scaling(self):
1059
+ return self._pag_adaptive_scaling > 0
1060
+
1061
+ @property
1062
+ def pag_drop_rate(self):
1063
+ return self._pag_drop_rate
1064
+
1065
+ @property
1066
+ def pag_applied_layers(self):
1067
+ return self._pag_applied_layers
1068
+
1069
+ @property
1070
+ def pag_applied_layers_index(self):
1071
+ return self._pag_applied_layers_index
1072
+
1073
+
1074
+ @torch.no_grad()
1075
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
1076
+ def __call__(
1077
+ self,
1078
+ prompt: Union[str, List[str]] = None,
1079
+ height: Optional[int] = None,
1080
+ width: Optional[int] = None,
1081
+ num_inference_steps: int = 50,
1082
+ timesteps: List[int] = None,
1083
+ guidance_scale: float = 7.5,
1084
+ pag_scale: float = 0.0,
1085
+ pag_adaptive_scaling: float = 0.0,
1086
+ pag_drop_rate: float = 0.5,
1087
+ pag_applied_layers: List[str] = ['down'], #['down', 'mid', 'up']
1088
+ pag_applied_layers_index: List[str] = ['d4'], #['d4', 'd5', 'm0']
1089
+ negative_prompt: Optional[Union[str, List[str]]] = None,
1090
+ num_images_per_prompt: Optional[int] = 1,
1091
+ eta: float = 0.0,
1092
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
1093
+ latents: Optional[torch.FloatTensor] = None,
1094
+ prompt_embeds: Optional[torch.FloatTensor] = None,
1095
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
1096
+ ip_adapter_image: Optional[PipelineImageInput] = None,
1097
+ ip_adapter_image_embeds: Optional[List[torch.FloatTensor]] = None,
1098
+ output_type: Optional[str] = "pil",
1099
+ return_dict: bool = True,
1100
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
1101
+ guidance_rescale: float = 0.0,
1102
+ clip_skip: Optional[int] = None,
1103
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
1104
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
1105
+ **kwargs,
1106
+ ):
1107
+ r"""
1108
+ The call function to the pipeline for generation.
1109
+ Args:
1110
+ prompt (`str` or `List[str]`, *optional*):
1111
+ The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
1112
+ height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
1113
+ The height in pixels of the generated image.
1114
+ width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
1115
+ The width in pixels of the generated image.
1116
+ num_inference_steps (`int`, *optional*, defaults to 50):
1117
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
1118
+ expense of slower inference.
1119
+ timesteps (`List[int]`, *optional*):
1120
+ Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
1121
+ in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
1122
+ passed will be used. Must be in descending order.
1123
+ guidance_scale (`float`, *optional*, defaults to 7.5):
1124
+ A higher guidance scale value encourages the model to generate images closely linked to the text
1125
+ `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
1126
+ negative_prompt (`str` or `List[str]`, *optional*):
1127
+ The prompt or prompts to guide what to not include in image generation. If not defined, you need to
1128
+ pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
1129
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
1130
+ The number of images to generate per prompt.
1131
+ eta (`float`, *optional*, defaults to 0.0):
1132
+ Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
1133
+ to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
1134
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
1135
+ A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
1136
+ generation deterministic.
1137
+ latents (`torch.FloatTensor`, *optional*):
1138
+ Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
1139
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
1140
+ tensor is generated by sampling using the supplied random `generator`.
1141
+ prompt_embeds (`torch.FloatTensor`, *optional*):
1142
+ Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
1143
+ provided, text embeddings are generated from the `prompt` input argument.
1144
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
1145
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
1146
+ not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
1147
+ ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
1148
+ ip_adapter_image_embeds (`List[torch.FloatTensor]`, *optional*):
1149
+ Pre-generated image embeddings for IP-Adapter. If not
1150
+ provided, embeddings are computed from the `ip_adapter_image` input argument.
1151
+ output_type (`str`, *optional*, defaults to `"pil"`):
1152
+ The output format of the generated image. Choose between `PIL.Image` or `np.array`.
1153
+ return_dict (`bool`, *optional*, defaults to `True`):
1154
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
1155
+ plain tuple.
1156
+ cross_attention_kwargs (`dict`, *optional*):
1157
+ A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
1158
+ [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
1159
+ guidance_rescale (`float`, *optional*, defaults to 0.0):
1160
+ Guidance rescale factor from [Common Diffusion Noise Schedules and Sample Steps are
1161
+ Flawed](https://arxiv.org/pdf/2305.08891.pdf). Guidance rescale factor should fix overexposure when
1162
+ using zero terminal SNR.
1163
+ clip_skip (`int`, *optional*):
1164
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
1165
+ the output of the pre-final layer will be used for computing the prompt embeddings.
1166
+ callback_on_step_end (`Callable`, *optional*):
1167
+ A function that calls at the end of each denoising steps during the inference. The function is called
1168
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
1169
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
1170
+ `callback_on_step_end_tensor_inputs`.
1171
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
1172
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
1173
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
1174
+ `._callback_tensor_inputs` attribute of your pipeline class.
1175
+ Examples:
1176
+ Returns:
1177
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
1178
+ If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
1179
+ otherwise a `tuple` is returned where the first element is a list with the generated images and the
1180
+ second element is a list of `bool`s indicating whether the corresponding generated image contains
1181
+ "not-safe-for-work" (nsfw) content.
1182
+ """
1183
+
1184
+ callback = kwargs.pop("callback", None)
1185
+ callback_steps = kwargs.pop("callback_steps", None)
1186
+
1187
+ if callback is not None:
1188
+ deprecate(
1189
+ "callback",
1190
+ "1.0.0",
1191
+ "Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
1192
+ )
1193
+ if callback_steps is not None:
1194
+ deprecate(
1195
+ "callback_steps",
1196
+ "1.0.0",
1197
+ "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
1198
+ )
1199
+
1200
+ # 0. Default height and width to unet
1201
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
1202
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
1203
+ # to deal with lora scaling and other possible forward hooks
1204
+
1205
+ # 1. Check inputs. Raise error if not correct
1206
+ self.check_inputs(
1207
+ prompt,
1208
+ height,
1209
+ width,
1210
+ callback_steps,
1211
+ negative_prompt,
1212
+ prompt_embeds,
1213
+ negative_prompt_embeds,
1214
+ ip_adapter_image,
1215
+ ip_adapter_image_embeds,
1216
+ callback_on_step_end_tensor_inputs,
1217
+ )
1218
+
1219
+ self._guidance_scale = guidance_scale
1220
+ self._guidance_rescale = guidance_rescale
1221
+ self._clip_skip = clip_skip
1222
+ self._cross_attention_kwargs = cross_attention_kwargs
1223
+ self._interrupt = False
1224
+
1225
+ self._pag_scale = pag_scale
1226
+ self._pag_adaptive_scaling = pag_adaptive_scaling
1227
+ self._pag_drop_rate = pag_drop_rate
1228
+ self._pag_applied_layers = pag_applied_layers
1229
+ self._pag_applied_layers_index = pag_applied_layers_index
1230
+
1231
+ # 2. Define call parameters
1232
+ if prompt is not None and isinstance(prompt, str):
1233
+ batch_size = 1
1234
+ elif prompt is not None and isinstance(prompt, list):
1235
+ batch_size = len(prompt)
1236
+ else:
1237
+ batch_size = prompt_embeds.shape[0]
1238
+
1239
+ device = self._execution_device
1240
+
1241
+ # 3. Encode input prompt
1242
+ lora_scale = (
1243
+ self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
1244
+ )
1245
+
1246
+ prompt_embeds, negative_prompt_embeds = self.encode_prompt(
1247
+ prompt,
1248
+ device,
1249
+ num_images_per_prompt,
1250
+ self.do_classifier_free_guidance,
1251
+ negative_prompt,
1252
+ prompt_embeds=prompt_embeds,
1253
+ negative_prompt_embeds=negative_prompt_embeds,
1254
+ lora_scale=lora_scale,
1255
+ clip_skip=self.clip_skip,
1256
+ )
1257
+
1258
+ # For classifier free guidance, we need to do two forward passes.
1259
+ # Here we concatenate the unconditional and text embeddings into a single batch
1260
+ # to avoid doing two forward passes
1261
+
1262
+ #cfg
1263
+ if self.do_classifier_free_guidance and not self.do_adversarial_guidance:
1264
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
1265
+ #pag
1266
+ elif not self.do_classifier_free_guidance and self.do_adversarial_guidance:
1267
+ prompt_embeds = torch.cat([prompt_embeds, prompt_embeds])
1268
+ #both
1269
+ elif self.do_classifier_free_guidance and self.do_adversarial_guidance:
1270
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds, prompt_embeds])
1271
+
1272
+ if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
1273
+ image_embeds = self.prepare_ip_adapter_image_embeds(
1274
+ ip_adapter_image, ip_adapter_image_embeds, device, batch_size * num_images_per_prompt
1275
+ )
1276
+
1277
+ # 4. Prepare timesteps
1278
+ timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
1279
+
1280
+ # 5. Prepare latent variables
1281
+ num_channels_latents = self.unet.config.in_channels
1282
+ latents = self.prepare_latents(
1283
+ batch_size * num_images_per_prompt,
1284
+ num_channels_latents,
1285
+ height,
1286
+ width,
1287
+ prompt_embeds.dtype,
1288
+ device,
1289
+ generator,
1290
+ latents,
1291
+ )
1292
+
1293
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
1294
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
1295
+
1296
+ # 6.1 Add image embeds for IP-Adapter
1297
+ added_cond_kwargs = (
1298
+ {"image_embeds": image_embeds}
1299
+ if (ip_adapter_image is not None or ip_adapter_image_embeds is not None)
1300
+ else None
1301
+ )
1302
+
1303
+ # 6.2 Optionally get Guidance Scale Embedding
1304
+ timestep_cond = None
1305
+ if self.unet.config.time_cond_proj_dim is not None:
1306
+ guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt)
1307
+ timestep_cond = self.get_guidance_scale_embedding(
1308
+ guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
1309
+ ).to(device=device, dtype=latents.dtype)
1310
+
1311
+ # 7. Denoising loop
1312
+ if self.do_adversarial_guidance:
1313
+ down_layers = []
1314
+ mid_layers = []
1315
+ up_layers = []
1316
+ for name, module in self.unet.named_modules():
1317
+ if 'attn1' in name and 'to' not in name:
1318
+ layer_type = name.split('.')[0].split('_')[0]
1319
+ if layer_type == 'down':
1320
+ down_layers.append(module)
1321
+ elif layer_type == 'mid':
1322
+ mid_layers.append(module)
1323
+ elif layer_type == 'up':
1324
+ up_layers.append(module)
1325
+ else:
1326
+ raise ValueError(f"Invalid layer type: {layer_type}")
1327
+
1328
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
1329
+ self._num_timesteps = len(timesteps)
1330
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
1331
+ for i, t in enumerate(timesteps):
1332
+ if self.interrupt:
1333
+ continue
1334
+
1335
+ #cfg
1336
+ if self.do_classifier_free_guidance and not self.do_adversarial_guidance:
1337
+ latent_model_input = torch.cat([latents] * 2)
1338
+ #pag
1339
+ elif not self.do_classifier_free_guidance and self.do_adversarial_guidance:
1340
+ latent_model_input = torch.cat([latents] * 2)
1341
+ #both
1342
+ elif self.do_classifier_free_guidance and self.do_adversarial_guidance:
1343
+ latent_model_input = torch.cat([latents] * 3)
1344
+ #no
1345
+ else:
1346
+ latent_model_input = latents
1347
+
1348
+ # change attention layer in UNet if use PAG
1349
+ if self.do_adversarial_guidance:
1350
+
1351
+ if self.do_classifier_free_guidance:
1352
+ replace_processor = PAGCFGIdentitySelfAttnProcessor()
1353
+ else:
1354
+ replace_processor = PAGIdentitySelfAttnProcessor()
1355
+
1356
+ drop_layers = self.pag_applied_layers_index
1357
+ for drop_layer in drop_layers:
1358
+ try:
1359
+ if drop_layer[0] == 'd':
1360
+ down_layers[int(drop_layer[1])].processor = replace_processor
1361
+ elif drop_layer[0] == 'm':
1362
+ mid_layers[int(drop_layer[1])].processor = replace_processor
1363
+ elif drop_layer[0] == 'u':
1364
+ up_layers[int(drop_layer[1])].processor = replace_processor
1365
+ else:
1366
+ raise ValueError(f"Invalid layer type: {drop_layer[0]}")
1367
+ except IndexError:
1368
+ raise ValueError(
1369
+ f"Invalid layer index: {drop_layer}. Available layers: {len(down_layers)} down layers, {len(mid_layers)} mid layers, {len(up_layers)} up layers."
1370
+ )
1371
+
1372
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
1373
+
1374
+ # predict the noise residual
1375
+ noise_pred = self.unet(
1376
+ latent_model_input,
1377
+ t,
1378
+ encoder_hidden_states=prompt_embeds,
1379
+ timestep_cond=timestep_cond,
1380
+ cross_attention_kwargs=self.cross_attention_kwargs,
1381
+ added_cond_kwargs=added_cond_kwargs,
1382
+ return_dict=False,
1383
+ )[0]
1384
+
1385
+ # perform guidance
1386
+
1387
+ # cfg
1388
+ if self.do_classifier_free_guidance and not self.do_adversarial_guidance:
1389
+
1390
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
1391
+
1392
+ delta = noise_pred_text - noise_pred_uncond
1393
+ noise_pred = noise_pred_uncond + self.guidance_scale * delta
1394
+
1395
+ # pag
1396
+ elif not self.do_classifier_free_guidance and self.do_adversarial_guidance:
1397
+
1398
+ noise_pred_original, noise_pred_perturb = noise_pred.chunk(2)
1399
+
1400
+ signal_scale = self.pag_scale
1401
+ if self.do_pag_adaptive_scaling:
1402
+ signal_scale = self.pag_scale - self.pag_adaptive_scaling * (1000-t)
1403
+ if signal_scale<0:
1404
+ signal_scale = 0
1405
+
1406
+ noise_pred = noise_pred_original + signal_scale * (noise_pred_original - noise_pred_perturb)
1407
+
1408
+ # both
1409
+ elif self.do_classifier_free_guidance and self.do_adversarial_guidance:
1410
+
1411
+ noise_pred_uncond, noise_pred_text, noise_pred_text_perturb = noise_pred.chunk(3)
1412
+
1413
+ signal_scale = self.pag_scale
1414
+ if self.do_pag_adaptive_scaling:
1415
+ signal_scale = self.pag_scale - self.pag_adaptive_scaling * (1000-t)
1416
+ if signal_scale<0:
1417
+ signal_scale = 0
1418
+
1419
+ noise_pred = noise_pred_text + (self.guidance_scale-1.0) * (noise_pred_text - noise_pred_uncond) + signal_scale * (noise_pred_text - noise_pred_text_perturb)
1420
+
1421
+ if self.do_classifier_free_guidance and self.guidance_rescale > 0.0:
1422
+ # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
1423
+ noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale)
1424
+
1425
+ # compute the previous noisy sample x_t -> x_t-1
1426
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
1427
+
1428
+ if callback_on_step_end is not None:
1429
+ callback_kwargs = {}
1430
+ for k in callback_on_step_end_tensor_inputs:
1431
+ callback_kwargs[k] = locals()[k]
1432
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
1433
+
1434
+ latents = callback_outputs.pop("latents", latents)
1435
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
1436
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
1437
+
1438
+ # call the callback, if provided
1439
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
1440
+ progress_bar.update()
1441
+ if callback is not None and i % callback_steps == 0:
1442
+ step_idx = i // getattr(self.scheduler, "order", 1)
1443
+ callback(step_idx, t, latents)
1444
+
1445
+ if not output_type == "latent":
1446
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[
1447
+ 0
1448
+ ]
1449
+ image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
1450
+ else:
1451
+ image = latents
1452
+ has_nsfw_concept = None
1453
+
1454
+ if has_nsfw_concept is None:
1455
+ do_denormalize = [True] * image.shape[0]
1456
+ else:
1457
+ do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
1458
+
1459
+ image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
1460
+
1461
+ # Offload all models
1462
+ self.maybe_free_model_hooks()
1463
+
1464
+ if not return_dict:
1465
+ return (image, has_nsfw_concept)
1466
+
1467
+ # change attention layer in UNet if use PAG
1468
+ if self.do_adversarial_guidance:
1469
+ drop_layers = self.pag_applied_layers_index
1470
+ for drop_layer in drop_layers:
1471
+ try:
1472
+ if drop_layer[0] == 'd':
1473
+ down_layers[int(drop_layer[1])].processor = AttnProcessor2_0()
1474
+ elif drop_layer[0] == 'm':
1475
+ mid_layers[int(drop_layer[1])].processor = AttnProcessor2_0()
1476
+ elif drop_layer[0] == 'u':
1477
+ up_layers[int(drop_layer[1])].processor = AttnProcessor2_0()
1478
+ else:
1479
+ raise ValueError(f"Invalid layer type: {drop_layer[0]}")
1480
+ except IndexError:
1481
+ raise ValueError(
1482
+ f"Invalid layer index: {drop_layer}. Available layers: {len(down_layers)} down layers, {len(mid_layers)} mid layers, {len(up_layers)} up layers."
1483
+ )
1484
+
1485
+ return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
sd_pag_demo.ipynb ADDED
The diff for this file is too large to render. See raw diff