jozee commited on
Commit
1f2d1fd
·
verified ·
1 Parent(s): 94cac87

Create controlnet_union.py

Browse files
Files changed (1) hide show
  1. controlnet_union.py +1085 -0
controlnet_union.py ADDED
@@ -0,0 +1,1085 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from collections import OrderedDict
15
+ from dataclasses import dataclass
16
+ from typing import Any, Dict, List, Optional, Tuple, Union
17
+
18
+ import torch
19
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
20
+ from diffusers.loaders import FromOriginalModelMixin
21
+ from diffusers.models.attention_processor import (
22
+ ADDED_KV_ATTENTION_PROCESSORS,
23
+ CROSS_ATTENTION_PROCESSORS,
24
+ AttentionProcessor,
25
+ AttnAddedKVProcessor,
26
+ AttnProcessor,
27
+ )
28
+ from diffusers.models.embeddings import (
29
+ TextImageProjection,
30
+ TextImageTimeEmbedding,
31
+ TextTimeEmbedding,
32
+ TimestepEmbedding,
33
+ Timesteps,
34
+ )
35
+ from diffusers.models.modeling_utils import ModelMixin
36
+ from diffusers.models.unets.unet_2d_blocks import (
37
+ CrossAttnDownBlock2D,
38
+ DownBlock2D,
39
+ UNetMidBlock2DCrossAttn,
40
+ get_down_block,
41
+ )
42
+ from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
43
+ from diffusers.utils import BaseOutput, logging
44
+ from torch import nn
45
+ from torch.nn import functional as F
46
+
47
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
48
+
49
+
50
+ # Transformer Block
51
+ # Used to exchange info between different conditions and input image
52
+ # With reference to https://github.com/TencentARC/T2I-Adapter/blob/SD/ldm/modules/encoders/adapter.py#L147
53
+ class QuickGELU(nn.Module):
54
+ def forward(self, x: torch.Tensor):
55
+ return x * torch.sigmoid(1.702 * x)
56
+
57
+
58
+ class LayerNorm(nn.LayerNorm):
59
+ """Subclass torch's LayerNorm to handle fp16."""
60
+
61
+ def forward(self, x: torch.Tensor):
62
+ orig_type = x.dtype
63
+ ret = super().forward(x)
64
+ return ret.type(orig_type)
65
+
66
+
67
+ class ResidualAttentionBlock(nn.Module):
68
+ def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None):
69
+ super().__init__()
70
+
71
+ self.attn = nn.MultiheadAttention(d_model, n_head)
72
+ self.ln_1 = LayerNorm(d_model)
73
+ self.mlp = nn.Sequential(
74
+ OrderedDict(
75
+ [
76
+ ("c_fc", nn.Linear(d_model, d_model * 4)),
77
+ ("gelu", QuickGELU()),
78
+ ("c_proj", nn.Linear(d_model * 4, d_model)),
79
+ ]
80
+ )
81
+ )
82
+ self.ln_2 = LayerNorm(d_model)
83
+ self.attn_mask = attn_mask
84
+
85
+ def attention(self, x: torch.Tensor):
86
+ self.attn_mask = (
87
+ self.attn_mask.to(dtype=x.dtype, device=x.device)
88
+ if self.attn_mask is not None
89
+ else None
90
+ )
91
+ return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
92
+
93
+ def forward(self, x: torch.Tensor):
94
+ x = x + self.attention(self.ln_1(x))
95
+ x = x + self.mlp(self.ln_2(x))
96
+ return x
97
+
98
+
99
+ # -----------------------------------------------------------------------------------------------------
100
+
101
+
102
+ @dataclass
103
+ class ControlNetOutput(BaseOutput):
104
+ """
105
+ The output of [`ControlNetModel`].
106
+
107
+ Args:
108
+ down_block_res_samples (`tuple[torch.Tensor]`):
109
+ A tuple of downsample activations at different resolutions for each downsampling block. Each tensor should
110
+ be of shape `(batch_size, channel * resolution, height //resolution, width // resolution)`. Output can be
111
+ used to condition the original UNet's downsampling activations.
112
+ mid_down_block_re_sample (`torch.Tensor`):
113
+ The activation of the midde block (the lowest sample resolution). Each tensor should be of shape
114
+ `(batch_size, channel * lowest_resolution, height // lowest_resolution, width // lowest_resolution)`.
115
+ Output can be used to condition the original UNet's middle block activation.
116
+ """
117
+
118
+ down_block_res_samples: Tuple[torch.Tensor]
119
+ mid_block_res_sample: torch.Tensor
120
+
121
+
122
+ class ControlNetConditioningEmbedding(nn.Module):
123
+ """
124
+ Quoting from https://arxiv.org/abs/2302.05543: "Stable Diffusion uses a pre-processing method similar to VQ-GAN
125
+ [11] to convert the entire dataset of 512 × 512 images into smaller 64 × 64 “latent images” for stabilized
126
+ training. This requires ControlNets to convert image-based conditions to 64 × 64 feature space to match the
127
+ convolution size. We use a tiny network E(·) of four convolution layers with 4 × 4 kernels and 2 × 2 strides
128
+ (activated by ReLU, channels are 16, 32, 64, 128, initialized with Gaussian weights, trained jointly with the full
129
+ model) to encode image-space conditions ... into feature maps ..."
130
+ """
131
+
132
+ # original setting is (16, 32, 96, 256)
133
+ def __init__(
134
+ self,
135
+ conditioning_embedding_channels: int,
136
+ conditioning_channels: int = 3,
137
+ block_out_channels: Tuple[int] = (48, 96, 192, 384),
138
+ ):
139
+ super().__init__()
140
+
141
+ self.conv_in = nn.Conv2d(
142
+ conditioning_channels, block_out_channels[0], kernel_size=3, padding=1
143
+ )
144
+
145
+ self.blocks = nn.ModuleList([])
146
+
147
+ for i in range(len(block_out_channels) - 1):
148
+ channel_in = block_out_channels[i]
149
+ channel_out = block_out_channels[i + 1]
150
+ self.blocks.append(
151
+ nn.Conv2d(channel_in, channel_in, kernel_size=3, padding=1)
152
+ )
153
+ self.blocks.append(
154
+ nn.Conv2d(channel_in, channel_out, kernel_size=3, padding=1, stride=2)
155
+ )
156
+
157
+ self.conv_out = zero_module(
158
+ nn.Conv2d(
159
+ block_out_channels[-1],
160
+ conditioning_embedding_channels,
161
+ kernel_size=3,
162
+ padding=1,
163
+ )
164
+ )
165
+
166
+ def forward(self, conditioning):
167
+ embedding = self.conv_in(conditioning)
168
+ embedding = F.silu(embedding)
169
+
170
+ for block in self.blocks:
171
+ embedding = block(embedding)
172
+ embedding = F.silu(embedding)
173
+
174
+ embedding = self.conv_out(embedding)
175
+
176
+ return embedding
177
+
178
+
179
+ class ControlNetModel_Union(ModelMixin, ConfigMixin, FromOriginalModelMixin):
180
+ """
181
+ A ControlNet model.
182
+
183
+ Args:
184
+ in_channels (`int`, defaults to 4):
185
+ The number of channels in the input sample.
186
+ flip_sin_to_cos (`bool`, defaults to `True`):
187
+ Whether to flip the sin to cos in the time embedding.
188
+ freq_shift (`int`, defaults to 0):
189
+ The frequency shift to apply to the time embedding.
190
+ down_block_types (`tuple[str]`, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
191
+ The tuple of downsample blocks to use.
192
+ only_cross_attention (`Union[bool, Tuple[bool]]`, defaults to `False`):
193
+ block_out_channels (`tuple[int]`, defaults to `(320, 640, 1280, 1280)`):
194
+ The tuple of output channels for each block.
195
+ layers_per_block (`int`, defaults to 2):
196
+ The number of layers per block.
197
+ downsample_padding (`int`, defaults to 1):
198
+ The padding to use for the downsampling convolution.
199
+ mid_block_scale_factor (`float`, defaults to 1):
200
+ The scale factor to use for the mid block.
201
+ act_fn (`str`, defaults to "silu"):
202
+ The activation function to use.
203
+ norm_num_groups (`int`, *optional*, defaults to 32):
204
+ The number of groups to use for the normalization. If None, normalization and activation layers is skipped
205
+ in post-processing.
206
+ norm_eps (`float`, defaults to 1e-5):
207
+ The epsilon to use for the normalization.
208
+ cross_attention_dim (`int`, defaults to 1280):
209
+ The dimension of the cross attention features.
210
+ transformer_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 1):
211
+ The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
212
+ [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
213
+ [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
214
+ encoder_hid_dim (`int`, *optional*, defaults to None):
215
+ If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim`
216
+ dimension to `cross_attention_dim`.
217
+ encoder_hid_dim_type (`str`, *optional*, defaults to `None`):
218
+ If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text
219
+ embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`.
220
+ attention_head_dim (`Union[int, Tuple[int]]`, defaults to 8):
221
+ The dimension of the attention heads.
222
+ use_linear_projection (`bool`, defaults to `False`):
223
+ class_embed_type (`str`, *optional*, defaults to `None`):
224
+ The type of class embedding to use which is ultimately summed with the time embeddings. Choose from None,
225
+ `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`.
226
+ addition_embed_type (`str`, *optional*, defaults to `None`):
227
+ Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or
228
+ "text". "text" will use the `TextTimeEmbedding` layer.
229
+ num_class_embeds (`int`, *optional*, defaults to 0):
230
+ Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
231
+ class conditioning with `class_embed_type` equal to `None`.
232
+ upcast_attention (`bool`, defaults to `False`):
233
+ resnet_time_scale_shift (`str`, defaults to `"default"`):
234
+ Time scale shift config for ResNet blocks (see `ResnetBlock2D`). Choose from `default` or `scale_shift`.
235
+ projection_class_embeddings_input_dim (`int`, *optional*, defaults to `None`):
236
+ The dimension of the `class_labels` input when `class_embed_type="projection"`. Required when
237
+ `class_embed_type="projection"`.
238
+ controlnet_conditioning_channel_order (`str`, defaults to `"rgb"`):
239
+ The channel order of conditional image. Will convert to `rgb` if it's `bgr`.
240
+ conditioning_embedding_out_channels (`tuple[int]`, *optional*, defaults to `(16, 32, 96, 256)`):
241
+ The tuple of output channel for each block in the `conditioning_embedding` layer.
242
+ global_pool_conditions (`bool`, defaults to `False`):
243
+ """
244
+
245
+ _supports_gradient_checkpointing = True
246
+
247
+ @register_to_config
248
+ def __init__(
249
+ self,
250
+ in_channels: int = 4,
251
+ conditioning_channels: int = 3,
252
+ flip_sin_to_cos: bool = True,
253
+ freq_shift: int = 0,
254
+ down_block_types: Tuple[str] = (
255
+ "CrossAttnDownBlock2D",
256
+ "CrossAttnDownBlock2D",
257
+ "CrossAttnDownBlock2D",
258
+ "DownBlock2D",
259
+ ),
260
+ only_cross_attention: Union[bool, Tuple[bool]] = False,
261
+ block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
262
+ layers_per_block: int = 2,
263
+ downsample_padding: int = 1,
264
+ mid_block_scale_factor: float = 1,
265
+ act_fn: str = "silu",
266
+ norm_num_groups: Optional[int] = 32,
267
+ norm_eps: float = 1e-5,
268
+ cross_attention_dim: int = 1280,
269
+ transformer_layers_per_block: Union[int, Tuple[int]] = 1,
270
+ encoder_hid_dim: Optional[int] = None,
271
+ encoder_hid_dim_type: Optional[str] = None,
272
+ attention_head_dim: Union[int, Tuple[int]] = 8,
273
+ num_attention_heads: Optional[Union[int, Tuple[int]]] = None,
274
+ use_linear_projection: bool = False,
275
+ class_embed_type: Optional[str] = None,
276
+ addition_embed_type: Optional[str] = None,
277
+ addition_time_embed_dim: Optional[int] = None,
278
+ num_class_embeds: Optional[int] = None,
279
+ upcast_attention: bool = False,
280
+ resnet_time_scale_shift: str = "default",
281
+ projection_class_embeddings_input_dim: Optional[int] = None,
282
+ controlnet_conditioning_channel_order: str = "rgb",
283
+ conditioning_embedding_out_channels: Optional[Tuple[int]] = (16, 32, 96, 256),
284
+ global_pool_conditions: bool = False,
285
+ addition_embed_type_num_heads=64,
286
+ num_control_type=6,
287
+ ):
288
+ super().__init__()
289
+
290
+ # If `num_attention_heads` is not defined (which is the case for most models)
291
+ # it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
292
+ # The reason for this behavior is to correct for incorrectly named variables that were introduced
293
+ # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
294
+ # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
295
+ # which is why we correct for the naming here.
296
+ num_attention_heads = num_attention_heads or attention_head_dim
297
+
298
+ # Check inputs
299
+ if len(block_out_channels) != len(down_block_types):
300
+ raise ValueError(
301
+ f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
302
+ )
303
+
304
+ if not isinstance(only_cross_attention, bool) and len(
305
+ only_cross_attention
306
+ ) != len(down_block_types):
307
+ raise ValueError(
308
+ f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
309
+ )
310
+
311
+ if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(
312
+ down_block_types
313
+ ):
314
+ raise ValueError(
315
+ f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
316
+ )
317
+
318
+ if isinstance(transformer_layers_per_block, int):
319
+ transformer_layers_per_block = [transformer_layers_per_block] * len(
320
+ down_block_types
321
+ )
322
+
323
+ # input
324
+ conv_in_kernel = 3
325
+ conv_in_padding = (conv_in_kernel - 1) // 2
326
+ self.conv_in = nn.Conv2d(
327
+ in_channels,
328
+ block_out_channels[0],
329
+ kernel_size=conv_in_kernel,
330
+ padding=conv_in_padding,
331
+ )
332
+
333
+ # time
334
+ time_embed_dim = block_out_channels[0] * 4
335
+ self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
336
+ timestep_input_dim = block_out_channels[0]
337
+ self.time_embedding = TimestepEmbedding(
338
+ timestep_input_dim,
339
+ time_embed_dim,
340
+ act_fn=act_fn,
341
+ )
342
+
343
+ if encoder_hid_dim_type is None and encoder_hid_dim is not None:
344
+ encoder_hid_dim_type = "text_proj"
345
+ self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type)
346
+ logger.info(
347
+ "encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined."
348
+ )
349
+
350
+ if encoder_hid_dim is None and encoder_hid_dim_type is not None:
351
+ raise ValueError(
352
+ f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}."
353
+ )
354
+
355
+ if encoder_hid_dim_type == "text_proj":
356
+ self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim)
357
+ elif encoder_hid_dim_type == "text_image_proj":
358
+ # image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much
359
+ # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
360
+ # case when `addition_embed_type == "text_image_proj"` (Kadinsky 2.1)`
361
+ self.encoder_hid_proj = TextImageProjection(
362
+ text_embed_dim=encoder_hid_dim,
363
+ image_embed_dim=cross_attention_dim,
364
+ cross_attention_dim=cross_attention_dim,
365
+ )
366
+
367
+ elif encoder_hid_dim_type is not None:
368
+ raise ValueError(
369
+ f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'."
370
+ )
371
+ else:
372
+ self.encoder_hid_proj = None
373
+
374
+ # class embedding
375
+ if class_embed_type is None and num_class_embeds is not None:
376
+ self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
377
+ elif class_embed_type == "timestep":
378
+ self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
379
+ elif class_embed_type == "identity":
380
+ self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
381
+ elif class_embed_type == "projection":
382
+ if projection_class_embeddings_input_dim is None:
383
+ raise ValueError(
384
+ "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set"
385
+ )
386
+ # The projection `class_embed_type` is the same as the timestep `class_embed_type` except
387
+ # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings
388
+ # 2. it projects from an arbitrary input dimension.
389
+ #
390
+ # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.
391
+ # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.
392
+ # As a result, `TimestepEmbedding` can be passed arbitrary vectors.
393
+ self.class_embedding = TimestepEmbedding(
394
+ projection_class_embeddings_input_dim, time_embed_dim
395
+ )
396
+ else:
397
+ self.class_embedding = None
398
+
399
+ if addition_embed_type == "text":
400
+ if encoder_hid_dim is not None:
401
+ text_time_embedding_from_dim = encoder_hid_dim
402
+ else:
403
+ text_time_embedding_from_dim = cross_attention_dim
404
+
405
+ self.add_embedding = TextTimeEmbedding(
406
+ text_time_embedding_from_dim,
407
+ time_embed_dim,
408
+ num_heads=addition_embed_type_num_heads,
409
+ )
410
+ elif addition_embed_type == "text_image":
411
+ # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much
412
+ # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
413
+ # case when `addition_embed_type == "text_image"` (Kadinsky 2.1)`
414
+ self.add_embedding = TextImageTimeEmbedding(
415
+ text_embed_dim=cross_attention_dim,
416
+ image_embed_dim=cross_attention_dim,
417
+ time_embed_dim=time_embed_dim,
418
+ )
419
+ elif addition_embed_type == "text_time":
420
+ self.add_time_proj = Timesteps(
421
+ addition_time_embed_dim, flip_sin_to_cos, freq_shift
422
+ )
423
+ self.add_embedding = TimestepEmbedding(
424
+ projection_class_embeddings_input_dim, time_embed_dim
425
+ )
426
+
427
+ elif addition_embed_type is not None:
428
+ raise ValueError(
429
+ f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'."
430
+ )
431
+
432
+ # control net conditioning embedding
433
+ self.controlnet_cond_embedding = ControlNetConditioningEmbedding(
434
+ conditioning_embedding_channels=block_out_channels[0],
435
+ block_out_channels=conditioning_embedding_out_channels,
436
+ conditioning_channels=conditioning_channels,
437
+ )
438
+
439
+ # Copyright by Qi Xin(2024/07/06)
440
+ # Condition Transformer(fuse single/multi conditions with input image)
441
+ # The Condition Transformer augment the feature representation of conditions
442
+ # The overall design is somewhat like resnet. The output of Condition Transformer is used to predict a condition bias adding to the original condition feature.
443
+ # num_control_type = 6
444
+ num_trans_channel = 320
445
+ num_trans_head = 8
446
+ num_trans_layer = 1
447
+ num_proj_channel = 320
448
+ task_scale_factor = num_trans_channel**0.5
449
+
450
+ self.task_embedding = nn.Parameter(
451
+ task_scale_factor * torch.randn(num_control_type, num_trans_channel)
452
+ )
453
+ self.transformer_layes = nn.Sequential(
454
+ *[
455
+ ResidualAttentionBlock(num_trans_channel, num_trans_head)
456
+ for _ in range(num_trans_layer)
457
+ ]
458
+ )
459
+ self.spatial_ch_projs = zero_module(
460
+ nn.Linear(num_trans_channel, num_proj_channel)
461
+ )
462
+ # -----------------------------------------------------------------------------------------------------
463
+
464
+ # Copyright by Qi Xin(2024/07/06)
465
+ # Control Encoder to distinguish different control conditions
466
+ # A simple but effective module, consists of an embedding layer and a linear layer, to inject the control info to time embedding.
467
+ self.control_type_proj = Timesteps(
468
+ addition_time_embed_dim, flip_sin_to_cos, freq_shift
469
+ )
470
+ self.control_add_embedding = TimestepEmbedding(
471
+ addition_time_embed_dim * num_control_type, time_embed_dim
472
+ )
473
+ # -----------------------------------------------------------------------------------------------------
474
+
475
+ self.down_blocks = nn.ModuleList([])
476
+ self.controlnet_down_blocks = nn.ModuleList([])
477
+
478
+ if isinstance(only_cross_attention, bool):
479
+ only_cross_attention = [only_cross_attention] * len(down_block_types)
480
+
481
+ if isinstance(attention_head_dim, int):
482
+ attention_head_dim = (attention_head_dim,) * len(down_block_types)
483
+
484
+ if isinstance(num_attention_heads, int):
485
+ num_attention_heads = (num_attention_heads,) * len(down_block_types)
486
+
487
+ # down
488
+ output_channel = block_out_channels[0]
489
+
490
+ controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
491
+ controlnet_block = zero_module(controlnet_block)
492
+ self.controlnet_down_blocks.append(controlnet_block)
493
+
494
+ for i, down_block_type in enumerate(down_block_types):
495
+ input_channel = output_channel
496
+ output_channel = block_out_channels[i]
497
+ is_final_block = i == len(block_out_channels) - 1
498
+
499
+ down_block = get_down_block(
500
+ down_block_type,
501
+ num_layers=layers_per_block,
502
+ transformer_layers_per_block=transformer_layers_per_block[i],
503
+ in_channels=input_channel,
504
+ out_channels=output_channel,
505
+ temb_channels=time_embed_dim,
506
+ add_downsample=not is_final_block,
507
+ resnet_eps=norm_eps,
508
+ resnet_act_fn=act_fn,
509
+ resnet_groups=norm_num_groups,
510
+ cross_attention_dim=cross_attention_dim,
511
+ num_attention_heads=num_attention_heads[i],
512
+ attention_head_dim=attention_head_dim[i]
513
+ if attention_head_dim[i] is not None
514
+ else output_channel,
515
+ downsample_padding=downsample_padding,
516
+ use_linear_projection=use_linear_projection,
517
+ only_cross_attention=only_cross_attention[i],
518
+ upcast_attention=upcast_attention,
519
+ resnet_time_scale_shift=resnet_time_scale_shift,
520
+ )
521
+ self.down_blocks.append(down_block)
522
+
523
+ for _ in range(layers_per_block):
524
+ controlnet_block = nn.Conv2d(
525
+ output_channel, output_channel, kernel_size=1
526
+ )
527
+ controlnet_block = zero_module(controlnet_block)
528
+ self.controlnet_down_blocks.append(controlnet_block)
529
+
530
+ if not is_final_block:
531
+ controlnet_block = nn.Conv2d(
532
+ output_channel, output_channel, kernel_size=1
533
+ )
534
+ controlnet_block = zero_module(controlnet_block)
535
+ self.controlnet_down_blocks.append(controlnet_block)
536
+
537
+ # mid
538
+ mid_block_channel = block_out_channels[-1]
539
+
540
+ controlnet_block = nn.Conv2d(
541
+ mid_block_channel, mid_block_channel, kernel_size=1
542
+ )
543
+ controlnet_block = zero_module(controlnet_block)
544
+ self.controlnet_mid_block = controlnet_block
545
+
546
+ self.mid_block = UNetMidBlock2DCrossAttn(
547
+ transformer_layers_per_block=transformer_layers_per_block[-1],
548
+ in_channels=mid_block_channel,
549
+ temb_channels=time_embed_dim,
550
+ resnet_eps=norm_eps,
551
+ resnet_act_fn=act_fn,
552
+ output_scale_factor=mid_block_scale_factor,
553
+ resnet_time_scale_shift=resnet_time_scale_shift,
554
+ cross_attention_dim=cross_attention_dim,
555
+ num_attention_heads=num_attention_heads[-1],
556
+ resnet_groups=norm_num_groups,
557
+ use_linear_projection=use_linear_projection,
558
+ upcast_attention=upcast_attention,
559
+ )
560
+
561
+ @classmethod
562
+ def from_unet(
563
+ cls,
564
+ unet: UNet2DConditionModel,
565
+ controlnet_conditioning_channel_order: str = "rgb",
566
+ conditioning_embedding_out_channels: Optional[Tuple[int]] = (16, 32, 96, 256),
567
+ load_weights_from_unet: bool = True,
568
+ ):
569
+ r"""
570
+ Instantiate a [`ControlNetModel`] from [`UNet2DConditionModel`].
571
+
572
+ Parameters:
573
+ unet (`UNet2DConditionModel`):
574
+ The UNet model weights to copy to the [`ControlNetModel`]. All configuration options are also copied
575
+ where applicable.
576
+ """
577
+ transformer_layers_per_block = (
578
+ unet.config.transformer_layers_per_block
579
+ if "transformer_layers_per_block" in unet.config
580
+ else 1
581
+ )
582
+ encoder_hid_dim = (
583
+ unet.config.encoder_hid_dim if "encoder_hid_dim" in unet.config else None
584
+ )
585
+ encoder_hid_dim_type = (
586
+ unet.config.encoder_hid_dim_type
587
+ if "encoder_hid_dim_type" in unet.config
588
+ else None
589
+ )
590
+ addition_embed_type = (
591
+ unet.config.addition_embed_type
592
+ if "addition_embed_type" in unet.config
593
+ else None
594
+ )
595
+ addition_time_embed_dim = (
596
+ unet.config.addition_time_embed_dim
597
+ if "addition_time_embed_dim" in unet.config
598
+ else None
599
+ )
600
+
601
+ controlnet = cls(
602
+ encoder_hid_dim=encoder_hid_dim,
603
+ encoder_hid_dim_type=encoder_hid_dim_type,
604
+ addition_embed_type=addition_embed_type,
605
+ addition_time_embed_dim=addition_time_embed_dim,
606
+ transformer_layers_per_block=transformer_layers_per_block,
607
+ # transformer_layers_per_block=[1, 2, 5],
608
+ in_channels=unet.config.in_channels,
609
+ flip_sin_to_cos=unet.config.flip_sin_to_cos,
610
+ freq_shift=unet.config.freq_shift,
611
+ down_block_types=unet.config.down_block_types,
612
+ only_cross_attention=unet.config.only_cross_attention,
613
+ block_out_channels=unet.config.block_out_channels,
614
+ layers_per_block=unet.config.layers_per_block,
615
+ downsample_padding=unet.config.downsample_padding,
616
+ mid_block_scale_factor=unet.config.mid_block_scale_factor,
617
+ act_fn=unet.config.act_fn,
618
+ norm_num_groups=unet.config.norm_num_groups,
619
+ norm_eps=unet.config.norm_eps,
620
+ cross_attention_dim=unet.config.cross_attention_dim,
621
+ attention_head_dim=unet.config.attention_head_dim,
622
+ num_attention_heads=unet.config.num_attention_heads,
623
+ use_linear_projection=unet.config.use_linear_projection,
624
+ class_embed_type=unet.config.class_embed_type,
625
+ num_class_embeds=unet.config.num_class_embeds,
626
+ upcast_attention=unet.config.upcast_attention,
627
+ resnet_time_scale_shift=unet.config.resnet_time_scale_shift,
628
+ projection_class_embeddings_input_dim=unet.config.projection_class_embeddings_input_dim,
629
+ controlnet_conditioning_channel_order=controlnet_conditioning_channel_order,
630
+ conditioning_embedding_out_channels=conditioning_embedding_out_channels,
631
+ )
632
+
633
+ if load_weights_from_unet:
634
+ controlnet.conv_in.load_state_dict(unet.conv_in.state_dict())
635
+ controlnet.time_proj.load_state_dict(unet.time_proj.state_dict())
636
+ controlnet.time_embedding.load_state_dict(unet.time_embedding.state_dict())
637
+
638
+ if controlnet.class_embedding:
639
+ controlnet.class_embedding.load_state_dict(
640
+ unet.class_embedding.state_dict()
641
+ )
642
+
643
+ controlnet.down_blocks.load_state_dict(
644
+ unet.down_blocks.state_dict(), strict=False
645
+ )
646
+ controlnet.mid_block.load_state_dict(
647
+ unet.mid_block.state_dict(), strict=False
648
+ )
649
+
650
+ return controlnet
651
+
652
+ @property
653
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors
654
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
655
+ r"""
656
+ Returns:
657
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
658
+ indexed by its weight name.
659
+ """
660
+ # set recursively
661
+ processors = {}
662
+
663
+ def fn_recursive_add_processors(
664
+ name: str,
665
+ module: torch.nn.Module,
666
+ processors: Dict[str, AttentionProcessor],
667
+ ):
668
+ if hasattr(module, "get_processor"):
669
+ processors[f"{name}.processor"] = module.get_processor(
670
+ return_deprecated_lora=True
671
+ )
672
+
673
+ for sub_name, child in module.named_children():
674
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
675
+
676
+ return processors
677
+
678
+ for name, module in self.named_children():
679
+ fn_recursive_add_processors(name, module, processors)
680
+
681
+ return processors
682
+
683
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
684
+ def set_attn_processor(
685
+ self,
686
+ processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]],
687
+ _remove_lora=False,
688
+ ):
689
+ r"""
690
+ Sets the attention processor to use to compute attention.
691
+
692
+ Parameters:
693
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
694
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
695
+ for **all** `Attention` layers.
696
+
697
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
698
+ processor. This is strongly recommended when setting trainable attention processors.
699
+
700
+ """
701
+ count = len(self.attn_processors.keys())
702
+
703
+ if isinstance(processor, dict) and len(processor) != count:
704
+ raise ValueError(
705
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
706
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
707
+ )
708
+
709
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
710
+ if hasattr(module, "set_processor"):
711
+ if not isinstance(processor, dict):
712
+ module.set_processor(processor, _remove_lora=_remove_lora)
713
+ else:
714
+ module.set_processor(
715
+ processor.pop(f"{name}.processor"), _remove_lora=_remove_lora
716
+ )
717
+
718
+ for sub_name, child in module.named_children():
719
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
720
+
721
+ for name, module in self.named_children():
722
+ fn_recursive_attn_processor(name, module, processor)
723
+
724
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
725
+ def set_default_attn_processor(self):
726
+ """
727
+ Disables custom attention processors and sets the default attention implementation.
728
+ """
729
+ if all(
730
+ proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS
731
+ for proc in self.attn_processors.values()
732
+ ):
733
+ processor = AttnAddedKVProcessor()
734
+ elif all(
735
+ proc.__class__ in CROSS_ATTENTION_PROCESSORS
736
+ for proc in self.attn_processors.values()
737
+ ):
738
+ processor = AttnProcessor()
739
+ else:
740
+ raise ValueError(
741
+ f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
742
+ )
743
+
744
+ self.set_attn_processor(processor, _remove_lora=True)
745
+
746
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attention_slice
747
+ def set_attention_slice(self, slice_size):
748
+ r"""
749
+ Enable sliced attention computation.
750
+
751
+ When this option is enabled, the attention module splits the input tensor in slices to compute attention in
752
+ several steps. This is useful for saving some memory in exchange for a small decrease in speed.
753
+
754
+ Args:
755
+ slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
756
+ When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If
757
+ `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is
758
+ provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
759
+ must be a multiple of `slice_size`.
760
+ """
761
+ sliceable_head_dims = []
762
+
763
+ def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
764
+ if hasattr(module, "set_attention_slice"):
765
+ sliceable_head_dims.append(module.sliceable_head_dim)
766
+
767
+ for child in module.children():
768
+ fn_recursive_retrieve_sliceable_dims(child)
769
+
770
+ # retrieve number of attention layers
771
+ for module in self.children():
772
+ fn_recursive_retrieve_sliceable_dims(module)
773
+
774
+ num_sliceable_layers = len(sliceable_head_dims)
775
+
776
+ if slice_size == "auto":
777
+ # half the attention head size is usually a good trade-off between
778
+ # speed and memory
779
+ slice_size = [dim // 2 for dim in sliceable_head_dims]
780
+ elif slice_size == "max":
781
+ # make smallest slice possible
782
+ slice_size = num_sliceable_layers * [1]
783
+
784
+ slice_size = (
785
+ num_sliceable_layers * [slice_size]
786
+ if not isinstance(slice_size, list)
787
+ else slice_size
788
+ )
789
+
790
+ if len(slice_size) != len(sliceable_head_dims):
791
+ raise ValueError(
792
+ f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
793
+ f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
794
+ )
795
+
796
+ for i in range(len(slice_size)):
797
+ size = slice_size[i]
798
+ dim = sliceable_head_dims[i]
799
+ if size is not None and size > dim:
800
+ raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
801
+
802
+ # Recursively walk through all the children.
803
+ # Any children which exposes the set_attention_slice method
804
+ # gets the message
805
+ def fn_recursive_set_attention_slice(
806
+ module: torch.nn.Module, slice_size: List[int]
807
+ ):
808
+ if hasattr(module, "set_attention_slice"):
809
+ module.set_attention_slice(slice_size.pop())
810
+
811
+ for child in module.children():
812
+ fn_recursive_set_attention_slice(child, slice_size)
813
+
814
+ reversed_slice_size = list(reversed(slice_size))
815
+ for module in self.children():
816
+ fn_recursive_set_attention_slice(module, reversed_slice_size)
817
+
818
+ def _set_gradient_checkpointing(self, module, value=False):
819
+ if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D)):
820
+ module.gradient_checkpointing = value
821
+
822
+ def forward(
823
+ self,
824
+ sample: torch.FloatTensor,
825
+ timestep: Union[torch.Tensor, float, int],
826
+ encoder_hidden_states: torch.Tensor,
827
+ controlnet_cond_list: torch.FloatTensor,
828
+ conditioning_scale: float = 1.0,
829
+ class_labels: Optional[torch.Tensor] = None,
830
+ timestep_cond: Optional[torch.Tensor] = None,
831
+ attention_mask: Optional[torch.Tensor] = None,
832
+ added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
833
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
834
+ guess_mode: bool = False,
835
+ return_dict: bool = True,
836
+ ) -> Union[ControlNetOutput, Tuple]:
837
+ """
838
+ The [`ControlNetModel`] forward method.
839
+
840
+ Args:
841
+ sample (`torch.FloatTensor`):
842
+ The noisy input tensor.
843
+ timestep (`Union[torch.Tensor, float, int]`):
844
+ The number of timesteps to denoise an input.
845
+ encoder_hidden_states (`torch.Tensor`):
846
+ The encoder hidden states.
847
+ controlnet_cond (`torch.FloatTensor`):
848
+ The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`.
849
+ conditioning_scale (`float`, defaults to `1.0`):
850
+ The scale factor for ControlNet outputs.
851
+ class_labels (`torch.Tensor`, *optional*, defaults to `None`):
852
+ Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
853
+ timestep_cond (`torch.Tensor`, *optional*, defaults to `None`):
854
+ Additional conditional embeddings for timestep. If provided, the embeddings will be summed with the
855
+ timestep_embedding passed through the `self.time_embedding` layer to obtain the final timestep
856
+ embeddings.
857
+ attention_mask (`torch.Tensor`, *optional*, defaults to `None`):
858
+ An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
859
+ is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
860
+ negative values to the attention scores corresponding to "discard" tokens.
861
+ added_cond_kwargs (`dict`):
862
+ Additional conditions for the Stable Diffusion XL UNet.
863
+ cross_attention_kwargs (`dict[str]`, *optional*, defaults to `None`):
864
+ A kwargs dictionary that if specified is passed along to the `AttnProcessor`.
865
+ guess_mode (`bool`, defaults to `False`):
866
+ In this mode, the ControlNet encoder tries its best to recognize the input content of the input even if
867
+ you remove all prompts. A `guidance_scale` between 3.0 and 5.0 is recommended.
868
+ return_dict (`bool`, defaults to `True`):
869
+ Whether or not to return a [`~models.controlnet.ControlNetOutput`] instead of a plain tuple.
870
+
871
+ Returns:
872
+ [`~models.controlnet.ControlNetOutput`] **or** `tuple`:
873
+ If `return_dict` is `True`, a [`~models.controlnet.ControlNetOutput`] is returned, otherwise a tuple is
874
+ returned where the first element is the sample tensor.
875
+ """
876
+ # check channel order
877
+ channel_order = self.config.controlnet_conditioning_channel_order
878
+
879
+ if channel_order == "rgb":
880
+ # in rgb order by default
881
+ ...
882
+ # elif channel_order == "bgr":
883
+ # controlnet_cond = torch.flip(controlnet_cond, dims=[1])
884
+ else:
885
+ raise ValueError(
886
+ f"unknown `controlnet_conditioning_channel_order`: {channel_order}"
887
+ )
888
+
889
+ # prepare attention_mask
890
+ if attention_mask is not None:
891
+ attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
892
+ attention_mask = attention_mask.unsqueeze(1)
893
+
894
+ # 1. time
895
+ timesteps = timestep
896
+ if not torch.is_tensor(timesteps):
897
+ # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
898
+ # This would be a good case for the `match` statement (Python 3.10+)
899
+ is_mps = sample.device.type == "mps"
900
+ if isinstance(timestep, float):
901
+ dtype = torch.float32 if is_mps else torch.float64
902
+ else:
903
+ dtype = torch.int32 if is_mps else torch.int64
904
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
905
+ elif len(timesteps.shape) == 0:
906
+ timesteps = timesteps[None].to(sample.device)
907
+
908
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
909
+ timesteps = timesteps.expand(sample.shape[0])
910
+
911
+ t_emb = self.time_proj(timesteps)
912
+
913
+ # timesteps does not contain any weights and will always return f32 tensors
914
+ # but time_embedding might actually be running in fp16. so we need to cast here.
915
+ # there might be better ways to encapsulate this.
916
+ t_emb = t_emb.to(dtype=sample.dtype)
917
+
918
+ emb = self.time_embedding(t_emb, timestep_cond)
919
+ aug_emb = None
920
+
921
+ if self.class_embedding is not None:
922
+ if class_labels is None:
923
+ raise ValueError(
924
+ "class_labels should be provided when num_class_embeds > 0"
925
+ )
926
+
927
+ if self.config.class_embed_type == "timestep":
928
+ class_labels = self.time_proj(class_labels)
929
+
930
+ class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
931
+ emb = emb + class_emb
932
+
933
+ if self.config.addition_embed_type is not None:
934
+ if self.config.addition_embed_type == "text":
935
+ aug_emb = self.add_embedding(encoder_hidden_states)
936
+
937
+ elif self.config.addition_embed_type == "text_time":
938
+ if "text_embeds" not in added_cond_kwargs:
939
+ raise ValueError(
940
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
941
+ )
942
+ text_embeds = added_cond_kwargs.get("text_embeds")
943
+ if "time_ids" not in added_cond_kwargs:
944
+ raise ValueError(
945
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
946
+ )
947
+ time_ids = added_cond_kwargs.get("time_ids")
948
+ time_embeds = self.add_time_proj(time_ids.flatten())
949
+ time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
950
+
951
+ add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
952
+ add_embeds = add_embeds.to(emb.dtype)
953
+ aug_emb = self.add_embedding(add_embeds)
954
+
955
+ # Copyright by Qi Xin(2024/07/06)
956
+ # inject control type info to time embedding to distinguish different control conditions
957
+ control_type = added_cond_kwargs.get("control_type")
958
+ control_embeds = self.control_type_proj(control_type.flatten())
959
+ control_embeds = control_embeds.reshape((t_emb.shape[0], -1))
960
+ control_embeds = control_embeds.to(emb.dtype)
961
+ control_emb = self.control_add_embedding(control_embeds)
962
+ emb = emb + control_emb
963
+ # ---------------------------------------------------------------------------------
964
+
965
+ emb = emb + aug_emb if aug_emb is not None else emb
966
+
967
+ # 2. pre-process
968
+ sample = self.conv_in(sample)
969
+ indices = torch.nonzero(control_type[0])
970
+
971
+ # Copyright by Qi Xin(2024/07/06)
972
+ # add single/multi conditons to input image.
973
+ # Condition Transformer provides an easy and effective way to fuse different features naturally
974
+ inputs = []
975
+ condition_list = []
976
+
977
+ for idx in range(indices.shape[0] + 1):
978
+ if idx == indices.shape[0]:
979
+ controlnet_cond = sample
980
+ feat_seq = torch.mean(controlnet_cond, dim=(2, 3)) # N * C
981
+ else:
982
+ controlnet_cond = self.controlnet_cond_embedding(
983
+ controlnet_cond_list[indices[idx][0]]
984
+ )
985
+ feat_seq = torch.mean(controlnet_cond, dim=(2, 3)) # N * C
986
+ feat_seq = feat_seq + self.task_embedding[indices[idx][0]]
987
+
988
+ inputs.append(feat_seq.unsqueeze(1))
989
+ condition_list.append(controlnet_cond)
990
+
991
+ x = torch.cat(inputs, dim=1) # NxLxC
992
+ x = self.transformer_layes(x)
993
+
994
+ controlnet_cond_fuser = sample * 0.0
995
+ for idx in range(indices.shape[0]):
996
+ alpha = self.spatial_ch_projs(x[:, idx])
997
+ alpha = alpha.unsqueeze(-1).unsqueeze(-1)
998
+ controlnet_cond_fuser += condition_list[idx] + alpha
999
+
1000
+ sample = sample + controlnet_cond_fuser
1001
+ # -------------------------------------------------------------------------------------------
1002
+
1003
+ # 3. down
1004
+ down_block_res_samples = (sample,)
1005
+ for downsample_block in self.down_blocks:
1006
+ if (
1007
+ hasattr(downsample_block, "has_cross_attention")
1008
+ and downsample_block.has_cross_attention
1009
+ ):
1010
+ sample, res_samples = downsample_block(
1011
+ hidden_states=sample,
1012
+ temb=emb,
1013
+ encoder_hidden_states=encoder_hidden_states,
1014
+ attention_mask=attention_mask,
1015
+ cross_attention_kwargs=cross_attention_kwargs,
1016
+ )
1017
+ else:
1018
+ sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
1019
+
1020
+ down_block_res_samples += res_samples
1021
+
1022
+ # 4. mid
1023
+ if self.mid_block is not None:
1024
+ sample = self.mid_block(
1025
+ sample,
1026
+ emb,
1027
+ encoder_hidden_states=encoder_hidden_states,
1028
+ attention_mask=attention_mask,
1029
+ cross_attention_kwargs=cross_attention_kwargs,
1030
+ )
1031
+
1032
+ # 5. Control net blocks
1033
+
1034
+ controlnet_down_block_res_samples = ()
1035
+
1036
+ for down_block_res_sample, controlnet_block in zip(
1037
+ down_block_res_samples, self.controlnet_down_blocks
1038
+ ):
1039
+ down_block_res_sample = controlnet_block(down_block_res_sample)
1040
+ controlnet_down_block_res_samples = controlnet_down_block_res_samples + (
1041
+ down_block_res_sample,
1042
+ )
1043
+
1044
+ down_block_res_samples = controlnet_down_block_res_samples
1045
+
1046
+ mid_block_res_sample = self.controlnet_mid_block(sample)
1047
+
1048
+ # 6. scaling
1049
+ if guess_mode and not self.config.global_pool_conditions:
1050
+ scales = torch.logspace(
1051
+ -1, 0, len(down_block_res_samples) + 1, device=sample.device
1052
+ ) # 0.1 to 1.0
1053
+ scales = scales * conditioning_scale
1054
+ down_block_res_samples = [
1055
+ sample * scale for sample, scale in zip(down_block_res_samples, scales)
1056
+ ]
1057
+ mid_block_res_sample = mid_block_res_sample * scales[-1] # last one
1058
+ else:
1059
+ down_block_res_samples = [
1060
+ sample * conditioning_scale for sample in down_block_res_samples
1061
+ ]
1062
+ mid_block_res_sample = mid_block_res_sample * conditioning_scale
1063
+
1064
+ if self.config.global_pool_conditions:
1065
+ down_block_res_samples = [
1066
+ torch.mean(sample, dim=(2, 3), keepdim=True)
1067
+ for sample in down_block_res_samples
1068
+ ]
1069
+ mid_block_res_sample = torch.mean(
1070
+ mid_block_res_sample, dim=(2, 3), keepdim=True
1071
+ )
1072
+
1073
+ if not return_dict:
1074
+ return (down_block_res_samples, mid_block_res_sample)
1075
+
1076
+ return ControlNetOutput(
1077
+ down_block_res_samples=down_block_res_samples,
1078
+ mid_block_res_sample=mid_block_res_sample,
1079
+ )
1080
+
1081
+
1082
+ def zero_module(module):
1083
+ for p in module.parameters():
1084
+ nn.init.zeros_(p)
1085
+ return module