Sanster commited on
Commit
4dbd536
·
1 Parent(s): 88874f9
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ __pycache__/
PowerPaint_Brushnet/config.json ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "BrushNetModel",
3
+ "_diffusers_version": "0.27.2",
4
+ "act_fn": "silu",
5
+ "addition_embed_type": null,
6
+ "addition_embed_type_num_heads": 64,
7
+ "addition_time_embed_dim": null,
8
+ "attention_head_dim": 8,
9
+ "block_out_channels": [
10
+ 320,
11
+ 640,
12
+ 1280,
13
+ 1280
14
+ ],
15
+ "brushnet_conditioning_channel_order": "rgb",
16
+ "class_embed_type": null,
17
+ "conditioning_channels": 5,
18
+ "conditioning_embedding_out_channels": [
19
+ 16,
20
+ 32,
21
+ 96,
22
+ 256
23
+ ],
24
+ "cross_attention_dim": 768,
25
+ "down_block_types": [
26
+ "CrossAttnDownBlock2D",
27
+ "CrossAttnDownBlock2D",
28
+ "CrossAttnDownBlock2D",
29
+ "DownBlock2D"
30
+ ],
31
+ "downsample_padding": 1,
32
+ "encoder_hid_dim": null,
33
+ "encoder_hid_dim_type": null,
34
+ "flip_sin_to_cos": true,
35
+ "freq_shift": 0,
36
+ "global_pool_conditions": false,
37
+ "in_channels": 4,
38
+ "layers_per_block": 2,
39
+ "mid_block_scale_factor": 1,
40
+ "mid_block_type": "UNetMidBlock2DCrossAttn",
41
+ "norm_eps": 1e-05,
42
+ "norm_num_groups": 32,
43
+ "num_attention_heads": null,
44
+ "num_class_embeds": null,
45
+ "only_cross_attention": false,
46
+ "projection_class_embeddings_input_dim": null,
47
+ "resnet_time_scale_shift": "default",
48
+ "transformer_layers_per_block": 1,
49
+ "up_block_types": [
50
+ "UpBlock2D",
51
+ "CrossAttnUpBlock2D",
52
+ "CrossAttnUpBlock2D",
53
+ "CrossAttnUpBlock2D"
54
+ ],
55
+ "upcast_attention": false,
56
+ "use_linear_projection": false
57
+ }
PowerPaint_Brushnet/diffusion_pytorch_model.fp16.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e9c63d0b055c91cb098d303f83087090c09b4edd7848f1fedab313eff004f014
3
+ size 1772227696
PowerPaint_Brushnet/diffusion_pytorch_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:745a3babce414c8b765c57e86412544cecdbdb0601648900d10b482256babb76
3
+ size 3544366408
context-aware_result.png ADDED
image-outpainting_result.png ADDED
inpaint_result.png ADDED
main.py ADDED
@@ -0,0 +1,210 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+ import torch
4
+ from PIL import Image, ImageFilter, ImageOps
5
+ from transformers import CLIPTextModel, CLIPTokenizer
6
+ from diffusers.utils import load_image
7
+ from diffusers import DPMSolverMultistepScheduler
8
+ from safetensors.torch import load_model
9
+
10
+ from powerpaint_v2.BrushNet_CA import BrushNetModel
11
+ from powerpaint_v2.pipeline_PowerPaint_Brushnet_CA import (
12
+ StableDiffusionPowerPaintBrushNetPipeline,
13
+ )
14
+ from powerpaint_v2.power_paint_tokenizer import PowerPaintTokenizer
15
+ from powerpaint_v2.unet_2d_condition import UNet2DConditionModel
16
+
17
+
18
+ def task_to_prompt(control_type):
19
+ if control_type == "object-removal":
20
+ promptA = "P_ctxt"
21
+ promptB = "P_ctxt"
22
+ negative_promptA = "P_obj"
23
+ negative_promptB = "P_obj"
24
+ elif control_type == "context-aware":
25
+ promptA = "P_ctxt"
26
+ promptB = "P_ctxt"
27
+ negative_promptA = ""
28
+ negative_promptB = ""
29
+ elif control_type == "shape-guided":
30
+ promptA = "P_shape"
31
+ promptB = "P_ctxt"
32
+ negative_promptA = "P_shape"
33
+ negative_promptB = "P_ctxt"
34
+ elif control_type == "image-outpainting":
35
+ promptA = "P_ctxt"
36
+ promptB = "P_ctxt"
37
+ negative_promptA = "P_obj"
38
+ negative_promptB = "P_obj"
39
+ else:
40
+ promptA = "P_obj"
41
+ promptB = "P_obj"
42
+ negative_promptA = "P_obj"
43
+ negative_promptB = "P_obj"
44
+
45
+ return promptA, promptB, negative_promptA, negative_promptB
46
+
47
+
48
+ @torch.inference_mode()
49
+ def predict(
50
+ pipe,
51
+ input_image,
52
+ prompt,
53
+ fitting_degree,
54
+ ddim_steps,
55
+ scale,
56
+ negative_prompt,
57
+ task,
58
+ ):
59
+ promptA, promptB, negative_promptA, negative_promptB = task_to_prompt(task)
60
+ print(task, promptA, promptB, negative_promptA, negative_promptB)
61
+ img = np.array(input_image["image"].convert("RGB"))
62
+
63
+ W = int(np.shape(img)[0] - np.shape(img)[0] % 8)
64
+ H = int(np.shape(img)[1] - np.shape(img)[1] % 8)
65
+ input_image["image"] = input_image["image"].resize((H, W))
66
+ input_image["mask"] = input_image["mask"].resize((H, W))
67
+
68
+ np_inpimg = np.array(input_image["image"])
69
+ np_inmask = np.array(input_image["mask"]) / 255.0
70
+
71
+ np_inpimg = np_inpimg * (1 - np_inmask)
72
+
73
+ input_image["image"] = Image.fromarray(np_inpimg.astype(np.uint8)).convert("RGB")
74
+
75
+ result = pipe(
76
+ promptA=promptA,
77
+ promptB=promptB,
78
+ promptU=prompt,
79
+ tradoff=fitting_degree,
80
+ tradoff_nag=fitting_degree,
81
+ image=input_image["image"].convert("RGB"),
82
+ mask=input_image["mask"].convert("RGB"),
83
+ num_inference_steps=ddim_steps,
84
+ brushnet_conditioning_scale=1.0,
85
+ negative_promptA=negative_promptA,
86
+ negative_promptB=negative_promptB,
87
+ negative_promptU=negative_prompt,
88
+ guidance_scale=scale,
89
+ width=H,
90
+ height=W,
91
+ ).images[0]
92
+ return result
93
+ m_img = (
94
+ input_image["mask"].convert("RGB").filter(ImageFilter.GaussianBlur(radius=3))
95
+ )
96
+ m_img = np.asarray(m_img) / 255.0
97
+ img_np = np.asarray(input_image["image"].convert("RGB")) / 255.0
98
+ ours_np = np.asarray(result) / 255.0
99
+ ours_np = ours_np * m_img + (1 - m_img) * img_np
100
+ result_paste = Image.fromarray(np.uint8(ours_np * 255))
101
+ return result_paste
102
+
103
+
104
+ text_encoder_brushnet = CLIPTextModel.from_pretrained(
105
+ "text_encoder_brushnet",
106
+ variant="fp16",
107
+ torch_dtype=torch.float16,
108
+ )
109
+ unet = UNet2DConditionModel.from_pretrained(
110
+ "runwayml/stable-diffusion-v1-5",
111
+ subfolder="unet",
112
+ variant="fp16",
113
+ torch_dtype=torch.float16,
114
+ )
115
+ brushnet = BrushNetModel.from_pretrained(
116
+ "./PowerPaint_Brushnet",
117
+ variant="fp16",
118
+ torch_dtype=torch.float16,
119
+ )
120
+ pipe = StableDiffusionPowerPaintBrushNetPipeline.from_pretrained(
121
+ "runwayml/stable-diffusion-v1-5",
122
+ torch_dtype=torch.float16,
123
+ safety_checker=None,
124
+ unet=unet,
125
+ brushnet=brushnet,
126
+ text_encoder_brushnet=text_encoder_brushnet,
127
+ variant="fp16",
128
+ )
129
+ pipe.tokenizer = PowerPaintTokenizer(CLIPTokenizer.from_pretrained("./tokenizer"))
130
+ pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
131
+ pipe = pipe.to("mps")
132
+
133
+
134
+ img_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png"
135
+ mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png"
136
+ image = load_image(img_url).convert("RGB").resize((512, 512))
137
+ mask = load_image(mask_url).convert("RGB").resize((512, 512))
138
+
139
+
140
+ input_image = {"image": image, "mask": mask}
141
+ prompt = "Face of a fox sitting on a bench"
142
+ negative_prompt = "out of frame, lowres, error, cropped, worst quality, low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, out of frame, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, disfigured, gross proportions, malformed limbs, watermark, signature"
143
+ fitting_degree = 1
144
+ steps = 30
145
+ tasks = [
146
+ {
147
+ "task": "object-removal",
148
+ "guidance_scale": 12,
149
+ "prompt": "empty scene blur",
150
+ "negative_prompt": "",
151
+ },
152
+ {
153
+ "task": "shape-guided",
154
+ "guidance_scale": 7.5,
155
+ "prompt": prompt,
156
+ "negative_prompt": negative_prompt,
157
+ },
158
+ {
159
+ "task": "context-aware",
160
+ "guidance_scale": 7.5,
161
+ "prompt": "empty secne",
162
+ "negative_prompt": negative_prompt,
163
+ },
164
+ {
165
+ "task": "inpaint",
166
+ "guidance_scale": 7.5,
167
+ "prompt": prompt,
168
+ "negative_prompt": negative_prompt,
169
+ },
170
+ {
171
+ "task": "image-outpainting",
172
+ "guidance_scale": 7.5,
173
+ "prompt": "empty scene",
174
+ "negative_prompt": negative_prompt,
175
+ },
176
+ ]
177
+
178
+ for task in tasks:
179
+ if task["task"] == "image-outpainting":
180
+ margin = 128
181
+ input_image["image"] = ImageOps.expand(
182
+ input_image["image"],
183
+ border=(margin, margin, margin, margin),
184
+ fill=(127, 127, 127),
185
+ )
186
+ outpaint_mask = np.zeros_like(np.asarray(input_image["mask"]))
187
+ input_image["mask"] = Image.fromarray(
188
+ cv2.copyMakeBorder(
189
+ outpaint_mask,
190
+ margin,
191
+ margin,
192
+ margin,
193
+ margin,
194
+ cv2.BORDER_CONSTANT,
195
+ value=(255, 255, 255),
196
+ )
197
+ )
198
+
199
+ result_image = predict(
200
+ pipe,
201
+ input_image,
202
+ task["prompt"],
203
+ fitting_degree,
204
+ steps,
205
+ task["guidance_scale"],
206
+ task["negative_prompt"],
207
+ task["task"],
208
+ )
209
+
210
+ result_image.save(f"{task['task']}_result.png")
object-removal_result.png ADDED
powerpaint_v2/BrushNet_CA.py ADDED
@@ -0,0 +1,933 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import Any, Dict, List, Optional, Tuple, Union
3
+
4
+ import torch
5
+ from torch import nn
6
+
7
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
8
+ from diffusers.utils import BaseOutput, logging
9
+ from diffusers.models.attention_processor import (
10
+ ADDED_KV_ATTENTION_PROCESSORS,
11
+ CROSS_ATTENTION_PROCESSORS,
12
+ AttentionProcessor,
13
+ AttnAddedKVProcessor,
14
+ AttnProcessor,
15
+ )
16
+ from diffusers.models.embeddings import TextImageProjection, TextImageTimeEmbedding, TextTimeEmbedding, \
17
+ TimestepEmbedding, Timesteps
18
+ from diffusers.models.modeling_utils import ModelMixin
19
+ from .unet_2d_blocks import (
20
+ CrossAttnDownBlock2D,
21
+ DownBlock2D,
22
+ get_down_block,
23
+ get_mid_block,
24
+ get_up_block
25
+ )
26
+
27
+ from .unet_2d_condition import UNet2DConditionModel
28
+
29
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
30
+
31
+
32
+ @dataclass
33
+ class BrushNetOutput(BaseOutput):
34
+ """
35
+ The output of [`BrushNetModel`].
36
+
37
+ Args:
38
+ up_block_res_samples (`tuple[torch.Tensor]`):
39
+ A tuple of upsample activations at different resolutions for each upsampling block. Each tensor should
40
+ be of shape `(batch_size, channel * resolution, height //resolution, width // resolution)`. Output can be
41
+ used to condition the original UNet's upsampling activations.
42
+ down_block_res_samples (`tuple[torch.Tensor]`):
43
+ A tuple of downsample activations at different resolutions for each downsampling block. Each tensor should
44
+ be of shape `(batch_size, channel * resolution, height //resolution, width // resolution)`. Output can be
45
+ used to condition the original UNet's downsampling activations.
46
+ mid_down_block_re_sample (`torch.Tensor`):
47
+ The activation of the midde block (the lowest sample resolution). Each tensor should be of shape
48
+ `(batch_size, channel * lowest_resolution, height // lowest_resolution, width // lowest_resolution)`.
49
+ Output can be used to condition the original UNet's middle block activation.
50
+ """
51
+
52
+ up_block_res_samples: Tuple[torch.Tensor]
53
+ down_block_res_samples: Tuple[torch.Tensor]
54
+ mid_block_res_sample: torch.Tensor
55
+
56
+
57
+ class BrushNetModel(ModelMixin, ConfigMixin):
58
+ """
59
+ A BrushNet model.
60
+
61
+ Args:
62
+ in_channels (`int`, defaults to 4):
63
+ The number of channels in the input sample.
64
+ flip_sin_to_cos (`bool`, defaults to `True`):
65
+ Whether to flip the sin to cos in the time embedding.
66
+ freq_shift (`int`, defaults to 0):
67
+ The frequency shift to apply to the time embedding.
68
+ down_block_types (`tuple[str]`, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
69
+ The tuple of downsample blocks to use.
70
+ mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`):
71
+ Block type for middle of UNet, it can be one of `UNetMidBlock2DCrossAttn`, `UNetMidBlock2D`, or
72
+ `UNetMidBlock2DSimpleCrossAttn`. If `None`, the mid block layer is skipped.
73
+ up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")`):
74
+ The tuple of upsample blocks to use.
75
+ only_cross_attention (`Union[bool, Tuple[bool]]`, defaults to `False`):
76
+ block_out_channels (`tuple[int]`, defaults to `(320, 640, 1280, 1280)`):
77
+ The tuple of output channels for each block.
78
+ layers_per_block (`int`, defaults to 2):
79
+ The number of layers per block.
80
+ downsample_padding (`int`, defaults to 1):
81
+ The padding to use for the downsampling convolution.
82
+ mid_block_scale_factor (`float`, defaults to 1):
83
+ The scale factor to use for the mid block.
84
+ act_fn (`str`, defaults to "silu"):
85
+ The activation function to use.
86
+ norm_num_groups (`int`, *optional*, defaults to 32):
87
+ The number of groups to use for the normalization. If None, normalization and activation layers is skipped
88
+ in post-processing.
89
+ norm_eps (`float`, defaults to 1e-5):
90
+ The epsilon to use for the normalization.
91
+ cross_attention_dim (`int`, defaults to 1280):
92
+ The dimension of the cross attention features.
93
+ transformer_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 1):
94
+ The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
95
+ [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
96
+ [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
97
+ encoder_hid_dim (`int`, *optional*, defaults to None):
98
+ If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim`
99
+ dimension to `cross_attention_dim`.
100
+ encoder_hid_dim_type (`str`, *optional*, defaults to `None`):
101
+ If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text
102
+ embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`.
103
+ attention_head_dim (`Union[int, Tuple[int]]`, defaults to 8):
104
+ The dimension of the attention heads.
105
+ use_linear_projection (`bool`, defaults to `False`):
106
+ class_embed_type (`str`, *optional*, defaults to `None`):
107
+ The type of class embedding to use which is ultimately summed with the time embeddings. Choose from None,
108
+ `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`.
109
+ addition_embed_type (`str`, *optional*, defaults to `None`):
110
+ Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or
111
+ "text". "text" will use the `TextTimeEmbedding` layer.
112
+ num_class_embeds (`int`, *optional*, defaults to 0):
113
+ Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
114
+ class conditioning with `class_embed_type` equal to `None`.
115
+ upcast_attention (`bool`, defaults to `False`):
116
+ resnet_time_scale_shift (`str`, defaults to `"default"`):
117
+ Time scale shift config for ResNet blocks (see `ResnetBlock2D`). Choose from `default` or `scale_shift`.
118
+ projection_class_embeddings_input_dim (`int`, *optional*, defaults to `None`):
119
+ The dimension of the `class_labels` input when `class_embed_type="projection"`. Required when
120
+ `class_embed_type="projection"`.
121
+ brushnet_conditioning_channel_order (`str`, defaults to `"rgb"`):
122
+ The channel order of conditional image. Will convert to `rgb` if it's `bgr`.
123
+ conditioning_embedding_out_channels (`tuple[int]`, *optional*, defaults to `(16, 32, 96, 256)`):
124
+ The tuple of output channel for each block in the `conditioning_embedding` layer.
125
+ global_pool_conditions (`bool`, defaults to `False`):
126
+ TODO(Patrick) - unused parameter.
127
+ addition_embed_type_num_heads (`int`, defaults to 64):
128
+ The number of heads to use for the `TextTimeEmbedding` layer.
129
+ """
130
+
131
+ _supports_gradient_checkpointing = True
132
+
133
+ @register_to_config
134
+ def __init__(
135
+ self,
136
+ in_channels: int = 4,
137
+ conditioning_channels: int = 5,
138
+ flip_sin_to_cos: bool = True,
139
+ freq_shift: int = 0,
140
+ down_block_types: Tuple[str, ...] = (
141
+ "CrossAttnDownBlock2D",
142
+ "CrossAttnDownBlock2D",
143
+ "CrossAttnDownBlock2D",
144
+ "DownBlock2D",
145
+ ),
146
+ mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn",
147
+ up_block_types: Tuple[str, ...] = (
148
+ "UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"
149
+ ),
150
+ only_cross_attention: Union[bool, Tuple[bool]] = False,
151
+ block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280),
152
+ layers_per_block: int = 2,
153
+ downsample_padding: int = 1,
154
+ mid_block_scale_factor: float = 1,
155
+ act_fn: str = "silu",
156
+ norm_num_groups: Optional[int] = 32,
157
+ norm_eps: float = 1e-5,
158
+ cross_attention_dim: int = 1280,
159
+ transformer_layers_per_block: Union[int, Tuple[int, ...]] = 1,
160
+ encoder_hid_dim: Optional[int] = None,
161
+ encoder_hid_dim_type: Optional[str] = None,
162
+ attention_head_dim: Union[int, Tuple[int, ...]] = 8,
163
+ num_attention_heads: Optional[Union[int, Tuple[int, ...]]] = None,
164
+ use_linear_projection: bool = False,
165
+ class_embed_type: Optional[str] = None,
166
+ addition_embed_type: Optional[str] = None,
167
+ addition_time_embed_dim: Optional[int] = None,
168
+ num_class_embeds: Optional[int] = None,
169
+ upcast_attention: bool = False,
170
+ resnet_time_scale_shift: str = "default",
171
+ projection_class_embeddings_input_dim: Optional[int] = None,
172
+ brushnet_conditioning_channel_order: str = "rgb",
173
+ conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (16, 32, 96, 256),
174
+ global_pool_conditions: bool = False,
175
+ addition_embed_type_num_heads: int = 64,
176
+ ):
177
+ super().__init__()
178
+
179
+ # If `num_attention_heads` is not defined (which is the case for most models)
180
+ # it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
181
+ # The reason for this behavior is to correct for incorrectly named variables that were introduced
182
+ # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
183
+ # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
184
+ # which is why we correct for the naming here.
185
+ num_attention_heads = num_attention_heads or attention_head_dim
186
+
187
+ # Check inputs
188
+ if len(down_block_types) != len(up_block_types):
189
+ raise ValueError(
190
+ f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
191
+ )
192
+
193
+ if len(block_out_channels) != len(down_block_types):
194
+ raise ValueError(
195
+ f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
196
+ )
197
+
198
+ if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):
199
+ raise ValueError(
200
+ f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
201
+ )
202
+
203
+ if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
204
+ raise ValueError(
205
+ f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
206
+ )
207
+
208
+ if isinstance(transformer_layers_per_block, int):
209
+ transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)
210
+
211
+ # input
212
+ conv_in_kernel = 3
213
+ conv_in_padding = (conv_in_kernel - 1) // 2
214
+ self.conv_in_condition = nn.Conv2d(
215
+ in_channels + conditioning_channels, block_out_channels[0], kernel_size=conv_in_kernel,
216
+ padding=conv_in_padding
217
+ )
218
+
219
+ # time
220
+ time_embed_dim = block_out_channels[0] * 4
221
+ self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
222
+ timestep_input_dim = block_out_channels[0]
223
+ self.time_embedding = TimestepEmbedding(
224
+ timestep_input_dim,
225
+ time_embed_dim,
226
+ act_fn=act_fn,
227
+ )
228
+
229
+ if encoder_hid_dim_type is None and encoder_hid_dim is not None:
230
+ encoder_hid_dim_type = "text_proj"
231
+ self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type)
232
+ logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.")
233
+
234
+ if encoder_hid_dim is None and encoder_hid_dim_type is not None:
235
+ raise ValueError(
236
+ f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}."
237
+ )
238
+
239
+ if encoder_hid_dim_type == "text_proj":
240
+ self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim)
241
+ elif encoder_hid_dim_type == "text_image_proj":
242
+ # image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much
243
+ # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
244
+ # case when `addition_embed_type == "text_image_proj"` (Kadinsky 2.1)`
245
+ self.encoder_hid_proj = TextImageProjection(
246
+ text_embed_dim=encoder_hid_dim,
247
+ image_embed_dim=cross_attention_dim,
248
+ cross_attention_dim=cross_attention_dim,
249
+ )
250
+
251
+ elif encoder_hid_dim_type is not None:
252
+ raise ValueError(
253
+ f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'."
254
+ )
255
+ else:
256
+ self.encoder_hid_proj = None
257
+
258
+ # class embedding
259
+ if class_embed_type is None and num_class_embeds is not None:
260
+ self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
261
+ elif class_embed_type == "timestep":
262
+ self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
263
+ elif class_embed_type == "identity":
264
+ self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
265
+ elif class_embed_type == "projection":
266
+ if projection_class_embeddings_input_dim is None:
267
+ raise ValueError(
268
+ "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set"
269
+ )
270
+ # The projection `class_embed_type` is the same as the timestep `class_embed_type` except
271
+ # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings
272
+ # 2. it projects from an arbitrary input dimension.
273
+ #
274
+ # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.
275
+ # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.
276
+ # As a result, `TimestepEmbedding` can be passed arbitrary vectors.
277
+ self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
278
+ else:
279
+ self.class_embedding = None
280
+
281
+ if addition_embed_type == "text":
282
+ if encoder_hid_dim is not None:
283
+ text_time_embedding_from_dim = encoder_hid_dim
284
+ else:
285
+ text_time_embedding_from_dim = cross_attention_dim
286
+
287
+ self.add_embedding = TextTimeEmbedding(
288
+ text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads
289
+ )
290
+ elif addition_embed_type == "text_image":
291
+ # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much
292
+ # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
293
+ # case when `addition_embed_type == "text_image"` (Kadinsky 2.1)`
294
+ self.add_embedding = TextImageTimeEmbedding(
295
+ text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim
296
+ )
297
+ elif addition_embed_type == "text_time":
298
+ self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift)
299
+ self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
300
+
301
+ elif addition_embed_type is not None:
302
+ raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.")
303
+
304
+ self.down_blocks = nn.ModuleList([])
305
+ self.brushnet_down_blocks = nn.ModuleList([])
306
+
307
+ if isinstance(only_cross_attention, bool):
308
+ only_cross_attention = [only_cross_attention] * len(down_block_types)
309
+
310
+ if isinstance(attention_head_dim, int):
311
+ attention_head_dim = (attention_head_dim,) * len(down_block_types)
312
+
313
+ if isinstance(num_attention_heads, int):
314
+ num_attention_heads = (num_attention_heads,) * len(down_block_types)
315
+
316
+ # down
317
+ output_channel = block_out_channels[0]
318
+
319
+ brushnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
320
+ brushnet_block = zero_module(brushnet_block)
321
+ self.brushnet_down_blocks.append(brushnet_block)
322
+
323
+ for i, down_block_type in enumerate(down_block_types):
324
+ input_channel = output_channel
325
+ output_channel = block_out_channels[i]
326
+ is_final_block = i == len(block_out_channels) - 1
327
+
328
+ down_block = get_down_block(
329
+ down_block_type,
330
+ num_layers=layers_per_block,
331
+ transformer_layers_per_block=transformer_layers_per_block[i],
332
+ in_channels=input_channel,
333
+ out_channels=output_channel,
334
+ temb_channels=time_embed_dim,
335
+ add_downsample=not is_final_block,
336
+ resnet_eps=norm_eps,
337
+ resnet_act_fn=act_fn,
338
+ resnet_groups=norm_num_groups,
339
+ cross_attention_dim=cross_attention_dim,
340
+ num_attention_heads=num_attention_heads[i],
341
+ attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
342
+ downsample_padding=downsample_padding,
343
+ use_linear_projection=use_linear_projection,
344
+ only_cross_attention=only_cross_attention[i],
345
+ upcast_attention=upcast_attention,
346
+ resnet_time_scale_shift=resnet_time_scale_shift,
347
+ )
348
+ self.down_blocks.append(down_block)
349
+
350
+ for _ in range(layers_per_block):
351
+ brushnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
352
+ brushnet_block = zero_module(brushnet_block)
353
+ self.brushnet_down_blocks.append(brushnet_block)
354
+
355
+ if not is_final_block:
356
+ brushnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
357
+ brushnet_block = zero_module(brushnet_block)
358
+ self.brushnet_down_blocks.append(brushnet_block)
359
+
360
+ # mid
361
+ mid_block_channel = block_out_channels[-1]
362
+
363
+ brushnet_block = nn.Conv2d(mid_block_channel, mid_block_channel, kernel_size=1)
364
+ brushnet_block = zero_module(brushnet_block)
365
+ self.brushnet_mid_block = brushnet_block
366
+
367
+ self.mid_block = get_mid_block(
368
+ mid_block_type,
369
+ transformer_layers_per_block=transformer_layers_per_block[-1],
370
+ in_channels=mid_block_channel,
371
+ temb_channels=time_embed_dim,
372
+ resnet_eps=norm_eps,
373
+ resnet_act_fn=act_fn,
374
+ output_scale_factor=mid_block_scale_factor,
375
+ resnet_time_scale_shift=resnet_time_scale_shift,
376
+ cross_attention_dim=cross_attention_dim,
377
+ num_attention_heads=num_attention_heads[-1],
378
+ resnet_groups=norm_num_groups,
379
+ use_linear_projection=use_linear_projection,
380
+ upcast_attention=upcast_attention,
381
+ )
382
+
383
+ # count how many layers upsample the images
384
+ self.num_upsamplers = 0
385
+
386
+ # up
387
+ reversed_block_out_channels = list(reversed(block_out_channels))
388
+ reversed_num_attention_heads = list(reversed(num_attention_heads))
389
+ reversed_transformer_layers_per_block = (list(reversed(transformer_layers_per_block)))
390
+ only_cross_attention = list(reversed(only_cross_attention))
391
+
392
+ output_channel = reversed_block_out_channels[0]
393
+
394
+ self.up_blocks = nn.ModuleList([])
395
+ self.brushnet_up_blocks = nn.ModuleList([])
396
+
397
+ for i, up_block_type in enumerate(up_block_types):
398
+ is_final_block = i == len(block_out_channels) - 1
399
+
400
+ prev_output_channel = output_channel
401
+ output_channel = reversed_block_out_channels[i]
402
+ input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
403
+
404
+ # add upsample block for all BUT final layer
405
+ if not is_final_block:
406
+ add_upsample = True
407
+ self.num_upsamplers += 1
408
+ else:
409
+ add_upsample = False
410
+
411
+ up_block = get_up_block(
412
+ up_block_type,
413
+ num_layers=layers_per_block + 1,
414
+ transformer_layers_per_block=reversed_transformer_layers_per_block[i],
415
+ in_channels=input_channel,
416
+ out_channels=output_channel,
417
+ prev_output_channel=prev_output_channel,
418
+ temb_channels=time_embed_dim,
419
+ add_upsample=add_upsample,
420
+ resnet_eps=norm_eps,
421
+ resnet_act_fn=act_fn,
422
+ resolution_idx=i,
423
+ resnet_groups=norm_num_groups,
424
+ cross_attention_dim=cross_attention_dim,
425
+ num_attention_heads=reversed_num_attention_heads[i],
426
+ use_linear_projection=use_linear_projection,
427
+ only_cross_attention=only_cross_attention[i],
428
+ upcast_attention=upcast_attention,
429
+ resnet_time_scale_shift=resnet_time_scale_shift,
430
+ attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
431
+ )
432
+ self.up_blocks.append(up_block)
433
+ prev_output_channel = output_channel
434
+
435
+ for _ in range(layers_per_block + 1):
436
+ brushnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
437
+ brushnet_block = zero_module(brushnet_block)
438
+ self.brushnet_up_blocks.append(brushnet_block)
439
+
440
+ if not is_final_block:
441
+ brushnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
442
+ brushnet_block = zero_module(brushnet_block)
443
+ self.brushnet_up_blocks.append(brushnet_block)
444
+
445
+ @classmethod
446
+ def from_unet(
447
+ cls,
448
+ unet: UNet2DConditionModel,
449
+ brushnet_conditioning_channel_order: str = "rgb",
450
+ conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (16, 32, 96, 256),
451
+ load_weights_from_unet: bool = True,
452
+ conditioning_channels: int = 5,
453
+ ):
454
+ r"""
455
+ Instantiate a [`BrushNetModel`] from [`UNet2DConditionModel`].
456
+
457
+ Parameters:
458
+ unet (`UNet2DConditionModel`):
459
+ The UNet model weights to copy to the [`BrushNetModel`]. All configuration options are also copied
460
+ where applicable.
461
+ """
462
+ transformer_layers_per_block = (
463
+ unet.config.transformer_layers_per_block if "transformer_layers_per_block" in unet.config else 1
464
+ )
465
+ encoder_hid_dim = unet.config.encoder_hid_dim if "encoder_hid_dim" in unet.config else None
466
+ encoder_hid_dim_type = unet.config.encoder_hid_dim_type if "encoder_hid_dim_type" in unet.config else None
467
+ addition_embed_type = unet.config.addition_embed_type if "addition_embed_type" in unet.config else None
468
+ addition_time_embed_dim = (
469
+ unet.config.addition_time_embed_dim if "addition_time_embed_dim" in unet.config else None
470
+ )
471
+
472
+ brushnet = cls(
473
+ in_channels=unet.config.in_channels,
474
+ conditioning_channels=conditioning_channels,
475
+ flip_sin_to_cos=unet.config.flip_sin_to_cos,
476
+ freq_shift=unet.config.freq_shift,
477
+ # down_block_types=['DownBlock2D','DownBlock2D','DownBlock2D','DownBlock2D'],
478
+ down_block_types=["CrossAttnDownBlock2D",
479
+ "CrossAttnDownBlock2D",
480
+ "CrossAttnDownBlock2D",
481
+ "DownBlock2D", ],
482
+ # mid_block_type='MidBlock2D',
483
+ mid_block_type="UNetMidBlock2DCrossAttn",
484
+ # up_block_types=['UpBlock2D','UpBlock2D','UpBlock2D','UpBlock2D'],
485
+ up_block_types=["UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"],
486
+ only_cross_attention=unet.config.only_cross_attention,
487
+ block_out_channels=unet.config.block_out_channels,
488
+ layers_per_block=unet.config.layers_per_block,
489
+ downsample_padding=unet.config.downsample_padding,
490
+ mid_block_scale_factor=unet.config.mid_block_scale_factor,
491
+ act_fn=unet.config.act_fn,
492
+ norm_num_groups=unet.config.norm_num_groups,
493
+ norm_eps=unet.config.norm_eps,
494
+ cross_attention_dim=unet.config.cross_attention_dim,
495
+ transformer_layers_per_block=transformer_layers_per_block,
496
+ encoder_hid_dim=encoder_hid_dim,
497
+ encoder_hid_dim_type=encoder_hid_dim_type,
498
+ attention_head_dim=unet.config.attention_head_dim,
499
+ num_attention_heads=unet.config.num_attention_heads,
500
+ use_linear_projection=unet.config.use_linear_projection,
501
+ class_embed_type=unet.config.class_embed_type,
502
+ addition_embed_type=addition_embed_type,
503
+ addition_time_embed_dim=addition_time_embed_dim,
504
+ num_class_embeds=unet.config.num_class_embeds,
505
+ upcast_attention=unet.config.upcast_attention,
506
+ resnet_time_scale_shift=unet.config.resnet_time_scale_shift,
507
+ projection_class_embeddings_input_dim=unet.config.projection_class_embeddings_input_dim,
508
+ brushnet_conditioning_channel_order=brushnet_conditioning_channel_order,
509
+ conditioning_embedding_out_channels=conditioning_embedding_out_channels,
510
+ )
511
+
512
+ if load_weights_from_unet:
513
+ conv_in_condition_weight = torch.zeros_like(brushnet.conv_in_condition.weight)
514
+ conv_in_condition_weight[:, :4, ...] = unet.conv_in.weight
515
+ conv_in_condition_weight[:, 4:8, ...] = unet.conv_in.weight
516
+ brushnet.conv_in_condition.weight = torch.nn.Parameter(conv_in_condition_weight)
517
+ brushnet.conv_in_condition.bias = unet.conv_in.bias
518
+
519
+ brushnet.time_proj.load_state_dict(unet.time_proj.state_dict())
520
+ brushnet.time_embedding.load_state_dict(unet.time_embedding.state_dict())
521
+
522
+ if brushnet.class_embedding:
523
+ brushnet.class_embedding.load_state_dict(unet.class_embedding.state_dict())
524
+
525
+ brushnet.down_blocks.load_state_dict(unet.down_blocks.state_dict(), strict=False)
526
+ brushnet.mid_block.load_state_dict(unet.mid_block.state_dict(), strict=False)
527
+ brushnet.up_blocks.load_state_dict(unet.up_blocks.state_dict(), strict=False)
528
+
529
+ return brushnet.to(unet.dtype)
530
+
531
+ @property
532
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
533
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
534
+ r"""
535
+ Returns:
536
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
537
+ indexed by its weight name.
538
+ """
539
+ # set recursively
540
+ processors = {}
541
+
542
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
543
+ if hasattr(module, "get_processor"):
544
+ processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
545
+
546
+ for sub_name, child in module.named_children():
547
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
548
+
549
+ return processors
550
+
551
+ for name, module in self.named_children():
552
+ fn_recursive_add_processors(name, module, processors)
553
+
554
+ return processors
555
+
556
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
557
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
558
+ r"""
559
+ Sets the attention processor to use to compute attention.
560
+
561
+ Parameters:
562
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
563
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
564
+ for **all** `Attention` layers.
565
+
566
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
567
+ processor. This is strongly recommended when setting trainable attention processors.
568
+
569
+ """
570
+ count = len(self.attn_processors.keys())
571
+
572
+ if isinstance(processor, dict) and len(processor) != count:
573
+ raise ValueError(
574
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
575
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
576
+ )
577
+
578
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
579
+ if hasattr(module, "set_processor"):
580
+ if not isinstance(processor, dict):
581
+ module.set_processor(processor)
582
+ else:
583
+ module.set_processor(processor.pop(f"{name}.processor"))
584
+
585
+ for sub_name, child in module.named_children():
586
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
587
+
588
+ for name, module in self.named_children():
589
+ fn_recursive_attn_processor(name, module, processor)
590
+
591
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
592
+ def set_default_attn_processor(self):
593
+ """
594
+ Disables custom attention processors and sets the default attention implementation.
595
+ """
596
+ if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
597
+ processor = AttnAddedKVProcessor()
598
+ elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
599
+ processor = AttnProcessor()
600
+ else:
601
+ raise ValueError(
602
+ f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
603
+ )
604
+
605
+ self.set_attn_processor(processor)
606
+
607
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attention_slice
608
+ def set_attention_slice(self, slice_size: Union[str, int, List[int]]) -> None:
609
+ r"""
610
+ Enable sliced attention computation.
611
+
612
+ When this option is enabled, the attention module splits the input tensor in slices to compute attention in
613
+ several steps. This is useful for saving some memory in exchange for a small decrease in speed.
614
+
615
+ Args:
616
+ slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
617
+ When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If
618
+ `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is
619
+ provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
620
+ must be a multiple of `slice_size`.
621
+ """
622
+ sliceable_head_dims = []
623
+
624
+ def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
625
+ if hasattr(module, "set_attention_slice"):
626
+ sliceable_head_dims.append(module.sliceable_head_dim)
627
+
628
+ for child in module.children():
629
+ fn_recursive_retrieve_sliceable_dims(child)
630
+
631
+ # retrieve number of attention layers
632
+ for module in self.children():
633
+ fn_recursive_retrieve_sliceable_dims(module)
634
+
635
+ num_sliceable_layers = len(sliceable_head_dims)
636
+
637
+ if slice_size == "auto":
638
+ # half the attention head size is usually a good trade-off between
639
+ # speed and memory
640
+ slice_size = [dim // 2 for dim in sliceable_head_dims]
641
+ elif slice_size == "max":
642
+ # make smallest slice possible
643
+ slice_size = num_sliceable_layers * [1]
644
+
645
+ slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
646
+
647
+ if len(slice_size) != len(sliceable_head_dims):
648
+ raise ValueError(
649
+ f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
650
+ f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
651
+ )
652
+
653
+ for i in range(len(slice_size)):
654
+ size = slice_size[i]
655
+ dim = sliceable_head_dims[i]
656
+ if size is not None and size > dim:
657
+ raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
658
+
659
+ # Recursively walk through all the children.
660
+ # Any children which exposes the set_attention_slice method
661
+ # gets the message
662
+ def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
663
+ if hasattr(module, "set_attention_slice"):
664
+ module.set_attention_slice(slice_size.pop())
665
+
666
+ for child in module.children():
667
+ fn_recursive_set_attention_slice(child, slice_size)
668
+
669
+ reversed_slice_size = list(reversed(slice_size))
670
+ for module in self.children():
671
+ fn_recursive_set_attention_slice(module, reversed_slice_size)
672
+
673
+ def _set_gradient_checkpointing(self, module, value: bool = False) -> None:
674
+ if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D)):
675
+ module.gradient_checkpointing = value
676
+
677
+ def forward(
678
+ self,
679
+ sample: torch.FloatTensor,
680
+ timestep: Union[torch.Tensor, float, int],
681
+ encoder_hidden_states: torch.Tensor,
682
+ brushnet_cond: torch.FloatTensor,
683
+ conditioning_scale: float = 1.0,
684
+ class_labels: Optional[torch.Tensor] = None,
685
+ timestep_cond: Optional[torch.Tensor] = None,
686
+ attention_mask: Optional[torch.Tensor] = None,
687
+ added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
688
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
689
+ guess_mode: bool = False,
690
+ return_dict: bool = True,
691
+ ) -> Union[BrushNetOutput, Tuple[Tuple[torch.FloatTensor, ...], torch.FloatTensor]]:
692
+ """
693
+ The [`BrushNetModel`] forward method.
694
+
695
+ Args:
696
+ sample (`torch.FloatTensor`):
697
+ The noisy input tensor.
698
+ timestep (`Union[torch.Tensor, float, int]`):
699
+ The number of timesteps to denoise an input.
700
+ encoder_hidden_states (`torch.Tensor`):
701
+ The encoder hidden states.
702
+ brushnet_cond (`torch.FloatTensor`):
703
+ The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`.
704
+ conditioning_scale (`float`, defaults to `1.0`):
705
+ The scale factor for BrushNet outputs.
706
+ class_labels (`torch.Tensor`, *optional*, defaults to `None`):
707
+ Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
708
+ timestep_cond (`torch.Tensor`, *optional*, defaults to `None`):
709
+ Additional conditional embeddings for timestep. If provided, the embeddings will be summed with the
710
+ timestep_embedding passed through the `self.time_embedding` layer to obtain the final timestep
711
+ embeddings.
712
+ attention_mask (`torch.Tensor`, *optional*, defaults to `None`):
713
+ An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
714
+ is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
715
+ negative values to the attention scores corresponding to "discard" tokens.
716
+ added_cond_kwargs (`dict`):
717
+ Additional conditions for the Stable Diffusion XL UNet.
718
+ cross_attention_kwargs (`dict[str]`, *optional*, defaults to `None`):
719
+ A kwargs dictionary that if specified is passed along to the `AttnProcessor`.
720
+ guess_mode (`bool`, defaults to `False`):
721
+ In this mode, the BrushNet encoder tries its best to recognize the input content of the input even if
722
+ you remove all prompts. A `guidance_scale` between 3.0 and 5.0 is recommended.
723
+ return_dict (`bool`, defaults to `True`):
724
+ Whether or not to return a [`~models.brushnet.BrushNetOutput`] instead of a plain tuple.
725
+
726
+ Returns:
727
+ [`~models.brushnet.BrushNetOutput`] **or** `tuple`:
728
+ If `return_dict` is `True`, a [`~models.brushnet.BrushNetOutput`] is returned, otherwise a tuple is
729
+ returned where the first element is the sample tensor.
730
+ """
731
+ # check channel order
732
+ channel_order = self.config.brushnet_conditioning_channel_order
733
+
734
+ if channel_order == "rgb":
735
+ # in rgb order by default
736
+ ...
737
+ elif channel_order == "bgr":
738
+ brushnet_cond = torch.flip(brushnet_cond, dims=[1])
739
+ else:
740
+ raise ValueError(f"unknown `brushnet_conditioning_channel_order`: {channel_order}")
741
+
742
+ # prepare attention_mask
743
+ if attention_mask is not None:
744
+ attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
745
+ attention_mask = attention_mask.unsqueeze(1)
746
+
747
+ # 1. time
748
+ timesteps = timestep
749
+ if not torch.is_tensor(timesteps):
750
+ # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
751
+ # This would be a good case for the `match` statement (Python 3.10+)
752
+ is_mps = sample.device.type == "mps"
753
+ if isinstance(timestep, float):
754
+ dtype = torch.float32 if is_mps else torch.float64
755
+ else:
756
+ dtype = torch.int32 if is_mps else torch.int64
757
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
758
+ elif len(timesteps.shape) == 0:
759
+ timesteps = timesteps[None].to(sample.device)
760
+
761
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
762
+ timesteps = timesteps.expand(sample.shape[0])
763
+
764
+ t_emb = self.time_proj(timesteps)
765
+
766
+ # timesteps does not contain any weights and will always return f32 tensors
767
+ # but time_embedding might actually be running in fp16. so we need to cast here.
768
+ # there might be better ways to encapsulate this.
769
+ t_emb = t_emb.to(dtype=sample.dtype)
770
+
771
+ emb = self.time_embedding(t_emb, timestep_cond)
772
+ aug_emb = None
773
+
774
+ if self.class_embedding is not None:
775
+ if class_labels is None:
776
+ raise ValueError("class_labels should be provided when num_class_embeds > 0")
777
+
778
+ if self.config.class_embed_type == "timestep":
779
+ class_labels = self.time_proj(class_labels)
780
+
781
+ class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
782
+ emb = emb + class_emb
783
+
784
+ if self.config.addition_embed_type is not None:
785
+ if self.config.addition_embed_type == "text":
786
+ aug_emb = self.add_embedding(encoder_hidden_states)
787
+
788
+ elif self.config.addition_embed_type == "text_time":
789
+ if "text_embeds" not in added_cond_kwargs:
790
+ raise ValueError(
791
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
792
+ )
793
+ text_embeds = added_cond_kwargs.get("text_embeds")
794
+ if "time_ids" not in added_cond_kwargs:
795
+ raise ValueError(
796
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
797
+ )
798
+ time_ids = added_cond_kwargs.get("time_ids")
799
+ time_embeds = self.add_time_proj(time_ids.flatten())
800
+ time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
801
+
802
+ add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
803
+ add_embeds = add_embeds.to(emb.dtype)
804
+ aug_emb = self.add_embedding(add_embeds)
805
+
806
+ emb = emb + aug_emb if aug_emb is not None else emb
807
+
808
+ # 2. pre-process
809
+ brushnet_cond = torch.concat([sample, brushnet_cond], 1)
810
+ sample = self.conv_in_condition(brushnet_cond)
811
+
812
+ # 3. down
813
+ down_block_res_samples = (sample,)
814
+ for downsample_block in self.down_blocks:
815
+ if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
816
+ sample, res_samples = downsample_block(
817
+ hidden_states=sample,
818
+ temb=emb,
819
+ encoder_hidden_states=encoder_hidden_states,
820
+ attention_mask=attention_mask,
821
+ cross_attention_kwargs=cross_attention_kwargs,
822
+ )
823
+ else:
824
+ sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
825
+
826
+ down_block_res_samples += res_samples
827
+
828
+ # 4. PaintingNet down blocks
829
+ brushnet_down_block_res_samples = ()
830
+ for down_block_res_sample, brushnet_down_block in zip(down_block_res_samples, self.brushnet_down_blocks):
831
+ down_block_res_sample = brushnet_down_block(down_block_res_sample)
832
+ brushnet_down_block_res_samples = brushnet_down_block_res_samples + (down_block_res_sample,)
833
+
834
+ # 5. mid
835
+ if self.mid_block is not None:
836
+ if hasattr(self.mid_block, "has_cross_attention") and self.mid_block.has_cross_attention:
837
+ sample = self.mid_block(
838
+ sample,
839
+ emb,
840
+ encoder_hidden_states=encoder_hidden_states,
841
+ attention_mask=attention_mask,
842
+ cross_attention_kwargs=cross_attention_kwargs,
843
+ )
844
+ else:
845
+ sample = self.mid_block(sample, emb)
846
+
847
+ # 6. BrushNet mid blocks
848
+ brushnet_mid_block_res_sample = self.brushnet_mid_block(sample)
849
+
850
+ # 7. up
851
+ up_block_res_samples = ()
852
+ for i, upsample_block in enumerate(self.up_blocks):
853
+ is_final_block = i == len(self.up_blocks) - 1
854
+
855
+ res_samples = down_block_res_samples[-len(upsample_block.resnets):]
856
+ down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
857
+
858
+ # if we have not reached the final block and need to forward the
859
+ # upsample size, we do it here
860
+ if not is_final_block:
861
+ upsample_size = down_block_res_samples[-1].shape[2:]
862
+
863
+ if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
864
+ sample, up_res_samples = upsample_block(
865
+ hidden_states=sample,
866
+ temb=emb,
867
+ res_hidden_states_tuple=res_samples,
868
+ encoder_hidden_states=encoder_hidden_states,
869
+ cross_attention_kwargs=cross_attention_kwargs,
870
+ upsample_size=upsample_size,
871
+ attention_mask=attention_mask,
872
+ return_res_samples=True
873
+ )
874
+ else:
875
+ sample, up_res_samples = upsample_block(
876
+ hidden_states=sample,
877
+ temb=emb,
878
+ res_hidden_states_tuple=res_samples,
879
+ upsample_size=upsample_size,
880
+ return_res_samples=True
881
+ )
882
+
883
+ up_block_res_samples += up_res_samples
884
+
885
+ # 8. BrushNet up blocks
886
+ brushnet_up_block_res_samples = ()
887
+ for up_block_res_sample, brushnet_up_block in zip(up_block_res_samples, self.brushnet_up_blocks):
888
+ up_block_res_sample = brushnet_up_block(up_block_res_sample)
889
+ brushnet_up_block_res_samples = brushnet_up_block_res_samples + (up_block_res_sample,)
890
+
891
+ # 6. scaling
892
+ if guess_mode and not self.config.global_pool_conditions:
893
+ scales = torch.logspace(-1, 0,
894
+ len(brushnet_down_block_res_samples) + 1 + len(brushnet_up_block_res_samples),
895
+ device=sample.device) # 0.1 to 1.0
896
+ scales = scales * conditioning_scale
897
+
898
+ brushnet_down_block_res_samples = [sample * scale for sample, scale in zip(brushnet_down_block_res_samples,
899
+ scales[:len(
900
+ brushnet_down_block_res_samples)])]
901
+ brushnet_mid_block_res_sample = brushnet_mid_block_res_sample * scales[len(brushnet_down_block_res_samples)]
902
+ brushnet_up_block_res_samples = [sample * scale for sample, scale in zip(brushnet_up_block_res_samples,
903
+ scales[
904
+ len(brushnet_down_block_res_samples) + 1:])]
905
+ else:
906
+ brushnet_down_block_res_samples = [sample * conditioning_scale for sample in
907
+ brushnet_down_block_res_samples]
908
+ brushnet_mid_block_res_sample = brushnet_mid_block_res_sample * conditioning_scale
909
+ brushnet_up_block_res_samples = [sample * conditioning_scale for sample in brushnet_up_block_res_samples]
910
+
911
+ if self.config.global_pool_conditions:
912
+ brushnet_down_block_res_samples = [
913
+ torch.mean(sample, dim=(2, 3), keepdim=True) for sample in brushnet_down_block_res_samples
914
+ ]
915
+ brushnet_mid_block_res_sample = torch.mean(brushnet_mid_block_res_sample, dim=(2, 3), keepdim=True)
916
+ brushnet_up_block_res_samples = [
917
+ torch.mean(sample, dim=(2, 3), keepdim=True) for sample in brushnet_up_block_res_samples
918
+ ]
919
+
920
+ if not return_dict:
921
+ return (brushnet_down_block_res_samples, brushnet_mid_block_res_sample, brushnet_up_block_res_samples)
922
+
923
+ return BrushNetOutput(
924
+ down_block_res_samples=brushnet_down_block_res_samples,
925
+ mid_block_res_sample=brushnet_mid_block_res_sample,
926
+ up_block_res_samples=brushnet_up_block_res_samples
927
+ )
928
+
929
+
930
+ def zero_module(module):
931
+ for p in module.parameters():
932
+ nn.init.zeros_(p)
933
+ return module
powerpaint_v2/pipeline_PowerPaint_Brushnet_CA.py ADDED
@@ -0,0 +1,1494 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import inspect
2
+ from typing import Any, Callable, Dict, List, Optional, Union
3
+
4
+ import numpy as np
5
+ import PIL.Image
6
+ import torch
7
+ import torch.nn.functional as F
8
+ from diffusers import StableDiffusionMixin
9
+ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
10
+
11
+ from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
12
+ from diffusers.loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin
13
+ from diffusers.models import AutoencoderKL, ImageProjection
14
+ from diffusers.models.lora import adjust_lora_scale_text_encoder
15
+ from diffusers.schedulers import KarrasDiffusionSchedulers
16
+ from diffusers.utils import (
17
+ USE_PEFT_BACKEND,
18
+ deprecate,
19
+ logging,
20
+ replace_example_docstring,
21
+ scale_lora_layers,
22
+ unscale_lora_layers,
23
+ )
24
+ from diffusers.utils.torch_utils import is_compiled_module, is_torch_version, randn_tensor
25
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
26
+ from diffusers.pipelines.stable_diffusion.pipeline_output import StableDiffusionPipelineOutput
27
+ from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
28
+
29
+ from .BrushNet_CA import BrushNetModel
30
+ from .unet_2d_condition import UNet2DConditionModel
31
+
32
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
33
+
34
+ EXAMPLE_DOC_STRING = """
35
+ Examples:
36
+ ```py
37
+ from diffusers import StableDiffusionBrushNetPipeline, BrushNetModel, UniPCMultistepScheduler
38
+ from diffusers.utils import load_image
39
+ import torch
40
+ import cv2
41
+ import numpy as np
42
+ from PIL import Image
43
+
44
+ base_model_path = "runwayml/stable-diffusion-v1-5"
45
+ brushnet_path = "ckpt_path"
46
+
47
+ brushnet = BrushNetModel.from_pretrained(brushnet_path, torch_dtype=torch.float16)
48
+ pipe = StableDiffusionBrushNetPipeline.from_pretrained(
49
+ base_model_path, brushnet=brushnet, torch_dtype=torch.float16, low_cpu_mem_usage=False
50
+ )
51
+
52
+ # speed up diffusion process with faster scheduler and memory optimization
53
+ pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
54
+ # remove following line if xformers is not installed or when using Torch 2.0.
55
+ # pipe.enable_xformers_memory_efficient_attention()
56
+ # memory optimization.
57
+ pipe.enable_model_cpu_offload()
58
+
59
+ image_path="examples/brushnet/src/test_image.jpg"
60
+ mask_path="examples/brushnet/src/test_mask.jpg"
61
+ caption="A cake on the table."
62
+
63
+ init_image = cv2.imread(image_path)
64
+ mask_image = 1.*(cv2.imread(mask_path).sum(-1)>255)[:,:,np.newaxis]
65
+ init_image = init_image * (1-mask_image)
66
+
67
+ init_image = Image.fromarray(init_image.astype(np.uint8)).convert("RGB")
68
+ mask_image = Image.fromarray(mask_image.astype(np.uint8).repeat(3,-1)*255).convert("RGB")
69
+
70
+ generator = torch.Generator("cuda").manual_seed(1234)
71
+
72
+ image = pipe(
73
+ caption,
74
+ init_image,
75
+ mask_image,
76
+ num_inference_steps=50,
77
+ generator=generator,
78
+ paintingnet_conditioning_scale=1.0
79
+ ).images[0]
80
+ image.save("output.png")
81
+ ```
82
+ """
83
+
84
+
85
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
86
+ def retrieve_timesteps(
87
+ scheduler,
88
+ num_inference_steps: Optional[int] = None,
89
+ device: Optional[Union[str, torch.device]] = None,
90
+ timesteps: Optional[List[int]] = None,
91
+ **kwargs,
92
+ ):
93
+ """
94
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
95
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
96
+
97
+ Args:
98
+ scheduler (`SchedulerMixin`):
99
+ The scheduler to get timesteps from.
100
+ num_inference_steps (`int`):
101
+ The number of diffusion steps used when generating samples with a pre-trained model. If used,
102
+ `timesteps` must be `None`.
103
+ device (`str` or `torch.device`, *optional*):
104
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
105
+ timesteps (`List[int]`, *optional*):
106
+ Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default
107
+ timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps`
108
+ must be `None`.
109
+
110
+ Returns:
111
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
112
+ second element is the number of inference steps.
113
+ """
114
+ if timesteps is not None:
115
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
116
+ if not accepts_timesteps:
117
+ raise ValueError(
118
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
119
+ f" timestep schedules. Please check whether you are using the correct scheduler."
120
+ )
121
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
122
+ timesteps = scheduler.timesteps
123
+ num_inference_steps = len(timesteps)
124
+ else:
125
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
126
+ timesteps = scheduler.timesteps
127
+ return timesteps, num_inference_steps
128
+
129
+
130
+ class StableDiffusionPowerPaintBrushNetPipeline(
131
+ DiffusionPipeline,
132
+ StableDiffusionMixin,
133
+ TextualInversionLoaderMixin,
134
+ LoraLoaderMixin,
135
+ IPAdapterMixin,
136
+ FromSingleFileMixin,
137
+ ):
138
+ r"""
139
+ Pipeline for text-to-image generation using Stable Diffusion with BrushNet guidance.
140
+
141
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
142
+ implemented for all pipelines (downloading, saving, running on a particular device, etc.).
143
+
144
+ The pipeline also inherits the following loading methods:
145
+ - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
146
+ - [`~loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights
147
+ - [`~loaders.LoraLoaderMixin.save_lora_weights`] for saving LoRA weights
148
+ - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
149
+ - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
150
+
151
+ Args:
152
+ vae ([`AutoencoderKL`]):
153
+ Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.
154
+ text_encoder ([`~transformers.CLIPTextModel`]):
155
+ Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).
156
+ tokenizer ([`~transformers.CLIPTokenizer`]):
157
+ A `CLIPTokenizer` to tokenize text.
158
+ unet ([`UNet2DConditionModel`]):
159
+ A `UNet2DConditionModel` to denoise the encoded image latents.
160
+ brushnet ([`BrushNetModel`]`):
161
+ Provides additional conditioning to the `unet` during the denoising process.
162
+ scheduler ([`SchedulerMixin`]):
163
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
164
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
165
+ safety_checker ([`StableDiffusionSafetyChecker`]):
166
+ Classification module that estimates whether generated images could be considered offensive or harmful.
167
+ Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details
168
+ about a model's potential harms.
169
+ feature_extractor ([`~transformers.CLIPImageProcessor`]):
170
+ A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
171
+ """
172
+
173
+ model_cpu_offload_seq = "text_encoder->image_encoder->unet->vae"
174
+ _optional_components = ["safety_checker", "feature_extractor", "image_encoder"]
175
+ _exclude_from_cpu_offload = ["safety_checker"]
176
+ _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
177
+
178
+ def __init__(
179
+ self,
180
+ vae: AutoencoderKL,
181
+ text_encoder: CLIPTextModel,
182
+ text_encoder_brushnet: CLIPTextModel,
183
+ tokenizer: CLIPTokenizer,
184
+ unet: UNet2DConditionModel,
185
+ brushnet: BrushNetModel,
186
+ scheduler: KarrasDiffusionSchedulers,
187
+ safety_checker: StableDiffusionSafetyChecker,
188
+ feature_extractor: CLIPImageProcessor,
189
+ image_encoder: CLIPVisionModelWithProjection = None,
190
+ requires_safety_checker: bool = True,
191
+ ):
192
+ super().__init__()
193
+
194
+ if safety_checker is None and requires_safety_checker:
195
+ logger.warning(
196
+ f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
197
+ " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
198
+ " results in services or applications open to the public. Both the diffusers team and Hugging Face"
199
+ " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
200
+ " it only for use-cases that involve analyzing network behavior or auditing its results. For more"
201
+ " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
202
+ )
203
+
204
+ if safety_checker is not None and feature_extractor is None:
205
+ raise ValueError(
206
+ "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
207
+ " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
208
+ )
209
+
210
+ self.register_modules(
211
+ vae=vae,
212
+ text_encoder=text_encoder,
213
+ text_encoder_brushnet=text_encoder_brushnet,
214
+ tokenizer=tokenizer,
215
+ unet=unet,
216
+ brushnet=brushnet,
217
+ scheduler=scheduler,
218
+ safety_checker=safety_checker,
219
+ feature_extractor=feature_extractor,
220
+ image_encoder=image_encoder,
221
+ )
222
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
223
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True)
224
+ self.register_to_config(requires_safety_checker=requires_safety_checker)
225
+
226
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
227
+ def _encode_prompt(
228
+ self,
229
+ promptA,
230
+ promptB,
231
+ t,
232
+ device,
233
+ num_images_per_prompt,
234
+ do_classifier_free_guidance,
235
+ negative_promptA=None,
236
+ negative_promptB=None,
237
+ t_nag=None,
238
+ prompt_embeds: Optional[torch.FloatTensor] = None,
239
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
240
+ lora_scale: Optional[float] = None,
241
+ ):
242
+ r"""
243
+ Encodes the prompt into text encoder hidden states.
244
+
245
+ Args:
246
+ prompt (`str` or `List[str]`, *optional*):
247
+ prompt to be encoded
248
+ device: (`torch.device`):
249
+ torch device
250
+ num_images_per_prompt (`int`):
251
+ number of images that should be generated per prompt
252
+ do_classifier_free_guidance (`bool`):
253
+ whether to use classifier free guidance or not
254
+ negative_prompt (`str` or `List[str]`, *optional*):
255
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
256
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
257
+ less than `1`).
258
+ prompt_embeds (`torch.FloatTensor`, *optional*):
259
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
260
+ provided, text embeddings will be generated from `prompt` input argument.
261
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
262
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
263
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
264
+ argument.
265
+ lora_scale (`float`, *optional*):
266
+ A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
267
+ """
268
+ # set lora scale so that monkey patched LoRA
269
+ # function of text encoder can correctly access it
270
+ if lora_scale is not None and isinstance(self, LoraLoaderMixin):
271
+ self._lora_scale = lora_scale
272
+
273
+ prompt = promptA
274
+ negative_prompt = negative_promptA
275
+
276
+ if promptA is not None and isinstance(promptA, str):
277
+ batch_size = 1
278
+ elif promptA is not None and isinstance(promptA, list):
279
+ batch_size = len(promptA)
280
+ else:
281
+ batch_size = prompt_embeds.shape[0]
282
+
283
+ if prompt_embeds is None:
284
+ # textual inversion: procecss multi-vector tokens if necessary
285
+ if isinstance(self, TextualInversionLoaderMixin):
286
+ promptA = self.maybe_convert_prompt(promptA, self.tokenizer)
287
+
288
+ text_inputsA = self.tokenizer(
289
+ promptA,
290
+ padding="max_length",
291
+ max_length=self.tokenizer.model_max_length,
292
+ truncation=True,
293
+ return_tensors="pt",
294
+ )
295
+ text_inputsB = self.tokenizer(
296
+ promptB,
297
+ padding="max_length",
298
+ max_length=self.tokenizer.model_max_length,
299
+ truncation=True,
300
+ return_tensors="pt",
301
+ )
302
+ text_input_idsA = text_inputsA.input_ids
303
+ text_input_idsB = text_inputsB.input_ids
304
+ untruncated_ids = self.tokenizer(promptA, padding="longest", return_tensors="pt").input_ids
305
+
306
+ if untruncated_ids.shape[-1] >= text_input_idsA.shape[-1] and not torch.equal(
307
+ text_input_idsA, untruncated_ids
308
+ ):
309
+ removed_text = self.tokenizer.batch_decode(
310
+ untruncated_ids[:, self.tokenizer.model_max_length - 1: -1]
311
+ )
312
+ logger.warning(
313
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
314
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
315
+ )
316
+
317
+ if hasattr(self.text_encoder_brushnet.config,
318
+ "use_attention_mask") and self.text_encoder_brushnet.config.use_attention_mask:
319
+ attention_mask = text_inputsA.attention_mask.to(device)
320
+ else:
321
+ attention_mask = None
322
+
323
+ # print("text_input_idsA: ",text_input_idsA)
324
+ # print("text_input_idsB: ",text_input_idsB)
325
+ # print('t: ',t)
326
+
327
+ prompt_embedsA = self.text_encoder_brushnet(
328
+ text_input_idsA.to(device),
329
+ attention_mask=attention_mask,
330
+ )
331
+ prompt_embedsA = prompt_embedsA[0]
332
+
333
+ prompt_embedsB = self.text_encoder_brushnet(
334
+ text_input_idsB.to(device),
335
+ attention_mask=attention_mask,
336
+ )
337
+ prompt_embedsB = prompt_embedsB[0]
338
+ prompt_embeds = prompt_embedsA * (t) + (1 - t) * prompt_embedsB
339
+ # print("prompt_embeds: ",prompt_embeds)
340
+
341
+ if self.text_encoder_brushnet is not None:
342
+ prompt_embeds_dtype = self.text_encoder_brushnet.dtype
343
+ elif self.unet is not None:
344
+ prompt_embeds_dtype = self.unet.dtype
345
+ else:
346
+ prompt_embeds_dtype = prompt_embeds.dtype
347
+
348
+ prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
349
+
350
+ bs_embed, seq_len, _ = prompt_embeds.shape
351
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
352
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
353
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
354
+
355
+ # get unconditional embeddings for classifier free guidance
356
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
357
+ uncond_tokensA: List[str]
358
+ uncond_tokensB: List[str]
359
+ if negative_prompt is None:
360
+ uncond_tokensA = [""] * batch_size
361
+ uncond_tokensB = [""] * batch_size
362
+ elif prompt is not None and type(prompt) is not type(negative_prompt):
363
+ raise TypeError(
364
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
365
+ f" {type(prompt)}."
366
+ )
367
+ elif isinstance(negative_prompt, str):
368
+ uncond_tokensA = [negative_promptA]
369
+ uncond_tokensB = [negative_promptB]
370
+ elif batch_size != len(negative_prompt):
371
+ raise ValueError(
372
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
373
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
374
+ " the batch size of `prompt`."
375
+ )
376
+ else:
377
+ uncond_tokensA = negative_promptA
378
+ uncond_tokensB = negative_promptB
379
+
380
+ # textual inversion: procecss multi-vector tokens if necessary
381
+ if isinstance(self, TextualInversionLoaderMixin):
382
+ uncond_tokensA = self.maybe_convert_prompt(uncond_tokensA, self.tokenizer)
383
+ uncond_tokensB = self.maybe_convert_prompt(uncond_tokensB, self.tokenizer)
384
+
385
+ max_length = prompt_embeds.shape[1]
386
+ uncond_inputA = self.tokenizer(
387
+ uncond_tokensA,
388
+ padding="max_length",
389
+ max_length=max_length,
390
+ truncation=True,
391
+ return_tensors="pt",
392
+ )
393
+ uncond_inputB = self.tokenizer(
394
+ uncond_tokensB,
395
+ padding="max_length",
396
+ max_length=max_length,
397
+ truncation=True,
398
+ return_tensors="pt",
399
+ )
400
+
401
+ if hasattr(self.text_encoder_brushnet.config,
402
+ "use_attention_mask") and self.text_encoder_brushnet.config.use_attention_mask:
403
+ attention_mask = uncond_inputA.attention_mask.to(device)
404
+ else:
405
+ attention_mask = None
406
+
407
+ negative_prompt_embedsA = self.text_encoder_brushnet(
408
+ uncond_inputA.input_ids.to(device),
409
+ attention_mask=attention_mask,
410
+ )
411
+ negative_prompt_embedsB = self.text_encoder_brushnet(
412
+ uncond_inputB.input_ids.to(device),
413
+ attention_mask=attention_mask,
414
+ )
415
+ negative_prompt_embeds = negative_prompt_embedsA[0] * (t_nag) + (1 - t_nag) * negative_prompt_embedsB[0]
416
+
417
+ # negative_prompt_embeds = negative_prompt_embeds[0]
418
+
419
+ if do_classifier_free_guidance:
420
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
421
+ seq_len = negative_prompt_embeds.shape[1]
422
+
423
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
424
+
425
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
426
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
427
+
428
+ # For classifier free guidance, we need to do two forward passes.
429
+ # Here we concatenate the unconditional and text embeddings into a single batch
430
+ # to avoid doing two forward passes
431
+ # print("prompt_embeds: ",prompt_embeds)
432
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
433
+
434
+ return prompt_embeds
435
+
436
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt
437
+ def encode_prompt(
438
+ self,
439
+ prompt,
440
+ device,
441
+ num_images_per_prompt,
442
+ do_classifier_free_guidance,
443
+ negative_prompt=None,
444
+ prompt_embeds: Optional[torch.FloatTensor] = None,
445
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
446
+ lora_scale: Optional[float] = None,
447
+ clip_skip: Optional[int] = None,
448
+ ):
449
+ r"""
450
+ Encodes the prompt into text encoder hidden states.
451
+
452
+ Args:
453
+ prompt (`str` or `List[str]`, *optional*):
454
+ prompt to be encoded
455
+ device: (`torch.device`):
456
+ torch device
457
+ num_images_per_prompt (`int`):
458
+ number of images that should be generated per prompt
459
+ do_classifier_free_guidance (`bool`):
460
+ whether to use classifier free guidance or not
461
+ negative_prompt (`str` or `List[str]`, *optional*):
462
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
463
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
464
+ less than `1`).
465
+ prompt_embeds (`torch.FloatTensor`, *optional*):
466
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
467
+ provided, text embeddings will be generated from `prompt` input argument.
468
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
469
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
470
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
471
+ argument.
472
+ lora_scale (`float`, *optional*):
473
+ A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
474
+ clip_skip (`int`, *optional*):
475
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
476
+ the output of the pre-final layer will be used for computing the prompt embeddings.
477
+ """
478
+ # set lora scale so that monkey patched LoRA
479
+ # function of text encoder can correctly access it
480
+ # print('1 ',prompt,negative_prompt)
481
+ if lora_scale is not None and isinstance(self, LoraLoaderMixin):
482
+ self._lora_scale = lora_scale
483
+
484
+ # dynamically adjust the LoRA scale
485
+ if not USE_PEFT_BACKEND:
486
+ adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
487
+ else:
488
+ scale_lora_layers(self.text_encoder, lora_scale)
489
+ # print('2 ',prompt,negative_prompt)
490
+ if prompt is not None and isinstance(prompt, str):
491
+ batch_size = 1
492
+ elif prompt is not None and isinstance(prompt, list):
493
+ batch_size = len(prompt)
494
+ else:
495
+ batch_size = prompt_embeds.shape[0]
496
+ # print('3 ',prompt,negative_prompt)
497
+ if prompt_embeds is None:
498
+ # textual inversion: process multi-vector tokens if necessary
499
+ # print('4 ',prompt,negative_prompt)
500
+ if isinstance(self, TextualInversionLoaderMixin):
501
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
502
+
503
+ # print('5 ',prompt,negative_prompt)
504
+
505
+ text_inputs = self.tokenizer(
506
+ prompt,
507
+ padding="max_length",
508
+ max_length=self.tokenizer.model_max_length,
509
+ truncation=True,
510
+ return_tensors="pt",
511
+ )
512
+ text_input_ids = text_inputs.input_ids
513
+ # print(prompt, text_input_ids)
514
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
515
+
516
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
517
+ text_input_ids, untruncated_ids
518
+ ):
519
+ removed_text = self.tokenizer.batch_decode(
520
+ untruncated_ids[:, self.tokenizer.model_max_length - 1: -1]
521
+ )
522
+ logger.warning(
523
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
524
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
525
+ )
526
+
527
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
528
+ attention_mask = text_inputs.attention_mask.to(device)
529
+ else:
530
+ attention_mask = None
531
+
532
+ if clip_skip is None:
533
+ prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask)
534
+ prompt_embeds = prompt_embeds[0]
535
+ else:
536
+ prompt_embeds = self.text_encoder(
537
+ text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True
538
+ )
539
+ # Access the `hidden_states` first, that contains a tuple of
540
+ # all the hidden states from the encoder layers. Then index into
541
+ # the tuple to access the hidden states from the desired layer.
542
+ prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)]
543
+ # We also need to apply the final LayerNorm here to not mess with the
544
+ # representations. The `last_hidden_states` that we typically use for
545
+ # obtaining the final prompt representations passes through the LayerNorm
546
+ # layer.
547
+ prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds)
548
+
549
+ if self.text_encoder is not None:
550
+ prompt_embeds_dtype = self.text_encoder.dtype
551
+ elif self.unet is not None:
552
+ prompt_embeds_dtype = self.unet.dtype
553
+ else:
554
+ prompt_embeds_dtype = prompt_embeds.dtype
555
+
556
+ prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
557
+
558
+ bs_embed, seq_len, _ = prompt_embeds.shape
559
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
560
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
561
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
562
+
563
+ # get unconditional embeddings for classifier free guidance
564
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
565
+ uncond_tokens: List[str]
566
+ if negative_prompt is None:
567
+ uncond_tokens = [""] * batch_size
568
+ elif prompt is not None and type(prompt) is not type(negative_prompt):
569
+ raise TypeError(
570
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
571
+ f" {type(prompt)}."
572
+ )
573
+ elif isinstance(negative_prompt, str):
574
+ uncond_tokens = [negative_prompt]
575
+ elif batch_size != len(negative_prompt):
576
+ raise ValueError(
577
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
578
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
579
+ " the batch size of `prompt`."
580
+ )
581
+ else:
582
+ uncond_tokens = negative_prompt
583
+
584
+ # textual inversion: process multi-vector tokens if necessary
585
+ if isinstance(self, TextualInversionLoaderMixin):
586
+ uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
587
+
588
+ max_length = prompt_embeds.shape[1]
589
+ uncond_input = self.tokenizer(
590
+ uncond_tokens,
591
+ padding="max_length",
592
+ max_length=max_length,
593
+ truncation=True,
594
+ return_tensors="pt",
595
+ )
596
+ # print("neg: ", uncond_input.input_ids)
597
+
598
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
599
+ attention_mask = uncond_input.attention_mask.to(device)
600
+ else:
601
+ attention_mask = None
602
+
603
+ negative_prompt_embeds = self.text_encoder(
604
+ uncond_input.input_ids.to(device),
605
+ attention_mask=attention_mask,
606
+ )
607
+ negative_prompt_embeds = negative_prompt_embeds[0]
608
+
609
+ if do_classifier_free_guidance:
610
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
611
+ seq_len = negative_prompt_embeds.shape[1]
612
+
613
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
614
+
615
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
616
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
617
+
618
+ if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
619
+ # Retrieve the original scale by scaling back the LoRA layers
620
+ unscale_lora_layers(self.text_encoder, lora_scale)
621
+
622
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
623
+
624
+ return prompt_embeds
625
+
626
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image
627
+ def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None):
628
+ dtype = next(self.image_encoder.parameters()).dtype
629
+
630
+ if not isinstance(image, torch.Tensor):
631
+ image = self.feature_extractor(image, return_tensors="pt").pixel_values
632
+
633
+ image = image.to(device=device, dtype=dtype)
634
+ if output_hidden_states:
635
+ image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
636
+ image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
637
+ uncond_image_enc_hidden_states = self.image_encoder(
638
+ torch.zeros_like(image), output_hidden_states=True
639
+ ).hidden_states[-2]
640
+ uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(
641
+ num_images_per_prompt, dim=0
642
+ )
643
+ return image_enc_hidden_states, uncond_image_enc_hidden_states
644
+ else:
645
+ image_embeds = self.image_encoder(image).image_embeds
646
+ image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
647
+ uncond_image_embeds = torch.zeros_like(image_embeds)
648
+
649
+ return image_embeds, uncond_image_embeds
650
+
651
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds
652
+ def prepare_ip_adapter_image_embeds(
653
+ self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance
654
+ ):
655
+ if ip_adapter_image_embeds is None:
656
+ if not isinstance(ip_adapter_image, list):
657
+ ip_adapter_image = [ip_adapter_image]
658
+
659
+ if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers):
660
+ raise ValueError(
661
+ f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
662
+ )
663
+
664
+ image_embeds = []
665
+ for single_ip_adapter_image, image_proj_layer in zip(
666
+ ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers
667
+ ):
668
+ output_hidden_state = not isinstance(image_proj_layer, ImageProjection)
669
+ single_image_embeds, single_negative_image_embeds = self.encode_image(
670
+ single_ip_adapter_image, device, 1, output_hidden_state
671
+ )
672
+ single_image_embeds = torch.stack([single_image_embeds] * num_images_per_prompt, dim=0)
673
+ single_negative_image_embeds = torch.stack(
674
+ [single_negative_image_embeds] * num_images_per_prompt, dim=0
675
+ )
676
+
677
+ if do_classifier_free_guidance:
678
+ single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
679
+ single_image_embeds = single_image_embeds.to(device)
680
+
681
+ image_embeds.append(single_image_embeds)
682
+ else:
683
+ repeat_dims = [1]
684
+ image_embeds = []
685
+ for single_image_embeds in ip_adapter_image_embeds:
686
+ if do_classifier_free_guidance:
687
+ single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2)
688
+ single_image_embeds = single_image_embeds.repeat(
689
+ num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
690
+ )
691
+ single_negative_image_embeds = single_negative_image_embeds.repeat(
692
+ num_images_per_prompt, *(repeat_dims * len(single_negative_image_embeds.shape[1:]))
693
+ )
694
+ single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
695
+ else:
696
+ single_image_embeds = single_image_embeds.repeat(
697
+ num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
698
+ )
699
+ image_embeds.append(single_image_embeds)
700
+
701
+ return image_embeds
702
+
703
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
704
+ def run_safety_checker(self, image, device, dtype):
705
+ if self.safety_checker is None:
706
+ has_nsfw_concept = None
707
+ else:
708
+ if torch.is_tensor(image):
709
+ feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
710
+ else:
711
+ feature_extractor_input = self.image_processor.numpy_to_pil(image)
712
+ safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
713
+ image, has_nsfw_concept = self.safety_checker(
714
+ images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
715
+ )
716
+ return image, has_nsfw_concept
717
+
718
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
719
+ def decode_latents(self, latents):
720
+ deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead"
721
+ deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False)
722
+
723
+ latents = 1 / self.vae.config.scaling_factor * latents
724
+ image = self.vae.decode(latents, return_dict=False)[0]
725
+ image = (image / 2 + 0.5).clamp(0, 1)
726
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
727
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
728
+ return image
729
+
730
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
731
+ def prepare_extra_step_kwargs(self, generator, eta):
732
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
733
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
734
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
735
+ # and should be between [0, 1]
736
+
737
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
738
+ extra_step_kwargs = {}
739
+ if accepts_eta:
740
+ extra_step_kwargs["eta"] = eta
741
+
742
+ # check if the scheduler accepts generator
743
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
744
+ if accepts_generator:
745
+ extra_step_kwargs["generator"] = generator
746
+ return extra_step_kwargs
747
+
748
+ def check_inputs(
749
+ self,
750
+ prompt,
751
+ image,
752
+ mask,
753
+ callback_steps,
754
+ negative_prompt=None,
755
+ prompt_embeds=None,
756
+ negative_prompt_embeds=None,
757
+ ip_adapter_image=None,
758
+ ip_adapter_image_embeds=None,
759
+ brushnet_conditioning_scale=1.0,
760
+ control_guidance_start=0.0,
761
+ control_guidance_end=1.0,
762
+ callback_on_step_end_tensor_inputs=None,
763
+ ):
764
+ if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):
765
+ raise ValueError(
766
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
767
+ f" {type(callback_steps)}."
768
+ )
769
+
770
+ if callback_on_step_end_tensor_inputs is not None and not all(
771
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
772
+ ):
773
+ raise ValueError(
774
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
775
+ )
776
+
777
+ if prompt is not None and prompt_embeds is not None:
778
+ raise ValueError(
779
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
780
+ " only forward one of the two."
781
+ )
782
+ elif prompt is None and prompt_embeds is None:
783
+ raise ValueError(
784
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
785
+ )
786
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
787
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
788
+
789
+ if negative_prompt is not None and negative_prompt_embeds is not None:
790
+ raise ValueError(
791
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
792
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
793
+ )
794
+
795
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
796
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
797
+ raise ValueError(
798
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
799
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
800
+ f" {negative_prompt_embeds.shape}."
801
+ )
802
+
803
+ # Check `image`
804
+ is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance(
805
+ self.brushnet, torch._dynamo.eval_frame.OptimizedModule
806
+ )
807
+ if (
808
+ isinstance(self.brushnet, BrushNetModel)
809
+ or is_compiled
810
+ and isinstance(self.brushnet._orig_mod, BrushNetModel)
811
+ ):
812
+ self.check_image(image, mask, prompt, prompt_embeds)
813
+ else:
814
+ assert False
815
+
816
+ # Check `brushnet_conditioning_scale`
817
+ if (
818
+ isinstance(self.brushnet, BrushNetModel)
819
+ or is_compiled
820
+ and isinstance(self.brushnet._orig_mod, BrushNetModel)
821
+ ):
822
+ if not isinstance(brushnet_conditioning_scale, float):
823
+ raise TypeError("For single brushnet: `brushnet_conditioning_scale` must be type `float`.")
824
+ else:
825
+ assert False
826
+
827
+ if not isinstance(control_guidance_start, (tuple, list)):
828
+ control_guidance_start = [control_guidance_start]
829
+
830
+ if not isinstance(control_guidance_end, (tuple, list)):
831
+ control_guidance_end = [control_guidance_end]
832
+
833
+ if len(control_guidance_start) != len(control_guidance_end):
834
+ raise ValueError(
835
+ f"`control_guidance_start` has {len(control_guidance_start)} elements, but `control_guidance_end` has {len(control_guidance_end)} elements. Make sure to provide the same number of elements to each list."
836
+ )
837
+
838
+ for start, end in zip(control_guidance_start, control_guidance_end):
839
+ if start >= end:
840
+ raise ValueError(
841
+ f"control guidance start: {start} cannot be larger or equal to control guidance end: {end}."
842
+ )
843
+ if start < 0.0:
844
+ raise ValueError(f"control guidance start: {start} can't be smaller than 0.")
845
+ if end > 1.0:
846
+ raise ValueError(f"control guidance end: {end} can't be larger than 1.0.")
847
+
848
+ if ip_adapter_image is not None and ip_adapter_image_embeds is not None:
849
+ raise ValueError(
850
+ "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined."
851
+ )
852
+
853
+ if ip_adapter_image_embeds is not None:
854
+ if not isinstance(ip_adapter_image_embeds, list):
855
+ raise ValueError(
856
+ f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}"
857
+ )
858
+ elif ip_adapter_image_embeds[0].ndim not in [3, 4]:
859
+ raise ValueError(
860
+ f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D"
861
+ )
862
+
863
+ def check_image(self, image, mask, prompt, prompt_embeds):
864
+ image_is_pil = isinstance(image, PIL.Image.Image)
865
+ image_is_tensor = isinstance(image, torch.Tensor)
866
+ image_is_np = isinstance(image, np.ndarray)
867
+ image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image)
868
+ image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor)
869
+ image_is_np_list = isinstance(image, list) and isinstance(image[0], np.ndarray)
870
+
871
+ if (
872
+ not image_is_pil
873
+ and not image_is_tensor
874
+ and not image_is_np
875
+ and not image_is_pil_list
876
+ and not image_is_tensor_list
877
+ and not image_is_np_list
878
+ ):
879
+ raise TypeError(
880
+ f"image must be passed and be one of PIL image, numpy array, torch tensor, list of PIL images, list of numpy arrays or list of torch tensors, but is {type(image)}"
881
+ )
882
+
883
+ mask_is_pil = isinstance(mask, PIL.Image.Image)
884
+ mask_is_tensor = isinstance(mask, torch.Tensor)
885
+ mask_is_np = isinstance(mask, np.ndarray)
886
+ mask_is_pil_list = isinstance(mask, list) and isinstance(mask[0], PIL.Image.Image)
887
+ mask_is_tensor_list = isinstance(mask, list) and isinstance(mask[0], torch.Tensor)
888
+ mask_is_np_list = isinstance(mask, list) and isinstance(mask[0], np.ndarray)
889
+
890
+ if (
891
+ not mask_is_pil
892
+ and not mask_is_tensor
893
+ and not mask_is_np
894
+ and not mask_is_pil_list
895
+ and not mask_is_tensor_list
896
+ and not mask_is_np_list
897
+ ):
898
+ raise TypeError(
899
+ f"mask must be passed and be one of PIL image, numpy array, torch tensor, list of PIL images, list of numpy arrays or list of torch tensors, but is {type(mask)}"
900
+ )
901
+
902
+ if image_is_pil:
903
+ image_batch_size = 1
904
+ else:
905
+ image_batch_size = len(image)
906
+
907
+ if prompt is not None and isinstance(prompt, str):
908
+ prompt_batch_size = 1
909
+ elif prompt is not None and isinstance(prompt, list):
910
+ prompt_batch_size = len(prompt)
911
+ elif prompt_embeds is not None:
912
+ prompt_batch_size = prompt_embeds.shape[0]
913
+
914
+ if image_batch_size != 1 and image_batch_size != prompt_batch_size:
915
+ raise ValueError(
916
+ f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}"
917
+ )
918
+
919
+ def prepare_image(
920
+ self,
921
+ image,
922
+ width,
923
+ height,
924
+ batch_size,
925
+ num_images_per_prompt,
926
+ device,
927
+ dtype,
928
+ do_classifier_free_guidance=False,
929
+ guess_mode=False,
930
+ ):
931
+ image = self.image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32)
932
+ image_batch_size = image.shape[0]
933
+
934
+ if image_batch_size == 1:
935
+ repeat_by = batch_size
936
+ else:
937
+ # image batch size is the same as prompt batch size
938
+ repeat_by = num_images_per_prompt
939
+
940
+ image = image.repeat_interleave(repeat_by, dim=0)
941
+
942
+ image = image.to(device=device, dtype=dtype)
943
+
944
+ if do_classifier_free_guidance and not guess_mode:
945
+ image = torch.cat([image] * 2)
946
+
947
+ return image.to(device=device, dtype=dtype)
948
+
949
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
950
+ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
951
+ shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
952
+ if isinstance(generator, list) and len(generator) != batch_size:
953
+ raise ValueError(
954
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
955
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
956
+ )
957
+
958
+ if latents is None:
959
+ noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
960
+ else:
961
+ noise = latents.to(device)
962
+
963
+ # scale the initial noise by the standard deviation required by the scheduler
964
+ latents = noise * self.scheduler.init_noise_sigma
965
+ return latents, noise
966
+
967
+ # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding
968
+ def get_guidance_scale_embedding(self, w, embedding_dim=512, dtype=torch.float32):
969
+ """
970
+ See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298
971
+
972
+ Args:
973
+ timesteps (`torch.Tensor`):
974
+ generate embedding vectors at these timesteps
975
+ embedding_dim (`int`, *optional*, defaults to 512):
976
+ dimension of the embeddings to generate
977
+ dtype:
978
+ data type of the generated embeddings
979
+
980
+ Returns:
981
+ `torch.FloatTensor`: Embedding vectors with shape `(len(timesteps), embedding_dim)`
982
+ """
983
+ assert len(w.shape) == 1
984
+ w = w * 1000.0
985
+
986
+ half_dim = embedding_dim // 2
987
+ emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
988
+ emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)
989
+ emb = w.to(dtype)[:, None] * emb[None, :]
990
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
991
+ if embedding_dim % 2 == 1: # zero pad
992
+ emb = torch.nn.functional.pad(emb, (0, 1))
993
+ assert emb.shape == (w.shape[0], embedding_dim)
994
+ return emb
995
+
996
+ @property
997
+ def guidance_scale(self):
998
+ return self._guidance_scale
999
+
1000
+ @property
1001
+ def clip_skip(self):
1002
+ return self._clip_skip
1003
+
1004
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
1005
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
1006
+ # corresponds to doing no classifier free guidance.
1007
+ @property
1008
+ def do_classifier_free_guidance(self):
1009
+ return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None
1010
+
1011
+ @property
1012
+ def cross_attention_kwargs(self):
1013
+ return self._cross_attention_kwargs
1014
+
1015
+ @property
1016
+ def num_timesteps(self):
1017
+ return self._num_timesteps
1018
+
1019
+ @torch.no_grad()
1020
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
1021
+ def __call__(
1022
+ self,
1023
+ promptA: Union[str, List[str]] = None,
1024
+ promptB: Union[str, List[str]] = None,
1025
+ promptU: Union[str, List[str]] = None,
1026
+ tradoff: float = 1.0,
1027
+ tradoff_nag: float = 1.0,
1028
+ image: PipelineImageInput = None,
1029
+ mask: PipelineImageInput = None,
1030
+ height: Optional[int] = None,
1031
+ width: Optional[int] = None,
1032
+ num_inference_steps: int = 50,
1033
+ timesteps: List[int] = None,
1034
+ guidance_scale: float = 7.5,
1035
+ negative_promptA: Optional[Union[str, List[str]]] = None,
1036
+ negative_promptB: Optional[Union[str, List[str]]] = None,
1037
+ negative_promptU: Optional[Union[str, List[str]]] = None,
1038
+ num_images_per_prompt: Optional[int] = 1,
1039
+ eta: float = 0.0,
1040
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
1041
+ latents: Optional[torch.FloatTensor] = None,
1042
+ prompt_embeds: Optional[torch.FloatTensor] = None,
1043
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
1044
+ ip_adapter_image: Optional[PipelineImageInput] = None,
1045
+ ip_adapter_image_embeds: Optional[List[torch.FloatTensor]] = None,
1046
+ output_type: Optional[str] = "pil",
1047
+ return_dict: bool = True,
1048
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
1049
+ brushnet_conditioning_scale: Union[float, List[float]] = 1.0,
1050
+ guess_mode: bool = False,
1051
+ control_guidance_start: Union[float, List[float]] = 0.0,
1052
+ control_guidance_end: Union[float, List[float]] = 1.0,
1053
+ clip_skip: Optional[int] = None,
1054
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
1055
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
1056
+ **kwargs,
1057
+ ):
1058
+ r"""
1059
+ The call function to the pipeline for generation.
1060
+
1061
+ Args:
1062
+ prompt (`str` or `List[str]`, *optional*):
1063
+ The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
1064
+ image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,:
1065
+ `List[List[torch.FloatTensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`):
1066
+ The BrushNet input condition to provide guidance to the `unet` for generation. If the type is
1067
+ specified as `torch.FloatTensor`, it is passed to BrushNet as is. `PIL.Image.Image` can also be
1068
+ accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If height
1069
+ and/or width are passed, `image` is resized accordingly. If multiple BrushNets are specified in
1070
+ `init`, images must be passed as a list such that each element of the list can be correctly batched for
1071
+ input to a single BrushNet. When `prompt` is a list, and if a list of images is passed for a single BrushNet,
1072
+ each will be paired with each prompt in the `prompt` list. This also applies to multiple BrushNets,
1073
+ where a list of image lists can be passed to batch for each prompt and each BrushNet.
1074
+ mask (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,:
1075
+ `List[List[torch.FloatTensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`):
1076
+ The BrushNet input condition to provide guidance to the `unet` for generation. If the type is
1077
+ specified as `torch.FloatTensor`, it is passed to BrushNet as is. `PIL.Image.Image` can also be
1078
+ accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If height
1079
+ and/or width are passed, `image` is resized accordingly. If multiple BrushNets are specified in
1080
+ `init`, images must be passed as a list such that each element of the list can be correctly batched for
1081
+ input to a single BrushNet. When `prompt` is a list, and if a list of images is passed for a single BrushNet,
1082
+ each will be paired with each prompt in the `prompt` list. This also applies to multiple BrushNets,
1083
+ where a list of image lists can be passed to batch for each prompt and each BrushNet.
1084
+ height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
1085
+ The height in pixels of the generated image.
1086
+ width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
1087
+ The width in pixels of the generated image.
1088
+ num_inference_steps (`int`, *optional*, defaults to 50):
1089
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
1090
+ expense of slower inference.
1091
+ timesteps (`List[int]`, *optional*):
1092
+ Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
1093
+ in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
1094
+ passed will be used. Must be in descending order.
1095
+ guidance_scale (`float`, *optional*, defaults to 7.5):
1096
+ A higher guidance scale value encourages the model to generate images closely linked to the text
1097
+ `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
1098
+ negative_prompt (`str` or `List[str]`, *optional*):
1099
+ The prompt or prompts to guide what to not include in image generation. If not defined, you need to
1100
+ pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
1101
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
1102
+ The number of images to generate per prompt.
1103
+ eta (`float`, *optional*, defaults to 0.0):
1104
+ Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
1105
+ to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
1106
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
1107
+ A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
1108
+ generation deterministic.
1109
+ latents (`torch.FloatTensor`, *optional*):
1110
+ Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
1111
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
1112
+ tensor is generated by sampling using the supplied random `generator`.
1113
+ prompt_embeds (`torch.FloatTensor`, *optional*):
1114
+ Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
1115
+ provided, text embeddings are generated from the `prompt` input argument.
1116
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
1117
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
1118
+ not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
1119
+ ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
1120
+ ip_adapter_image_embeds (`List[torch.FloatTensor]`, *optional*):
1121
+ Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of IP-adapters.
1122
+ Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should contain the negative image embedding
1123
+ if `do_classifier_free_guidance` is set to `True`.
1124
+ If not provided, embeddings are computed from the `ip_adapter_image` input argument.
1125
+ output_type (`str`, *optional*, defaults to `"pil"`):
1126
+ The output format of the generated image. Choose between `PIL.Image` or `np.array`.
1127
+ return_dict (`bool`, *optional*, defaults to `True`):
1128
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
1129
+ plain tuple.
1130
+ callback (`Callable`, *optional*):
1131
+ A function that calls every `callback_steps` steps during inference. The function is called with the
1132
+ following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
1133
+ callback_steps (`int`, *optional*, defaults to 1):
1134
+ The frequency at which the `callback` function is called. If not specified, the callback is called at
1135
+ every step.
1136
+ cross_attention_kwargs (`dict`, *optional*):
1137
+ A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
1138
+ [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
1139
+ brushnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0):
1140
+ The outputs of the BrushNet are multiplied by `brushnet_conditioning_scale` before they are added
1141
+ to the residual in the original `unet`. If multiple BrushNets are specified in `init`, you can set
1142
+ the corresponding scale as a list.
1143
+ guess_mode (`bool`, *optional*, defaults to `False`):
1144
+ The BrushNet encoder tries to recognize the content of the input image even if you remove all
1145
+ prompts. A `guidance_scale` value between 3.0 and 5.0 is recommended.
1146
+ control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0):
1147
+ The percentage of total steps at which the BrushNet starts applying.
1148
+ control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0):
1149
+ The percentage of total steps at which the BrushNet stops applying.
1150
+ clip_skip (`int`, *optional*):
1151
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
1152
+ the output of the pre-final layer will be used for computing the prompt embeddings.
1153
+ callback_on_step_end (`Callable`, *optional*):
1154
+ A function that calls at the end of each denoising steps during the inference. The function is called
1155
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
1156
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
1157
+ `callback_on_step_end_tensor_inputs`.
1158
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
1159
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
1160
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
1161
+ `._callback_tensor_inputs` attribute of your pipeine class.
1162
+
1163
+ Examples:
1164
+
1165
+ Returns:
1166
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
1167
+ If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
1168
+ otherwise a `tuple` is returned where the first element is a list with the generated images and the
1169
+ second element is a list of `bool`s indicating whether the corresponding generated image contains
1170
+ "not-safe-for-work" (nsfw) content.
1171
+ """
1172
+
1173
+ callback = kwargs.pop("callback", None)
1174
+ callback_steps = kwargs.pop("callback_steps", None)
1175
+
1176
+ if callback is not None:
1177
+ deprecate(
1178
+ "callback",
1179
+ "1.0.0",
1180
+ "Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
1181
+ )
1182
+ if callback_steps is not None:
1183
+ deprecate(
1184
+ "callback_steps",
1185
+ "1.0.0",
1186
+ "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
1187
+ )
1188
+
1189
+ brushnet = self.brushnet._orig_mod if is_compiled_module(self.brushnet) else self.brushnet
1190
+
1191
+ # align format for control guidance
1192
+ if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):
1193
+ control_guidance_start = len(control_guidance_end) * [control_guidance_start]
1194
+ elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):
1195
+ control_guidance_end = len(control_guidance_start) * [control_guidance_end]
1196
+ elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list):
1197
+ control_guidance_start, control_guidance_end = (
1198
+ [control_guidance_start],
1199
+ [control_guidance_end],
1200
+ )
1201
+
1202
+ # 1. Check inputs. Raise error if not correct
1203
+ prompt = promptA
1204
+ negative_prompt = negative_promptA
1205
+ self.check_inputs(
1206
+ prompt,
1207
+ image,
1208
+ mask,
1209
+ callback_steps,
1210
+ negative_prompt,
1211
+ prompt_embeds,
1212
+ negative_prompt_embeds,
1213
+ ip_adapter_image,
1214
+ ip_adapter_image_embeds,
1215
+ brushnet_conditioning_scale,
1216
+ control_guidance_start,
1217
+ control_guidance_end,
1218
+ callback_on_step_end_tensor_inputs,
1219
+ )
1220
+
1221
+ self._guidance_scale = guidance_scale
1222
+ self._clip_skip = clip_skip
1223
+ self._cross_attention_kwargs = cross_attention_kwargs
1224
+
1225
+ # 2. Define call parameters
1226
+ if prompt is not None and isinstance(prompt, str):
1227
+ batch_size = 1
1228
+ elif prompt is not None and isinstance(prompt, list):
1229
+ batch_size = len(prompt)
1230
+ else:
1231
+ batch_size = prompt_embeds.shape[0]
1232
+
1233
+ device = self._execution_device
1234
+
1235
+ global_pool_conditions = (
1236
+ brushnet.config.global_pool_conditions
1237
+ if isinstance(brushnet, BrushNetModel)
1238
+ else brushnet.nets[0].config.global_pool_conditions
1239
+ )
1240
+ guess_mode = guess_mode or global_pool_conditions
1241
+
1242
+ # 3. Encode input prompt
1243
+ text_encoder_lora_scale = (
1244
+ self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
1245
+ )
1246
+
1247
+ prompt_embeds = self._encode_prompt(
1248
+ promptA,
1249
+ promptB,
1250
+ tradoff,
1251
+ device,
1252
+ num_images_per_prompt,
1253
+ self.do_classifier_free_guidance,
1254
+ negative_promptA,
1255
+ negative_promptB,
1256
+ tradoff_nag,
1257
+ prompt_embeds=prompt_embeds,
1258
+ negative_prompt_embeds=negative_prompt_embeds,
1259
+ lora_scale=text_encoder_lora_scale,
1260
+ )
1261
+ prompt_embedsU = None
1262
+ negative_prompt_embedsU = None
1263
+ prompt_embedsU = self.encode_prompt(
1264
+ promptU,
1265
+ device,
1266
+ num_images_per_prompt,
1267
+ self.do_classifier_free_guidance,
1268
+ negative_promptU,
1269
+ prompt_embeds=prompt_embedsU,
1270
+ negative_prompt_embeds=negative_prompt_embedsU,
1271
+ lora_scale=text_encoder_lora_scale,
1272
+ )
1273
+
1274
+ if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
1275
+ image_embeds = self.prepare_ip_adapter_image_embeds(
1276
+ ip_adapter_image,
1277
+ ip_adapter_image_embeds,
1278
+ device,
1279
+ batch_size * num_images_per_prompt,
1280
+ self.do_classifier_free_guidance,
1281
+ )
1282
+
1283
+ # 4. Prepare image
1284
+ if isinstance(brushnet, BrushNetModel):
1285
+ image = self.prepare_image(
1286
+ image=image,
1287
+ width=width,
1288
+ height=height,
1289
+ batch_size=batch_size * num_images_per_prompt,
1290
+ num_images_per_prompt=num_images_per_prompt,
1291
+ device=device,
1292
+ dtype=brushnet.dtype,
1293
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
1294
+ guess_mode=guess_mode,
1295
+ )
1296
+ original_mask = self.prepare_image(
1297
+ image=mask,
1298
+ width=width,
1299
+ height=height,
1300
+ batch_size=batch_size * num_images_per_prompt,
1301
+ num_images_per_prompt=num_images_per_prompt,
1302
+ device=device,
1303
+ dtype=brushnet.dtype,
1304
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
1305
+ guess_mode=guess_mode,
1306
+ )
1307
+ original_mask = (original_mask.sum(1)[:, None, :, :] < 0).to(image.dtype)
1308
+ height, width = image.shape[-2:]
1309
+ else:
1310
+ assert False
1311
+
1312
+ # 5. Prepare timesteps
1313
+ timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
1314
+ self._num_timesteps = len(timesteps)
1315
+
1316
+ # 6. Prepare latent variables
1317
+ num_channels_latents = self.unet.config.in_channels
1318
+ latents, noise = self.prepare_latents(
1319
+ batch_size * num_images_per_prompt,
1320
+ num_channels_latents,
1321
+ height,
1322
+ width,
1323
+ prompt_embeds.dtype,
1324
+ device,
1325
+ generator,
1326
+ latents,
1327
+ )
1328
+
1329
+ # 6.1 prepare condition latents
1330
+ # mask_i = transforms.ToPILImage()(image[0:1,:,:,:].squeeze(0))
1331
+ # mask_i.save('_mask.png')
1332
+ # print(brushnet.dtype)
1333
+ conditioning_latents = self.vae.encode(
1334
+ image.to(device=device, dtype=brushnet.dtype)).latent_dist.sample() * self.vae.config.scaling_factor
1335
+ mask = torch.nn.functional.interpolate(
1336
+ original_mask,
1337
+ size=(
1338
+ conditioning_latents.shape[-2],
1339
+ conditioning_latents.shape[-1]
1340
+ )
1341
+ )
1342
+ conditioning_latents = torch.concat([conditioning_latents, mask], 1)
1343
+ # image = self.vae.decode(conditioning_latents[:1,:4,:,:] / self.vae.config.scaling_factor, return_dict=False, generator=generator)[0]
1344
+ # from torchvision import transforms
1345
+ # mask_i = transforms.ToPILImage()(image[0:1,:,:,:].squeeze(0)/2+0.5)
1346
+ # mask_i.save(str(timesteps[0]) +'_C.png')
1347
+
1348
+ # 6.5 Optionally get Guidance Scale Embedding
1349
+ timestep_cond = None
1350
+ if self.unet.config.time_cond_proj_dim is not None:
1351
+ guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt)
1352
+ timestep_cond = self.get_guidance_scale_embedding(
1353
+ guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
1354
+ ).to(device=device, dtype=latents.dtype)
1355
+
1356
+ # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
1357
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
1358
+
1359
+ # 7.1 Add image embeds for IP-Adapter
1360
+ added_cond_kwargs = (
1361
+ {"image_embeds": image_embeds}
1362
+ if ip_adapter_image is not None or ip_adapter_image_embeds is not None
1363
+ else None
1364
+ )
1365
+
1366
+ # 7.2 Create tensor stating which brushnets to keep
1367
+ brushnet_keep = []
1368
+ for i in range(len(timesteps)):
1369
+ keeps = [
1370
+ 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e)
1371
+ for s, e in zip(control_guidance_start, control_guidance_end)
1372
+ ]
1373
+ brushnet_keep.append(keeps[0] if isinstance(brushnet, BrushNetModel) else keeps)
1374
+
1375
+ # 8. Denoising loop
1376
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
1377
+ is_unet_compiled = is_compiled_module(self.unet)
1378
+ is_brushnet_compiled = is_compiled_module(self.brushnet)
1379
+ is_torch_higher_equal_2_1 = is_torch_version(">=", "2.1")
1380
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
1381
+ for i, t in enumerate(timesteps):
1382
+ # Relevant thread:
1383
+ # https://dev-discuss.pytorch.org/t/cudagraphs-in-pytorch-2-0/1428
1384
+ if (is_unet_compiled and is_brushnet_compiled) and is_torch_higher_equal_2_1:
1385
+ torch._inductor.cudagraph_mark_step_begin()
1386
+ # expand the latents if we are doing classifier free guidance
1387
+ latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
1388
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
1389
+
1390
+ # brushnet(s) inference
1391
+ if guess_mode and self.do_classifier_free_guidance:
1392
+ # Infer BrushNet only for the conditional batch.
1393
+ control_model_input = latents
1394
+ control_model_input = self.scheduler.scale_model_input(control_model_input, t)
1395
+ brushnet_prompt_embeds = prompt_embeds.chunk(2)[1]
1396
+ else:
1397
+ control_model_input = latent_model_input
1398
+ brushnet_prompt_embeds = prompt_embeds
1399
+
1400
+ if isinstance(brushnet_keep[i], list):
1401
+ cond_scale = [c * s for c, s in zip(brushnet_conditioning_scale, brushnet_keep[i])]
1402
+ else:
1403
+ brushnet_cond_scale = brushnet_conditioning_scale
1404
+ if isinstance(brushnet_cond_scale, list):
1405
+ brushnet_cond_scale = brushnet_cond_scale[0]
1406
+ cond_scale = brushnet_cond_scale * brushnet_keep[i]
1407
+
1408
+ down_block_res_samples, mid_block_res_sample, up_block_res_samples = self.brushnet(
1409
+ control_model_input,
1410
+ t,
1411
+ encoder_hidden_states=brushnet_prompt_embeds,
1412
+ brushnet_cond=conditioning_latents,
1413
+ conditioning_scale=cond_scale,
1414
+ guess_mode=guess_mode,
1415
+ return_dict=False,
1416
+ )
1417
+
1418
+ if guess_mode and self.do_classifier_free_guidance:
1419
+ # Infered BrushNet only for the conditional batch.
1420
+ # To apply the output of BrushNet to both the unconditional and conditional batches,
1421
+ # add 0 to the unconditional batch to keep it unchanged.
1422
+ down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples]
1423
+ mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample])
1424
+ up_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in up_block_res_samples]
1425
+
1426
+ # predict the noise residual
1427
+ noise_pred = self.unet(
1428
+ latent_model_input,
1429
+ t,
1430
+ encoder_hidden_states=prompt_embedsU,
1431
+ timestep_cond=timestep_cond,
1432
+ cross_attention_kwargs=self.cross_attention_kwargs,
1433
+ down_block_add_samples=down_block_res_samples,
1434
+ mid_block_add_sample=mid_block_res_sample,
1435
+ up_block_add_samples=up_block_res_samples,
1436
+ added_cond_kwargs=added_cond_kwargs,
1437
+ return_dict=False,
1438
+ )[0]
1439
+
1440
+ # perform guidance
1441
+ if self.do_classifier_free_guidance:
1442
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
1443
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
1444
+
1445
+ # compute the previous noisy sample x_t -> x_t-1
1446
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
1447
+
1448
+ if callback_on_step_end is not None:
1449
+ callback_kwargs = {}
1450
+ for k in callback_on_step_end_tensor_inputs:
1451
+ callback_kwargs[k] = locals()[k]
1452
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
1453
+
1454
+ latents = callback_outputs.pop("latents", latents)
1455
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
1456
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
1457
+
1458
+ # call the callback, if provided
1459
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
1460
+ progress_bar.update()
1461
+ if callback is not None and i % callback_steps == 0:
1462
+ step_idx = i // getattr(self.scheduler, "order", 1)
1463
+ callback(step_idx, t, latents)
1464
+
1465
+ # If we do sequential model offloading, let's offload unet and brushnet
1466
+ # manually for max memory savings
1467
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
1468
+ self.unet.to("cpu")
1469
+ self.brushnet.to("cpu")
1470
+ torch.cuda.empty_cache()
1471
+
1472
+ if not output_type == "latent":
1473
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[
1474
+ 0
1475
+ ]
1476
+ image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
1477
+ else:
1478
+ image = latents
1479
+ has_nsfw_concept = None
1480
+
1481
+ if has_nsfw_concept is None:
1482
+ do_denormalize = [True] * image.shape[0]
1483
+ else:
1484
+ do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
1485
+
1486
+ image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
1487
+
1488
+ # Offload all models
1489
+ self.maybe_free_model_hooks()
1490
+
1491
+ if not return_dict:
1492
+ return (image, has_nsfw_concept)
1493
+
1494
+ return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
powerpaint_v2/power_paint_tokenizer.py ADDED
@@ -0,0 +1,513 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import copy
4
+ import random
5
+ from typing import Any, List, Optional, Union
6
+ from transformers import CLIPTokenizer
7
+
8
+
9
+ class PowerPaintTokenizer:
10
+ def __init__(self, tokenizer: CLIPTokenizer):
11
+ self.wrapped = tokenizer
12
+ self.token_map = {}
13
+ placeholder_tokens = ["P_ctxt", "P_shape", "P_obj"]
14
+ num_vec_per_token = 10
15
+ for placeholder_token in placeholder_tokens:
16
+ output = []
17
+ for i in range(num_vec_per_token):
18
+ ith_token = placeholder_token + f"_{i}"
19
+ output.append(ith_token)
20
+ self.token_map[placeholder_token] = output
21
+
22
+ def __getattr__(self, name: str) -> Any:
23
+ if name == "wrapped":
24
+ return super().__getattr__("wrapped")
25
+
26
+ try:
27
+ return getattr(self.wrapped, name)
28
+ except AttributeError:
29
+ try:
30
+ return super().__getattr__(name)
31
+ except AttributeError:
32
+ raise AttributeError(
33
+ "'name' cannot be found in both "
34
+ f"'{self.__class__.__name__}' and "
35
+ f"'{self.__class__.__name__}.tokenizer'."
36
+ )
37
+
38
+ def try_adding_tokens(self, tokens: Union[str, List[str]], *args, **kwargs):
39
+ """Attempt to add tokens to the tokenizer.
40
+
41
+ Args:
42
+ tokens (Union[str, List[str]]): The tokens to be added.
43
+ """
44
+ num_added_tokens = self.wrapped.add_tokens(tokens, *args, **kwargs)
45
+ assert num_added_tokens != 0, (
46
+ f"The tokenizer already contains the token {tokens}. Please pass "
47
+ "a different `placeholder_token` that is not already in the "
48
+ "tokenizer."
49
+ )
50
+
51
+ def get_token_info(self, token: str) -> dict:
52
+ """Get the information of a token, including its start and end index in
53
+ the current tokenizer.
54
+
55
+ Args:
56
+ token (str): The token to be queried.
57
+
58
+ Returns:
59
+ dict: The information of the token, including its start and end
60
+ index in current tokenizer.
61
+ """
62
+ token_ids = self.__call__(token).input_ids
63
+ start, end = token_ids[1], token_ids[-2] + 1
64
+ return {"name": token, "start": start, "end": end}
65
+
66
+ def add_placeholder_token(
67
+ self, placeholder_token: str, *args, num_vec_per_token: int = 1, **kwargs
68
+ ):
69
+ """Add placeholder tokens to the tokenizer.
70
+
71
+ Args:
72
+ placeholder_token (str): The placeholder token to be added.
73
+ num_vec_per_token (int, optional): The number of vectors of
74
+ the added placeholder token.
75
+ *args, **kwargs: The arguments for `self.wrapped.add_tokens`.
76
+ """
77
+ output = []
78
+ if num_vec_per_token == 1:
79
+ self.try_adding_tokens(placeholder_token, *args, **kwargs)
80
+ output.append(placeholder_token)
81
+ else:
82
+ output = []
83
+ for i in range(num_vec_per_token):
84
+ ith_token = placeholder_token + f"_{i}"
85
+ self.try_adding_tokens(ith_token, *args, **kwargs)
86
+ output.append(ith_token)
87
+
88
+ for token in self.token_map:
89
+ if token in placeholder_token:
90
+ raise ValueError(
91
+ f"The tokenizer already has placeholder token {token} "
92
+ f"that can get confused with {placeholder_token} "
93
+ "keep placeholder tokens independent"
94
+ )
95
+ self.token_map[placeholder_token] = output
96
+
97
+ def replace_placeholder_tokens_in_text(
98
+ self,
99
+ text: Union[str, List[str]],
100
+ vector_shuffle: bool = False,
101
+ prop_tokens_to_load: float = 1.0,
102
+ ) -> Union[str, List[str]]:
103
+ """Replace the keywords in text with placeholder tokens. This function
104
+ will be called in `self.__call__` and `self.encode`.
105
+
106
+ Args:
107
+ text (Union[str, List[str]]): The text to be processed.
108
+ vector_shuffle (bool, optional): Whether to shuffle the vectors.
109
+ Defaults to False.
110
+ prop_tokens_to_load (float, optional): The proportion of tokens to
111
+ be loaded. If 1.0, all tokens will be loaded. Defaults to 1.0.
112
+
113
+ Returns:
114
+ Union[str, List[str]]: The processed text.
115
+ """
116
+ if isinstance(text, list):
117
+ output = []
118
+ for i in range(len(text)):
119
+ output.append(
120
+ self.replace_placeholder_tokens_in_text(
121
+ text[i], vector_shuffle=vector_shuffle
122
+ )
123
+ )
124
+ return output
125
+
126
+ for placeholder_token in self.token_map:
127
+ if placeholder_token in text:
128
+ tokens = self.token_map[placeholder_token]
129
+ tokens = tokens[: 1 + int(len(tokens) * prop_tokens_to_load)]
130
+ if vector_shuffle:
131
+ tokens = copy.copy(tokens)
132
+ random.shuffle(tokens)
133
+ text = text.replace(placeholder_token, " ".join(tokens))
134
+ return text
135
+
136
+ def replace_text_with_placeholder_tokens(
137
+ self, text: Union[str, List[str]]
138
+ ) -> Union[str, List[str]]:
139
+ """Replace the placeholder tokens in text with the original keywords.
140
+ This function will be called in `self.decode`.
141
+
142
+ Args:
143
+ text (Union[str, List[str]]): The text to be processed.
144
+
145
+ Returns:
146
+ Union[str, List[str]]: The processed text.
147
+ """
148
+ if isinstance(text, list):
149
+ output = []
150
+ for i in range(len(text)):
151
+ output.append(self.replace_text_with_placeholder_tokens(text[i]))
152
+ return output
153
+
154
+ for placeholder_token, tokens in self.token_map.items():
155
+ merged_tokens = " ".join(tokens)
156
+ if merged_tokens in text:
157
+ text = text.replace(merged_tokens, placeholder_token)
158
+ return text
159
+
160
+ def __call__(
161
+ self,
162
+ text: Union[str, List[str]],
163
+ *args,
164
+ vector_shuffle: bool = False,
165
+ prop_tokens_to_load: float = 1.0,
166
+ **kwargs,
167
+ ):
168
+ """The call function of the wrapper.
169
+
170
+ Args:
171
+ text (Union[str, List[str]]): The text to be tokenized.
172
+ vector_shuffle (bool, optional): Whether to shuffle the vectors.
173
+ Defaults to False.
174
+ prop_tokens_to_load (float, optional): The proportion of tokens to
175
+ be loaded. If 1.0, all tokens will be loaded. Defaults to 1.0
176
+ *args, **kwargs: The arguments for `self.wrapped.__call__`.
177
+ """
178
+ replaced_text = self.replace_placeholder_tokens_in_text(
179
+ text, vector_shuffle=vector_shuffle, prop_tokens_to_load=prop_tokens_to_load
180
+ )
181
+
182
+ return self.wrapped.__call__(replaced_text, *args, **kwargs)
183
+
184
+ def encode(self, text: Union[str, List[str]], *args, **kwargs):
185
+ """Encode the passed text to token index.
186
+
187
+ Args:
188
+ text (Union[str, List[str]]): The text to be encode.
189
+ *args, **kwargs: The arguments for `self.wrapped.__call__`.
190
+ """
191
+ replaced_text = self.replace_placeholder_tokens_in_text(text)
192
+ return self.wrapped(replaced_text, *args, **kwargs)
193
+
194
+ def decode(
195
+ self, token_ids, return_raw: bool = False, *args, **kwargs
196
+ ) -> Union[str, List[str]]:
197
+ """Decode the token index to text.
198
+
199
+ Args:
200
+ token_ids: The token index to be decoded.
201
+ return_raw: Whether keep the placeholder token in the text.
202
+ Defaults to False.
203
+ *args, **kwargs: The arguments for `self.wrapped.decode`.
204
+
205
+ Returns:
206
+ Union[str, List[str]]: The decoded text.
207
+ """
208
+ text = self.wrapped.decode(token_ids, *args, **kwargs)
209
+ if return_raw:
210
+ return text
211
+ replaced_text = self.replace_text_with_placeholder_tokens(text)
212
+ return replaced_text
213
+
214
+
215
+ class EmbeddingLayerWithFixes(nn.Module):
216
+ """The revised embedding layer to support external embeddings. This design
217
+ of this class is inspired by https://github.com/AUTOMATIC1111/stable-
218
+ diffusion-webui/blob/22bcc7be428c94e9408f589966c2040187245d81/modules/sd_hi
219
+ jack.py#L224 # noqa.
220
+
221
+ Args:
222
+ wrapped (nn.Emebdding): The embedding layer to be wrapped.
223
+ external_embeddings (Union[dict, List[dict]], optional): The external
224
+ embeddings added to this layer. Defaults to None.
225
+ """
226
+
227
+ def __init__(
228
+ self,
229
+ wrapped: nn.Embedding,
230
+ external_embeddings: Optional[Union[dict, List[dict]]] = None,
231
+ ):
232
+ super().__init__()
233
+ self.wrapped = wrapped
234
+ self.num_embeddings = wrapped.weight.shape[0]
235
+
236
+ self.external_embeddings = []
237
+ if external_embeddings:
238
+ self.add_embeddings(external_embeddings)
239
+
240
+ self.trainable_embeddings = nn.ParameterDict()
241
+
242
+ @property
243
+ def weight(self):
244
+ """Get the weight of wrapped embedding layer."""
245
+ return self.wrapped.weight
246
+
247
+ def check_duplicate_names(self, embeddings: List[dict]):
248
+ """Check whether duplicate names exist in list of 'external
249
+ embeddings'.
250
+
251
+ Args:
252
+ embeddings (List[dict]): A list of embedding to be check.
253
+ """
254
+ names = [emb["name"] for emb in embeddings]
255
+ assert len(names) == len(set(names)), (
256
+ "Found duplicated names in 'external_embeddings'. Name list: " f"'{names}'"
257
+ )
258
+
259
+ def check_ids_overlap(self, embeddings):
260
+ """Check whether overlap exist in token ids of 'external_embeddings'.
261
+
262
+ Args:
263
+ embeddings (List[dict]): A list of embedding to be check.
264
+ """
265
+ ids_range = [[emb["start"], emb["end"], emb["name"]] for emb in embeddings]
266
+ ids_range.sort() # sort by 'start'
267
+ # check if 'end' has overlapping
268
+ for idx in range(len(ids_range) - 1):
269
+ name1, name2 = ids_range[idx][-1], ids_range[idx + 1][-1]
270
+ assert ids_range[idx][1] <= ids_range[idx + 1][0], (
271
+ f"Found ids overlapping between embeddings '{name1}' " f"and '{name2}'."
272
+ )
273
+
274
+ def add_embeddings(self, embeddings: Optional[Union[dict, List[dict]]]):
275
+ """Add external embeddings to this layer.
276
+
277
+ Use case:
278
+
279
+ >>> 1. Add token to tokenizer and get the token id.
280
+ >>> tokenizer = TokenizerWrapper('openai/clip-vit-base-patch32')
281
+ >>> # 'how much' in kiswahili
282
+ >>> tokenizer.add_placeholder_tokens('ngapi', num_vec_per_token=4)
283
+ >>>
284
+ >>> 2. Add external embeddings to the model.
285
+ >>> new_embedding = {
286
+ >>> 'name': 'ngapi', # 'how much' in kiswahili
287
+ >>> 'embedding': torch.ones(1, 15) * 4,
288
+ >>> 'start': tokenizer.get_token_info('kwaheri')['start'],
289
+ >>> 'end': tokenizer.get_token_info('kwaheri')['end'],
290
+ >>> 'trainable': False # if True, will registry as a parameter
291
+ >>> }
292
+ >>> embedding_layer = nn.Embedding(10, 15)
293
+ >>> embedding_layer_wrapper = EmbeddingLayerWithFixes(embedding_layer)
294
+ >>> embedding_layer_wrapper.add_embeddings(new_embedding)
295
+ >>>
296
+ >>> 3. Forward tokenizer and embedding layer!
297
+ >>> input_text = ['hello, ngapi!', 'hello my friend, ngapi?']
298
+ >>> input_ids = tokenizer(
299
+ >>> input_text, padding='max_length', truncation=True,
300
+ >>> return_tensors='pt')['input_ids']
301
+ >>> out_feat = embedding_layer_wrapper(input_ids)
302
+ >>>
303
+ >>> 4. Let's validate the result!
304
+ >>> assert (out_feat[0, 3: 7] == 2.3).all()
305
+ >>> assert (out_feat[2, 5: 9] == 2.3).all()
306
+
307
+ Args:
308
+ embeddings (Union[dict, list[dict]]): The external embeddings to
309
+ be added. Each dict must contain the following 4 fields: 'name'
310
+ (the name of this embedding), 'embedding' (the embedding
311
+ tensor), 'start' (the start token id of this embedding), 'end'
312
+ (the end token id of this embedding). For example:
313
+ `{name: NAME, start: START, end: END, embedding: torch.Tensor}`
314
+ """
315
+ if isinstance(embeddings, dict):
316
+ embeddings = [embeddings]
317
+
318
+ self.external_embeddings += embeddings
319
+ self.check_duplicate_names(self.external_embeddings)
320
+ self.check_ids_overlap(self.external_embeddings)
321
+
322
+ # set for trainable
323
+ added_trainable_emb_info = []
324
+ for embedding in embeddings:
325
+ trainable = embedding.get("trainable", False)
326
+ if trainable:
327
+ name = embedding["name"]
328
+ embedding["embedding"] = torch.nn.Parameter(embedding["embedding"])
329
+ self.trainable_embeddings[name] = embedding["embedding"]
330
+ added_trainable_emb_info.append(name)
331
+
332
+ added_emb_info = [emb["name"] for emb in embeddings]
333
+ added_emb_info = ", ".join(added_emb_info)
334
+ print(f"Successfully add external embeddings: {added_emb_info}.", "current")
335
+
336
+ if added_trainable_emb_info:
337
+ added_trainable_emb_info = ", ".join(added_trainable_emb_info)
338
+ print(
339
+ "Successfully add trainable external embeddings: "
340
+ f"{added_trainable_emb_info}",
341
+ "current",
342
+ )
343
+
344
+ def replace_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
345
+ """Replace external input ids to 0.
346
+
347
+ Args:
348
+ input_ids (torch.Tensor): The input ids to be replaced.
349
+
350
+ Returns:
351
+ torch.Tensor: The replaced input ids.
352
+ """
353
+ input_ids_fwd = input_ids.clone()
354
+ input_ids_fwd[input_ids_fwd >= self.num_embeddings] = 0
355
+ return input_ids_fwd
356
+
357
+ def replace_embeddings(
358
+ self, input_ids: torch.Tensor, embedding: torch.Tensor, external_embedding: dict
359
+ ) -> torch.Tensor:
360
+ """Replace external embedding to the embedding layer. Noted that, in
361
+ this function we use `torch.cat` to avoid inplace modification.
362
+
363
+ Args:
364
+ input_ids (torch.Tensor): The original token ids. Shape like
365
+ [LENGTH, ].
366
+ embedding (torch.Tensor): The embedding of token ids after
367
+ `replace_input_ids` function.
368
+ external_embedding (dict): The external embedding to be replaced.
369
+
370
+ Returns:
371
+ torch.Tensor: The replaced embedding.
372
+ """
373
+ new_embedding = []
374
+
375
+ name = external_embedding["name"]
376
+ start = external_embedding["start"]
377
+ end = external_embedding["end"]
378
+ target_ids_to_replace = [i for i in range(start, end)]
379
+ ext_emb = external_embedding["embedding"]
380
+
381
+ # do not need to replace
382
+ if not (input_ids == start).any():
383
+ return embedding
384
+
385
+ # start replace
386
+ s_idx, e_idx = 0, 0
387
+ while e_idx < len(input_ids):
388
+ if input_ids[e_idx] == start:
389
+ if e_idx != 0:
390
+ # add embedding do not need to replace
391
+ new_embedding.append(embedding[s_idx:e_idx])
392
+
393
+ # check if the next embedding need to replace is valid
394
+ actually_ids_to_replace = [
395
+ int(i) for i in input_ids[e_idx : e_idx + end - start]
396
+ ]
397
+ assert actually_ids_to_replace == target_ids_to_replace, (
398
+ f"Invalid 'input_ids' in position: {s_idx} to {e_idx}. "
399
+ f"Expect '{target_ids_to_replace}' for embedding "
400
+ f"'{name}' but found '{actually_ids_to_replace}'."
401
+ )
402
+
403
+ new_embedding.append(ext_emb)
404
+
405
+ s_idx = e_idx + end - start
406
+ e_idx = s_idx + 1
407
+ else:
408
+ e_idx += 1
409
+
410
+ if e_idx == len(input_ids):
411
+ new_embedding.append(embedding[s_idx:e_idx])
412
+
413
+ return torch.cat(new_embedding, dim=0)
414
+
415
+ def forward(
416
+ self, input_ids: torch.Tensor, external_embeddings: Optional[List[dict]] = None
417
+ ):
418
+ """The forward function.
419
+
420
+ Args:
421
+ input_ids (torch.Tensor): The token ids shape like [bz, LENGTH] or
422
+ [LENGTH, ].
423
+ external_embeddings (Optional[List[dict]]): The external
424
+ embeddings. If not passed, only `self.external_embeddings`
425
+ will be used. Defaults to None.
426
+
427
+ input_ids: shape like [bz, LENGTH] or [LENGTH].
428
+ """
429
+ assert input_ids.ndim in [1, 2]
430
+ if input_ids.ndim == 1:
431
+ input_ids = input_ids.unsqueeze(0)
432
+
433
+ if external_embeddings is None and not self.external_embeddings:
434
+ return self.wrapped(input_ids)
435
+
436
+ input_ids_fwd = self.replace_input_ids(input_ids)
437
+ inputs_embeds = self.wrapped(input_ids_fwd)
438
+
439
+ vecs = []
440
+
441
+ if external_embeddings is None:
442
+ external_embeddings = []
443
+ elif isinstance(external_embeddings, dict):
444
+ external_embeddings = [external_embeddings]
445
+ embeddings = self.external_embeddings + external_embeddings
446
+
447
+ for input_id, embedding in zip(input_ids, inputs_embeds):
448
+ new_embedding = embedding
449
+ for external_embedding in embeddings:
450
+ new_embedding = self.replace_embeddings(
451
+ input_id, new_embedding, external_embedding
452
+ )
453
+ vecs.append(new_embedding)
454
+
455
+ return torch.stack(vecs)
456
+
457
+
458
+ def add_tokens(
459
+ tokenizer,
460
+ text_encoder,
461
+ placeholder_tokens: list,
462
+ initialize_tokens: list = None,
463
+ num_vectors_per_token: int = 1,
464
+ ):
465
+ """Add token for training.
466
+
467
+ # TODO: support add tokens as dict, then we can load pretrained tokens.
468
+ """
469
+ if initialize_tokens is not None:
470
+ assert len(initialize_tokens) == len(
471
+ placeholder_tokens
472
+ ), "placeholder_token should be the same length as initialize_token"
473
+ for ii in range(len(placeholder_tokens)):
474
+ tokenizer.add_placeholder_token(
475
+ placeholder_tokens[ii], num_vec_per_token=num_vectors_per_token
476
+ )
477
+
478
+ # text_encoder.set_embedding_layer()
479
+ embedding_layer = text_encoder.text_model.embeddings.token_embedding
480
+ text_encoder.text_model.embeddings.token_embedding = EmbeddingLayerWithFixes(
481
+ embedding_layer
482
+ )
483
+ embedding_layer = text_encoder.text_model.embeddings.token_embedding
484
+
485
+ assert embedding_layer is not None, (
486
+ "Do not support get embedding layer for current text encoder. "
487
+ "Please check your configuration."
488
+ )
489
+ initialize_embedding = []
490
+ if initialize_tokens is not None:
491
+ for ii in range(len(placeholder_tokens)):
492
+ init_id = tokenizer(initialize_tokens[ii]).input_ids[1]
493
+ temp_embedding = embedding_layer.weight[init_id]
494
+ initialize_embedding.append(
495
+ temp_embedding[None, ...].repeat(num_vectors_per_token, 1)
496
+ )
497
+ else:
498
+ for ii in range(len(placeholder_tokens)):
499
+ init_id = tokenizer("a").input_ids[1]
500
+ temp_embedding = embedding_layer.weight[init_id]
501
+ len_emb = temp_embedding.shape[0]
502
+ init_weight = (torch.rand(num_vectors_per_token, len_emb) - 0.5) / 2.0
503
+ initialize_embedding.append(init_weight)
504
+
505
+ # initialize_embedding = torch.cat(initialize_embedding,dim=0)
506
+
507
+ token_info_all = []
508
+ for ii in range(len(placeholder_tokens)):
509
+ token_info = tokenizer.get_token_info(placeholder_tokens[ii])
510
+ token_info["embedding"] = initialize_embedding[ii]
511
+ token_info["trainable"] = True
512
+ token_info_all.append(token_info)
513
+ embedding_layer.add_embeddings(token_info_all)
powerpaint_v2/unet_2d_blocks.py ADDED
The diff for this file is too large to render. See raw diff
 
powerpaint_v2/unet_2d_condition.py ADDED
@@ -0,0 +1,1353 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from dataclasses import dataclass
15
+ from typing import Any, Dict, List, Optional, Tuple, Union
16
+
17
+ import torch
18
+ import torch.nn as nn
19
+ import torch.utils.checkpoint
20
+
21
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
22
+ from diffusers.loaders import PeftAdapterMixin, UNet2DConditionLoadersMixin
23
+ from diffusers.utils import USE_PEFT_BACKEND, BaseOutput, deprecate, logging, scale_lora_layers, unscale_lora_layers
24
+ from diffusers.models.activations import get_activation
25
+ from diffusers.models.attention_processor import (
26
+ ADDED_KV_ATTENTION_PROCESSORS,
27
+ CROSS_ATTENTION_PROCESSORS,
28
+ Attention,
29
+ AttentionProcessor,
30
+ AttnAddedKVProcessor,
31
+ AttnProcessor,
32
+ )
33
+ from diffusers.models.embeddings import (
34
+ GaussianFourierProjection,
35
+ GLIGENTextBoundingboxProjection,
36
+ ImageHintTimeEmbedding,
37
+ ImageProjection,
38
+ ImageTimeEmbedding,
39
+ TextImageProjection,
40
+ TextImageTimeEmbedding,
41
+ TextTimeEmbedding,
42
+ TimestepEmbedding,
43
+ Timesteps,
44
+ )
45
+ from diffusers.models.modeling_utils import ModelMixin
46
+ from .unet_2d_blocks import (
47
+ get_down_block,
48
+ get_mid_block,
49
+ get_up_block,
50
+ )
51
+
52
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
53
+
54
+
55
+ @dataclass
56
+ class UNet2DConditionOutput(BaseOutput):
57
+ """
58
+ The output of [`UNet2DConditionModel`].
59
+
60
+ Args:
61
+ sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
62
+ The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model.
63
+ """
64
+
65
+ sample: torch.FloatTensor = None
66
+
67
+
68
+ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, PeftAdapterMixin):
69
+ r"""
70
+ A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample
71
+ shaped output.
72
+
73
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
74
+ for all models (such as downloading or saving).
75
+
76
+ Parameters:
77
+ sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
78
+ Height and width of input/output sample.
79
+ in_channels (`int`, *optional*, defaults to 4): Number of channels in the input sample.
80
+ out_channels (`int`, *optional*, defaults to 4): Number of channels in the output.
81
+ center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample.
82
+ flip_sin_to_cos (`bool`, *optional*, defaults to `False`):
83
+ Whether to flip the sin to cos in the time embedding.
84
+ freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding.
85
+ down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
86
+ The tuple of downsample blocks to use.
87
+ mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`):
88
+ Block type for middle of UNet, it can be one of `UNetMidBlock2DCrossAttn`, `UNetMidBlock2D`, or
89
+ `UNetMidBlock2DSimpleCrossAttn`. If `None`, the mid block layer is skipped.
90
+ up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")`):
91
+ The tuple of upsample blocks to use.
92
+ only_cross_attention(`bool` or `Tuple[bool]`, *optional*, default to `False`):
93
+ Whether to include self-attention in the basic transformer blocks, see
94
+ [`~models.attention.BasicTransformerBlock`].
95
+ block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
96
+ The tuple of output channels for each block.
97
+ layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block.
98
+ downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution.
99
+ mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block.
100
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
101
+ act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
102
+ norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization.
103
+ If `None`, normalization and activation layers is skipped in post-processing.
104
+ norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization.
105
+ cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280):
106
+ The dimension of the cross attention features.
107
+ transformer_layers_per_block (`int`, `Tuple[int]`, or `Tuple[Tuple]` , *optional*, defaults to 1):
108
+ The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
109
+ [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
110
+ [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
111
+ reverse_transformer_layers_per_block : (`Tuple[Tuple]`, *optional*, defaults to None):
112
+ The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`], in the upsampling
113
+ blocks of the U-Net. Only relevant if `transformer_layers_per_block` is of type `Tuple[Tuple]` and for
114
+ [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
115
+ [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
116
+ encoder_hid_dim (`int`, *optional*, defaults to None):
117
+ If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim`
118
+ dimension to `cross_attention_dim`.
119
+ encoder_hid_dim_type (`str`, *optional*, defaults to `None`):
120
+ If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text
121
+ embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`.
122
+ attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
123
+ num_attention_heads (`int`, *optional*):
124
+ The number of attention heads. If not defined, defaults to `attention_head_dim`
125
+ resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config
126
+ for ResNet blocks (see [`~models.resnet.ResnetBlock2D`]). Choose from `default` or `scale_shift`.
127
+ class_embed_type (`str`, *optional*, defaults to `None`):
128
+ The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`,
129
+ `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`.
130
+ addition_embed_type (`str`, *optional*, defaults to `None`):
131
+ Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or
132
+ "text". "text" will use the `TextTimeEmbedding` layer.
133
+ addition_time_embed_dim: (`int`, *optional*, defaults to `None`):
134
+ Dimension for the timestep embeddings.
135
+ num_class_embeds (`int`, *optional*, defaults to `None`):
136
+ Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
137
+ class conditioning with `class_embed_type` equal to `None`.
138
+ time_embedding_type (`str`, *optional*, defaults to `positional`):
139
+ The type of position embedding to use for timesteps. Choose from `positional` or `fourier`.
140
+ time_embedding_dim (`int`, *optional*, defaults to `None`):
141
+ An optional override for the dimension of the projected time embedding.
142
+ time_embedding_act_fn (`str`, *optional*, defaults to `None`):
143
+ Optional activation function to use only once on the time embeddings before they are passed to the rest of
144
+ the UNet. Choose from `silu`, `mish`, `gelu`, and `swish`.
145
+ timestep_post_act (`str`, *optional*, defaults to `None`):
146
+ The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`.
147
+ time_cond_proj_dim (`int`, *optional*, defaults to `None`):
148
+ The dimension of `cond_proj` layer in the timestep embedding.
149
+ conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer. conv_out_kernel (`int`,
150
+ *optional*, default to `3`): The kernel size of `conv_out` layer. projection_class_embeddings_input_dim (`int`,
151
+ *optional*): The dimension of the `class_labels` input when
152
+ `class_embed_type="projection"`. Required when `class_embed_type="projection"`.
153
+ class_embeddings_concat (`bool`, *optional*, defaults to `False`): Whether to concatenate the time
154
+ embeddings with the class embeddings.
155
+ mid_block_only_cross_attention (`bool`, *optional*, defaults to `None`):
156
+ Whether to use cross attention with the mid block when using the `UNetMidBlock2DSimpleCrossAttn`. If
157
+ `only_cross_attention` is given as a single boolean and `mid_block_only_cross_attention` is `None`, the
158
+ `only_cross_attention` value is used as the value for `mid_block_only_cross_attention`. Default to `False`
159
+ otherwise.
160
+ """
161
+
162
+ _supports_gradient_checkpointing = True
163
+
164
+ @register_to_config
165
+ def __init__(
166
+ self,
167
+ sample_size: Optional[int] = None,
168
+ in_channels: int = 4,
169
+ out_channels: int = 4,
170
+ center_input_sample: bool = False,
171
+ flip_sin_to_cos: bool = True,
172
+ freq_shift: int = 0,
173
+ down_block_types: Tuple[str] = (
174
+ "CrossAttnDownBlock2D",
175
+ "CrossAttnDownBlock2D",
176
+ "CrossAttnDownBlock2D",
177
+ "DownBlock2D",
178
+ ),
179
+ mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn",
180
+ up_block_types: Tuple[str] = (
181
+ "UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"),
182
+ only_cross_attention: Union[bool, Tuple[bool]] = False,
183
+ block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
184
+ layers_per_block: Union[int, Tuple[int]] = 2,
185
+ downsample_padding: int = 1,
186
+ mid_block_scale_factor: float = 1,
187
+ dropout: float = 0.0,
188
+ act_fn: str = "silu",
189
+ norm_num_groups: Optional[int] = 32,
190
+ norm_eps: float = 1e-5,
191
+ cross_attention_dim: Union[int, Tuple[int]] = 1280,
192
+ transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1,
193
+ reverse_transformer_layers_per_block: Optional[Tuple[Tuple[int]]] = None,
194
+ encoder_hid_dim: Optional[int] = None,
195
+ encoder_hid_dim_type: Optional[str] = None,
196
+ attention_head_dim: Union[int, Tuple[int]] = 8,
197
+ num_attention_heads: Optional[Union[int, Tuple[int]]] = None,
198
+ dual_cross_attention: bool = False,
199
+ use_linear_projection: bool = False,
200
+ class_embed_type: Optional[str] = None,
201
+ addition_embed_type: Optional[str] = None,
202
+ addition_time_embed_dim: Optional[int] = None,
203
+ num_class_embeds: Optional[int] = None,
204
+ upcast_attention: bool = False,
205
+ resnet_time_scale_shift: str = "default",
206
+ resnet_skip_time_act: bool = False,
207
+ resnet_out_scale_factor: float = 1.0,
208
+ time_embedding_type: str = "positional",
209
+ time_embedding_dim: Optional[int] = None,
210
+ time_embedding_act_fn: Optional[str] = None,
211
+ timestep_post_act: Optional[str] = None,
212
+ time_cond_proj_dim: Optional[int] = None,
213
+ conv_in_kernel: int = 3,
214
+ conv_out_kernel: int = 3,
215
+ projection_class_embeddings_input_dim: Optional[int] = None,
216
+ attention_type: str = "default",
217
+ class_embeddings_concat: bool = False,
218
+ mid_block_only_cross_attention: Optional[bool] = None,
219
+ cross_attention_norm: Optional[str] = None,
220
+ addition_embed_type_num_heads: int = 64,
221
+ ):
222
+ super().__init__()
223
+
224
+ self.sample_size = sample_size
225
+
226
+ if num_attention_heads is not None:
227
+ raise ValueError(
228
+ "At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19."
229
+ )
230
+
231
+ # If `num_attention_heads` is not defined (which is the case for most models)
232
+ # it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
233
+ # The reason for this behavior is to correct for incorrectly named variables that were introduced
234
+ # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
235
+ # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
236
+ # which is why we correct for the naming here.
237
+ num_attention_heads = num_attention_heads or attention_head_dim
238
+
239
+ # Check inputs
240
+ self._check_config(
241
+ down_block_types=down_block_types,
242
+ up_block_types=up_block_types,
243
+ only_cross_attention=only_cross_attention,
244
+ block_out_channels=block_out_channels,
245
+ layers_per_block=layers_per_block,
246
+ cross_attention_dim=cross_attention_dim,
247
+ transformer_layers_per_block=transformer_layers_per_block,
248
+ reverse_transformer_layers_per_block=reverse_transformer_layers_per_block,
249
+ attention_head_dim=attention_head_dim,
250
+ num_attention_heads=num_attention_heads,
251
+ )
252
+
253
+ # input
254
+ conv_in_padding = (conv_in_kernel - 1) // 2
255
+ self.conv_in = nn.Conv2d(
256
+ in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
257
+ )
258
+
259
+ # time
260
+ time_embed_dim, timestep_input_dim = self._set_time_proj(
261
+ time_embedding_type,
262
+ block_out_channels=block_out_channels,
263
+ flip_sin_to_cos=flip_sin_to_cos,
264
+ freq_shift=freq_shift,
265
+ time_embedding_dim=time_embedding_dim,
266
+ )
267
+
268
+ self.time_embedding = TimestepEmbedding(
269
+ timestep_input_dim,
270
+ time_embed_dim,
271
+ act_fn=act_fn,
272
+ post_act_fn=timestep_post_act,
273
+ cond_proj_dim=time_cond_proj_dim,
274
+ )
275
+
276
+ self._set_encoder_hid_proj(
277
+ encoder_hid_dim_type,
278
+ cross_attention_dim=cross_attention_dim,
279
+ encoder_hid_dim=encoder_hid_dim,
280
+ )
281
+
282
+ # class embedding
283
+ self._set_class_embedding(
284
+ class_embed_type,
285
+ act_fn=act_fn,
286
+ num_class_embeds=num_class_embeds,
287
+ projection_class_embeddings_input_dim=projection_class_embeddings_input_dim,
288
+ time_embed_dim=time_embed_dim,
289
+ timestep_input_dim=timestep_input_dim,
290
+ )
291
+
292
+ self._set_add_embedding(
293
+ addition_embed_type,
294
+ addition_embed_type_num_heads=addition_embed_type_num_heads,
295
+ addition_time_embed_dim=addition_time_embed_dim,
296
+ cross_attention_dim=cross_attention_dim,
297
+ encoder_hid_dim=encoder_hid_dim,
298
+ flip_sin_to_cos=flip_sin_to_cos,
299
+ freq_shift=freq_shift,
300
+ projection_class_embeddings_input_dim=projection_class_embeddings_input_dim,
301
+ time_embed_dim=time_embed_dim,
302
+ )
303
+
304
+ if time_embedding_act_fn is None:
305
+ self.time_embed_act = None
306
+ else:
307
+ self.time_embed_act = get_activation(time_embedding_act_fn)
308
+
309
+ self.down_blocks = nn.ModuleList([])
310
+ self.up_blocks = nn.ModuleList([])
311
+
312
+ if isinstance(only_cross_attention, bool):
313
+ if mid_block_only_cross_attention is None:
314
+ mid_block_only_cross_attention = only_cross_attention
315
+
316
+ only_cross_attention = [only_cross_attention] * len(down_block_types)
317
+
318
+ if mid_block_only_cross_attention is None:
319
+ mid_block_only_cross_attention = False
320
+
321
+ if isinstance(num_attention_heads, int):
322
+ num_attention_heads = (num_attention_heads,) * len(down_block_types)
323
+
324
+ if isinstance(attention_head_dim, int):
325
+ attention_head_dim = (attention_head_dim,) * len(down_block_types)
326
+
327
+ if isinstance(cross_attention_dim, int):
328
+ cross_attention_dim = (cross_attention_dim,) * len(down_block_types)
329
+
330
+ if isinstance(layers_per_block, int):
331
+ layers_per_block = [layers_per_block] * len(down_block_types)
332
+
333
+ if isinstance(transformer_layers_per_block, int):
334
+ transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)
335
+
336
+ if class_embeddings_concat:
337
+ # The time embeddings are concatenated with the class embeddings. The dimension of the
338
+ # time embeddings passed to the down, middle, and up blocks is twice the dimension of the
339
+ # regular time embeddings
340
+ blocks_time_embed_dim = time_embed_dim * 2
341
+ else:
342
+ blocks_time_embed_dim = time_embed_dim
343
+
344
+ # down
345
+ output_channel = block_out_channels[0]
346
+ for i, down_block_type in enumerate(down_block_types):
347
+ input_channel = output_channel
348
+ output_channel = block_out_channels[i]
349
+ is_final_block = i == len(block_out_channels) - 1
350
+
351
+ down_block = get_down_block(
352
+ down_block_type,
353
+ num_layers=layers_per_block[i],
354
+ transformer_layers_per_block=transformer_layers_per_block[i],
355
+ in_channels=input_channel,
356
+ out_channels=output_channel,
357
+ temb_channels=blocks_time_embed_dim,
358
+ add_downsample=not is_final_block,
359
+ resnet_eps=norm_eps,
360
+ resnet_act_fn=act_fn,
361
+ resnet_groups=norm_num_groups,
362
+ cross_attention_dim=cross_attention_dim[i],
363
+ num_attention_heads=num_attention_heads[i],
364
+ downsample_padding=downsample_padding,
365
+ dual_cross_attention=dual_cross_attention,
366
+ use_linear_projection=use_linear_projection,
367
+ only_cross_attention=only_cross_attention[i],
368
+ upcast_attention=upcast_attention,
369
+ resnet_time_scale_shift=resnet_time_scale_shift,
370
+ attention_type=attention_type,
371
+ resnet_skip_time_act=resnet_skip_time_act,
372
+ resnet_out_scale_factor=resnet_out_scale_factor,
373
+ cross_attention_norm=cross_attention_norm,
374
+ attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
375
+ dropout=dropout,
376
+ )
377
+ self.down_blocks.append(down_block)
378
+
379
+ # mid
380
+ self.mid_block = get_mid_block(
381
+ mid_block_type,
382
+ temb_channels=blocks_time_embed_dim,
383
+ in_channels=block_out_channels[-1],
384
+ resnet_eps=norm_eps,
385
+ resnet_act_fn=act_fn,
386
+ resnet_groups=norm_num_groups,
387
+ output_scale_factor=mid_block_scale_factor,
388
+ transformer_layers_per_block=transformer_layers_per_block[-1],
389
+ num_attention_heads=num_attention_heads[-1],
390
+ cross_attention_dim=cross_attention_dim[-1],
391
+ dual_cross_attention=dual_cross_attention,
392
+ use_linear_projection=use_linear_projection,
393
+ mid_block_only_cross_attention=mid_block_only_cross_attention,
394
+ upcast_attention=upcast_attention,
395
+ resnet_time_scale_shift=resnet_time_scale_shift,
396
+ attention_type=attention_type,
397
+ resnet_skip_time_act=resnet_skip_time_act,
398
+ cross_attention_norm=cross_attention_norm,
399
+ attention_head_dim=attention_head_dim[-1],
400
+ dropout=dropout,
401
+ )
402
+
403
+ # count how many layers upsample the images
404
+ self.num_upsamplers = 0
405
+
406
+ # up
407
+ reversed_block_out_channels = list(reversed(block_out_channels))
408
+ reversed_num_attention_heads = list(reversed(num_attention_heads))
409
+ reversed_layers_per_block = list(reversed(layers_per_block))
410
+ reversed_cross_attention_dim = list(reversed(cross_attention_dim))
411
+ reversed_transformer_layers_per_block = (
412
+ list(reversed(transformer_layers_per_block))
413
+ if reverse_transformer_layers_per_block is None
414
+ else reverse_transformer_layers_per_block
415
+ )
416
+ only_cross_attention = list(reversed(only_cross_attention))
417
+
418
+ output_channel = reversed_block_out_channels[0]
419
+ for i, up_block_type in enumerate(up_block_types):
420
+ is_final_block = i == len(block_out_channels) - 1
421
+
422
+ prev_output_channel = output_channel
423
+ output_channel = reversed_block_out_channels[i]
424
+ input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
425
+
426
+ # add upsample block for all BUT final layer
427
+ if not is_final_block:
428
+ add_upsample = True
429
+ self.num_upsamplers += 1
430
+ else:
431
+ add_upsample = False
432
+
433
+ up_block = get_up_block(
434
+ up_block_type,
435
+ num_layers=reversed_layers_per_block[i] + 1,
436
+ transformer_layers_per_block=reversed_transformer_layers_per_block[i],
437
+ in_channels=input_channel,
438
+ out_channels=output_channel,
439
+ prev_output_channel=prev_output_channel,
440
+ temb_channels=blocks_time_embed_dim,
441
+ add_upsample=add_upsample,
442
+ resnet_eps=norm_eps,
443
+ resnet_act_fn=act_fn,
444
+ resolution_idx=i,
445
+ resnet_groups=norm_num_groups,
446
+ cross_attention_dim=reversed_cross_attention_dim[i],
447
+ num_attention_heads=reversed_num_attention_heads[i],
448
+ dual_cross_attention=dual_cross_attention,
449
+ use_linear_projection=use_linear_projection,
450
+ only_cross_attention=only_cross_attention[i],
451
+ upcast_attention=upcast_attention,
452
+ resnet_time_scale_shift=resnet_time_scale_shift,
453
+ attention_type=attention_type,
454
+ resnet_skip_time_act=resnet_skip_time_act,
455
+ resnet_out_scale_factor=resnet_out_scale_factor,
456
+ cross_attention_norm=cross_attention_norm,
457
+ attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
458
+ dropout=dropout,
459
+ )
460
+ self.up_blocks.append(up_block)
461
+ prev_output_channel = output_channel
462
+
463
+ # out
464
+ if norm_num_groups is not None:
465
+ self.conv_norm_out = nn.GroupNorm(
466
+ num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps
467
+ )
468
+
469
+ self.conv_act = get_activation(act_fn)
470
+
471
+ else:
472
+ self.conv_norm_out = None
473
+ self.conv_act = None
474
+
475
+ conv_out_padding = (conv_out_kernel - 1) // 2
476
+ self.conv_out = nn.Conv2d(
477
+ block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding
478
+ )
479
+
480
+ self._set_pos_net_if_use_gligen(attention_type=attention_type, cross_attention_dim=cross_attention_dim)
481
+
482
+ def _check_config(
483
+ self,
484
+ down_block_types: Tuple[str],
485
+ up_block_types: Tuple[str],
486
+ only_cross_attention: Union[bool, Tuple[bool]],
487
+ block_out_channels: Tuple[int],
488
+ layers_per_block: Union[int, Tuple[int]],
489
+ cross_attention_dim: Union[int, Tuple[int]],
490
+ transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple[int]]],
491
+ reverse_transformer_layers_per_block: bool,
492
+ attention_head_dim: int,
493
+ num_attention_heads: Optional[Union[int, Tuple[int]]],
494
+ ):
495
+ if len(down_block_types) != len(up_block_types):
496
+ raise ValueError(
497
+ f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
498
+ )
499
+
500
+ if len(block_out_channels) != len(down_block_types):
501
+ raise ValueError(
502
+ f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
503
+ )
504
+
505
+ if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):
506
+ raise ValueError(
507
+ f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
508
+ )
509
+
510
+ if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
511
+ raise ValueError(
512
+ f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
513
+ )
514
+
515
+ if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types):
516
+ raise ValueError(
517
+ f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}."
518
+ )
519
+
520
+ if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types):
521
+ raise ValueError(
522
+ f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}."
523
+ )
524
+
525
+ if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types):
526
+ raise ValueError(
527
+ f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}."
528
+ )
529
+ if isinstance(transformer_layers_per_block, list) and reverse_transformer_layers_per_block is None:
530
+ for layer_number_per_block in transformer_layers_per_block:
531
+ if isinstance(layer_number_per_block, list):
532
+ raise ValueError("Must provide 'reverse_transformer_layers_per_block` if using asymmetrical UNet.")
533
+
534
+ def _set_time_proj(
535
+ self,
536
+ time_embedding_type: str,
537
+ block_out_channels: int,
538
+ flip_sin_to_cos: bool,
539
+ freq_shift: float,
540
+ time_embedding_dim: int,
541
+ ) -> Tuple[int, int]:
542
+ if time_embedding_type == "fourier":
543
+ time_embed_dim = time_embedding_dim or block_out_channels[0] * 2
544
+ if time_embed_dim % 2 != 0:
545
+ raise ValueError(f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}.")
546
+ self.time_proj = GaussianFourierProjection(
547
+ time_embed_dim // 2, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos
548
+ )
549
+ timestep_input_dim = time_embed_dim
550
+ elif time_embedding_type == "positional":
551
+ time_embed_dim = time_embedding_dim or block_out_channels[0] * 4
552
+
553
+ self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
554
+ timestep_input_dim = block_out_channels[0]
555
+ else:
556
+ raise ValueError(
557
+ f"{time_embedding_type} does not exist. Please make sure to use one of `fourier` or `positional`."
558
+ )
559
+
560
+ return time_embed_dim, timestep_input_dim
561
+
562
+ def _set_encoder_hid_proj(
563
+ self,
564
+ encoder_hid_dim_type: Optional[str],
565
+ cross_attention_dim: Union[int, Tuple[int]],
566
+ encoder_hid_dim: Optional[int],
567
+ ):
568
+ if encoder_hid_dim_type is None and encoder_hid_dim is not None:
569
+ encoder_hid_dim_type = "text_proj"
570
+ self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type)
571
+ logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.")
572
+
573
+ if encoder_hid_dim is None and encoder_hid_dim_type is not None:
574
+ raise ValueError(
575
+ f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}."
576
+ )
577
+
578
+ if encoder_hid_dim_type == "text_proj":
579
+ self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim)
580
+ elif encoder_hid_dim_type == "text_image_proj":
581
+ # image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much
582
+ # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
583
+ # case when `addition_embed_type == "text_image_proj"` (Kadinsky 2.1)`
584
+ self.encoder_hid_proj = TextImageProjection(
585
+ text_embed_dim=encoder_hid_dim,
586
+ image_embed_dim=cross_attention_dim,
587
+ cross_attention_dim=cross_attention_dim,
588
+ )
589
+ elif encoder_hid_dim_type == "image_proj":
590
+ # Kandinsky 2.2
591
+ self.encoder_hid_proj = ImageProjection(
592
+ image_embed_dim=encoder_hid_dim,
593
+ cross_attention_dim=cross_attention_dim,
594
+ )
595
+ elif encoder_hid_dim_type is not None:
596
+ raise ValueError(
597
+ f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'."
598
+ )
599
+ else:
600
+ self.encoder_hid_proj = None
601
+
602
+ def _set_class_embedding(
603
+ self,
604
+ class_embed_type: Optional[str],
605
+ act_fn: str,
606
+ num_class_embeds: Optional[int],
607
+ projection_class_embeddings_input_dim: Optional[int],
608
+ time_embed_dim: int,
609
+ timestep_input_dim: int,
610
+ ):
611
+ if class_embed_type is None and num_class_embeds is not None:
612
+ self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
613
+ elif class_embed_type == "timestep":
614
+ self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim, act_fn=act_fn)
615
+ elif class_embed_type == "identity":
616
+ self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
617
+ elif class_embed_type == "projection":
618
+ if projection_class_embeddings_input_dim is None:
619
+ raise ValueError(
620
+ "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set"
621
+ )
622
+ # The projection `class_embed_type` is the same as the timestep `class_embed_type` except
623
+ # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings
624
+ # 2. it projects from an arbitrary input dimension.
625
+ #
626
+ # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.
627
+ # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.
628
+ # As a result, `TimestepEmbedding` can be passed arbitrary vectors.
629
+ self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
630
+ elif class_embed_type == "simple_projection":
631
+ if projection_class_embeddings_input_dim is None:
632
+ raise ValueError(
633
+ "`class_embed_type`: 'simple_projection' requires `projection_class_embeddings_input_dim` be set"
634
+ )
635
+ self.class_embedding = nn.Linear(projection_class_embeddings_input_dim, time_embed_dim)
636
+ else:
637
+ self.class_embedding = None
638
+
639
+ def _set_add_embedding(
640
+ self,
641
+ addition_embed_type: str,
642
+ addition_embed_type_num_heads: int,
643
+ addition_time_embed_dim: Optional[int],
644
+ flip_sin_to_cos: bool,
645
+ freq_shift: float,
646
+ cross_attention_dim: Optional[int],
647
+ encoder_hid_dim: Optional[int],
648
+ projection_class_embeddings_input_dim: Optional[int],
649
+ time_embed_dim: int,
650
+ ):
651
+ if addition_embed_type == "text":
652
+ if encoder_hid_dim is not None:
653
+ text_time_embedding_from_dim = encoder_hid_dim
654
+ else:
655
+ text_time_embedding_from_dim = cross_attention_dim
656
+
657
+ self.add_embedding = TextTimeEmbedding(
658
+ text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads
659
+ )
660
+ elif addition_embed_type == "text_image":
661
+ # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much
662
+ # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
663
+ # case when `addition_embed_type == "text_image"` (Kadinsky 2.1)`
664
+ self.add_embedding = TextImageTimeEmbedding(
665
+ text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim
666
+ )
667
+ elif addition_embed_type == "text_time":
668
+ self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift)
669
+ self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
670
+ elif addition_embed_type == "image":
671
+ # Kandinsky 2.2
672
+ self.add_embedding = ImageTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim)
673
+ elif addition_embed_type == "image_hint":
674
+ # Kandinsky 2.2 ControlNet
675
+ self.add_embedding = ImageHintTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim)
676
+ elif addition_embed_type is not None:
677
+ raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.")
678
+
679
+ def _set_pos_net_if_use_gligen(self, attention_type: str, cross_attention_dim: int):
680
+ if attention_type in ["gated", "gated-text-image"]:
681
+ positive_len = 768
682
+ if isinstance(cross_attention_dim, int):
683
+ positive_len = cross_attention_dim
684
+ elif isinstance(cross_attention_dim, tuple) or isinstance(cross_attention_dim, list):
685
+ positive_len = cross_attention_dim[0]
686
+
687
+ feature_type = "text-only" if attention_type == "gated" else "text-image"
688
+ self.position_net = GLIGENTextBoundingboxProjection(
689
+ positive_len=positive_len, out_dim=cross_attention_dim, feature_type=feature_type
690
+ )
691
+
692
+ @property
693
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
694
+ r"""
695
+ Returns:
696
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
697
+ indexed by its weight name.
698
+ """
699
+ # set recursively
700
+ processors = {}
701
+
702
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
703
+ if hasattr(module, "get_processor"):
704
+ processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
705
+
706
+ for sub_name, child in module.named_children():
707
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
708
+
709
+ return processors
710
+
711
+ for name, module in self.named_children():
712
+ fn_recursive_add_processors(name, module, processors)
713
+
714
+ return processors
715
+
716
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
717
+ r"""
718
+ Sets the attention processor to use to compute attention.
719
+
720
+ Parameters:
721
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
722
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
723
+ for **all** `Attention` layers.
724
+
725
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
726
+ processor. This is strongly recommended when setting trainable attention processors.
727
+
728
+ """
729
+ count = len(self.attn_processors.keys())
730
+
731
+ if isinstance(processor, dict) and len(processor) != count:
732
+ raise ValueError(
733
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
734
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
735
+ )
736
+
737
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
738
+ if hasattr(module, "set_processor"):
739
+ if not isinstance(processor, dict):
740
+ module.set_processor(processor)
741
+ else:
742
+ module.set_processor(processor.pop(f"{name}.processor"))
743
+
744
+ for sub_name, child in module.named_children():
745
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
746
+
747
+ for name, module in self.named_children():
748
+ fn_recursive_attn_processor(name, module, processor)
749
+
750
+ def set_default_attn_processor(self):
751
+ """
752
+ Disables custom attention processors and sets the default attention implementation.
753
+ """
754
+ if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
755
+ processor = AttnAddedKVProcessor()
756
+ elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
757
+ processor = AttnProcessor()
758
+ else:
759
+ raise ValueError(
760
+ f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
761
+ )
762
+
763
+ self.set_attn_processor(processor)
764
+
765
+ def set_attention_slice(self, slice_size: Union[str, int, List[int]] = "auto"):
766
+ r"""
767
+ Enable sliced attention computation.
768
+
769
+ When this option is enabled, the attention module splits the input tensor in slices to compute attention in
770
+ several steps. This is useful for saving some memory in exchange for a small decrease in speed.
771
+
772
+ Args:
773
+ slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
774
+ When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If
775
+ `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is
776
+ provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
777
+ must be a multiple of `slice_size`.
778
+ """
779
+ sliceable_head_dims = []
780
+
781
+ def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
782
+ if hasattr(module, "set_attention_slice"):
783
+ sliceable_head_dims.append(module.sliceable_head_dim)
784
+
785
+ for child in module.children():
786
+ fn_recursive_retrieve_sliceable_dims(child)
787
+
788
+ # retrieve number of attention layers
789
+ for module in self.children():
790
+ fn_recursive_retrieve_sliceable_dims(module)
791
+
792
+ num_sliceable_layers = len(sliceable_head_dims)
793
+
794
+ if slice_size == "auto":
795
+ # half the attention head size is usually a good trade-off between
796
+ # speed and memory
797
+ slice_size = [dim // 2 for dim in sliceable_head_dims]
798
+ elif slice_size == "max":
799
+ # make smallest slice possible
800
+ slice_size = num_sliceable_layers * [1]
801
+
802
+ slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
803
+
804
+ if len(slice_size) != len(sliceable_head_dims):
805
+ raise ValueError(
806
+ f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
807
+ f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
808
+ )
809
+
810
+ for i in range(len(slice_size)):
811
+ size = slice_size[i]
812
+ dim = sliceable_head_dims[i]
813
+ if size is not None and size > dim:
814
+ raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
815
+
816
+ # Recursively walk through all the children.
817
+ # Any children which exposes the set_attention_slice method
818
+ # gets the message
819
+ def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
820
+ if hasattr(module, "set_attention_slice"):
821
+ module.set_attention_slice(slice_size.pop())
822
+
823
+ for child in module.children():
824
+ fn_recursive_set_attention_slice(child, slice_size)
825
+
826
+ reversed_slice_size = list(reversed(slice_size))
827
+ for module in self.children():
828
+ fn_recursive_set_attention_slice(module, reversed_slice_size)
829
+
830
+ def _set_gradient_checkpointing(self, module, value=False):
831
+ if hasattr(module, "gradient_checkpointing"):
832
+ module.gradient_checkpointing = value
833
+
834
+ def enable_freeu(self, s1: float, s2: float, b1: float, b2: float):
835
+ r"""Enables the FreeU mechanism from https://arxiv.org/abs/2309.11497.
836
+
837
+ The suffixes after the scaling factors represent the stage blocks where they are being applied.
838
+
839
+ Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of values that
840
+ are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL.
841
+
842
+ Args:
843
+ s1 (`float`):
844
+ Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to
845
+ mitigate the "oversmoothing effect" in the enhanced denoising process.
846
+ s2 (`float`):
847
+ Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to
848
+ mitigate the "oversmoothing effect" in the enhanced denoising process.
849
+ b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features.
850
+ b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features.
851
+ """
852
+ for i, upsample_block in enumerate(self.up_blocks):
853
+ setattr(upsample_block, "s1", s1)
854
+ setattr(upsample_block, "s2", s2)
855
+ setattr(upsample_block, "b1", b1)
856
+ setattr(upsample_block, "b2", b2)
857
+
858
+ def disable_freeu(self):
859
+ """Disables the FreeU mechanism."""
860
+ freeu_keys = {"s1", "s2", "b1", "b2"}
861
+ for i, upsample_block in enumerate(self.up_blocks):
862
+ for k in freeu_keys:
863
+ if hasattr(upsample_block, k) or getattr(upsample_block, k, None) is not None:
864
+ setattr(upsample_block, k, None)
865
+
866
+ def fuse_qkv_projections(self):
867
+ """
868
+ Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query,
869
+ key, value) are fused. For cross-attention modules, key and value projection matrices are fused.
870
+
871
+ <Tip warning={true}>
872
+
873
+ This API is 🧪 experimental.
874
+
875
+ </Tip>
876
+ """
877
+ self.original_attn_processors = None
878
+
879
+ for _, attn_processor in self.attn_processors.items():
880
+ if "Added" in str(attn_processor.__class__.__name__):
881
+ raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
882
+
883
+ self.original_attn_processors = self.attn_processors
884
+
885
+ for module in self.modules():
886
+ if isinstance(module, Attention):
887
+ module.fuse_projections(fuse=True)
888
+
889
+ def unfuse_qkv_projections(self):
890
+ """Disables the fused QKV projection if enabled.
891
+
892
+ <Tip warning={true}>
893
+
894
+ This API is 🧪 experimental.
895
+
896
+ </Tip>
897
+
898
+ """
899
+ if self.original_attn_processors is not None:
900
+ self.set_attn_processor(self.original_attn_processors)
901
+
902
+ def unload_lora(self):
903
+ """Unloads LoRA weights."""
904
+ deprecate(
905
+ "unload_lora",
906
+ "0.28.0",
907
+ "Calling `unload_lora()` is deprecated and will be removed in a future version. Please install `peft` and then call `disable_adapters().",
908
+ )
909
+ for module in self.modules():
910
+ if hasattr(module, "set_lora_layer"):
911
+ module.set_lora_layer(None)
912
+
913
+ def get_time_embed(
914
+ self, sample: torch.Tensor, timestep: Union[torch.Tensor, float, int]
915
+ ) -> Optional[torch.Tensor]:
916
+ timesteps = timestep
917
+ if not torch.is_tensor(timesteps):
918
+ # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
919
+ # This would be a good case for the `match` statement (Python 3.10+)
920
+ is_mps = sample.device.type == "mps"
921
+ if isinstance(timestep, float):
922
+ dtype = torch.float32 if is_mps else torch.float64
923
+ else:
924
+ dtype = torch.int32 if is_mps else torch.int64
925
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
926
+ elif len(timesteps.shape) == 0:
927
+ timesteps = timesteps[None].to(sample.device)
928
+
929
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
930
+ timesteps = timesteps.expand(sample.shape[0])
931
+
932
+ t_emb = self.time_proj(timesteps)
933
+ # `Timesteps` does not contain any weights and will always return f32 tensors
934
+ # but time_embedding might actually be running in fp16. so we need to cast here.
935
+ # there might be better ways to encapsulate this.
936
+ t_emb = t_emb.to(dtype=sample.dtype)
937
+ return t_emb
938
+
939
+ def get_class_embed(self, sample: torch.Tensor, class_labels: Optional[torch.Tensor]) -> Optional[torch.Tensor]:
940
+ class_emb = None
941
+ if self.class_embedding is not None:
942
+ if class_labels is None:
943
+ raise ValueError("class_labels should be provided when num_class_embeds > 0")
944
+
945
+ if self.config.class_embed_type == "timestep":
946
+ class_labels = self.time_proj(class_labels)
947
+
948
+ # `Timesteps` does not contain any weights and will always return f32 tensors
949
+ # there might be better ways to encapsulate this.
950
+ class_labels = class_labels.to(dtype=sample.dtype)
951
+
952
+ class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype)
953
+ return class_emb
954
+
955
+ def get_aug_embed(
956
+ self, emb: torch.Tensor, encoder_hidden_states: torch.Tensor, added_cond_kwargs: Dict[str, Any]
957
+ ) -> Optional[torch.Tensor]:
958
+ aug_emb = None
959
+ if self.config.addition_embed_type == "text":
960
+ aug_emb = self.add_embedding(encoder_hidden_states)
961
+ elif self.config.addition_embed_type == "text_image":
962
+ # Kandinsky 2.1 - style
963
+ if "image_embeds" not in added_cond_kwargs:
964
+ raise ValueError(
965
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
966
+ )
967
+
968
+ image_embs = added_cond_kwargs.get("image_embeds")
969
+ text_embs = added_cond_kwargs.get("text_embeds", encoder_hidden_states)
970
+ aug_emb = self.add_embedding(text_embs, image_embs)
971
+ elif self.config.addition_embed_type == "text_time":
972
+ # SDXL - style
973
+ if "text_embeds" not in added_cond_kwargs:
974
+ raise ValueError(
975
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
976
+ )
977
+ text_embeds = added_cond_kwargs.get("text_embeds")
978
+ if "time_ids" not in added_cond_kwargs:
979
+ raise ValueError(
980
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
981
+ )
982
+ time_ids = added_cond_kwargs.get("time_ids")
983
+ time_embeds = self.add_time_proj(time_ids.flatten())
984
+ time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
985
+ add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
986
+ add_embeds = add_embeds.to(emb.dtype)
987
+ aug_emb = self.add_embedding(add_embeds)
988
+ elif self.config.addition_embed_type == "image":
989
+ # Kandinsky 2.2 - style
990
+ if "image_embeds" not in added_cond_kwargs:
991
+ raise ValueError(
992
+ f"{self.__class__} has the config param `addition_embed_type` set to 'image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
993
+ )
994
+ image_embs = added_cond_kwargs.get("image_embeds")
995
+ aug_emb = self.add_embedding(image_embs)
996
+ elif self.config.addition_embed_type == "image_hint":
997
+ # Kandinsky 2.2 - style
998
+ if "image_embeds" not in added_cond_kwargs or "hint" not in added_cond_kwargs:
999
+ raise ValueError(
1000
+ f"{self.__class__} has the config param `addition_embed_type` set to 'image_hint' which requires the keyword arguments `image_embeds` and `hint` to be passed in `added_cond_kwargs`"
1001
+ )
1002
+ image_embs = added_cond_kwargs.get("image_embeds")
1003
+ hint = added_cond_kwargs.get("hint")
1004
+ aug_emb = self.add_embedding(image_embs, hint)
1005
+ return aug_emb
1006
+
1007
+ def process_encoder_hidden_states(
1008
+ self, encoder_hidden_states: torch.Tensor, added_cond_kwargs: Dict[str, Any]
1009
+ ) -> torch.Tensor:
1010
+ if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_proj":
1011
+ encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states)
1012
+ elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_image_proj":
1013
+ # Kadinsky 2.1 - style
1014
+ if "image_embeds" not in added_cond_kwargs:
1015
+ raise ValueError(
1016
+ f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
1017
+ )
1018
+
1019
+ image_embeds = added_cond_kwargs.get("image_embeds")
1020
+ encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states, image_embeds)
1021
+ elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "image_proj":
1022
+ # Kandinsky 2.2 - style
1023
+ if "image_embeds" not in added_cond_kwargs:
1024
+ raise ValueError(
1025
+ f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
1026
+ )
1027
+ image_embeds = added_cond_kwargs.get("image_embeds")
1028
+ encoder_hidden_states = self.encoder_hid_proj(image_embeds)
1029
+ elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "ip_image_proj":
1030
+ if "image_embeds" not in added_cond_kwargs:
1031
+ raise ValueError(
1032
+ f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'ip_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
1033
+ )
1034
+ image_embeds = added_cond_kwargs.get("image_embeds")
1035
+ image_embeds = self.encoder_hid_proj(image_embeds)
1036
+ encoder_hidden_states = (encoder_hidden_states, image_embeds)
1037
+ return encoder_hidden_states
1038
+
1039
+ def forward(
1040
+ self,
1041
+ sample: torch.FloatTensor,
1042
+ timestep: Union[torch.Tensor, float, int],
1043
+ encoder_hidden_states: torch.Tensor,
1044
+ class_labels: Optional[torch.Tensor] = None,
1045
+ timestep_cond: Optional[torch.Tensor] = None,
1046
+ attention_mask: Optional[torch.Tensor] = None,
1047
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
1048
+ added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
1049
+ down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
1050
+ mid_block_additional_residual: Optional[torch.Tensor] = None,
1051
+ down_intrablock_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
1052
+ encoder_attention_mask: Optional[torch.Tensor] = None,
1053
+ return_dict: bool = True,
1054
+ down_block_add_samples: Optional[Tuple[torch.Tensor]] = None,
1055
+ mid_block_add_sample: Optional[Tuple[torch.Tensor]] = None,
1056
+ up_block_add_samples: Optional[Tuple[torch.Tensor]] = None,
1057
+ ) -> Union[UNet2DConditionOutput, Tuple]:
1058
+ r"""
1059
+ The [`UNet2DConditionModel`] forward method.
1060
+
1061
+ Args:
1062
+ sample (`torch.FloatTensor`):
1063
+ The noisy input tensor with the following shape `(batch, channel, height, width)`.
1064
+ timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
1065
+ encoder_hidden_states (`torch.FloatTensor`):
1066
+ The encoder hidden states with shape `(batch, sequence_length, feature_dim)`.
1067
+ class_labels (`torch.Tensor`, *optional*, defaults to `None`):
1068
+ Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
1069
+ timestep_cond: (`torch.Tensor`, *optional*, defaults to `None`):
1070
+ Conditional embeddings for timestep. If provided, the embeddings will be summed with the samples passed
1071
+ through the `self.time_embedding` layer to obtain the timestep embeddings.
1072
+ attention_mask (`torch.Tensor`, *optional*, defaults to `None`):
1073
+ An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
1074
+ is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
1075
+ negative values to the attention scores corresponding to "discard" tokens.
1076
+ cross_attention_kwargs (`dict`, *optional*):
1077
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
1078
+ `self.processor` in
1079
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
1080
+ added_cond_kwargs: (`dict`, *optional*):
1081
+ A kwargs dictionary containing additional embeddings that if specified are added to the embeddings that
1082
+ are passed along to the UNet blocks.
1083
+ down_block_additional_residuals: (`tuple` of `torch.Tensor`, *optional*):
1084
+ A tuple of tensors that if specified are added to the residuals of down unet blocks.
1085
+ mid_block_additional_residual: (`torch.Tensor`, *optional*):
1086
+ A tensor that if specified is added to the residual of the middle unet block.
1087
+ encoder_attention_mask (`torch.Tensor`):
1088
+ A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If
1089
+ `True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias,
1090
+ which adds large negative values to the attention scores corresponding to "discard" tokens.
1091
+ return_dict (`bool`, *optional*, defaults to `True`):
1092
+ Whether or not to return a [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
1093
+ tuple.
1094
+ cross_attention_kwargs (`dict`, *optional*):
1095
+ A kwargs dictionary that if specified is passed along to the [`AttnProcessor`].
1096
+ added_cond_kwargs: (`dict`, *optional*):
1097
+ A kwargs dictionary containin additional embeddings that if specified are added to the embeddings that
1098
+ are passed along to the UNet blocks.
1099
+ down_block_additional_residuals (`tuple` of `torch.Tensor`, *optional*):
1100
+ additional residuals to be added to UNet long skip connections from down blocks to up blocks for
1101
+ example from ControlNet side model(s)
1102
+ mid_block_additional_residual (`torch.Tensor`, *optional*):
1103
+ additional residual to be added to UNet mid block output, for example from ControlNet side model
1104
+ down_intrablock_additional_residuals (`tuple` of `torch.Tensor`, *optional*):
1105
+ additional residuals to be added within UNet down blocks, for example from T2I-Adapter side model(s)
1106
+
1107
+ Returns:
1108
+ [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
1109
+ If `return_dict` is True, an [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] is returned, otherwise
1110
+ a `tuple` is returned where the first element is the sample tensor.
1111
+ """
1112
+ # By default samples have to be AT least a multiple of the overall upsampling factor.
1113
+ # The overall upsampling factor is equal to 2 ** (# num of upsampling layers).
1114
+ # However, the upsampling interpolation output size can be forced to fit any upsampling size
1115
+ # on the fly if necessary.
1116
+ default_overall_up_factor = 2 ** self.num_upsamplers
1117
+
1118
+ # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
1119
+ forward_upsample_size = False
1120
+ upsample_size = None
1121
+
1122
+ for dim in sample.shape[-2:]:
1123
+ if dim % default_overall_up_factor != 0:
1124
+ # Forward upsample size to force interpolation output size.
1125
+ forward_upsample_size = True
1126
+ break
1127
+
1128
+ # ensure attention_mask is a bias, and give it a singleton query_tokens dimension
1129
+ # expects mask of shape:
1130
+ # [batch, key_tokens]
1131
+ # adds singleton query_tokens dimension:
1132
+ # [batch, 1, key_tokens]
1133
+ # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
1134
+ # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
1135
+ # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
1136
+ if attention_mask is not None:
1137
+ # assume that mask is expressed as:
1138
+ # (1 = keep, 0 = discard)
1139
+ # convert mask into a bias that can be added to attention scores:
1140
+ # (keep = +0, discard = -10000.0)
1141
+ attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
1142
+ attention_mask = attention_mask.unsqueeze(1)
1143
+
1144
+ # convert encoder_attention_mask to a bias the same way we do for attention_mask
1145
+ if encoder_attention_mask is not None:
1146
+ encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0
1147
+ encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
1148
+
1149
+ # 0. center input if necessary
1150
+ if self.config.center_input_sample:
1151
+ sample = 2 * sample - 1.0
1152
+
1153
+ # 1. time
1154
+ t_emb = self.get_time_embed(sample=sample, timestep=timestep)
1155
+ emb = self.time_embedding(t_emb, timestep_cond)
1156
+ aug_emb = None
1157
+
1158
+ class_emb = self.get_class_embed(sample=sample, class_labels=class_labels)
1159
+ if class_emb is not None:
1160
+ if self.config.class_embeddings_concat:
1161
+ emb = torch.cat([emb, class_emb], dim=-1)
1162
+ else:
1163
+ emb = emb + class_emb
1164
+
1165
+ aug_emb = self.get_aug_embed(
1166
+ emb=emb, encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs
1167
+ )
1168
+ if self.config.addition_embed_type == "image_hint":
1169
+ aug_emb, hint = aug_emb
1170
+ sample = torch.cat([sample, hint], dim=1)
1171
+
1172
+ emb = emb + aug_emb if aug_emb is not None else emb
1173
+
1174
+ if self.time_embed_act is not None:
1175
+ emb = self.time_embed_act(emb)
1176
+
1177
+ encoder_hidden_states = self.process_encoder_hidden_states(
1178
+ encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs
1179
+ )
1180
+
1181
+ # 2. pre-process
1182
+ sample = self.conv_in(sample)
1183
+
1184
+ # 2.5 GLIGEN position net
1185
+ if cross_attention_kwargs is not None and cross_attention_kwargs.get("gligen", None) is not None:
1186
+ cross_attention_kwargs = cross_attention_kwargs.copy()
1187
+ gligen_args = cross_attention_kwargs.pop("gligen")
1188
+ cross_attention_kwargs["gligen"] = {"objs": self.position_net(**gligen_args)}
1189
+
1190
+ # 3. down
1191
+ lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
1192
+ if USE_PEFT_BACKEND:
1193
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
1194
+ scale_lora_layers(self, lora_scale)
1195
+
1196
+ is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None
1197
+ # using new arg down_intrablock_additional_residuals for T2I-Adapters, to distinguish from controlnets
1198
+ is_adapter = down_intrablock_additional_residuals is not None
1199
+ # maintain backward compatibility for legacy usage, where
1200
+ # T2I-Adapter and ControlNet both use down_block_additional_residuals arg
1201
+ # but can only use one or the other
1202
+ is_brushnet = down_block_add_samples is not None and mid_block_add_sample is not None and up_block_add_samples is not None
1203
+ if not is_adapter and mid_block_additional_residual is None and down_block_additional_residuals is not None:
1204
+ deprecate(
1205
+ "T2I should not use down_block_additional_residuals",
1206
+ "1.3.0",
1207
+ "Passing intrablock residual connections with `down_block_additional_residuals` is deprecated \
1208
+ and will be removed in diffusers 1.3.0. `down_block_additional_residuals` should only be used \
1209
+ for ControlNet. Please make sure use `down_intrablock_additional_residuals` instead. ",
1210
+ standard_warn=False,
1211
+ )
1212
+ down_intrablock_additional_residuals = down_block_additional_residuals
1213
+ is_adapter = True
1214
+
1215
+ down_block_res_samples = (sample,)
1216
+
1217
+ if is_brushnet:
1218
+ sample = sample + down_block_add_samples.pop(0)
1219
+
1220
+ for downsample_block in self.down_blocks:
1221
+ if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
1222
+ # For t2i-adapter CrossAttnDownBlock2D
1223
+ additional_residuals = {}
1224
+ if is_adapter and len(down_intrablock_additional_residuals) > 0:
1225
+ additional_residuals["additional_residuals"] = down_intrablock_additional_residuals.pop(0)
1226
+
1227
+ if is_brushnet and len(down_block_add_samples) > 0:
1228
+ additional_residuals["down_block_add_samples"] = [down_block_add_samples.pop(0)
1229
+ for _ in range(
1230
+ len(downsample_block.resnets) + (downsample_block.downsamplers != None))]
1231
+
1232
+ sample, res_samples = downsample_block(
1233
+ hidden_states=sample,
1234
+ temb=emb,
1235
+ encoder_hidden_states=encoder_hidden_states,
1236
+ attention_mask=attention_mask,
1237
+ cross_attention_kwargs=cross_attention_kwargs,
1238
+ encoder_attention_mask=encoder_attention_mask,
1239
+ **additional_residuals,
1240
+ )
1241
+ else:
1242
+ additional_residuals = {}
1243
+ if is_brushnet and len(down_block_add_samples) > 0:
1244
+ additional_residuals["down_block_add_samples"] = [down_block_add_samples.pop(0)
1245
+ for _ in range(
1246
+ len(downsample_block.resnets) + (downsample_block.downsamplers != None))]
1247
+
1248
+ sample, res_samples = downsample_block(hidden_states=sample, temb=emb, scale=lora_scale,
1249
+ **additional_residuals)
1250
+ if is_adapter and len(down_intrablock_additional_residuals) > 0:
1251
+ sample += down_intrablock_additional_residuals.pop(0)
1252
+
1253
+ down_block_res_samples += res_samples
1254
+
1255
+ if is_controlnet:
1256
+ new_down_block_res_samples = ()
1257
+
1258
+ for down_block_res_sample, down_block_additional_residual in zip(
1259
+ down_block_res_samples, down_block_additional_residuals
1260
+ ):
1261
+ down_block_res_sample = down_block_res_sample + down_block_additional_residual
1262
+ new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,)
1263
+
1264
+ down_block_res_samples = new_down_block_res_samples
1265
+
1266
+ # 4. mid
1267
+ if self.mid_block is not None:
1268
+ if hasattr(self.mid_block, "has_cross_attention") and self.mid_block.has_cross_attention:
1269
+ sample = self.mid_block(
1270
+ sample,
1271
+ emb,
1272
+ encoder_hidden_states=encoder_hidden_states,
1273
+ attention_mask=attention_mask,
1274
+ cross_attention_kwargs=cross_attention_kwargs,
1275
+ encoder_attention_mask=encoder_attention_mask,
1276
+ )
1277
+ else:
1278
+ sample = self.mid_block(sample, emb)
1279
+
1280
+ # To support T2I-Adapter-XL
1281
+ if (
1282
+ is_adapter
1283
+ and len(down_intrablock_additional_residuals) > 0
1284
+ and sample.shape == down_intrablock_additional_residuals[0].shape
1285
+ ):
1286
+ sample += down_intrablock_additional_residuals.pop(0)
1287
+
1288
+ if is_controlnet:
1289
+ sample = sample + mid_block_additional_residual
1290
+
1291
+ if is_brushnet:
1292
+ sample = sample + mid_block_add_sample
1293
+
1294
+ # 5. up
1295
+ for i, upsample_block in enumerate(self.up_blocks):
1296
+ is_final_block = i == len(self.up_blocks) - 1
1297
+
1298
+ res_samples = down_block_res_samples[-len(upsample_block.resnets):]
1299
+ down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
1300
+
1301
+ # if we have not reached the final block and need to forward the
1302
+ # upsample size, we do it here
1303
+ if not is_final_block and forward_upsample_size:
1304
+ upsample_size = down_block_res_samples[-1].shape[2:]
1305
+
1306
+ if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
1307
+ additional_residuals = {}
1308
+ if is_brushnet and len(up_block_add_samples) > 0:
1309
+ additional_residuals["up_block_add_samples"] = [up_block_add_samples.pop(0)
1310
+ for _ in range(
1311
+ len(upsample_block.resnets) + (upsample_block.upsamplers != None))]
1312
+
1313
+ sample = upsample_block(
1314
+ hidden_states=sample,
1315
+ temb=emb,
1316
+ res_hidden_states_tuple=res_samples,
1317
+ encoder_hidden_states=encoder_hidden_states,
1318
+ cross_attention_kwargs=cross_attention_kwargs,
1319
+ upsample_size=upsample_size,
1320
+ attention_mask=attention_mask,
1321
+ encoder_attention_mask=encoder_attention_mask,
1322
+ **additional_residuals,
1323
+ )
1324
+ else:
1325
+ additional_residuals = {}
1326
+ if is_brushnet and len(up_block_add_samples) > 0:
1327
+ additional_residuals["up_block_add_samples"] = [up_block_add_samples.pop(0)
1328
+ for _ in range(
1329
+ len(upsample_block.resnets) + (upsample_block.upsamplers != None))]
1330
+
1331
+ sample = upsample_block(
1332
+ hidden_states=sample,
1333
+ temb=emb,
1334
+ res_hidden_states_tuple=res_samples,
1335
+ upsample_size=upsample_size,
1336
+ scale=lora_scale,
1337
+ **additional_residuals,
1338
+ )
1339
+
1340
+ # 6. post-process
1341
+ if self.conv_norm_out:
1342
+ sample = self.conv_norm_out(sample)
1343
+ sample = self.conv_act(sample)
1344
+ sample = self.conv_out(sample)
1345
+
1346
+ if USE_PEFT_BACKEND:
1347
+ # remove `lora_scale` from each PEFT layer
1348
+ unscale_lora_layers(self, lora_scale)
1349
+
1350
+ if not return_dict:
1351
+ return (sample,)
1352
+
1353
+ return UNet2DConditionOutput(sample=sample)
shape-guided_result.png ADDED
text_encoder_brushnet/config.json ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "runwayml/stable-diffusion-v1-5",
3
+ "architectures": [
4
+ "CLIPTextModel"
5
+ ],
6
+ "attention_dropout": 0.0,
7
+ "bos_token_id": 0,
8
+ "dropout": 0.0,
9
+ "eos_token_id": 2,
10
+ "hidden_act": "quick_gelu",
11
+ "hidden_size": 768,
12
+ "initializer_factor": 1.0,
13
+ "initializer_range": 0.02,
14
+ "intermediate_size": 3072,
15
+ "layer_norm_eps": 1e-05,
16
+ "max_position_embeddings": 77,
17
+ "model_type": "clip_text_model",
18
+ "num_attention_heads": 12,
19
+ "num_hidden_layers": 12,
20
+ "pad_token_id": 1,
21
+ "projection_dim": 768,
22
+ "torch_dtype": "float32",
23
+ "transformers_version": "4.38.2",
24
+ "vocab_size": 49438
25
+ }
text_encoder_brushnet/model.fp16.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d683c8f067a24a6acee712806f3b5e9b1d7cbdb6a38ba4cbe121b9c39fba3012
3
+ size 246190232
text_encoder_brushnet/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:51b58244429ca454caa7eeae1d7f89ffd5a57c348676b30d7ff5c7e2b7388820
3
+ size 492357328
tokenizer/added_tokens.json ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "P_ctxt_0": 49408,
3
+ "P_ctxt_1": 49409,
4
+ "P_ctxt_2": 49410,
5
+ "P_ctxt_3": 49411,
6
+ "P_ctxt_4": 49412,
7
+ "P_ctxt_5": 49413,
8
+ "P_ctxt_6": 49414,
9
+ "P_ctxt_7": 49415,
10
+ "P_ctxt_8": 49416,
11
+ "P_ctxt_9": 49417,
12
+ "P_obj_0": 49428,
13
+ "P_obj_1": 49429,
14
+ "P_obj_2": 49430,
15
+ "P_obj_3": 49431,
16
+ "P_obj_4": 49432,
17
+ "P_obj_5": 49433,
18
+ "P_obj_6": 49434,
19
+ "P_obj_7": 49435,
20
+ "P_obj_8": 49436,
21
+ "P_obj_9": 49437,
22
+ "P_shape_0": 49418,
23
+ "P_shape_1": 49419,
24
+ "P_shape_2": 49420,
25
+ "P_shape_3": 49421,
26
+ "P_shape_4": 49422,
27
+ "P_shape_5": 49423,
28
+ "P_shape_6": 49424,
29
+ "P_shape_7": 49425,
30
+ "P_shape_8": 49426,
31
+ "P_shape_9": 49427
32
+ }
tokenizer/merges.txt ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer/special_tokens_map.json ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": {
3
+ "content": "<|startoftext|>",
4
+ "lstrip": false,
5
+ "normalized": true,
6
+ "rstrip": false,
7
+ "single_word": false
8
+ },
9
+ "eos_token": {
10
+ "content": "<|endoftext|>",
11
+ "lstrip": false,
12
+ "normalized": true,
13
+ "rstrip": false,
14
+ "single_word": false
15
+ },
16
+ "pad_token": "<|endoftext|>",
17
+ "unk_token": {
18
+ "content": "<|endoftext|>",
19
+ "lstrip": false,
20
+ "normalized": true,
21
+ "rstrip": false,
22
+ "single_word": false
23
+ }
24
+ }
tokenizer/tokenizer_config.json ADDED
@@ -0,0 +1,270 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_prefix_space": false,
3
+ "added_tokens_decoder": {
4
+ "49406": {
5
+ "content": "<|startoftext|>",
6
+ "lstrip": false,
7
+ "normalized": true,
8
+ "rstrip": false,
9
+ "single_word": false,
10
+ "special": true
11
+ },
12
+ "49407": {
13
+ "content": "<|endoftext|>",
14
+ "lstrip": false,
15
+ "normalized": true,
16
+ "rstrip": false,
17
+ "single_word": false,
18
+ "special": true
19
+ },
20
+ "49408": {
21
+ "content": "P_ctxt_0",
22
+ "lstrip": false,
23
+ "normalized": true,
24
+ "rstrip": false,
25
+ "single_word": false,
26
+ "special": false
27
+ },
28
+ "49409": {
29
+ "content": "P_ctxt_1",
30
+ "lstrip": false,
31
+ "normalized": true,
32
+ "rstrip": false,
33
+ "single_word": false,
34
+ "special": false
35
+ },
36
+ "49410": {
37
+ "content": "P_ctxt_2",
38
+ "lstrip": false,
39
+ "normalized": true,
40
+ "rstrip": false,
41
+ "single_word": false,
42
+ "special": false
43
+ },
44
+ "49411": {
45
+ "content": "P_ctxt_3",
46
+ "lstrip": false,
47
+ "normalized": true,
48
+ "rstrip": false,
49
+ "single_word": false,
50
+ "special": false
51
+ },
52
+ "49412": {
53
+ "content": "P_ctxt_4",
54
+ "lstrip": false,
55
+ "normalized": true,
56
+ "rstrip": false,
57
+ "single_word": false,
58
+ "special": false
59
+ },
60
+ "49413": {
61
+ "content": "P_ctxt_5",
62
+ "lstrip": false,
63
+ "normalized": true,
64
+ "rstrip": false,
65
+ "single_word": false,
66
+ "special": false
67
+ },
68
+ "49414": {
69
+ "content": "P_ctxt_6",
70
+ "lstrip": false,
71
+ "normalized": true,
72
+ "rstrip": false,
73
+ "single_word": false,
74
+ "special": false
75
+ },
76
+ "49415": {
77
+ "content": "P_ctxt_7",
78
+ "lstrip": false,
79
+ "normalized": true,
80
+ "rstrip": false,
81
+ "single_word": false,
82
+ "special": false
83
+ },
84
+ "49416": {
85
+ "content": "P_ctxt_8",
86
+ "lstrip": false,
87
+ "normalized": true,
88
+ "rstrip": false,
89
+ "single_word": false,
90
+ "special": false
91
+ },
92
+ "49417": {
93
+ "content": "P_ctxt_9",
94
+ "lstrip": false,
95
+ "normalized": true,
96
+ "rstrip": false,
97
+ "single_word": false,
98
+ "special": false
99
+ },
100
+ "49418": {
101
+ "content": "P_shape_0",
102
+ "lstrip": false,
103
+ "normalized": true,
104
+ "rstrip": false,
105
+ "single_word": false,
106
+ "special": false
107
+ },
108
+ "49419": {
109
+ "content": "P_shape_1",
110
+ "lstrip": false,
111
+ "normalized": true,
112
+ "rstrip": false,
113
+ "single_word": false,
114
+ "special": false
115
+ },
116
+ "49420": {
117
+ "content": "P_shape_2",
118
+ "lstrip": false,
119
+ "normalized": true,
120
+ "rstrip": false,
121
+ "single_word": false,
122
+ "special": false
123
+ },
124
+ "49421": {
125
+ "content": "P_shape_3",
126
+ "lstrip": false,
127
+ "normalized": true,
128
+ "rstrip": false,
129
+ "single_word": false,
130
+ "special": false
131
+ },
132
+ "49422": {
133
+ "content": "P_shape_4",
134
+ "lstrip": false,
135
+ "normalized": true,
136
+ "rstrip": false,
137
+ "single_word": false,
138
+ "special": false
139
+ },
140
+ "49423": {
141
+ "content": "P_shape_5",
142
+ "lstrip": false,
143
+ "normalized": true,
144
+ "rstrip": false,
145
+ "single_word": false,
146
+ "special": false
147
+ },
148
+ "49424": {
149
+ "content": "P_shape_6",
150
+ "lstrip": false,
151
+ "normalized": true,
152
+ "rstrip": false,
153
+ "single_word": false,
154
+ "special": false
155
+ },
156
+ "49425": {
157
+ "content": "P_shape_7",
158
+ "lstrip": false,
159
+ "normalized": true,
160
+ "rstrip": false,
161
+ "single_word": false,
162
+ "special": false
163
+ },
164
+ "49426": {
165
+ "content": "P_shape_8",
166
+ "lstrip": false,
167
+ "normalized": true,
168
+ "rstrip": false,
169
+ "single_word": false,
170
+ "special": false
171
+ },
172
+ "49427": {
173
+ "content": "P_shape_9",
174
+ "lstrip": false,
175
+ "normalized": true,
176
+ "rstrip": false,
177
+ "single_word": false,
178
+ "special": false
179
+ },
180
+ "49428": {
181
+ "content": "P_obj_0",
182
+ "lstrip": false,
183
+ "normalized": true,
184
+ "rstrip": false,
185
+ "single_word": false,
186
+ "special": false
187
+ },
188
+ "49429": {
189
+ "content": "P_obj_1",
190
+ "lstrip": false,
191
+ "normalized": true,
192
+ "rstrip": false,
193
+ "single_word": false,
194
+ "special": false
195
+ },
196
+ "49430": {
197
+ "content": "P_obj_2",
198
+ "lstrip": false,
199
+ "normalized": true,
200
+ "rstrip": false,
201
+ "single_word": false,
202
+ "special": false
203
+ },
204
+ "49431": {
205
+ "content": "P_obj_3",
206
+ "lstrip": false,
207
+ "normalized": true,
208
+ "rstrip": false,
209
+ "single_word": false,
210
+ "special": false
211
+ },
212
+ "49432": {
213
+ "content": "P_obj_4",
214
+ "lstrip": false,
215
+ "normalized": true,
216
+ "rstrip": false,
217
+ "single_word": false,
218
+ "special": false
219
+ },
220
+ "49433": {
221
+ "content": "P_obj_5",
222
+ "lstrip": false,
223
+ "normalized": true,
224
+ "rstrip": false,
225
+ "single_word": false,
226
+ "special": false
227
+ },
228
+ "49434": {
229
+ "content": "P_obj_6",
230
+ "lstrip": false,
231
+ "normalized": true,
232
+ "rstrip": false,
233
+ "single_word": false,
234
+ "special": false
235
+ },
236
+ "49435": {
237
+ "content": "P_obj_7",
238
+ "lstrip": false,
239
+ "normalized": true,
240
+ "rstrip": false,
241
+ "single_word": false,
242
+ "special": false
243
+ },
244
+ "49436": {
245
+ "content": "P_obj_8",
246
+ "lstrip": false,
247
+ "normalized": true,
248
+ "rstrip": false,
249
+ "single_word": false,
250
+ "special": false
251
+ },
252
+ "49437": {
253
+ "content": "P_obj_9",
254
+ "lstrip": false,
255
+ "normalized": true,
256
+ "rstrip": false,
257
+ "single_word": false,
258
+ "special": false
259
+ }
260
+ },
261
+ "bos_token": "<|startoftext|>",
262
+ "clean_up_tokenization_spaces": true,
263
+ "do_lower_case": true,
264
+ "eos_token": "<|endoftext|>",
265
+ "errors": "replace",
266
+ "model_max_length": 77,
267
+ "pad_token": "<|endoftext|>",
268
+ "tokenizer_class": "CLIPTokenizer",
269
+ "unk_token": "<|endoftext|>"
270
+ }
tokenizer/vocab.json ADDED
The diff for this file is too large to render. See raw diff