depthanyvideo commited on
Commit
a1739b9
·
1 Parent(s): 66993df
models/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ from .unets.unet_spatio_temporal_rope_condition import (
2
+ UNetSpatioTemporalRopeConditionModel,
3
+ )
4
+
5
+ __all__ = {
6
+ "UNetSpatioTemporalRopeConditionModel": UNetSpatioTemporalRopeConditionModel,
7
+ }
models/attention.py ADDED
@@ -0,0 +1,741 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, Optional
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from torch import nn
6
+
7
+ from diffusers.utils import deprecate, logging
8
+ from diffusers.utils.torch_utils import maybe_allow_in_graph
9
+ from diffusers.models.activations import GEGLU, GELU, ApproximateGELU
10
+ from .attention_processor import (
11
+ Attention,
12
+ AttnProcessor2_0,
13
+ JointAttnProcessor2_0,
14
+ JointAttnROPEProcessor2_0,
15
+ AttnRopeProcessor2_0,
16
+ )
17
+ from .embeddings import SinusoidalPositionalEmbedding
18
+ from diffusers.models.normalization import (
19
+ AdaLayerNorm,
20
+ AdaLayerNormContinuous,
21
+ AdaLayerNormZero,
22
+ RMSNorm,
23
+ )
24
+
25
+
26
+ logger = logging.get_logger(__name__)
27
+
28
+
29
+ def _chunked_feed_forward(
30
+ ff: nn.Module, hidden_states: torch.Tensor, chunk_dim: int, chunk_size: int
31
+ ):
32
+ # "feed_forward_chunk_size" can be used to save memory
33
+ if hidden_states.shape[chunk_dim] % chunk_size != 0:
34
+ raise ValueError(
35
+ f"`hidden_states` dimension to be chunked: {hidden_states.shape[chunk_dim]} has to be divisible by chunk size: {chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`."
36
+ )
37
+
38
+ num_chunks = hidden_states.shape[chunk_dim] // chunk_size
39
+ ff_output = torch.cat(
40
+ [ff(hid_slice) for hid_slice in hidden_states.chunk(num_chunks, dim=chunk_dim)],
41
+ dim=chunk_dim,
42
+ )
43
+ return ff_output
44
+
45
+
46
+ @maybe_allow_in_graph
47
+ class GatedSelfAttentionDense(nn.Module):
48
+ r"""
49
+ A gated self-attention dense layer that combines visual features and object features.
50
+
51
+ Parameters:
52
+ query_dim (`int`): The number of channels in the query.
53
+ context_dim (`int`): The number of channels in the context.
54
+ n_heads (`int`): The number of heads to use for attention.
55
+ d_head (`int`): The number of channels in each head.
56
+ """
57
+
58
+ def __init__(self, query_dim: int, context_dim: int, n_heads: int, d_head: int):
59
+ super().__init__()
60
+
61
+ # we need a linear projection since we need cat visual feature and obj feature
62
+ self.linear = nn.Linear(context_dim, query_dim)
63
+
64
+ self.attn = Attention(query_dim=query_dim, heads=n_heads, dim_head=d_head)
65
+ self.ff = FeedForward(query_dim, activation_fn="geglu")
66
+
67
+ self.norm1 = nn.LayerNorm(query_dim)
68
+ self.norm2 = nn.LayerNorm(query_dim)
69
+
70
+ self.register_parameter("alpha_attn", nn.Parameter(torch.tensor(0.0)))
71
+ self.register_parameter("alpha_dense", nn.Parameter(torch.tensor(0.0)))
72
+
73
+ self.enabled = True
74
+
75
+ def forward(self, x: torch.Tensor, objs: torch.Tensor) -> torch.Tensor:
76
+ if not self.enabled:
77
+ return x
78
+
79
+ n_visual = x.shape[1]
80
+ objs = self.linear(objs)
81
+
82
+ x = (
83
+ x
84
+ + self.alpha_attn.tanh()
85
+ * self.attn(self.norm1(torch.cat([x, objs], dim=1)))[:, :n_visual, :]
86
+ )
87
+ x = x + self.alpha_dense.tanh() * self.ff(self.norm2(x))
88
+
89
+ return x
90
+
91
+
92
+ @maybe_allow_in_graph
93
+ class TransformerBlock(nn.Module):
94
+ r"""
95
+ A Transformer block following the MMDiT architecture, introduced in Stable Diffusion 3.
96
+
97
+ Reference: https://arxiv.org/abs/2403.03206
98
+
99
+ Parameters:
100
+ dim (`int`): The number of channels in the input and output.
101
+ num_attention_heads (`int`): The number of heads to use for multi-head attention.
102
+ attention_head_dim (`int`): The number of channels in each head.
103
+ context_pre_only (`bool`): Boolean to determine if we should add some blocks associated with the
104
+ processing of `context` conditions.
105
+ """
106
+
107
+ def __init__(
108
+ self, dim, num_attention_heads, attention_head_dim, context_pre_only=False
109
+ ):
110
+ super().__init__()
111
+
112
+ self.norm1 = AdaLayerNormZero(dim)
113
+
114
+ if hasattr(F, "scaled_dot_product_attention"):
115
+ processor = AttnProcessor2_0()
116
+ else:
117
+ raise ValueError(
118
+ "The current PyTorch version does not support the `scaled_dot_product_attention` function."
119
+ )
120
+ self.attn = Attention(
121
+ query_dim=dim,
122
+ cross_attention_dim=None,
123
+ added_kv_proj_dim=None,
124
+ dim_head=attention_head_dim // num_attention_heads,
125
+ heads=num_attention_heads,
126
+ out_dim=attention_head_dim,
127
+ context_pre_only=context_pre_only,
128
+ bias=True,
129
+ processor=processor,
130
+ )
131
+
132
+ self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
133
+ self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
134
+
135
+ # let chunk size default to None
136
+ self._chunk_size = None
137
+ self._chunk_dim = 0
138
+
139
+ # Copied from diffusers.models.attention.BasicTransformerBlock.set_chunk_feed_forward
140
+ def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0):
141
+ # Sets chunk feed-forward
142
+ self._chunk_size = chunk_size
143
+ self._chunk_dim = dim
144
+
145
+ def forward(
146
+ self,
147
+ hidden_states: torch.FloatTensor,
148
+ temb: torch.FloatTensor,
149
+ ):
150
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
151
+ hidden_states, emb=temb
152
+ )
153
+
154
+ # Attention.
155
+ attn_output = self.attn(hidden_states=norm_hidden_states)
156
+
157
+ # Process attention outputs for the `hidden_states`.
158
+ attn_output = gate_msa.unsqueeze(1) * attn_output
159
+ hidden_states = hidden_states + attn_output
160
+
161
+ norm_hidden_states = self.norm2(hidden_states)
162
+ norm_hidden_states = (
163
+ norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
164
+ )
165
+ if self._chunk_size is not None:
166
+ # "feed_forward_chunk_size" can be used to save memory
167
+ ff_output = _chunked_feed_forward(
168
+ self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size
169
+ )
170
+ else:
171
+ ff_output = self.ff(norm_hidden_states)
172
+ ff_output = gate_mlp.unsqueeze(1) * ff_output
173
+
174
+ hidden_states = hidden_states + ff_output
175
+
176
+ return hidden_states
177
+
178
+
179
+ @maybe_allow_in_graph
180
+ class BasicTransformerBlock(nn.Module):
181
+ r"""
182
+ A basic Transformer block.
183
+
184
+ Parameters:
185
+ dim (`int`): The number of channels in the input and output.
186
+ num_attention_heads (`int`): The number of heads to use for multi-head attention.
187
+ attention_head_dim (`int`): The number of channels in each head.
188
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
189
+ cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
190
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
191
+ num_embeds_ada_norm (:
192
+ obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
193
+ attention_bias (:
194
+ obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
195
+ only_cross_attention (`bool`, *optional*):
196
+ Whether to use only cross-attention layers. In this case two cross attention layers are used.
197
+ double_self_attention (`bool`, *optional*):
198
+ Whether to use two self-attention layers. In this case no cross attention layers are used.
199
+ upcast_attention (`bool`, *optional*):
200
+ Whether to upcast the attention computation to float32. This is useful for mixed precision training.
201
+ norm_elementwise_affine (`bool`, *optional*, defaults to `True`):
202
+ Whether to use learnable elementwise affine parameters for normalization.
203
+ norm_type (`str`, *optional*, defaults to `"layer_norm"`):
204
+ The normalization layer to use. Can be `"layer_norm"`, `"ada_norm"` or `"ada_norm_zero"`.
205
+ final_dropout (`bool` *optional*, defaults to False):
206
+ Whether to apply a final dropout after the last feed-forward layer.
207
+ attention_type (`str`, *optional*, defaults to `"default"`):
208
+ The type of attention to use. Can be `"default"` or `"gated"` or `"gated-text-image"`.
209
+ positional_embeddings (`str`, *optional*, defaults to `None`):
210
+ The type of positional embeddings to apply to.
211
+ num_positional_embeddings (`int`, *optional*, defaults to `None`):
212
+ The maximum number of positional embeddings to apply.
213
+ """
214
+
215
+ def __init__(
216
+ self,
217
+ dim: int,
218
+ num_attention_heads: int,
219
+ attention_head_dim: int,
220
+ dropout=0.0,
221
+ cross_attention_dim: Optional[int] = None,
222
+ activation_fn: str = "geglu",
223
+ num_embeds_ada_norm: Optional[int] = None,
224
+ attention_bias: bool = False,
225
+ only_cross_attention: bool = False,
226
+ double_self_attention: bool = False,
227
+ upcast_attention: bool = False,
228
+ norm_elementwise_affine: bool = True,
229
+ norm_type: str = "layer_norm", # 'layer_norm', 'ada_norm', 'ada_norm_zero', 'ada_norm_single', 'ada_norm_continuous', 'layer_norm_i2vgen'
230
+ norm_eps: float = 1e-5,
231
+ final_dropout: bool = False,
232
+ attention_type: str = "default",
233
+ positional_embeddings: Optional[str] = None,
234
+ num_positional_embeddings: Optional[int] = None,
235
+ ada_norm_continous_conditioning_embedding_dim: Optional[int] = None,
236
+ ada_norm_bias: Optional[int] = None,
237
+ ff_inner_dim: Optional[int] = None,
238
+ ff_bias: bool = True,
239
+ attention_out_bias: bool = True,
240
+ ):
241
+ super().__init__()
242
+ self.only_cross_attention = only_cross_attention
243
+
244
+ # We keep these boolean flags for backward-compatibility.
245
+ self.use_ada_layer_norm_zero = (
246
+ num_embeds_ada_norm is not None
247
+ ) and norm_type == "ada_norm_zero"
248
+ self.use_ada_layer_norm = (
249
+ num_embeds_ada_norm is not None
250
+ ) and norm_type == "ada_norm"
251
+ self.use_ada_layer_norm_single = norm_type == "ada_norm_single"
252
+ self.use_layer_norm = norm_type == "layer_norm"
253
+ self.use_ada_layer_norm_continuous = norm_type == "ada_norm_continuous"
254
+
255
+ if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
256
+ raise ValueError(
257
+ f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to"
258
+ f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}."
259
+ )
260
+
261
+ self.norm_type = norm_type
262
+ self.num_embeds_ada_norm = num_embeds_ada_norm
263
+
264
+ if positional_embeddings and (num_positional_embeddings is None):
265
+ raise ValueError(
266
+ "If `positional_embedding` type is defined, `num_positition_embeddings` must also be defined."
267
+ )
268
+
269
+ if positional_embeddings == "sinusoidal":
270
+ self.pos_embed = SinusoidalPositionalEmbedding(
271
+ dim, max_seq_length=num_positional_embeddings
272
+ )
273
+ else:
274
+ self.pos_embed = None
275
+
276
+ # Define 3 blocks. Each block has its own normalization layer.
277
+ # 1. Self-Attn
278
+ if norm_type == "ada_norm":
279
+ self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm)
280
+ elif norm_type == "ada_norm_zero":
281
+ self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm)
282
+ elif norm_type == "ada_norm_continuous":
283
+ self.norm1 = AdaLayerNormContinuous(
284
+ dim,
285
+ ada_norm_continous_conditioning_embedding_dim,
286
+ norm_elementwise_affine,
287
+ norm_eps,
288
+ ada_norm_bias,
289
+ "rms_norm",
290
+ )
291
+ else:
292
+ self.norm1 = nn.LayerNorm(
293
+ dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps
294
+ )
295
+
296
+ self.attn1 = Attention(
297
+ query_dim=dim,
298
+ heads=num_attention_heads,
299
+ dim_head=attention_head_dim,
300
+ dropout=dropout,
301
+ bias=attention_bias,
302
+ cross_attention_dim=cross_attention_dim if only_cross_attention else None,
303
+ upcast_attention=upcast_attention,
304
+ out_bias=attention_out_bias,
305
+ )
306
+
307
+ # 2. Cross-Attn
308
+ if cross_attention_dim is not None or double_self_attention:
309
+ # We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
310
+ # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
311
+ # the second cross attention block.
312
+ if norm_type == "ada_norm":
313
+ self.norm2 = AdaLayerNorm(dim, num_embeds_ada_norm)
314
+ elif norm_type == "ada_norm_continuous":
315
+ self.norm2 = AdaLayerNormContinuous(
316
+ dim,
317
+ ada_norm_continous_conditioning_embedding_dim,
318
+ norm_elementwise_affine,
319
+ norm_eps,
320
+ ada_norm_bias,
321
+ "rms_norm",
322
+ )
323
+ else:
324
+ self.norm2 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine)
325
+
326
+ self.attn2 = Attention(
327
+ query_dim=dim,
328
+ cross_attention_dim=(
329
+ cross_attention_dim if not double_self_attention else None
330
+ ),
331
+ heads=num_attention_heads,
332
+ dim_head=attention_head_dim,
333
+ dropout=dropout,
334
+ bias=attention_bias,
335
+ upcast_attention=upcast_attention,
336
+ out_bias=attention_out_bias,
337
+ ) # is self-attn if encoder_hidden_states is none
338
+ else:
339
+ self.norm2 = None
340
+ self.attn2 = None
341
+
342
+ # 3. Feed-forward
343
+ if norm_type == "ada_norm_continuous":
344
+ self.norm3 = AdaLayerNormContinuous(
345
+ dim,
346
+ ada_norm_continous_conditioning_embedding_dim,
347
+ norm_elementwise_affine,
348
+ norm_eps,
349
+ ada_norm_bias,
350
+ "layer_norm",
351
+ )
352
+
353
+ elif norm_type in [
354
+ "ada_norm_zero",
355
+ "ada_norm",
356
+ "layer_norm",
357
+ "ada_norm_continuous",
358
+ ]:
359
+ self.norm3 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine)
360
+ elif norm_type == "layer_norm_i2vgen":
361
+ self.norm3 = None
362
+
363
+ self.ff = FeedForward(
364
+ dim,
365
+ dropout=dropout,
366
+ activation_fn=activation_fn,
367
+ final_dropout=final_dropout,
368
+ inner_dim=ff_inner_dim,
369
+ bias=ff_bias,
370
+ )
371
+
372
+ # 4. Fuser
373
+ if attention_type == "gated" or attention_type == "gated-text-image":
374
+ self.fuser = GatedSelfAttentionDense(
375
+ dim, cross_attention_dim, num_attention_heads, attention_head_dim
376
+ )
377
+
378
+ # 5. Scale-shift for PixArt-Alpha.
379
+ if norm_type == "ada_norm_single":
380
+ self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5)
381
+
382
+ # let chunk size default to None
383
+ self._chunk_size = None
384
+ self._chunk_dim = 0
385
+
386
+ def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0):
387
+ # Sets chunk feed-forward
388
+ self._chunk_size = chunk_size
389
+ self._chunk_dim = dim
390
+
391
+ def forward(
392
+ self,
393
+ hidden_states: torch.Tensor,
394
+ attention_mask: Optional[torch.Tensor] = None,
395
+ encoder_hidden_states: Optional[torch.Tensor] = None,
396
+ encoder_attention_mask: Optional[torch.Tensor] = None,
397
+ timestep: Optional[torch.LongTensor] = None,
398
+ cross_attention_kwargs: Dict[str, Any] = None,
399
+ class_labels: Optional[torch.LongTensor] = None,
400
+ added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
401
+ ) -> torch.Tensor:
402
+ if cross_attention_kwargs is not None:
403
+ if cross_attention_kwargs.get("scale", None) is not None:
404
+ logger.warning(
405
+ "Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored."
406
+ )
407
+
408
+ # Notice that normalization is always applied before the real computation in the following blocks.
409
+ # 0. Self-Attention
410
+ batch_size = hidden_states.shape[0]
411
+
412
+ if self.norm_type == "ada_norm":
413
+ norm_hidden_states = self.norm1(hidden_states, timestep)
414
+ elif self.norm_type == "ada_norm_zero":
415
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
416
+ hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
417
+ )
418
+ elif self.norm_type in ["layer_norm", "layer_norm_i2vgen"]:
419
+ norm_hidden_states = self.norm1(hidden_states)
420
+ elif self.norm_type == "ada_norm_continuous":
421
+ norm_hidden_states = self.norm1(
422
+ hidden_states, added_cond_kwargs["pooled_text_emb"]
423
+ )
424
+ elif self.norm_type == "ada_norm_single":
425
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
426
+ self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1)
427
+ ).chunk(6, dim=1)
428
+ norm_hidden_states = self.norm1(hidden_states)
429
+ norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
430
+ norm_hidden_states = norm_hidden_states.squeeze(1)
431
+ else:
432
+ raise ValueError("Incorrect norm used")
433
+
434
+ if self.pos_embed is not None:
435
+ norm_hidden_states = self.pos_embed(norm_hidden_states)
436
+
437
+ # 1. Prepare GLIGEN inputs
438
+ cross_attention_kwargs = (
439
+ cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}
440
+ )
441
+ gligen_kwargs = cross_attention_kwargs.pop("gligen", None)
442
+
443
+ attn_output = self.attn1(
444
+ norm_hidden_states,
445
+ encoder_hidden_states=(
446
+ encoder_hidden_states if self.only_cross_attention else None
447
+ ),
448
+ attention_mask=attention_mask,
449
+ **cross_attention_kwargs,
450
+ )
451
+ if self.norm_type == "ada_norm_zero":
452
+ attn_output = gate_msa.unsqueeze(1) * attn_output
453
+ elif self.norm_type == "ada_norm_single":
454
+ attn_output = gate_msa * attn_output
455
+
456
+ hidden_states = attn_output + hidden_states
457
+ if hidden_states.ndim == 4:
458
+ hidden_states = hidden_states.squeeze(1)
459
+
460
+ # 1.2 GLIGEN Control
461
+ if gligen_kwargs is not None:
462
+ hidden_states = self.fuser(hidden_states, gligen_kwargs["objs"])
463
+
464
+ # 3. Cross-Attention
465
+ if self.attn2 is not None:
466
+ if self.norm_type == "ada_norm":
467
+ norm_hidden_states = self.norm2(hidden_states, timestep)
468
+ elif self.norm_type in ["ada_norm_zero", "layer_norm", "layer_norm_i2vgen"]:
469
+ norm_hidden_states = self.norm2(hidden_states)
470
+ elif self.norm_type == "ada_norm_single":
471
+ # For PixArt norm2 isn't applied here:
472
+ # https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L70C1-L76C103
473
+ norm_hidden_states = hidden_states
474
+ elif self.norm_type == "ada_norm_continuous":
475
+ norm_hidden_states = self.norm2(
476
+ hidden_states, added_cond_kwargs["pooled_text_emb"]
477
+ )
478
+ else:
479
+ raise ValueError("Incorrect norm")
480
+
481
+ if self.pos_embed is not None and self.norm_type != "ada_norm_single":
482
+ norm_hidden_states = self.pos_embed(norm_hidden_states)
483
+
484
+ attn_output = self.attn2(
485
+ norm_hidden_states,
486
+ encoder_hidden_states=encoder_hidden_states,
487
+ attention_mask=encoder_attention_mask,
488
+ **cross_attention_kwargs,
489
+ )
490
+ hidden_states = attn_output + hidden_states
491
+
492
+ # 4. Feed-forward
493
+ # i2vgen doesn't have this norm 🤷‍♂️
494
+ if self.norm_type == "ada_norm_continuous":
495
+ norm_hidden_states = self.norm3(
496
+ hidden_states, added_cond_kwargs["pooled_text_emb"]
497
+ )
498
+ elif not self.norm_type == "ada_norm_single":
499
+ norm_hidden_states = self.norm3(hidden_states)
500
+
501
+ if self.norm_type == "ada_norm_zero":
502
+ norm_hidden_states = (
503
+ norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
504
+ )
505
+
506
+ if self.norm_type == "ada_norm_single":
507
+ norm_hidden_states = self.norm2(hidden_states)
508
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp
509
+
510
+ if self._chunk_size is not None:
511
+ # "feed_forward_chunk_size" can be used to save memory
512
+ ff_output = _chunked_feed_forward(
513
+ self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size
514
+ )
515
+ else:
516
+ ff_output = self.ff(norm_hidden_states)
517
+
518
+ if self.norm_type == "ada_norm_zero":
519
+ ff_output = gate_mlp.unsqueeze(1) * ff_output
520
+ elif self.norm_type == "ada_norm_single":
521
+ ff_output = gate_mlp * ff_output
522
+
523
+ hidden_states = ff_output + hidden_states
524
+ if hidden_states.ndim == 4:
525
+ hidden_states = hidden_states.squeeze(1)
526
+
527
+ return hidden_states
528
+
529
+
530
+ @maybe_allow_in_graph
531
+ class TemporalRopeBasicTransformerBlock(nn.Module):
532
+ r"""
533
+ A basic Transformer block for video like data.
534
+
535
+ Parameters:
536
+ dim (`int`): The number of channels in the input and output.
537
+ time_mix_inner_dim (`int`): The number of channels for temporal attention.
538
+ num_attention_heads (`int`): The number of heads to use for multi-head attention.
539
+ attention_head_dim (`int`): The number of channels in each head.
540
+ cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
541
+ """
542
+
543
+ def __init__(
544
+ self,
545
+ dim: int,
546
+ time_mix_inner_dim: int,
547
+ num_attention_heads: int,
548
+ attention_head_dim: int,
549
+ cross_attention_dim: Optional[int] = None,
550
+ ):
551
+ super().__init__()
552
+ self.is_res = dim == time_mix_inner_dim
553
+
554
+ self.norm_in = nn.LayerNorm(dim)
555
+
556
+ # Define 3 blocks. Each block has its own normalization layer.
557
+ # 1. Self-Attn
558
+ self.ff_in = FeedForward(
559
+ dim,
560
+ dim_out=time_mix_inner_dim,
561
+ activation_fn="geglu",
562
+ )
563
+
564
+ processor = AttnRopeProcessor2_0()
565
+
566
+ self.norm1 = nn.LayerNorm(time_mix_inner_dim)
567
+ self.attn1 = Attention(
568
+ query_dim=time_mix_inner_dim,
569
+ heads=num_attention_heads,
570
+ dim_head=attention_head_dim,
571
+ cross_attention_dim=None,
572
+ processor=processor,
573
+ )
574
+
575
+ # 2. Cross-Attn
576
+ if cross_attention_dim is not None:
577
+ # We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
578
+ # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
579
+ # the second cross attention block.
580
+ self.norm2 = nn.LayerNorm(time_mix_inner_dim)
581
+ self.attn2 = Attention(
582
+ query_dim=time_mix_inner_dim,
583
+ cross_attention_dim=cross_attention_dim,
584
+ heads=num_attention_heads,
585
+ dim_head=attention_head_dim,
586
+ processor=processor,
587
+ ) # is self-attn if encoder_hidden_states is none
588
+ else:
589
+ self.norm2 = None
590
+ self.attn2 = None
591
+
592
+ # 3. Feed-forward
593
+ self.norm3 = nn.LayerNorm(time_mix_inner_dim)
594
+ self.ff = FeedForward(time_mix_inner_dim, activation_fn="geglu")
595
+
596
+ # let chunk size default to None
597
+ self._chunk_size = None
598
+ self._chunk_dim = None
599
+
600
+ def set_chunk_feed_forward(self, chunk_size: Optional[int], **kwargs):
601
+ # Sets chunk feed-forward
602
+ self._chunk_size = chunk_size
603
+ # chunk dim should be hardcoded to 1 to have better speed vs. memory trade-off
604
+ self._chunk_dim = 1
605
+
606
+ def forward(
607
+ self,
608
+ hidden_states: torch.Tensor,
609
+ num_frames: int,
610
+ encoder_hidden_states: Optional[torch.Tensor] = None,
611
+ frame_rotary_emb=None,
612
+ ) -> torch.Tensor:
613
+ # Notice that normalization is always applied before the real computation in the following blocks.
614
+ # 0. Self-Attention
615
+ batch_size = hidden_states.shape[0]
616
+
617
+ batch_frames, seq_length, channels = hidden_states.shape
618
+ batch_size = batch_frames // num_frames
619
+
620
+ hidden_states = hidden_states[None, :].reshape(
621
+ batch_size, num_frames, seq_length, channels
622
+ )
623
+ hidden_states = hidden_states.permute(0, 2, 1, 3)
624
+ hidden_states = hidden_states.reshape(
625
+ batch_size * seq_length, num_frames, channels
626
+ )
627
+
628
+ residual = hidden_states
629
+ hidden_states = self.norm_in(hidden_states)
630
+
631
+ if self._chunk_size is not None:
632
+ hidden_states = _chunked_feed_forward(
633
+ self.ff_in, hidden_states, self._chunk_dim, self._chunk_size
634
+ )
635
+ else:
636
+ hidden_states = self.ff_in(hidden_states)
637
+
638
+ if self.is_res:
639
+ hidden_states = hidden_states + residual
640
+
641
+ norm_hidden_states = self.norm1(hidden_states)
642
+ attn_output = self.attn1(
643
+ norm_hidden_states,
644
+ encoder_hidden_states=None,
645
+ frame_rotary_emb=frame_rotary_emb,
646
+ )
647
+ hidden_states = attn_output + hidden_states
648
+
649
+ # 3. Cross-Attention
650
+ if self.attn2 is not None:
651
+ norm_hidden_states = self.norm2(hidden_states)
652
+ attn_output = self.attn2(
653
+ norm_hidden_states,
654
+ encoder_hidden_states=encoder_hidden_states,
655
+ frame_rotary_emb=frame_rotary_emb,
656
+ )
657
+ hidden_states = attn_output + hidden_states
658
+
659
+ # 4. Feed-forward
660
+ norm_hidden_states = self.norm3(hidden_states)
661
+
662
+ if self._chunk_size is not None:
663
+ ff_output = _chunked_feed_forward(
664
+ self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size
665
+ )
666
+ else:
667
+ ff_output = self.ff(norm_hidden_states)
668
+
669
+ if self.is_res:
670
+ hidden_states = ff_output + hidden_states
671
+ else:
672
+ hidden_states = ff_output
673
+
674
+ hidden_states = hidden_states[None, :].reshape(
675
+ batch_size, seq_length, num_frames, channels
676
+ )
677
+ hidden_states = hidden_states.permute(0, 2, 1, 3)
678
+ hidden_states = hidden_states.reshape(
679
+ batch_size * num_frames, seq_length, channels
680
+ )
681
+
682
+ return hidden_states
683
+
684
+
685
+ class FeedForward(nn.Module):
686
+ r"""
687
+ A feed-forward layer.
688
+
689
+ Parameters:
690
+ dim (`int`): The number of channels in the input.
691
+ dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`.
692
+ mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension.
693
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
694
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
695
+ final_dropout (`bool` *optional*, defaults to False): Apply a final dropout.
696
+ bias (`bool`, defaults to True): Whether to use a bias in the linear layer.
697
+ """
698
+
699
+ def __init__(
700
+ self,
701
+ dim: int,
702
+ dim_out: Optional[int] = None,
703
+ mult: int = 4,
704
+ dropout: float = 0.0,
705
+ activation_fn: str = "geglu",
706
+ final_dropout: bool = False,
707
+ inner_dim=None,
708
+ bias: bool = True,
709
+ ):
710
+ super().__init__()
711
+ if inner_dim is None:
712
+ inner_dim = int(dim * mult)
713
+ dim_out = dim_out if dim_out is not None else dim
714
+
715
+ if activation_fn == "gelu":
716
+ act_fn = GELU(dim, inner_dim, bias=bias)
717
+ if activation_fn == "gelu-approximate":
718
+ act_fn = GELU(dim, inner_dim, approximate="tanh", bias=bias)
719
+ elif activation_fn == "geglu":
720
+ act_fn = GEGLU(dim, inner_dim, bias=bias)
721
+ elif activation_fn == "geglu-approximate":
722
+ act_fn = ApproximateGELU(dim, inner_dim, bias=bias)
723
+
724
+ self.net = nn.ModuleList([])
725
+ # project in
726
+ self.net.append(act_fn)
727
+ # project dropout
728
+ self.net.append(nn.Dropout(dropout))
729
+ # project out
730
+ self.net.append(nn.Linear(inner_dim, dim_out, bias=bias))
731
+ # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout
732
+ if final_dropout:
733
+ self.net.append(nn.Dropout(dropout))
734
+
735
+ def forward(self, hidden_states: torch.Tensor, *args, **kwargs) -> torch.Tensor:
736
+ if len(args) > 0 or kwargs.get("scale", None) is not None:
737
+ deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
738
+ deprecate("scale", "1.0.0", deprecation_message)
739
+ for module in self.net:
740
+ hidden_states = module(hidden_states)
741
+ return hidden_states
models/attention_processor.py ADDED
The diff for this file is too large to render. See raw diff
 
models/embeddings.py ADDED
@@ -0,0 +1,1539 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import math
15
+ from typing import List, Optional, Tuple, Union
16
+
17
+ import numpy as np
18
+ import torch
19
+ import torch.nn.functional as F
20
+ from torch import nn
21
+
22
+ from diffusers.utils import deprecate
23
+ from diffusers.models.activations import FP32SiLU, get_activation
24
+ from diffusers.models.attention_processor import Attention
25
+
26
+
27
+ def get_timestep_embedding(
28
+ timesteps: torch.Tensor,
29
+ embedding_dim: int,
30
+ flip_sin_to_cos: bool = False,
31
+ downscale_freq_shift: float = 1,
32
+ scale: float = 1,
33
+ max_period: int = 10000,
34
+ ):
35
+ """
36
+ This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.
37
+
38
+ :param timesteps: a 1-D Tensor of N indices, one per batch element.
39
+ These may be fractional.
40
+ :param embedding_dim: the dimension of the output. :param max_period: controls the minimum frequency of the
41
+ embeddings. :return: an [N x dim] Tensor of positional embeddings.
42
+ """
43
+ assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
44
+
45
+ half_dim = embedding_dim // 2
46
+ exponent = -math.log(max_period) * torch.arange(
47
+ start=0, end=half_dim, dtype=torch.float32, device=timesteps.device
48
+ )
49
+ exponent = exponent / (half_dim - downscale_freq_shift)
50
+
51
+ emb = torch.exp(exponent)
52
+ emb = timesteps[:, None].float() * emb[None, :]
53
+
54
+ # scale embeddings
55
+ emb = scale * emb
56
+
57
+ # concat sine and cosine embeddings
58
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
59
+
60
+ # flip sine and cosine embeddings
61
+ if flip_sin_to_cos:
62
+ emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1)
63
+
64
+ # zero pad
65
+ if embedding_dim % 2 == 1:
66
+ emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
67
+ return emb
68
+
69
+
70
+ def get_2d_sincos_pos_embed(
71
+ embed_dim,
72
+ grid_size,
73
+ cls_token=False,
74
+ extra_tokens=0,
75
+ interpolation_scale=1.0,
76
+ base_size=16,
77
+ ):
78
+ """
79
+ grid_size: int of the grid height and width return: pos_embed: [grid_size*grid_size, embed_dim] or
80
+ [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
81
+ """
82
+ if isinstance(grid_size, int):
83
+ grid_size = (grid_size, grid_size)
84
+
85
+ grid_h = (
86
+ np.arange(grid_size[0], dtype=np.float32)
87
+ / (grid_size[0] / base_size)
88
+ / interpolation_scale
89
+ )
90
+ grid_w = (
91
+ np.arange(grid_size[1], dtype=np.float32)
92
+ / (grid_size[1] / base_size)
93
+ / interpolation_scale
94
+ )
95
+ grid = np.meshgrid(grid_w, grid_h) # here w goes first
96
+ grid = np.stack(grid, axis=0)
97
+
98
+ grid = grid.reshape([2, 1, grid_size[1], grid_size[0]])
99
+ pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
100
+ if cls_token and extra_tokens > 0:
101
+ pos_embed = np.concatenate(
102
+ [np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0
103
+ )
104
+ return pos_embed
105
+
106
+
107
+ def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
108
+ if embed_dim % 2 != 0:
109
+ raise ValueError("embed_dim must be divisible by 2")
110
+
111
+ # use half of dimensions to encode grid_h
112
+ emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
113
+ emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
114
+
115
+ emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
116
+ return emb
117
+
118
+
119
+ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
120
+ """
121
+ embed_dim: output dimension for each position pos: a list of positions to be encoded: size (M,) out: (M, D)
122
+ """
123
+ if embed_dim % 2 != 0:
124
+ raise ValueError("embed_dim must be divisible by 2")
125
+
126
+ omega = np.arange(embed_dim // 2, dtype=np.float64)
127
+ omega /= embed_dim / 2.0
128
+ omega = 1.0 / 10000**omega # (D/2,)
129
+
130
+ pos = pos.reshape(-1) # (M,)
131
+ out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
132
+
133
+ emb_sin = np.sin(out) # (M, D/2)
134
+ emb_cos = np.cos(out) # (M, D/2)
135
+
136
+ emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
137
+ return emb
138
+
139
+
140
+ class PatchEmbed(nn.Module):
141
+ """2D Image to Patch Embedding with support for SD3 cropping."""
142
+
143
+ def __init__(
144
+ self,
145
+ height=224,
146
+ width=224,
147
+ patch_size=16,
148
+ in_channels=3,
149
+ embed_dim=768,
150
+ layer_norm=False,
151
+ flatten=True,
152
+ bias=True,
153
+ interpolation_scale=1,
154
+ pos_embed_type="sincos",
155
+ pos_embed_max_size=None, # For SD3 cropping
156
+ ):
157
+ super().__init__()
158
+
159
+ num_patches = (height // patch_size) * (width // patch_size)
160
+ self.flatten = flatten
161
+ self.layer_norm = layer_norm
162
+ self.pos_embed_max_size = pos_embed_max_size
163
+
164
+ self.proj = nn.Conv2d(
165
+ in_channels,
166
+ embed_dim,
167
+ kernel_size=(patch_size, patch_size),
168
+ stride=patch_size,
169
+ bias=bias,
170
+ )
171
+ if layer_norm:
172
+ self.norm = nn.LayerNorm(embed_dim, elementwise_affine=False, eps=1e-6)
173
+ else:
174
+ self.norm = None
175
+
176
+ self.patch_size = patch_size
177
+ self.height, self.width = height // patch_size, width // patch_size
178
+ self.base_size = height // patch_size
179
+ self.interpolation_scale = interpolation_scale
180
+
181
+ # Calculate positional embeddings based on max size or default
182
+ if pos_embed_max_size:
183
+ grid_size = pos_embed_max_size
184
+ else:
185
+ grid_size = int(num_patches**0.5)
186
+
187
+ if pos_embed_type is None:
188
+ self.pos_embed = None
189
+ elif pos_embed_type == "sincos":
190
+ pos_embed = get_2d_sincos_pos_embed(
191
+ embed_dim,
192
+ grid_size,
193
+ base_size=self.base_size,
194
+ interpolation_scale=self.interpolation_scale,
195
+ )
196
+ persistent = True if pos_embed_max_size else False
197
+ self.register_buffer(
198
+ "pos_embed",
199
+ torch.from_numpy(pos_embed).float().unsqueeze(0),
200
+ persistent=persistent,
201
+ )
202
+ else:
203
+ raise ValueError(f"Unsupported pos_embed_type: {pos_embed_type}")
204
+
205
+ def cropped_pos_embed(self, height, width):
206
+ """Crops positional embeddings for SD3 compatibility."""
207
+ if self.pos_embed_max_size is None:
208
+ raise ValueError("`pos_embed_max_size` must be set for cropping.")
209
+
210
+ height = height // self.patch_size
211
+ width = width // self.patch_size
212
+ if height > self.pos_embed_max_size:
213
+ raise ValueError(
214
+ f"Height ({height}) cannot be greater than `pos_embed_max_size`: {self.pos_embed_max_size}."
215
+ )
216
+ if width > self.pos_embed_max_size:
217
+ raise ValueError(
218
+ f"Width ({width}) cannot be greater than `pos_embed_max_size`: {self.pos_embed_max_size}."
219
+ )
220
+
221
+ top = (self.pos_embed_max_size - height) // 2
222
+ left = (self.pos_embed_max_size - width) // 2
223
+ spatial_pos_embed = self.pos_embed.reshape(
224
+ 1, self.pos_embed_max_size, self.pos_embed_max_size, -1
225
+ )
226
+ spatial_pos_embed = spatial_pos_embed[
227
+ :, top : top + height, left : left + width, :
228
+ ]
229
+ spatial_pos_embed = spatial_pos_embed.reshape(
230
+ 1, -1, spatial_pos_embed.shape[-1]
231
+ )
232
+ return spatial_pos_embed
233
+
234
+ def forward(self, latent):
235
+ if self.pos_embed_max_size is not None:
236
+ height, width = latent.shape[-2:]
237
+ else:
238
+ height, width = (
239
+ latent.shape[-2] // self.patch_size,
240
+ latent.shape[-1] // self.patch_size,
241
+ )
242
+
243
+ latent = self.proj(latent)
244
+ if self.flatten:
245
+ latent = latent.flatten(2).transpose(1, 2) # BCHW -> BNC
246
+ if self.layer_norm:
247
+ latent = self.norm(latent)
248
+ if self.pos_embed is None:
249
+ return latent.to(latent.dtype)
250
+ # Interpolate or crop positional embeddings as needed
251
+ if self.pos_embed_max_size:
252
+ pos_embed = self.cropped_pos_embed(height, width)
253
+ else:
254
+ if self.height != height or self.width != width:
255
+ pos_embed = get_2d_sincos_pos_embed(
256
+ embed_dim=self.pos_embed.shape[-1],
257
+ grid_size=(height, width),
258
+ base_size=self.base_size,
259
+ interpolation_scale=self.interpolation_scale,
260
+ )
261
+ pos_embed = (
262
+ torch.from_numpy(pos_embed).float().unsqueeze(0).to(latent.device)
263
+ )
264
+ else:
265
+ pos_embed = self.pos_embed
266
+
267
+ return (latent + pos_embed).to(latent.dtype)
268
+
269
+
270
+ def get_2d_rotary_pos_embed(embed_dim, crops_coords, grid_size, use_real=True):
271
+ """
272
+ RoPE for image tokens with 2d structure.
273
+
274
+ Args:
275
+ embed_dim: (`int`):
276
+ The embedding dimension size
277
+ crops_coords (`Tuple[int]`)
278
+ The top-left and bottom-right coordinates of the crop.
279
+ grid_size (`Tuple[int]`):
280
+ The grid size of the positional embedding.
281
+ use_real (`bool`):
282
+ If True, return real part and imaginary part separately. Otherwise, return complex numbers.
283
+
284
+ Returns:
285
+ `torch.Tensor`: positional embdding with shape `( grid_size * grid_size, embed_dim/2)`.
286
+ """
287
+ start, stop = crops_coords
288
+ grid_h = np.linspace(
289
+ start[0], stop[0], grid_size[0], endpoint=False, dtype=np.float32
290
+ )
291
+ grid_w = np.linspace(
292
+ start[1], stop[1], grid_size[1], endpoint=False, dtype=np.float32
293
+ )
294
+ grid = np.meshgrid(grid_w, grid_h) # here w goes first
295
+ grid = np.stack(grid, axis=0) # [2, W, H]
296
+
297
+ grid = grid.reshape([2, 1, *grid.shape[1:]])
298
+ pos_embed = get_2d_rotary_pos_embed_from_grid(embed_dim, grid, use_real=use_real)
299
+ return pos_embed
300
+
301
+
302
+ def get_2d_rotary_pos_embed_from_grid(embed_dim, grid, use_real=False):
303
+ assert embed_dim % 4 == 0
304
+
305
+ # use half of dimensions to encode grid_h
306
+ emb_h = get_1d_rotary_pos_embed(
307
+ embed_dim // 2, grid[0].reshape(-1), use_real=use_real
308
+ ) # (H*W, D/4)
309
+ emb_w = get_1d_rotary_pos_embed(
310
+ embed_dim // 2, grid[1].reshape(-1), use_real=use_real
311
+ ) # (H*W, D/4)
312
+
313
+ if use_real:
314
+ cos = torch.cat([emb_h[0], emb_w[0]], dim=1) # (H*W, D/2)
315
+ sin = torch.cat([emb_h[1], emb_w[1]], dim=1) # (H*W, D/2)
316
+ return cos, sin
317
+ else:
318
+ emb = torch.cat([emb_h, emb_w], dim=1) # (H*W, D/2)
319
+ return emb
320
+
321
+
322
+ def get_1d_rotary_pos_embed(
323
+ dim: int, pos: Union[np.ndarray, int], theta: float = 10000.0, use_real=False
324
+ ):
325
+ """
326
+ Precompute the frequency tensor for complex exponentials (cis) with given dimensions.
327
+
328
+ This function calculates a frequency tensor with complex exponentials using the given dimension 'dim' and the end
329
+ index 'end'. The 'theta' parameter scales the frequencies. The returned tensor contains complex values in complex64
330
+ data type.
331
+
332
+ Args:
333
+ dim (`int`): Dimension of the frequency tensor.
334
+ pos (`np.ndarray` or `int`): Position indices for the frequency tensor. [S] or scalar
335
+ theta (`float`, *optional*, defaults to 10000.0):
336
+ Scaling factor for frequency computation. Defaults to 10000.0.
337
+ use_real (`bool`, *optional*):
338
+ If True, return real part and imaginary part separately. Otherwise, return complex numbers.
339
+
340
+ Returns:
341
+ `torch.Tensor`: Precomputed frequency tensor with complex exponentials. [S, D/2]
342
+ """
343
+ if isinstance(pos, int):
344
+ pos = np.arange(pos)
345
+ freqs = 1.0 / (
346
+ theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)
347
+ ) # [D/2]
348
+ t = torch.from_numpy(pos).to(freqs.device) # type: ignore # [S]
349
+ freqs = torch.outer(t, freqs).float() # type: ignore # [S, D/2]
350
+ if use_real:
351
+ freqs_cos = freqs.cos().repeat_interleave(2, dim=1) # [S, D]
352
+ freqs_sin = freqs.sin().repeat_interleave(2, dim=1) # [S, D]
353
+ return freqs_cos, freqs_sin
354
+ else:
355
+ freqs_cis = torch.polar(
356
+ torch.ones_like(freqs), freqs
357
+ ) # complex64 # [S, D/2]
358
+ return freqs_cis
359
+
360
+
361
+ def apply_rotary_emb(
362
+ x: torch.Tensor,
363
+ freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]],
364
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
365
+ """
366
+ Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings
367
+ to the given query or key 'x' tensors using the provided frequency tensor 'freqs_cis'. The input tensors are
368
+ reshaped as complex numbers, and the frequency tensor is reshaped for broadcasting compatibility. The resulting
369
+ tensors contain rotary embeddings and are returned as real tensors.
370
+
371
+ Args:
372
+ x (`torch.Tensor`):
373
+ Query or key tensor to apply rotary embeddings. [B, H, S, D] xk (torch.Tensor): Key tensor to apply
374
+ freqs_cis (`Tuple[torch.Tensor]`): Precomputed frequency tensor for complex exponentials. ([S, D], [S, D],)
375
+
376
+ Returns:
377
+ Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
378
+ """
379
+ cos, sin = freqs_cis # [S, D]
380
+ cos = cos[None, None]
381
+ sin = sin[None, None]
382
+ cos, sin = cos.to(x.device), sin.to(x.device)
383
+
384
+ x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2]
385
+ x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
386
+ out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
387
+
388
+ return out
389
+
390
+
391
+ def rope(pos: torch.Tensor, dim: int, theta=10000.0) -> torch.Tensor:
392
+ assert dim % 2 == 0, "The dimension must be even."
393
+
394
+ scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim
395
+ omega = 1.0 / (theta**scale)
396
+
397
+ batch_size, seq_length = pos.shape
398
+ # (B, N, d/2)
399
+ out = torch.einsum("...n,d->...nd", pos, omega)
400
+ cos_out = torch.cos(out)
401
+ sin_out = torch.sin(out)
402
+
403
+ stacked_out = torch.stack([cos_out, -sin_out, sin_out, cos_out], dim=-1)
404
+ # (B, 1, N, d/2, 2, 2)
405
+ out = stacked_out.view(batch_size, 1, -1, dim // 2, 2, 2)
406
+ return out.float()
407
+
408
+
409
+ def apply_rope(x, freqs_cis):
410
+ # (B, num_heads, N, d/2, 1, 2)
411
+ x_ = x.float().reshape(*x.shape[:-1], -1, 1, 2)
412
+ # cos * q0 - sin * q1, sin * q0 + cos * q1
413
+ x_out = freqs_cis[..., 0] * x_[..., 0] + freqs_cis[..., 1] * x_[..., 1]
414
+ return x_out.reshape(*x.shape).type_as(x)
415
+
416
+
417
+ class TimestepEmbedding(nn.Module):
418
+ def __init__(
419
+ self,
420
+ in_channels: int,
421
+ time_embed_dim: int,
422
+ act_fn: str = "silu",
423
+ out_dim: int = None,
424
+ post_act_fn: Optional[str] = None,
425
+ cond_proj_dim=None,
426
+ sample_proj_bias=True,
427
+ ):
428
+ super().__init__()
429
+
430
+ self.linear_1 = nn.Linear(in_channels, time_embed_dim, sample_proj_bias)
431
+
432
+ if cond_proj_dim is not None:
433
+ self.cond_proj = nn.Linear(cond_proj_dim, in_channels, bias=False)
434
+ else:
435
+ self.cond_proj = None
436
+
437
+ self.act = get_activation(act_fn)
438
+
439
+ if out_dim is not None:
440
+ time_embed_dim_out = out_dim
441
+ else:
442
+ time_embed_dim_out = time_embed_dim
443
+ self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out, sample_proj_bias)
444
+
445
+ if post_act_fn is None:
446
+ self.post_act = None
447
+ else:
448
+ self.post_act = get_activation(post_act_fn)
449
+
450
+ def forward(self, sample, condition=None):
451
+ if condition is not None:
452
+ sample = sample + self.cond_proj(condition)
453
+ sample = self.linear_1(sample)
454
+
455
+ if self.act is not None:
456
+ sample = self.act(sample)
457
+
458
+ sample = self.linear_2(sample)
459
+
460
+ if self.post_act is not None:
461
+ sample = self.post_act(sample)
462
+ return sample
463
+
464
+
465
+ class Timesteps(nn.Module):
466
+ def __init__(
467
+ self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float
468
+ ):
469
+ super().__init__()
470
+ self.num_channels = num_channels
471
+ self.flip_sin_to_cos = flip_sin_to_cos
472
+ self.downscale_freq_shift = downscale_freq_shift
473
+
474
+ def forward(self, timesteps):
475
+ t_emb = get_timestep_embedding(
476
+ timesteps,
477
+ self.num_channels,
478
+ flip_sin_to_cos=self.flip_sin_to_cos,
479
+ downscale_freq_shift=self.downscale_freq_shift,
480
+ )
481
+ return t_emb
482
+
483
+
484
+ class GaussianFourierProjection(nn.Module):
485
+ """Gaussian Fourier embeddings for noise levels."""
486
+
487
+ def __init__(
488
+ self,
489
+ embedding_size: int = 256,
490
+ scale: float = 1.0,
491
+ set_W_to_weight=True,
492
+ log=True,
493
+ flip_sin_to_cos=False,
494
+ ):
495
+ super().__init__()
496
+ self.weight = nn.Parameter(
497
+ torch.randn(embedding_size) * scale, requires_grad=False
498
+ )
499
+ self.log = log
500
+ self.flip_sin_to_cos = flip_sin_to_cos
501
+
502
+ if set_W_to_weight:
503
+ # to delete later
504
+ self.W = nn.Parameter(
505
+ torch.randn(embedding_size) * scale, requires_grad=False
506
+ )
507
+
508
+ self.weight = self.W
509
+
510
+ def forward(self, x):
511
+ if self.log:
512
+ x = torch.log(x)
513
+
514
+ x_proj = x[:, None] * self.weight[None, :] * 2 * np.pi
515
+
516
+ if self.flip_sin_to_cos:
517
+ out = torch.cat([torch.cos(x_proj), torch.sin(x_proj)], dim=-1)
518
+ else:
519
+ out = torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1)
520
+ return out
521
+
522
+
523
+ class SinusoidalPositionalEmbedding(nn.Module):
524
+ """Apply positional information to a sequence of embeddings.
525
+
526
+ Takes in a sequence of embeddings with shape (batch_size, seq_length, embed_dim) and adds positional embeddings to
527
+ them
528
+
529
+ Args:
530
+ embed_dim: (int): Dimension of the positional embedding.
531
+ max_seq_length: Maximum sequence length to apply positional embeddings
532
+
533
+ """
534
+
535
+ def __init__(self, embed_dim: int, max_seq_length: int = 32):
536
+ super().__init__()
537
+ position = torch.arange(max_seq_length).unsqueeze(1)
538
+ div_term = torch.exp(
539
+ torch.arange(0, embed_dim, 2) * (-math.log(10000.0) / embed_dim)
540
+ )
541
+ pe = torch.zeros(1, max_seq_length, embed_dim)
542
+ pe[0, :, 0::2] = torch.sin(position * div_term)
543
+ pe[0, :, 1::2] = torch.cos(position * div_term)
544
+ self.register_buffer("pe", pe)
545
+
546
+ def forward(self, x):
547
+ _, seq_length, _ = x.shape
548
+ x = x + self.pe[:, :seq_length]
549
+ return x
550
+
551
+
552
+ class ImagePositionalEmbeddings(nn.Module):
553
+ """
554
+ Converts latent image classes into vector embeddings. Sums the vector embeddings with positional embeddings for the
555
+ height and width of the latent space.
556
+
557
+ For more details, see figure 10 of the dall-e paper: https://arxiv.org/abs/2102.12092
558
+
559
+ For VQ-diffusion:
560
+
561
+ Output vector embeddings are used as input for the transformer.
562
+
563
+ Note that the vector embeddings for the transformer are different than the vector embeddings from the VQVAE.
564
+
565
+ Args:
566
+ num_embed (`int`):
567
+ Number of embeddings for the latent pixels embeddings.
568
+ height (`int`):
569
+ Height of the latent image i.e. the number of height embeddings.
570
+ width (`int`):
571
+ Width of the latent image i.e. the number of width embeddings.
572
+ embed_dim (`int`):
573
+ Dimension of the produced vector embeddings. Used for the latent pixel, height, and width embeddings.
574
+ """
575
+
576
+ def __init__(
577
+ self,
578
+ num_embed: int,
579
+ height: int,
580
+ width: int,
581
+ embed_dim: int,
582
+ ):
583
+ super().__init__()
584
+
585
+ self.height = height
586
+ self.width = width
587
+ self.num_embed = num_embed
588
+ self.embed_dim = embed_dim
589
+
590
+ self.emb = nn.Embedding(self.num_embed, embed_dim)
591
+ self.height_emb = nn.Embedding(self.height, embed_dim)
592
+ self.width_emb = nn.Embedding(self.width, embed_dim)
593
+
594
+ def forward(self, index):
595
+ emb = self.emb(index)
596
+
597
+ height_emb = self.height_emb(
598
+ torch.arange(self.height, device=index.device).view(1, self.height)
599
+ )
600
+
601
+ # 1 x H x D -> 1 x H x 1 x D
602
+ height_emb = height_emb.unsqueeze(2)
603
+
604
+ width_emb = self.width_emb(
605
+ torch.arange(self.width, device=index.device).view(1, self.width)
606
+ )
607
+
608
+ # 1 x W x D -> 1 x 1 x W x D
609
+ width_emb = width_emb.unsqueeze(1)
610
+
611
+ pos_emb = height_emb + width_emb
612
+
613
+ # 1 x H x W x D -> 1 x L xD
614
+ pos_emb = pos_emb.view(1, self.height * self.width, -1)
615
+
616
+ emb = emb + pos_emb[:, : emb.shape[1], :]
617
+
618
+ return emb
619
+
620
+
621
+ class LabelEmbedding(nn.Module):
622
+ """
623
+ Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.
624
+
625
+ Args:
626
+ num_classes (`int`): The number of classes.
627
+ hidden_size (`int`): The size of the vector embeddings.
628
+ dropout_prob (`float`): The probability of dropping a label.
629
+ """
630
+
631
+ def __init__(self, num_classes, hidden_size, dropout_prob):
632
+ super().__init__()
633
+ use_cfg_embedding = dropout_prob > 0
634
+ self.embedding_table = nn.Embedding(
635
+ num_classes + use_cfg_embedding, hidden_size
636
+ )
637
+ self.num_classes = num_classes
638
+ self.dropout_prob = dropout_prob
639
+
640
+ def token_drop(self, labels, force_drop_ids=None):
641
+ """
642
+ Drops labels to enable classifier-free guidance.
643
+ """
644
+ if force_drop_ids is None:
645
+ drop_ids = (
646
+ torch.rand(labels.shape[0], device=labels.device) < self.dropout_prob
647
+ )
648
+ else:
649
+ drop_ids = torch.tensor(force_drop_ids == 1)
650
+ labels = torch.where(drop_ids, self.num_classes, labels)
651
+ return labels
652
+
653
+ def forward(self, labels: torch.LongTensor, force_drop_ids=None):
654
+ use_dropout = self.dropout_prob > 0
655
+ if (self.training and use_dropout) or (force_drop_ids is not None):
656
+ labels = self.token_drop(labels, force_drop_ids)
657
+ embeddings = self.embedding_table(labels)
658
+ return embeddings
659
+
660
+
661
+ class TextImageProjection(nn.Module):
662
+ def __init__(
663
+ self,
664
+ text_embed_dim: int = 1024,
665
+ image_embed_dim: int = 768,
666
+ cross_attention_dim: int = 768,
667
+ num_image_text_embeds: int = 10,
668
+ ):
669
+ super().__init__()
670
+
671
+ self.num_image_text_embeds = num_image_text_embeds
672
+ self.image_embeds = nn.Linear(
673
+ image_embed_dim, self.num_image_text_embeds * cross_attention_dim
674
+ )
675
+ self.text_proj = nn.Linear(text_embed_dim, cross_attention_dim)
676
+
677
+ def forward(self, text_embeds: torch.Tensor, image_embeds: torch.Tensor):
678
+ batch_size = text_embeds.shape[0]
679
+
680
+ # image
681
+ image_text_embeds = self.image_embeds(image_embeds)
682
+ image_text_embeds = image_text_embeds.reshape(
683
+ batch_size, self.num_image_text_embeds, -1
684
+ )
685
+
686
+ # text
687
+ text_embeds = self.text_proj(text_embeds)
688
+
689
+ return torch.cat([image_text_embeds, text_embeds], dim=1)
690
+
691
+
692
+ class ImageProjection(nn.Module):
693
+ def __init__(
694
+ self,
695
+ image_embed_dim: int = 768,
696
+ cross_attention_dim: int = 768,
697
+ num_image_text_embeds: int = 32,
698
+ ):
699
+ super().__init__()
700
+
701
+ self.num_image_text_embeds = num_image_text_embeds
702
+ self.image_embeds = nn.Linear(
703
+ image_embed_dim, self.num_image_text_embeds * cross_attention_dim
704
+ )
705
+ self.norm = nn.LayerNorm(cross_attention_dim)
706
+
707
+ def forward(self, image_embeds: torch.Tensor):
708
+ batch_size = image_embeds.shape[0]
709
+
710
+ # image
711
+ image_embeds = self.image_embeds(image_embeds)
712
+ image_embeds = image_embeds.reshape(batch_size, self.num_image_text_embeds, -1)
713
+ image_embeds = self.norm(image_embeds)
714
+ return image_embeds
715
+
716
+
717
+ class IPAdapterFullImageProjection(nn.Module):
718
+ def __init__(self, image_embed_dim=1024, cross_attention_dim=1024):
719
+ super().__init__()
720
+ from .attention import FeedForward
721
+
722
+ self.ff = FeedForward(
723
+ image_embed_dim, cross_attention_dim, mult=1, activation_fn="gelu"
724
+ )
725
+ self.norm = nn.LayerNorm(cross_attention_dim)
726
+
727
+ def forward(self, image_embeds: torch.Tensor):
728
+ return self.norm(self.ff(image_embeds))
729
+
730
+
731
+ class IPAdapterFaceIDImageProjection(nn.Module):
732
+ def __init__(
733
+ self, image_embed_dim=1024, cross_attention_dim=1024, mult=1, num_tokens=1
734
+ ):
735
+ super().__init__()
736
+ from .attention import FeedForward
737
+
738
+ self.num_tokens = num_tokens
739
+ self.cross_attention_dim = cross_attention_dim
740
+ self.ff = FeedForward(
741
+ image_embed_dim,
742
+ cross_attention_dim * num_tokens,
743
+ mult=mult,
744
+ activation_fn="gelu",
745
+ )
746
+ self.norm = nn.LayerNorm(cross_attention_dim)
747
+
748
+ def forward(self, image_embeds: torch.Tensor):
749
+ x = self.ff(image_embeds)
750
+ x = x.reshape(-1, self.num_tokens, self.cross_attention_dim)
751
+ return self.norm(x)
752
+
753
+
754
+ class CombinedTimestepLabelEmbeddings(nn.Module):
755
+ def __init__(self, num_classes, embedding_dim, class_dropout_prob=0.1):
756
+ super().__init__()
757
+
758
+ self.time_proj = Timesteps(
759
+ num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=1
760
+ )
761
+ self.timestep_embedder = TimestepEmbedding(
762
+ in_channels=256, time_embed_dim=embedding_dim
763
+ )
764
+ self.class_embedder = LabelEmbedding(
765
+ num_classes, embedding_dim, class_dropout_prob
766
+ )
767
+
768
+ def forward(self, timestep, class_labels, hidden_dtype=None):
769
+ timesteps_proj = self.time_proj(timestep)
770
+ timesteps_emb = self.timestep_embedder(
771
+ timesteps_proj.to(dtype=hidden_dtype)
772
+ ) # (N, D)
773
+
774
+ class_labels = self.class_embedder(class_labels) # (N, D)
775
+
776
+ conditioning = timesteps_emb + class_labels # (N, D)
777
+
778
+ return conditioning
779
+
780
+
781
+ class CombinedTimestepTextProjEmbeddings(nn.Module):
782
+ def __init__(self, embedding_dim, pooled_projection_dim):
783
+ super().__init__()
784
+
785
+ self.time_proj = Timesteps(
786
+ num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0
787
+ )
788
+ self.timestep_embedder = TimestepEmbedding(
789
+ in_channels=256, time_embed_dim=embedding_dim
790
+ )
791
+ self.text_embedder = PixArtAlphaTextProjection(
792
+ pooled_projection_dim, embedding_dim, act_fn="silu"
793
+ )
794
+
795
+ def forward(self, timestep, pooled_projection):
796
+ timesteps_proj = self.time_proj(timestep)
797
+ timesteps_emb = self.timestep_embedder(
798
+ timesteps_proj.to(dtype=pooled_projection.dtype)
799
+ ) # (N, D)
800
+
801
+ pooled_projections = self.text_embedder(pooled_projection)
802
+
803
+ conditioning = timesteps_emb + pooled_projections
804
+
805
+ return conditioning
806
+
807
+
808
+ class TimestepEmbeddings(nn.Module):
809
+ def __init__(self, embedding_dim):
810
+ super().__init__()
811
+
812
+ self.time_proj = Timesteps(
813
+ num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0
814
+ )
815
+ self.timestep_embedder = TimestepEmbedding(
816
+ in_channels=256, time_embed_dim=embedding_dim
817
+ )
818
+
819
+ def forward(self, timestep):
820
+ timesteps_proj = self.time_proj(timestep)
821
+ timesteps_emb = self.timestep_embedder(timesteps_proj) # (N, D)
822
+
823
+ conditioning = timesteps_emb
824
+
825
+ return conditioning
826
+
827
+
828
+ class HunyuanDiTAttentionPool(nn.Module):
829
+ # Copied from https://github.com/Tencent/HunyuanDiT/blob/cb709308d92e6c7e8d59d0dff41b74d35088db6a/hydit/modules/poolers.py#L6
830
+
831
+ def __init__(
832
+ self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None
833
+ ):
834
+ super().__init__()
835
+ self.positional_embedding = nn.Parameter(
836
+ torch.randn(spacial_dim + 1, embed_dim) / embed_dim**0.5
837
+ )
838
+ self.k_proj = nn.Linear(embed_dim, embed_dim)
839
+ self.q_proj = nn.Linear(embed_dim, embed_dim)
840
+ self.v_proj = nn.Linear(embed_dim, embed_dim)
841
+ self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
842
+ self.num_heads = num_heads
843
+
844
+ def forward(self, x):
845
+ x = x.permute(1, 0, 2) # NLC -> LNC
846
+ x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (L+1)NC
847
+ x = x + self.positional_embedding[:, None, :].to(x.dtype) # (L+1)NC
848
+ x, _ = F.multi_head_attention_forward(
849
+ query=x[:1],
850
+ key=x,
851
+ value=x,
852
+ embed_dim_to_check=x.shape[-1],
853
+ num_heads=self.num_heads,
854
+ q_proj_weight=self.q_proj.weight,
855
+ k_proj_weight=self.k_proj.weight,
856
+ v_proj_weight=self.v_proj.weight,
857
+ in_proj_weight=None,
858
+ in_proj_bias=torch.cat(
859
+ [self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]
860
+ ),
861
+ bias_k=None,
862
+ bias_v=None,
863
+ add_zero_attn=False,
864
+ dropout_p=0,
865
+ out_proj_weight=self.c_proj.weight,
866
+ out_proj_bias=self.c_proj.bias,
867
+ use_separate_proj_weight=True,
868
+ training=self.training,
869
+ need_weights=False,
870
+ )
871
+ return x.squeeze(0)
872
+
873
+
874
+ class HunyuanCombinedTimestepTextSizeStyleEmbedding(nn.Module):
875
+ def __init__(
876
+ self,
877
+ embedding_dim,
878
+ pooled_projection_dim=1024,
879
+ seq_len=256,
880
+ cross_attention_dim=2048,
881
+ ):
882
+ super().__init__()
883
+
884
+ self.time_proj = Timesteps(
885
+ num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0
886
+ )
887
+ self.timestep_embedder = TimestepEmbedding(
888
+ in_channels=256, time_embed_dim=embedding_dim
889
+ )
890
+
891
+ self.pooler = HunyuanDiTAttentionPool(
892
+ seq_len, cross_attention_dim, num_heads=8, output_dim=pooled_projection_dim
893
+ )
894
+ # Here we use a default learned embedder layer for future extension.
895
+ self.style_embedder = nn.Embedding(1, embedding_dim)
896
+ extra_in_dim = 256 * 6 + embedding_dim + pooled_projection_dim
897
+ self.extra_embedder = PixArtAlphaTextProjection(
898
+ in_features=extra_in_dim,
899
+ hidden_size=embedding_dim * 4,
900
+ out_features=embedding_dim,
901
+ act_fn="silu_fp32",
902
+ )
903
+
904
+ def forward(
905
+ self, timestep, encoder_hidden_states, image_meta_size, style, hidden_dtype=None
906
+ ):
907
+ timesteps_proj = self.time_proj(timestep)
908
+ timesteps_emb = self.timestep_embedder(
909
+ timesteps_proj.to(dtype=hidden_dtype)
910
+ ) # (N, 256)
911
+
912
+ # extra condition1: text
913
+ pooled_projections = self.pooler(encoder_hidden_states) # (N, 1024)
914
+
915
+ # extra condition2: image meta size embdding
916
+ image_meta_size = get_timestep_embedding(image_meta_size.view(-1), 256, True, 0)
917
+ image_meta_size = image_meta_size.to(dtype=hidden_dtype)
918
+ image_meta_size = image_meta_size.view(-1, 6 * 256) # (N, 1536)
919
+
920
+ # extra condition3: style embedding
921
+ style_embedding = self.style_embedder(style) # (N, embedding_dim)
922
+
923
+ # Concatenate all extra vectors
924
+ extra_cond = torch.cat(
925
+ [pooled_projections, image_meta_size, style_embedding], dim=1
926
+ )
927
+ conditioning = timesteps_emb + self.extra_embedder(extra_cond) # [B, D]
928
+
929
+ return conditioning
930
+
931
+
932
+ class TextTimeEmbedding(nn.Module):
933
+ def __init__(self, encoder_dim: int, time_embed_dim: int, num_heads: int = 64):
934
+ super().__init__()
935
+ self.norm1 = nn.LayerNorm(encoder_dim)
936
+ self.pool = AttentionPooling(num_heads, encoder_dim)
937
+ self.proj = nn.Linear(encoder_dim, time_embed_dim)
938
+ self.norm2 = nn.LayerNorm(time_embed_dim)
939
+
940
+ def forward(self, hidden_states):
941
+ hidden_states = self.norm1(hidden_states)
942
+ hidden_states = self.pool(hidden_states)
943
+ hidden_states = self.proj(hidden_states)
944
+ hidden_states = self.norm2(hidden_states)
945
+ return hidden_states
946
+
947
+
948
+ class TextImageTimeEmbedding(nn.Module):
949
+ def __init__(
950
+ self,
951
+ text_embed_dim: int = 768,
952
+ image_embed_dim: int = 768,
953
+ time_embed_dim: int = 1536,
954
+ ):
955
+ super().__init__()
956
+ self.text_proj = nn.Linear(text_embed_dim, time_embed_dim)
957
+ self.text_norm = nn.LayerNorm(time_embed_dim)
958
+ self.image_proj = nn.Linear(image_embed_dim, time_embed_dim)
959
+
960
+ def forward(self, text_embeds: torch.Tensor, image_embeds: torch.Tensor):
961
+ # text
962
+ time_text_embeds = self.text_proj(text_embeds)
963
+ time_text_embeds = self.text_norm(time_text_embeds)
964
+
965
+ # image
966
+ time_image_embeds = self.image_proj(image_embeds)
967
+
968
+ return time_image_embeds + time_text_embeds
969
+
970
+
971
+ class ImageTimeEmbedding(nn.Module):
972
+ def __init__(self, image_embed_dim: int = 768, time_embed_dim: int = 1536):
973
+ super().__init__()
974
+ self.image_proj = nn.Linear(image_embed_dim, time_embed_dim)
975
+ self.image_norm = nn.LayerNorm(time_embed_dim)
976
+
977
+ def forward(self, image_embeds: torch.Tensor):
978
+ # image
979
+ time_image_embeds = self.image_proj(image_embeds)
980
+ time_image_embeds = self.image_norm(time_image_embeds)
981
+ return time_image_embeds
982
+
983
+
984
+ class ImageHintTimeEmbedding(nn.Module):
985
+ def __init__(self, image_embed_dim: int = 768, time_embed_dim: int = 1536):
986
+ super().__init__()
987
+ self.image_proj = nn.Linear(image_embed_dim, time_embed_dim)
988
+ self.image_norm = nn.LayerNorm(time_embed_dim)
989
+ self.input_hint_block = nn.Sequential(
990
+ nn.Conv2d(3, 16, 3, padding=1),
991
+ nn.SiLU(),
992
+ nn.Conv2d(16, 16, 3, padding=1),
993
+ nn.SiLU(),
994
+ nn.Conv2d(16, 32, 3, padding=1, stride=2),
995
+ nn.SiLU(),
996
+ nn.Conv2d(32, 32, 3, padding=1),
997
+ nn.SiLU(),
998
+ nn.Conv2d(32, 96, 3, padding=1, stride=2),
999
+ nn.SiLU(),
1000
+ nn.Conv2d(96, 96, 3, padding=1),
1001
+ nn.SiLU(),
1002
+ nn.Conv2d(96, 256, 3, padding=1, stride=2),
1003
+ nn.SiLU(),
1004
+ nn.Conv2d(256, 4, 3, padding=1),
1005
+ )
1006
+
1007
+ def forward(self, image_embeds: torch.Tensor, hint: torch.Tensor):
1008
+ # image
1009
+ time_image_embeds = self.image_proj(image_embeds)
1010
+ time_image_embeds = self.image_norm(time_image_embeds)
1011
+ hint = self.input_hint_block(hint)
1012
+ return time_image_embeds, hint
1013
+
1014
+
1015
+ class AttentionPooling(nn.Module):
1016
+ # Copied from https://github.com/deep-floyd/IF/blob/2f91391f27dd3c468bf174be5805b4cc92980c0b/deepfloyd_if/model/nn.py#L54
1017
+
1018
+ def __init__(self, num_heads, embed_dim, dtype=None):
1019
+ super().__init__()
1020
+ self.dtype = dtype
1021
+ self.positional_embedding = nn.Parameter(
1022
+ torch.randn(1, embed_dim) / embed_dim**0.5
1023
+ )
1024
+ self.k_proj = nn.Linear(embed_dim, embed_dim, dtype=self.dtype)
1025
+ self.q_proj = nn.Linear(embed_dim, embed_dim, dtype=self.dtype)
1026
+ self.v_proj = nn.Linear(embed_dim, embed_dim, dtype=self.dtype)
1027
+ self.num_heads = num_heads
1028
+ self.dim_per_head = embed_dim // self.num_heads
1029
+
1030
+ def forward(self, x):
1031
+ bs, length, width = x.size()
1032
+
1033
+ def shape(x):
1034
+ # (bs, length, width) --> (bs, length, n_heads, dim_per_head)
1035
+ x = x.view(bs, -1, self.num_heads, self.dim_per_head)
1036
+ # (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head)
1037
+ x = x.transpose(1, 2)
1038
+ # (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head)
1039
+ x = x.reshape(bs * self.num_heads, -1, self.dim_per_head)
1040
+ # (bs*n_heads, length, dim_per_head) --> (bs*n_heads, dim_per_head, length)
1041
+ x = x.transpose(1, 2)
1042
+ return x
1043
+
1044
+ class_token = x.mean(dim=1, keepdim=True) + self.positional_embedding.to(
1045
+ x.dtype
1046
+ )
1047
+ x = torch.cat([class_token, x], dim=1) # (bs, length+1, width)
1048
+
1049
+ # (bs*n_heads, class_token_length, dim_per_head)
1050
+ q = shape(self.q_proj(class_token))
1051
+ # (bs*n_heads, length+class_token_length, dim_per_head)
1052
+ k = shape(self.k_proj(x))
1053
+ v = shape(self.v_proj(x))
1054
+
1055
+ # (bs*n_heads, class_token_length, length+class_token_length):
1056
+ scale = 1 / math.sqrt(math.sqrt(self.dim_per_head))
1057
+ weight = torch.einsum(
1058
+ "bct,bcs->bts", q * scale, k * scale
1059
+ ) # More stable with f16 than dividing afterwards
1060
+ weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
1061
+
1062
+ # (bs*n_heads, dim_per_head, class_token_length)
1063
+ a = torch.einsum("bts,bcs->bct", weight, v)
1064
+
1065
+ # (bs, length+1, width)
1066
+ a = a.reshape(bs, -1, 1).transpose(1, 2)
1067
+
1068
+ return a[:, 0, :] # cls_token
1069
+
1070
+
1071
+ def get_fourier_embeds_from_boundingbox(embed_dim, box):
1072
+ """
1073
+ Args:
1074
+ embed_dim: int
1075
+ box: a 3-D tensor [B x N x 4] representing the bounding boxes for GLIGEN pipeline
1076
+ Returns:
1077
+ [B x N x embed_dim] tensor of positional embeddings
1078
+ """
1079
+
1080
+ batch_size, num_boxes = box.shape[:2]
1081
+
1082
+ emb = 100 ** (torch.arange(embed_dim) / embed_dim)
1083
+ emb = emb[None, None, None].to(device=box.device, dtype=box.dtype)
1084
+ emb = emb * box.unsqueeze(-1)
1085
+
1086
+ emb = torch.stack((emb.sin(), emb.cos()), dim=-1)
1087
+ emb = emb.permute(0, 1, 3, 4, 2).reshape(batch_size, num_boxes, embed_dim * 2 * 4)
1088
+
1089
+ return emb
1090
+
1091
+
1092
+ class GLIGENTextBoundingboxProjection(nn.Module):
1093
+ def __init__(
1094
+ self, positive_len, out_dim, feature_type="text-only", fourier_freqs=8
1095
+ ):
1096
+ super().__init__()
1097
+ self.positive_len = positive_len
1098
+ self.out_dim = out_dim
1099
+
1100
+ self.fourier_embedder_dim = fourier_freqs
1101
+ self.position_dim = fourier_freqs * 2 * 4 # 2: sin/cos, 4: xyxy
1102
+
1103
+ if isinstance(out_dim, tuple):
1104
+ out_dim = out_dim[0]
1105
+
1106
+ if feature_type == "text-only":
1107
+ self.linears = nn.Sequential(
1108
+ nn.Linear(self.positive_len + self.position_dim, 512),
1109
+ nn.SiLU(),
1110
+ nn.Linear(512, 512),
1111
+ nn.SiLU(),
1112
+ nn.Linear(512, out_dim),
1113
+ )
1114
+ self.null_positive_feature = torch.nn.Parameter(
1115
+ torch.zeros([self.positive_len])
1116
+ )
1117
+
1118
+ elif feature_type == "text-image":
1119
+ self.linears_text = nn.Sequential(
1120
+ nn.Linear(self.positive_len + self.position_dim, 512),
1121
+ nn.SiLU(),
1122
+ nn.Linear(512, 512),
1123
+ nn.SiLU(),
1124
+ nn.Linear(512, out_dim),
1125
+ )
1126
+ self.linears_image = nn.Sequential(
1127
+ nn.Linear(self.positive_len + self.position_dim, 512),
1128
+ nn.SiLU(),
1129
+ nn.Linear(512, 512),
1130
+ nn.SiLU(),
1131
+ nn.Linear(512, out_dim),
1132
+ )
1133
+ self.null_text_feature = torch.nn.Parameter(
1134
+ torch.zeros([self.positive_len])
1135
+ )
1136
+ self.null_image_feature = torch.nn.Parameter(
1137
+ torch.zeros([self.positive_len])
1138
+ )
1139
+
1140
+ self.null_position_feature = torch.nn.Parameter(
1141
+ torch.zeros([self.position_dim])
1142
+ )
1143
+
1144
+ def forward(
1145
+ self,
1146
+ boxes,
1147
+ masks,
1148
+ positive_embeddings=None,
1149
+ phrases_masks=None,
1150
+ image_masks=None,
1151
+ phrases_embeddings=None,
1152
+ image_embeddings=None,
1153
+ ):
1154
+ masks = masks.unsqueeze(-1)
1155
+
1156
+ # embedding position (it may includes padding as placeholder)
1157
+ xyxy_embedding = get_fourier_embeds_from_boundingbox(
1158
+ self.fourier_embedder_dim, boxes
1159
+ ) # B*N*4 -> B*N*C
1160
+
1161
+ # learnable null embedding
1162
+ xyxy_null = self.null_position_feature.view(1, 1, -1)
1163
+
1164
+ # replace padding with learnable null embedding
1165
+ xyxy_embedding = xyxy_embedding * masks + (1 - masks) * xyxy_null
1166
+
1167
+ # positionet with text only information
1168
+ if positive_embeddings is not None:
1169
+ # learnable null embedding
1170
+ positive_null = self.null_positive_feature.view(1, 1, -1)
1171
+
1172
+ # replace padding with learnable null embedding
1173
+ positive_embeddings = (
1174
+ positive_embeddings * masks + (1 - masks) * positive_null
1175
+ )
1176
+
1177
+ objs = self.linears(
1178
+ torch.cat([positive_embeddings, xyxy_embedding], dim=-1)
1179
+ )
1180
+
1181
+ # positionet with text and image infomation
1182
+ else:
1183
+ phrases_masks = phrases_masks.unsqueeze(-1)
1184
+ image_masks = image_masks.unsqueeze(-1)
1185
+
1186
+ # learnable null embedding
1187
+ text_null = self.null_text_feature.view(1, 1, -1)
1188
+ image_null = self.null_image_feature.view(1, 1, -1)
1189
+
1190
+ # replace padding with learnable null embedding
1191
+ phrases_embeddings = (
1192
+ phrases_embeddings * phrases_masks + (1 - phrases_masks) * text_null
1193
+ )
1194
+ image_embeddings = (
1195
+ image_embeddings * image_masks + (1 - image_masks) * image_null
1196
+ )
1197
+
1198
+ objs_text = self.linears_text(
1199
+ torch.cat([phrases_embeddings, xyxy_embedding], dim=-1)
1200
+ )
1201
+ objs_image = self.linears_image(
1202
+ torch.cat([image_embeddings, xyxy_embedding], dim=-1)
1203
+ )
1204
+ objs = torch.cat([objs_text, objs_image], dim=1)
1205
+
1206
+ return objs
1207
+
1208
+
1209
+ class PixArtAlphaCombinedTimestepSizeEmbeddings(nn.Module):
1210
+ """
1211
+ For PixArt-Alpha.
1212
+
1213
+ Reference:
1214
+ https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L164C9-L168C29
1215
+ """
1216
+
1217
+ def __init__(
1218
+ self, embedding_dim, size_emb_dim, use_additional_conditions: bool = False
1219
+ ):
1220
+ super().__init__()
1221
+
1222
+ self.outdim = size_emb_dim
1223
+ self.time_proj = Timesteps(
1224
+ num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0
1225
+ )
1226
+ self.timestep_embedder = TimestepEmbedding(
1227
+ in_channels=256, time_embed_dim=embedding_dim
1228
+ )
1229
+
1230
+ self.use_additional_conditions = use_additional_conditions
1231
+ if use_additional_conditions:
1232
+ self.additional_condition_proj = Timesteps(
1233
+ num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0
1234
+ )
1235
+ self.resolution_embedder = TimestepEmbedding(
1236
+ in_channels=256, time_embed_dim=size_emb_dim
1237
+ )
1238
+ self.aspect_ratio_embedder = TimestepEmbedding(
1239
+ in_channels=256, time_embed_dim=size_emb_dim
1240
+ )
1241
+
1242
+ def forward(self, timestep, resolution, aspect_ratio, batch_size, hidden_dtype):
1243
+ timesteps_proj = self.time_proj(timestep)
1244
+ timesteps_emb = self.timestep_embedder(
1245
+ timesteps_proj.to(dtype=hidden_dtype)
1246
+ ) # (N, D)
1247
+
1248
+ if self.use_additional_conditions:
1249
+ resolution_emb = self.additional_condition_proj(resolution.flatten()).to(
1250
+ hidden_dtype
1251
+ )
1252
+ resolution_emb = self.resolution_embedder(resolution_emb).reshape(
1253
+ batch_size, -1
1254
+ )
1255
+ aspect_ratio_emb = self.additional_condition_proj(
1256
+ aspect_ratio.flatten()
1257
+ ).to(hidden_dtype)
1258
+ aspect_ratio_emb = self.aspect_ratio_embedder(aspect_ratio_emb).reshape(
1259
+ batch_size, -1
1260
+ )
1261
+ conditioning = timesteps_emb + torch.cat(
1262
+ [resolution_emb, aspect_ratio_emb], dim=1
1263
+ )
1264
+ else:
1265
+ conditioning = timesteps_emb
1266
+
1267
+ return conditioning
1268
+
1269
+
1270
+ class PixArtAlphaTextProjection(nn.Module):
1271
+ """
1272
+ Projects caption embeddings. Also handles dropout for classifier-free guidance.
1273
+
1274
+ Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/nets/PixArt_blocks.py
1275
+ """
1276
+
1277
+ def __init__(self, in_features, hidden_size, out_features=None, act_fn="gelu_tanh"):
1278
+ super().__init__()
1279
+ if out_features is None:
1280
+ out_features = hidden_size
1281
+ self.linear_1 = nn.Linear(
1282
+ in_features=in_features, out_features=hidden_size, bias=True
1283
+ )
1284
+ if act_fn == "gelu_tanh":
1285
+ self.act_1 = nn.GELU(approximate="tanh")
1286
+ elif act_fn == "silu":
1287
+ self.act_1 = nn.SiLU()
1288
+ elif act_fn == "silu_fp32":
1289
+ self.act_1 = FP32SiLU()
1290
+ else:
1291
+ raise ValueError(f"Unknown activation function: {act_fn}")
1292
+ self.linear_2 = nn.Linear(
1293
+ in_features=hidden_size, out_features=out_features, bias=True
1294
+ )
1295
+
1296
+ def forward(self, caption):
1297
+ hidden_states = self.linear_1(caption)
1298
+ hidden_states = self.act_1(hidden_states)
1299
+ hidden_states = self.linear_2(hidden_states)
1300
+ return hidden_states
1301
+
1302
+
1303
+ class IPAdapterPlusImageProjectionBlock(nn.Module):
1304
+ def __init__(
1305
+ self,
1306
+ embed_dims: int = 768,
1307
+ dim_head: int = 64,
1308
+ heads: int = 16,
1309
+ ffn_ratio: float = 4,
1310
+ ) -> None:
1311
+ super().__init__()
1312
+ from .attention import FeedForward
1313
+
1314
+ self.ln0 = nn.LayerNorm(embed_dims)
1315
+ self.ln1 = nn.LayerNorm(embed_dims)
1316
+ self.attn = Attention(
1317
+ query_dim=embed_dims,
1318
+ dim_head=dim_head,
1319
+ heads=heads,
1320
+ out_bias=False,
1321
+ )
1322
+ self.ff = nn.Sequential(
1323
+ nn.LayerNorm(embed_dims),
1324
+ FeedForward(
1325
+ embed_dims, embed_dims, activation_fn="gelu", mult=ffn_ratio, bias=False
1326
+ ),
1327
+ )
1328
+
1329
+ def forward(self, x, latents, residual):
1330
+ encoder_hidden_states = self.ln0(x)
1331
+ latents = self.ln1(latents)
1332
+ encoder_hidden_states = torch.cat([encoder_hidden_states, latents], dim=-2)
1333
+ latents = self.attn(latents, encoder_hidden_states) + residual
1334
+ latents = self.ff(latents) + latents
1335
+ return latents
1336
+
1337
+
1338
+ class IPAdapterPlusImageProjection(nn.Module):
1339
+ """Resampler of IP-Adapter Plus.
1340
+
1341
+ Args:
1342
+ embed_dims (int): The feature dimension. Defaults to 768. output_dims (int): The number of output channels,
1343
+ that is the same
1344
+ number of the channels in the `unet.config.cross_attention_dim`. Defaults to 1024.
1345
+ hidden_dims (int):
1346
+ The number of hidden channels. Defaults to 1280. depth (int): The number of blocks. Defaults
1347
+ to 8. dim_head (int): The number of head channels. Defaults to 64. heads (int): Parallel attention heads.
1348
+ Defaults to 16. num_queries (int):
1349
+ The number of queries. Defaults to 8. ffn_ratio (float): The expansion ratio
1350
+ of feedforward network hidden
1351
+ layer channels. Defaults to 4.
1352
+ """
1353
+
1354
+ def __init__(
1355
+ self,
1356
+ embed_dims: int = 768,
1357
+ output_dims: int = 1024,
1358
+ hidden_dims: int = 1280,
1359
+ depth: int = 4,
1360
+ dim_head: int = 64,
1361
+ heads: int = 16,
1362
+ num_queries: int = 8,
1363
+ ffn_ratio: float = 4,
1364
+ ) -> None:
1365
+ super().__init__()
1366
+ self.latents = nn.Parameter(
1367
+ torch.randn(1, num_queries, hidden_dims) / hidden_dims**0.5
1368
+ )
1369
+
1370
+ self.proj_in = nn.Linear(embed_dims, hidden_dims)
1371
+
1372
+ self.proj_out = nn.Linear(hidden_dims, output_dims)
1373
+ self.norm_out = nn.LayerNorm(output_dims)
1374
+
1375
+ self.layers = nn.ModuleList(
1376
+ [
1377
+ IPAdapterPlusImageProjectionBlock(
1378
+ hidden_dims, dim_head, heads, ffn_ratio
1379
+ )
1380
+ for _ in range(depth)
1381
+ ]
1382
+ )
1383
+
1384
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
1385
+ """Forward pass.
1386
+
1387
+ Args:
1388
+ x (torch.Tensor): Input Tensor.
1389
+ Returns:
1390
+ torch.Tensor: Output Tensor.
1391
+ """
1392
+ latents = self.latents.repeat(x.size(0), 1, 1)
1393
+
1394
+ x = self.proj_in(x)
1395
+
1396
+ for block in self.layers:
1397
+ residual = latents
1398
+ latents = block(x, latents, residual)
1399
+
1400
+ latents = self.proj_out(latents)
1401
+ return self.norm_out(latents)
1402
+
1403
+
1404
+ class IPAdapterFaceIDPlusImageProjection(nn.Module):
1405
+ """FacePerceiverResampler of IP-Adapter Plus.
1406
+
1407
+ Args:
1408
+ embed_dims (int): The feature dimension. Defaults to 768. output_dims (int): The number of output channels,
1409
+ that is the same
1410
+ number of the channels in the `unet.config.cross_attention_dim`. Defaults to 1024.
1411
+ hidden_dims (int):
1412
+ The number of hidden channels. Defaults to 1280. depth (int): The number of blocks. Defaults
1413
+ to 8. dim_head (int): The number of head channels. Defaults to 64. heads (int): Parallel attention heads.
1414
+ Defaults to 16. num_tokens (int): Number of tokens num_queries (int): The number of queries. Defaults to 8.
1415
+ ffn_ratio (float): The expansion ratio of feedforward network hidden
1416
+ layer channels. Defaults to 4.
1417
+ ffproj_ratio (float): The expansion ratio of feedforward network hidden
1418
+ layer channels (for ID embeddings). Defaults to 4.
1419
+ """
1420
+
1421
+ def __init__(
1422
+ self,
1423
+ embed_dims: int = 768,
1424
+ output_dims: int = 768,
1425
+ hidden_dims: int = 1280,
1426
+ id_embeddings_dim: int = 512,
1427
+ depth: int = 4,
1428
+ dim_head: int = 64,
1429
+ heads: int = 16,
1430
+ num_tokens: int = 4,
1431
+ num_queries: int = 8,
1432
+ ffn_ratio: float = 4,
1433
+ ffproj_ratio: int = 2,
1434
+ ) -> None:
1435
+ super().__init__()
1436
+ from .attention import FeedForward
1437
+
1438
+ self.num_tokens = num_tokens
1439
+ self.embed_dim = embed_dims
1440
+ self.clip_embeds = None
1441
+ self.shortcut = False
1442
+ self.shortcut_scale = 1.0
1443
+
1444
+ self.proj = FeedForward(
1445
+ id_embeddings_dim,
1446
+ embed_dims * num_tokens,
1447
+ activation_fn="gelu",
1448
+ mult=ffproj_ratio,
1449
+ )
1450
+ self.norm = nn.LayerNorm(embed_dims)
1451
+
1452
+ self.proj_in = nn.Linear(hidden_dims, embed_dims)
1453
+
1454
+ self.proj_out = nn.Linear(embed_dims, output_dims)
1455
+ self.norm_out = nn.LayerNorm(output_dims)
1456
+
1457
+ self.layers = nn.ModuleList(
1458
+ [
1459
+ IPAdapterPlusImageProjectionBlock(
1460
+ embed_dims, dim_head, heads, ffn_ratio
1461
+ )
1462
+ for _ in range(depth)
1463
+ ]
1464
+ )
1465
+
1466
+ def forward(self, id_embeds: torch.Tensor) -> torch.Tensor:
1467
+ """Forward pass.
1468
+
1469
+ Args:
1470
+ id_embeds (torch.Tensor): Input Tensor (ID embeds).
1471
+ Returns:
1472
+ torch.Tensor: Output Tensor.
1473
+ """
1474
+ id_embeds = id_embeds.to(self.clip_embeds.dtype)
1475
+ id_embeds = self.proj(id_embeds)
1476
+ id_embeds = id_embeds.reshape(-1, self.num_tokens, self.embed_dim)
1477
+ id_embeds = self.norm(id_embeds)
1478
+ latents = id_embeds
1479
+
1480
+ clip_embeds = self.proj_in(self.clip_embeds)
1481
+ x = clip_embeds.reshape(-1, clip_embeds.shape[2], clip_embeds.shape[3])
1482
+
1483
+ for block in self.layers:
1484
+ residual = latents
1485
+ latents = block(x, latents, residual)
1486
+
1487
+ latents = self.proj_out(latents)
1488
+ out = self.norm_out(latents)
1489
+ if self.shortcut:
1490
+ out = id_embeds + self.shortcut_scale * out
1491
+ return out
1492
+
1493
+
1494
+ class MultiIPAdapterImageProjection(nn.Module):
1495
+ def __init__(
1496
+ self, IPAdapterImageProjectionLayers: Union[List[nn.Module], Tuple[nn.Module]]
1497
+ ):
1498
+ super().__init__()
1499
+ self.image_projection_layers = nn.ModuleList(IPAdapterImageProjectionLayers)
1500
+
1501
+ def forward(self, image_embeds: List[torch.Tensor]):
1502
+ projected_image_embeds = []
1503
+
1504
+ # currently, we accept `image_embeds` as
1505
+ # 1. a tensor (deprecated) with shape [batch_size, embed_dim] or [batch_size, sequence_length, embed_dim]
1506
+ # 2. list of `n` tensors where `n` is number of ip-adapters, each tensor can hae shape [batch_size, num_images, embed_dim] or [batch_size, num_images, sequence_length, embed_dim]
1507
+ if not isinstance(image_embeds, list):
1508
+ deprecation_message = (
1509
+ "You have passed a tensor as `image_embeds`.This is deprecated and will be removed in a future release."
1510
+ " Please make sure to update your script to pass `image_embeds` as a list of tensors to supress this warning."
1511
+ )
1512
+ deprecate(
1513
+ "image_embeds not a list",
1514
+ "1.0.0",
1515
+ deprecation_message,
1516
+ standard_warn=False,
1517
+ )
1518
+ image_embeds = [image_embeds.unsqueeze(1)]
1519
+
1520
+ if len(image_embeds) != len(self.image_projection_layers):
1521
+ raise ValueError(
1522
+ f"image_embeds must have the same length as image_projection_layers, got {len(image_embeds)} and {len(self.image_projection_layers)}"
1523
+ )
1524
+
1525
+ for image_embed, image_projection_layer in zip(
1526
+ image_embeds, self.image_projection_layers
1527
+ ):
1528
+ batch_size, num_images = image_embed.shape[0], image_embed.shape[1]
1529
+ image_embed = image_embed.reshape(
1530
+ (batch_size * num_images,) + image_embed.shape[2:]
1531
+ )
1532
+ image_embed = image_projection_layer(image_embed)
1533
+ image_embed = image_embed.reshape(
1534
+ (batch_size, num_images) + image_embed.shape[1:]
1535
+ )
1536
+
1537
+ projected_image_embeds.append(image_embed)
1538
+
1539
+ return projected_image_embeds
models/resnet.py ADDED
@@ -0,0 +1,508 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import partial
2
+ from typing import Optional, Tuple, Union
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+
8
+ from diffusers.utils import deprecate
9
+ from diffusers.models.activations import get_activation
10
+ from diffusers.models.attention_processor import SpatialNorm
11
+ from diffusers.models.downsampling import ( # noqa
12
+ Downsample2D,
13
+ downsample_2d,
14
+ )
15
+ from diffusers.models.normalization import AdaGroupNorm
16
+ from diffusers.models.upsampling import ( # noqa
17
+ Upsample2D,
18
+ upsample_2d,
19
+ )
20
+
21
+
22
+ class ResnetBlock2D(nn.Module):
23
+ r"""
24
+ A Resnet block.
25
+
26
+ Parameters:
27
+ in_channels (`int`): The number of channels in the input.
28
+ out_channels (`int`, *optional*, default to be `None`):
29
+ The number of output channels for the first conv2d layer. If None, same as `in_channels`.
30
+ dropout (`float`, *optional*, defaults to `0.0`): The dropout probability to use.
31
+ temb_channels (`int`, *optional*, default to `512`): the number of channels in timestep embedding.
32
+ groups (`int`, *optional*, default to `32`): The number of groups to use for the first normalization layer.
33
+ groups_out (`int`, *optional*, default to None):
34
+ The number of groups to use for the second normalization layer. if set to None, same as `groups`.
35
+ eps (`float`, *optional*, defaults to `1e-6`): The epsilon to use for the normalization.
36
+ non_linearity (`str`, *optional*, default to `"swish"`): the activation function to use.
37
+ time_embedding_norm (`str`, *optional*, default to `"default"` ): Time scale shift config.
38
+ By default, apply timestep embedding conditioning with a simple shift mechanism. Choose "scale_shift"
39
+ for a stronger conditioning with scale and shift.
40
+ kernel (`torch.FloatTensor`, optional, default to None): FIR filter, see
41
+ [`~models.resnet.FirUpsample2D`] and [`~models.resnet.FirDownsample2D`].
42
+ output_scale_factor (`float`, *optional*, default to be `1.0`): the scale factor to use for the output.
43
+ use_in_shortcut (`bool`, *optional*, default to `True`):
44
+ If `True`, add a 1x1 nn.conv2d layer for skip-connection.
45
+ up (`bool`, *optional*, default to `False`): If `True`, add an upsample layer.
46
+ down (`bool`, *optional*, default to `False`): If `True`, add a downsample layer.
47
+ conv_shortcut_bias (`bool`, *optional*, default to `True`): If `True`, adds a learnable bias to the
48
+ `conv_shortcut` output.
49
+ conv_2d_out_channels (`int`, *optional*, default to `None`): the number of channels in the output.
50
+ If None, same as `out_channels`.
51
+ """
52
+
53
+ def __init__(
54
+ self,
55
+ *,
56
+ in_channels: int,
57
+ out_channels: Optional[int] = None,
58
+ conv_shortcut: bool = False,
59
+ dropout: float = 0.0,
60
+ temb_channels: int = 512,
61
+ groups: int = 32,
62
+ groups_out: Optional[int] = None,
63
+ pre_norm: bool = True,
64
+ eps: float = 1e-6,
65
+ non_linearity: str = "swish",
66
+ skip_time_act: bool = False,
67
+ time_embedding_norm: str = "default", # default, scale_shift,
68
+ kernel: Optional[torch.FloatTensor] = None,
69
+ output_scale_factor: float = 1.0,
70
+ use_in_shortcut: Optional[bool] = None,
71
+ up: bool = False,
72
+ down: bool = False,
73
+ conv_shortcut_bias: bool = True,
74
+ conv_2d_out_channels: Optional[int] = None,
75
+ ):
76
+ super().__init__()
77
+ if time_embedding_norm == "ada_group":
78
+ raise ValueError(
79
+ "This class cannot be used with `time_embedding_norm==ada_group`, please use `ResnetBlockCondNorm2D` instead",
80
+ )
81
+ if time_embedding_norm == "spatial":
82
+ raise ValueError(
83
+ "This class cannot be used with `time_embedding_norm==spatial`, please use `ResnetBlockCondNorm2D` instead",
84
+ )
85
+
86
+ self.pre_norm = True
87
+ self.in_channels = in_channels
88
+ out_channels = in_channels if out_channels is None else out_channels
89
+ self.out_channels = out_channels
90
+ self.use_conv_shortcut = conv_shortcut
91
+ self.up = up
92
+ self.down = down
93
+ self.output_scale_factor = output_scale_factor
94
+ self.time_embedding_norm = time_embedding_norm
95
+ self.skip_time_act = skip_time_act
96
+
97
+ linear_cls = nn.Linear
98
+ conv_cls = nn.Conv2d
99
+
100
+ if groups_out is None:
101
+ groups_out = groups
102
+
103
+ self.norm1 = torch.nn.GroupNorm(
104
+ num_groups=groups, num_channels=in_channels, eps=eps, affine=True
105
+ )
106
+
107
+ self.conv1 = conv_cls(
108
+ in_channels, out_channels, kernel_size=3, stride=1, padding=1
109
+ )
110
+
111
+ if temb_channels is not None:
112
+ if self.time_embedding_norm == "default":
113
+ self.time_emb_proj = linear_cls(temb_channels, out_channels)
114
+ elif self.time_embedding_norm == "scale_shift":
115
+ self.time_emb_proj = linear_cls(temb_channels, 2 * out_channels)
116
+ else:
117
+ raise ValueError(
118
+ f"unknown time_embedding_norm : {self.time_embedding_norm} "
119
+ )
120
+ else:
121
+ self.time_emb_proj = None
122
+
123
+ self.norm2 = torch.nn.GroupNorm(
124
+ num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True
125
+ )
126
+
127
+ self.dropout = torch.nn.Dropout(dropout)
128
+ conv_2d_out_channels = conv_2d_out_channels or out_channels
129
+ self.conv2 = conv_cls(
130
+ out_channels, conv_2d_out_channels, kernel_size=3, stride=1, padding=1
131
+ )
132
+
133
+ self.nonlinearity = get_activation(non_linearity)
134
+
135
+ self.upsample = self.downsample = None
136
+ if self.up:
137
+ if kernel == "fir":
138
+ fir_kernel = (1, 3, 3, 1)
139
+ self.upsample = lambda x: upsample_2d(x, kernel=fir_kernel)
140
+ elif kernel == "sde_vp":
141
+ self.upsample = partial(F.interpolate, scale_factor=2.0, mode="nearest")
142
+ else:
143
+ self.upsample = Upsample2D(in_channels, use_conv=False)
144
+ elif self.down:
145
+ if kernel == "fir":
146
+ fir_kernel = (1, 3, 3, 1)
147
+ self.downsample = lambda x: downsample_2d(x, kernel=fir_kernel)
148
+ elif kernel == "sde_vp":
149
+ self.downsample = partial(F.avg_pool2d, kernel_size=2, stride=2)
150
+ else:
151
+ self.downsample = Downsample2D(
152
+ in_channels, use_conv=False, padding=1, name="op"
153
+ )
154
+
155
+ self.use_in_shortcut = (
156
+ self.in_channels != conv_2d_out_channels
157
+ if use_in_shortcut is None
158
+ else use_in_shortcut
159
+ )
160
+
161
+ self.conv_shortcut = None
162
+ if self.use_in_shortcut:
163
+ self.conv_shortcut = conv_cls(
164
+ in_channels,
165
+ conv_2d_out_channels,
166
+ kernel_size=1,
167
+ stride=1,
168
+ padding=0,
169
+ bias=conv_shortcut_bias,
170
+ )
171
+
172
+ def forward(
173
+ self, input_tensor: torch.FloatTensor, temb: torch.FloatTensor, *args, **kwargs
174
+ ) -> torch.FloatTensor:
175
+ if len(args) > 0 or kwargs.get("scale", None) is not None:
176
+ deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
177
+ deprecate("scale", "1.0.0", deprecation_message)
178
+
179
+ hidden_states = input_tensor
180
+
181
+ hidden_states = self.norm1(hidden_states)
182
+ hidden_states = self.nonlinearity(hidden_states)
183
+
184
+ if self.upsample is not None:
185
+ # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
186
+ if hidden_states.shape[0] >= 64:
187
+ input_tensor = input_tensor.contiguous()
188
+ hidden_states = hidden_states.contiguous()
189
+ input_tensor = self.upsample(input_tensor)
190
+ hidden_states = self.upsample(hidden_states)
191
+ elif self.downsample is not None:
192
+ input_tensor = self.downsample(input_tensor)
193
+ hidden_states = self.downsample(hidden_states)
194
+
195
+ hidden_states = self.conv1(hidden_states)
196
+
197
+ if self.time_emb_proj is not None:
198
+ if not self.skip_time_act:
199
+ temb = self.nonlinearity(temb)
200
+ temb = self.time_emb_proj(temb)[:, :, None, None]
201
+
202
+ if self.time_embedding_norm == "default":
203
+ if temb is not None:
204
+ hidden_states = hidden_states + temb
205
+ hidden_states = self.norm2(hidden_states)
206
+ elif self.time_embedding_norm == "scale_shift":
207
+ if temb is None:
208
+ raise ValueError(
209
+ f" `temb` should not be None when `time_embedding_norm` is {self.time_embedding_norm}"
210
+ )
211
+ time_scale, time_shift = torch.chunk(temb, 2, dim=1)
212
+ hidden_states = self.norm2(hidden_states)
213
+ hidden_states = hidden_states * (1 + time_scale) + time_shift
214
+ else:
215
+ hidden_states = self.norm2(hidden_states)
216
+
217
+ hidden_states = self.nonlinearity(hidden_states)
218
+
219
+ hidden_states = self.dropout(hidden_states)
220
+ hidden_states = self.conv2(hidden_states)
221
+
222
+ if self.conv_shortcut is not None:
223
+ input_tensor = self.conv_shortcut(input_tensor)
224
+
225
+ output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
226
+
227
+ return output_tensor
228
+
229
+
230
+ class TemporalResnetBlock(nn.Module):
231
+ r"""
232
+ A Resnet block.
233
+
234
+ Parameters:
235
+ in_channels (`int`): The number of channels in the input.
236
+ out_channels (`int`, *optional*, default to be `None`):
237
+ The number of output channels for the first conv2d layer. If None, same as `in_channels`.
238
+ temb_channels (`int`, *optional*, default to `512`): the number of channels in timestep embedding.
239
+ eps (`float`, *optional*, defaults to `1e-6`): The epsilon to use for the normalization.
240
+ """
241
+
242
+ def __init__(
243
+ self,
244
+ in_channels: int,
245
+ out_channels: Optional[int] = None,
246
+ temb_channels: int = 512,
247
+ eps: float = 1e-6,
248
+ ):
249
+ super().__init__()
250
+ self.in_channels = in_channels
251
+ out_channels = in_channels if out_channels is None else out_channels
252
+ self.out_channels = out_channels
253
+
254
+ kernel_size = (3, 1, 1)
255
+ padding = [k // 2 for k in kernel_size]
256
+
257
+ self.norm1 = torch.nn.GroupNorm(
258
+ num_groups=32, num_channels=in_channels, eps=eps, affine=True
259
+ )
260
+ self.conv1 = nn.Conv3d(
261
+ in_channels,
262
+ out_channels,
263
+ kernel_size=kernel_size,
264
+ stride=1,
265
+ padding=padding,
266
+ )
267
+
268
+ if temb_channels is not None:
269
+ self.time_emb_proj = nn.Linear(temb_channels, out_channels)
270
+ else:
271
+ self.time_emb_proj = None
272
+
273
+ self.norm2 = torch.nn.GroupNorm(
274
+ num_groups=32, num_channels=out_channels, eps=eps, affine=True
275
+ )
276
+
277
+ self.dropout = torch.nn.Dropout(0.0)
278
+ self.conv2 = nn.Conv3d(
279
+ out_channels,
280
+ out_channels,
281
+ kernel_size=kernel_size,
282
+ stride=1,
283
+ padding=padding,
284
+ )
285
+
286
+ self.nonlinearity = get_activation("silu")
287
+
288
+ self.use_in_shortcut = self.in_channels != out_channels
289
+
290
+ self.conv_shortcut = None
291
+ if self.use_in_shortcut:
292
+ self.conv_shortcut = nn.Conv3d(
293
+ in_channels,
294
+ out_channels,
295
+ kernel_size=1,
296
+ stride=1,
297
+ padding=0,
298
+ )
299
+
300
+ def forward(
301
+ self, input_tensor: torch.FloatTensor, temb: torch.FloatTensor
302
+ ) -> torch.FloatTensor:
303
+ hidden_states = input_tensor
304
+
305
+ hidden_states = self.norm1(hidden_states)
306
+ hidden_states = self.nonlinearity(hidden_states)
307
+ hidden_states = self.conv1(hidden_states)
308
+
309
+ if self.time_emb_proj is not None:
310
+ temb = self.nonlinearity(temb)
311
+ temb = self.time_emb_proj(temb)[:, :, :, None, None]
312
+ temb = temb.permute(0, 2, 1, 3, 4)
313
+ hidden_states = hidden_states + temb
314
+
315
+ hidden_states = self.norm2(hidden_states)
316
+ hidden_states = self.nonlinearity(hidden_states)
317
+ hidden_states = self.dropout(hidden_states)
318
+ hidden_states = self.conv2(hidden_states)
319
+
320
+ if self.conv_shortcut is not None:
321
+ input_tensor = self.conv_shortcut(input_tensor)
322
+
323
+ output_tensor = input_tensor + hidden_states
324
+
325
+ return output_tensor
326
+
327
+
328
+ # VideoResBlock
329
+ class SpatioTemporalResBlock(nn.Module):
330
+ r"""
331
+ A SpatioTemporal Resnet block.
332
+
333
+ Parameters:
334
+ in_channels (`int`): The number of channels in the input.
335
+ out_channels (`int`, *optional*, default to be `None`):
336
+ The number of output channels for the first conv2d layer. If None, same as `in_channels`.
337
+ temb_channels (`int`, *optional*, default to `512`): the number of channels in timestep embedding.
338
+ eps (`float`, *optional*, defaults to `1e-6`): The epsilon to use for the spatial resenet.
339
+ temporal_eps (`float`, *optional*, defaults to `eps`): The epsilon to use for the temporal resnet.
340
+ merge_factor (`float`, *optional*, defaults to `0.5`): The merge factor to use for the temporal mixing.
341
+ merge_strategy (`str`, *optional*, defaults to `learned_with_images`):
342
+ The merge strategy to use for the temporal mixing.
343
+ switch_spatial_to_temporal_mix (`bool`, *optional*, defaults to `False`):
344
+ If `True`, switch the spatial and temporal mixing.
345
+ """
346
+
347
+ def __init__(
348
+ self,
349
+ in_channels: int,
350
+ out_channels: Optional[int] = None,
351
+ temb_channels: int = 512,
352
+ eps: float = 1e-6,
353
+ temporal_eps: Optional[float] = None,
354
+ merge_factor: float = 0.5,
355
+ merge_strategy="learned_with_images",
356
+ switch_spatial_to_temporal_mix: bool = False,
357
+ ):
358
+ super().__init__()
359
+
360
+ self.spatial_res_block = ResnetBlock2D(
361
+ in_channels=in_channels,
362
+ out_channels=out_channels,
363
+ temb_channels=temb_channels,
364
+ eps=eps,
365
+ )
366
+
367
+ self.temporal_res_block = TemporalResnetBlock(
368
+ in_channels=out_channels if out_channels is not None else in_channels,
369
+ out_channels=out_channels if out_channels is not None else in_channels,
370
+ temb_channels=temb_channels,
371
+ eps=temporal_eps if temporal_eps is not None else eps,
372
+ )
373
+
374
+ self.time_mixer = AlphaBlender(
375
+ alpha=merge_factor,
376
+ merge_strategy=merge_strategy,
377
+ switch_spatial_to_temporal_mix=switch_spatial_to_temporal_mix,
378
+ )
379
+
380
+ def forward(
381
+ self,
382
+ hidden_states: torch.FloatTensor,
383
+ temb: Optional[torch.FloatTensor] = None,
384
+ image_only_indicator: Optional[torch.Tensor] = None,
385
+ ):
386
+ num_frames = image_only_indicator.shape[-1]
387
+ hidden_states = self.spatial_res_block(hidden_states, temb)
388
+
389
+ batch_frames, channels, height, width = hidden_states.shape
390
+ batch_size = batch_frames // num_frames
391
+
392
+ hidden_states_mix = (
393
+ hidden_states[None, :]
394
+ .reshape(batch_size, num_frames, channels, height, width)
395
+ .permute(0, 2, 1, 3, 4)
396
+ )
397
+ hidden_states = (
398
+ hidden_states[None, :]
399
+ .reshape(batch_size, num_frames, channels, height, width)
400
+ .permute(0, 2, 1, 3, 4)
401
+ )
402
+
403
+ if temb is not None:
404
+ temb = temb.reshape(batch_size, num_frames, -1)
405
+
406
+ hidden_states = self.temporal_res_block(hidden_states, temb)
407
+ hidden_states = self.time_mixer(
408
+ x_spatial=hidden_states_mix,
409
+ x_temporal=hidden_states,
410
+ image_only_indicator=image_only_indicator,
411
+ )
412
+
413
+ hidden_states = hidden_states.permute(0, 2, 1, 3, 4).reshape(
414
+ batch_frames, channels, height, width
415
+ )
416
+ return hidden_states
417
+
418
+
419
+ class AlphaBlender(nn.Module):
420
+ r"""
421
+ A module to blend spatial and temporal features.
422
+
423
+ Parameters:
424
+ alpha (`float`): The initial value of the blending factor.
425
+ merge_strategy (`str`, *optional*, defaults to `learned_with_images`):
426
+ The merge strategy to use for the temporal mixing.
427
+ switch_spatial_to_temporal_mix (`bool`, *optional*, defaults to `False`):
428
+ If `True`, switch the spatial and temporal mixing.
429
+ """
430
+
431
+ strategies = ["learned", "fixed", "learned_with_images"]
432
+
433
+ def __init__(
434
+ self,
435
+ alpha: float,
436
+ merge_strategy: str = "learned_with_images",
437
+ switch_spatial_to_temporal_mix: bool = False,
438
+ ):
439
+ super().__init__()
440
+ self.merge_strategy = merge_strategy
441
+ self.switch_spatial_to_temporal_mix = (
442
+ switch_spatial_to_temporal_mix # For TemporalVAE
443
+ )
444
+
445
+ if merge_strategy not in self.strategies:
446
+ raise ValueError(f"merge_strategy needs to be in {self.strategies}")
447
+
448
+ if self.merge_strategy == "fixed":
449
+ self.register_buffer("mix_factor", torch.Tensor([alpha]))
450
+ elif (
451
+ self.merge_strategy == "learned"
452
+ or self.merge_strategy == "learned_with_images"
453
+ ):
454
+ self.register_parameter(
455
+ "mix_factor", torch.nn.Parameter(torch.Tensor([alpha]))
456
+ )
457
+ else:
458
+ raise ValueError(f"Unknown merge strategy {self.merge_strategy}")
459
+
460
+ def get_alpha(self, image_only_indicator: torch.Tensor, ndims: int) -> torch.Tensor:
461
+ if self.merge_strategy == "fixed":
462
+ alpha = self.mix_factor
463
+
464
+ elif self.merge_strategy == "learned":
465
+ alpha = torch.sigmoid(self.mix_factor)
466
+
467
+ elif self.merge_strategy == "learned_with_images":
468
+ if image_only_indicator is None:
469
+ raise ValueError(
470
+ "Please provide image_only_indicator to use learned_with_images merge strategy"
471
+ )
472
+
473
+ alpha = torch.where(
474
+ image_only_indicator.bool(),
475
+ torch.ones(1, 1, device=image_only_indicator.device),
476
+ torch.sigmoid(self.mix_factor)[..., None],
477
+ )
478
+
479
+ # (batch, channel, frames, height, width)
480
+ if ndims == 5:
481
+ alpha = alpha[:, None, :, None, None]
482
+ # (batch*frames, height*width, channels)
483
+ elif ndims == 3:
484
+ alpha = alpha.reshape(-1)[:, None, None]
485
+ else:
486
+ raise ValueError(
487
+ f"Unexpected ndims {ndims}. Dimensions should be 3 or 5"
488
+ )
489
+
490
+ else:
491
+ raise NotImplementedError
492
+
493
+ return alpha
494
+
495
+ def forward(
496
+ self,
497
+ x_spatial: torch.Tensor,
498
+ x_temporal: torch.Tensor,
499
+ image_only_indicator: Optional[torch.Tensor] = None,
500
+ ) -> torch.Tensor:
501
+ alpha = self.get_alpha(image_only_indicator, x_spatial.ndim)
502
+ alpha = alpha.to(x_spatial.dtype)
503
+
504
+ if self.switch_spatial_to_temporal_mix:
505
+ alpha = 1.0 - alpha
506
+
507
+ x = alpha * x_spatial + (1.0 - alpha) * x_temporal
508
+ return x
models/transformers/transformer_temporal_rope.py ADDED
@@ -0,0 +1,230 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from dataclasses import dataclass
15
+ from typing import Any, Dict, Optional
16
+
17
+ import torch
18
+ from torch import nn
19
+
20
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
21
+ from diffusers.utils import BaseOutput
22
+ from diffusers.models.embeddings import TimestepEmbedding, Timesteps
23
+ from diffusers.models.modeling_utils import ModelMixin
24
+ from ..resnet import AlphaBlender
25
+ from ..attention import (
26
+ BasicTransformerBlock,
27
+ TemporalRopeBasicTransformerBlock,
28
+ )
29
+ from ..embeddings import rope
30
+
31
+
32
+ @dataclass
33
+ class TransformerTemporalModelOutput(BaseOutput):
34
+ """
35
+ The output of [`TransformerTemporalModel`].
36
+
37
+ Args:
38
+ sample (`torch.FloatTensor` of shape `(batch_size x num_frames, num_channels, height, width)`):
39
+ The hidden states output conditioned on `encoder_hidden_states` input.
40
+ """
41
+
42
+ sample: torch.FloatTensor
43
+
44
+
45
+ class TransformerSpatioTemporalModel(nn.Module):
46
+ """
47
+ A Transformer model for video-like data.
48
+
49
+ Parameters:
50
+ num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
51
+ attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
52
+ in_channels (`int`, *optional*):
53
+ The number of channels in the input and output (specify if the input is **continuous**).
54
+ out_channels (`int`, *optional*):
55
+ The number of channels in the output (specify if the input is **continuous**).
56
+ num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
57
+ cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
58
+ """
59
+
60
+ def __init__(
61
+ self,
62
+ num_attention_heads: int = 16,
63
+ attention_head_dim: int = 88,
64
+ in_channels: int = 320,
65
+ out_channels: Optional[int] = None,
66
+ num_layers: int = 1,
67
+ cross_attention_dim: Optional[int] = None,
68
+ ):
69
+ super().__init__()
70
+ self.num_attention_heads = num_attention_heads
71
+ self.attention_head_dim = attention_head_dim
72
+
73
+ inner_dim = num_attention_heads * attention_head_dim
74
+ self.inner_dim = inner_dim
75
+
76
+ # 2. Define input layers
77
+ self.in_channels = in_channels
78
+ self.norm = torch.nn.GroupNorm(
79
+ num_groups=32, num_channels=in_channels, eps=1e-6
80
+ )
81
+ self.proj_in = nn.Linear(in_channels, inner_dim)
82
+
83
+ # 3. Define transformers blocks
84
+ self.transformer_blocks = nn.ModuleList(
85
+ [
86
+ BasicTransformerBlock(
87
+ inner_dim,
88
+ num_attention_heads,
89
+ attention_head_dim,
90
+ cross_attention_dim=cross_attention_dim,
91
+ )
92
+ for d in range(num_layers)
93
+ ]
94
+ )
95
+
96
+ time_mix_inner_dim = inner_dim
97
+ self.temporal_transformer_blocks = nn.ModuleList(
98
+ [
99
+ TemporalRopeBasicTransformerBlock(
100
+ inner_dim,
101
+ time_mix_inner_dim,
102
+ num_attention_heads,
103
+ attention_head_dim,
104
+ cross_attention_dim=cross_attention_dim,
105
+ )
106
+ for _ in range(num_layers)
107
+ ]
108
+ )
109
+
110
+ self.time_mixer = AlphaBlender(alpha=0.5, merge_strategy="learned_with_images")
111
+
112
+ # 4. Define output layers
113
+ self.out_channels = in_channels if out_channels is None else out_channels
114
+ # TODO: should use out_channels for continuous projections
115
+ self.proj_out = nn.Linear(inner_dim, in_channels)
116
+
117
+ self.gradient_checkpointing = False
118
+
119
+ def forward(
120
+ self,
121
+ hidden_states: torch.Tensor,
122
+ encoder_hidden_states: Optional[torch.Tensor] = None,
123
+ image_only_indicator: Optional[torch.Tensor] = None,
124
+ return_dict: bool = True,
125
+ position_ids: Optional[torch.Tensor] = None,
126
+ ):
127
+ """
128
+ Args:
129
+ hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`):
130
+ Input hidden_states.
131
+ num_frames (`int`):
132
+ The number of frames to be processed per batch. This is used to reshape the hidden states.
133
+ encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*):
134
+ Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
135
+ self-attention.
136
+ image_only_indicator (`torch.LongTensor` of shape `(batch size, num_frames)`, *optional*):
137
+ A tensor indicating whether the input contains only images. 1 indicates that the input contains only
138
+ images, 0 indicates that the input contains video frames.
139
+ return_dict (`bool`, *optional*, defaults to `True`):
140
+ Whether or not to return a [`~models.transformer_temporal.TransformerTemporalModelOutput`] instead of a plain
141
+ tuple.
142
+
143
+ Returns:
144
+ [`~models.transformer_temporal.TransformerTemporalModelOutput`] or `tuple`:
145
+ If `return_dict` is True, an [`~models.transformer_temporal.TransformerTemporalModelOutput`] is
146
+ returned, otherwise a `tuple` where the first element is the sample tensor.
147
+ """
148
+ # 1. Input
149
+ batch_frames, _, height, width = hidden_states.shape
150
+ num_frames = image_only_indicator.shape[-1]
151
+ batch_size = batch_frames // num_frames
152
+
153
+ # (B*F, 1, C)
154
+ time_context = encoder_hidden_states
155
+ # (B, 1, C)
156
+ time_context_first_timestep = time_context[None, :].reshape(
157
+ batch_size, num_frames, -1, time_context.shape[-1]
158
+ )[:, 0]
159
+
160
+ # (B*N, 1, C)
161
+ time_context = time_context_first_timestep.repeat_interleave(
162
+ height * width, dim=0
163
+ )
164
+
165
+ residual = hidden_states
166
+
167
+ hidden_states = self.norm(hidden_states)
168
+ inner_dim = hidden_states.shape[1]
169
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(
170
+ batch_frames, height * width, inner_dim
171
+ )
172
+ hidden_states = self.proj_in(hidden_states)
173
+
174
+ if position_ids is None:
175
+ # (B, F)
176
+ frame_rotary_emb = torch.arange(num_frames, device=hidden_states.device)
177
+ frame_rotary_emb = frame_rotary_emb[None, :].repeat(batch_size, 1)
178
+ else:
179
+ frame_rotary_emb = position_ids
180
+
181
+ # (B, 1, F, d/2, 2, 2)
182
+ frame_rotary_emb = rope(frame_rotary_emb, self.attention_head_dim)
183
+ # (B*N, 1, F, d/2, 2, 2)
184
+ frame_rotary_emb = frame_rotary_emb.repeat_interleave(height * width, dim=0)
185
+
186
+ # 2. Blocks
187
+ for block, temporal_block in zip(
188
+ self.transformer_blocks, self.temporal_transformer_blocks
189
+ ):
190
+ if self.training and self.gradient_checkpointing:
191
+ hidden_states = torch.utils.checkpoint.checkpoint(
192
+ block,
193
+ hidden_states,
194
+ None,
195
+ encoder_hidden_states,
196
+ None,
197
+ use_reentrant=False,
198
+ )
199
+ else:
200
+ hidden_states = block(
201
+ hidden_states,
202
+ encoder_hidden_states=encoder_hidden_states,
203
+ )
204
+
205
+ hidden_states_mix = temporal_block(
206
+ hidden_states,
207
+ num_frames=num_frames,
208
+ encoder_hidden_states=time_context,
209
+ frame_rotary_emb=frame_rotary_emb,
210
+ )
211
+ hidden_states = self.time_mixer(
212
+ x_spatial=hidden_states,
213
+ x_temporal=hidden_states_mix,
214
+ image_only_indicator=image_only_indicator,
215
+ )
216
+
217
+ # 3. Output
218
+ hidden_states = self.proj_out(hidden_states)
219
+ hidden_states = (
220
+ hidden_states.reshape(batch_frames, height, width, inner_dim)
221
+ .permute(0, 3, 1, 2)
222
+ .contiguous()
223
+ )
224
+
225
+ output = hidden_states + residual
226
+
227
+ if not return_dict:
228
+ return (output,)
229
+
230
+ return TransformerTemporalModelOutput(sample=output)
models/unets/unet_3d_rope_blocks.py ADDED
@@ -0,0 +1,682 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, Optional, Tuple, Union
2
+
3
+ import torch
4
+ from torch import nn
5
+
6
+ from diffusers.utils import is_torch_version, logging
7
+ from ..resnet import (
8
+ Downsample2D,
9
+ SpatioTemporalResBlock,
10
+ Upsample2D,
11
+ )
12
+
13
+ from ..transformers.transformer_temporal_rope import TransformerSpatioTemporalModel
14
+
15
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
16
+
17
+
18
+ def get_down_block(
19
+ down_block_type: str,
20
+ num_layers: int,
21
+ in_channels: int,
22
+ out_channels: int,
23
+ temb_channels: int,
24
+ add_downsample: bool,
25
+ resnet_eps: float,
26
+ resnet_act_fn: str,
27
+ num_attention_heads: int,
28
+ resnet_groups: Optional[int] = None,
29
+ cross_attention_dim: Optional[int] = None,
30
+ downsample_padding: Optional[int] = None,
31
+ dual_cross_attention: bool = False,
32
+ use_linear_projection: bool = True,
33
+ only_cross_attention: bool = False,
34
+ upcast_attention: bool = False,
35
+ resnet_time_scale_shift: str = "default",
36
+ temporal_num_attention_heads: int = 8,
37
+ temporal_max_seq_length: int = 32,
38
+ transformer_layers_per_block: int = 1,
39
+ ) -> Union[
40
+ "DownBlockSpatioTemporal",
41
+ "CrossAttnDownBlockSpatioTemporal",
42
+ ]:
43
+ if down_block_type == "DownBlockSpatioTemporal":
44
+ # added for SDV
45
+ return DownBlockSpatioTemporal(
46
+ num_layers=num_layers,
47
+ in_channels=in_channels,
48
+ out_channels=out_channels,
49
+ temb_channels=temb_channels,
50
+ add_downsample=add_downsample,
51
+ )
52
+ elif down_block_type == "CrossAttnDownBlockSpatioTemporal":
53
+ # added for SDV
54
+ if cross_attention_dim is None:
55
+ raise ValueError(
56
+ "cross_attention_dim must be specified for CrossAttnDownBlockSpatioTemporal"
57
+ )
58
+ return CrossAttnDownBlockSpatioTemporal(
59
+ in_channels=in_channels,
60
+ out_channels=out_channels,
61
+ temb_channels=temb_channels,
62
+ num_layers=num_layers,
63
+ transformer_layers_per_block=transformer_layers_per_block,
64
+ add_downsample=add_downsample,
65
+ cross_attention_dim=cross_attention_dim,
66
+ num_attention_heads=num_attention_heads,
67
+ )
68
+
69
+ raise ValueError(f"{down_block_type} does not exist.")
70
+
71
+
72
+ def get_up_block(
73
+ up_block_type: str,
74
+ num_layers: int,
75
+ in_channels: int,
76
+ out_channels: int,
77
+ prev_output_channel: int,
78
+ temb_channels: int,
79
+ add_upsample: bool,
80
+ resnet_eps: float,
81
+ resnet_act_fn: str,
82
+ num_attention_heads: int,
83
+ resolution_idx: Optional[int] = None,
84
+ resnet_groups: Optional[int] = None,
85
+ cross_attention_dim: Optional[int] = None,
86
+ dual_cross_attention: bool = False,
87
+ use_linear_projection: bool = True,
88
+ only_cross_attention: bool = False,
89
+ upcast_attention: bool = False,
90
+ resnet_time_scale_shift: str = "default",
91
+ temporal_num_attention_heads: int = 8,
92
+ temporal_cross_attention_dim: Optional[int] = None,
93
+ temporal_max_seq_length: int = 32,
94
+ transformer_layers_per_block: int = 1,
95
+ dropout: float = 0.0,
96
+ ) -> Union[
97
+ "UpBlockSpatioTemporal",
98
+ "CrossAttnUpBlockSpatioTemporal",
99
+ ]:
100
+ if up_block_type == "UpBlockSpatioTemporal":
101
+ # added for SDV
102
+ return UpBlockSpatioTemporal(
103
+ num_layers=num_layers,
104
+ in_channels=in_channels,
105
+ out_channels=out_channels,
106
+ prev_output_channel=prev_output_channel,
107
+ temb_channels=temb_channels,
108
+ resolution_idx=resolution_idx,
109
+ add_upsample=add_upsample,
110
+ )
111
+ elif up_block_type == "CrossAttnUpBlockSpatioTemporal":
112
+ # added for SDV
113
+ if cross_attention_dim is None:
114
+ raise ValueError(
115
+ "cross_attention_dim must be specified for CrossAttnUpBlockSpatioTemporal"
116
+ )
117
+ return CrossAttnUpBlockSpatioTemporal(
118
+ in_channels=in_channels,
119
+ out_channels=out_channels,
120
+ prev_output_channel=prev_output_channel,
121
+ temb_channels=temb_channels,
122
+ num_layers=num_layers,
123
+ transformer_layers_per_block=transformer_layers_per_block,
124
+ add_upsample=add_upsample,
125
+ cross_attention_dim=cross_attention_dim,
126
+ num_attention_heads=num_attention_heads,
127
+ resolution_idx=resolution_idx,
128
+ )
129
+
130
+ raise ValueError(f"{up_block_type} does not exist.")
131
+
132
+
133
+ class UNetMidBlockSpatioTemporal(nn.Module):
134
+ def __init__(
135
+ self,
136
+ in_channels: int,
137
+ temb_channels: int,
138
+ num_layers: int = 1,
139
+ transformer_layers_per_block: Union[int, Tuple[int]] = 1,
140
+ num_attention_heads: int = 1,
141
+ cross_attention_dim: int = 1280,
142
+ ):
143
+ super().__init__()
144
+
145
+ self.has_cross_attention = True
146
+ self.num_attention_heads = num_attention_heads
147
+
148
+ # support for variable transformer layers per block
149
+ if isinstance(transformer_layers_per_block, int):
150
+ transformer_layers_per_block = [transformer_layers_per_block] * num_layers
151
+
152
+ # there is always at least one resnet
153
+ resnets = [
154
+ SpatioTemporalResBlock(
155
+ in_channels=in_channels,
156
+ out_channels=in_channels,
157
+ temb_channels=temb_channels,
158
+ eps=1e-5,
159
+ )
160
+ ]
161
+ attentions = []
162
+
163
+ for i in range(num_layers):
164
+ attentions.append(
165
+ TransformerSpatioTemporalModel(
166
+ num_attention_heads,
167
+ in_channels // num_attention_heads,
168
+ in_channels=in_channels,
169
+ num_layers=transformer_layers_per_block[i],
170
+ cross_attention_dim=cross_attention_dim,
171
+ )
172
+ )
173
+
174
+ resnets.append(
175
+ SpatioTemporalResBlock(
176
+ in_channels=in_channels,
177
+ out_channels=in_channels,
178
+ temb_channels=temb_channels,
179
+ eps=1e-5,
180
+ )
181
+ )
182
+
183
+ self.attentions = nn.ModuleList(attentions)
184
+ self.resnets = nn.ModuleList(resnets)
185
+
186
+ self.gradient_checkpointing = False
187
+
188
+ def forward(
189
+ self,
190
+ hidden_states: torch.FloatTensor,
191
+ temb: Optional[torch.FloatTensor] = None,
192
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
193
+ image_only_indicator: Optional[torch.Tensor] = None,
194
+ position_ids: Optional[torch.Tensor] = None,
195
+ ) -> torch.FloatTensor:
196
+ hidden_states = self.resnets[0](
197
+ hidden_states,
198
+ temb,
199
+ image_only_indicator=image_only_indicator,
200
+ )
201
+
202
+ for attn, resnet in zip(self.attentions, self.resnets[1:]):
203
+ if self.training and self.gradient_checkpointing: # TODO
204
+
205
+ def create_custom_forward(module, return_dict=None):
206
+ def custom_forward(*inputs):
207
+ if return_dict is not None:
208
+ return module(*inputs, return_dict=return_dict)
209
+ else:
210
+ return module(*inputs)
211
+
212
+ return custom_forward
213
+
214
+ ckpt_kwargs: Dict[str, Any] = (
215
+ {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
216
+ )
217
+ hidden_states = attn(
218
+ hidden_states,
219
+ encoder_hidden_states=encoder_hidden_states,
220
+ image_only_indicator=image_only_indicator,
221
+ return_dict=False,
222
+ position_ids=position_ids,
223
+ )[0]
224
+ hidden_states = torch.utils.checkpoint.checkpoint(
225
+ create_custom_forward(resnet),
226
+ hidden_states,
227
+ temb,
228
+ image_only_indicator,
229
+ **ckpt_kwargs,
230
+ )
231
+ else:
232
+ hidden_states = attn(
233
+ hidden_states,
234
+ encoder_hidden_states=encoder_hidden_states,
235
+ image_only_indicator=image_only_indicator,
236
+ return_dict=False,
237
+ position_ids=position_ids,
238
+ )[0]
239
+ hidden_states = resnet(
240
+ hidden_states,
241
+ temb,
242
+ image_only_indicator=image_only_indicator,
243
+ )
244
+
245
+ return hidden_states
246
+
247
+
248
+ class DownBlockSpatioTemporal(nn.Module):
249
+ def __init__(
250
+ self,
251
+ in_channels: int,
252
+ out_channels: int,
253
+ temb_channels: int,
254
+ num_layers: int = 1,
255
+ add_downsample: bool = True,
256
+ ):
257
+ super().__init__()
258
+ resnets = []
259
+
260
+ for i in range(num_layers):
261
+ in_channels = in_channels if i == 0 else out_channels
262
+ resnets.append(
263
+ SpatioTemporalResBlock(
264
+ in_channels=in_channels,
265
+ out_channels=out_channels,
266
+ temb_channels=temb_channels,
267
+ eps=1e-5,
268
+ )
269
+ )
270
+
271
+ self.resnets = nn.ModuleList(resnets)
272
+
273
+ if add_downsample:
274
+ self.downsamplers = nn.ModuleList(
275
+ [
276
+ Downsample2D(
277
+ out_channels,
278
+ use_conv=True,
279
+ out_channels=out_channels,
280
+ name="op",
281
+ )
282
+ ]
283
+ )
284
+ else:
285
+ self.downsamplers = None
286
+
287
+ self.gradient_checkpointing = False
288
+
289
+ def forward(
290
+ self,
291
+ hidden_states: torch.FloatTensor,
292
+ temb: Optional[torch.FloatTensor] = None,
293
+ image_only_indicator: Optional[torch.Tensor] = None,
294
+ position_ids=None,
295
+ ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
296
+ output_states = ()
297
+ for resnet in self.resnets:
298
+ if self.training and self.gradient_checkpointing:
299
+
300
+ def create_custom_forward(module):
301
+ def custom_forward(*inputs):
302
+ return module(*inputs)
303
+
304
+ return custom_forward
305
+
306
+ if is_torch_version(">=", "1.11.0"):
307
+ hidden_states = torch.utils.checkpoint.checkpoint(
308
+ create_custom_forward(resnet),
309
+ hidden_states,
310
+ temb,
311
+ image_only_indicator,
312
+ use_reentrant=False,
313
+ )
314
+ else:
315
+ hidden_states = torch.utils.checkpoint.checkpoint(
316
+ create_custom_forward(resnet),
317
+ hidden_states,
318
+ temb,
319
+ image_only_indicator,
320
+ )
321
+ else:
322
+ hidden_states = resnet(
323
+ hidden_states,
324
+ temb,
325
+ image_only_indicator=image_only_indicator,
326
+ )
327
+
328
+ output_states = output_states + (hidden_states,)
329
+
330
+ if self.downsamplers is not None:
331
+ for downsampler in self.downsamplers:
332
+ hidden_states = downsampler(hidden_states)
333
+
334
+ output_states = output_states + (hidden_states,)
335
+
336
+ return hidden_states, output_states
337
+
338
+
339
+ class CrossAttnDownBlockSpatioTemporal(nn.Module):
340
+ def __init__(
341
+ self,
342
+ in_channels: int,
343
+ out_channels: int,
344
+ temb_channels: int,
345
+ num_layers: int = 1,
346
+ transformer_layers_per_block: Union[int, Tuple[int]] = 1,
347
+ num_attention_heads: int = 1,
348
+ cross_attention_dim: int = 1280,
349
+ add_downsample: bool = True,
350
+ ):
351
+ super().__init__()
352
+ resnets = []
353
+ attentions = []
354
+
355
+ self.has_cross_attention = True
356
+ self.num_attention_heads = num_attention_heads
357
+ if isinstance(transformer_layers_per_block, int):
358
+ transformer_layers_per_block = [transformer_layers_per_block] * num_layers
359
+
360
+ for i in range(num_layers):
361
+ in_channels = in_channels if i == 0 else out_channels
362
+ resnets.append(
363
+ SpatioTemporalResBlock(
364
+ in_channels=in_channels,
365
+ out_channels=out_channels,
366
+ temb_channels=temb_channels,
367
+ eps=1e-6,
368
+ )
369
+ )
370
+ attentions.append(
371
+ TransformerSpatioTemporalModel(
372
+ num_attention_heads,
373
+ out_channels // num_attention_heads,
374
+ in_channels=out_channels,
375
+ num_layers=transformer_layers_per_block[i],
376
+ cross_attention_dim=cross_attention_dim,
377
+ )
378
+ )
379
+
380
+ self.attentions = nn.ModuleList(attentions)
381
+ self.resnets = nn.ModuleList(resnets)
382
+
383
+ if add_downsample:
384
+ self.downsamplers = nn.ModuleList(
385
+ [
386
+ Downsample2D(
387
+ out_channels,
388
+ use_conv=True,
389
+ out_channels=out_channels,
390
+ padding=1,
391
+ name="op",
392
+ )
393
+ ]
394
+ )
395
+ else:
396
+ self.downsamplers = None
397
+
398
+ self.gradient_checkpointing = False
399
+
400
+ def forward(
401
+ self,
402
+ hidden_states: torch.FloatTensor,
403
+ temb: Optional[torch.FloatTensor] = None,
404
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
405
+ image_only_indicator: Optional[torch.Tensor] = None,
406
+ position_ids=None,
407
+ ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
408
+ output_states = ()
409
+
410
+ blocks = list(zip(self.resnets, self.attentions))
411
+ for resnet, attn in blocks:
412
+ if self.training and self.gradient_checkpointing: # TODO
413
+
414
+ def create_custom_forward(module, return_dict=None):
415
+ def custom_forward(*inputs):
416
+ if return_dict is not None:
417
+ return module(*inputs, return_dict=return_dict)
418
+ else:
419
+ return module(*inputs)
420
+
421
+ return custom_forward
422
+
423
+ ckpt_kwargs: Dict[str, Any] = (
424
+ {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
425
+ )
426
+ hidden_states = torch.utils.checkpoint.checkpoint(
427
+ create_custom_forward(resnet),
428
+ hidden_states,
429
+ temb,
430
+ image_only_indicator,
431
+ **ckpt_kwargs,
432
+ )
433
+
434
+ hidden_states = attn(
435
+ hidden_states,
436
+ encoder_hidden_states=encoder_hidden_states,
437
+ image_only_indicator=image_only_indicator,
438
+ return_dict=False,
439
+ position_ids=position_ids,
440
+ )[0]
441
+ else:
442
+ hidden_states = resnet(
443
+ hidden_states,
444
+ temb,
445
+ image_only_indicator=image_only_indicator,
446
+ )
447
+ hidden_states = attn(
448
+ hidden_states,
449
+ encoder_hidden_states=encoder_hidden_states,
450
+ image_only_indicator=image_only_indicator,
451
+ return_dict=False,
452
+ position_ids=position_ids,
453
+ )[0]
454
+
455
+ output_states = output_states + (hidden_states,)
456
+
457
+ if self.downsamplers is not None:
458
+ for downsampler in self.downsamplers:
459
+ hidden_states = downsampler(hidden_states)
460
+
461
+ output_states = output_states + (hidden_states,)
462
+
463
+ return hidden_states, output_states
464
+
465
+
466
+ class UpBlockSpatioTemporal(nn.Module):
467
+ def __init__(
468
+ self,
469
+ in_channels: int,
470
+ prev_output_channel: int,
471
+ out_channels: int,
472
+ temb_channels: int,
473
+ resolution_idx: Optional[int] = None,
474
+ num_layers: int = 1,
475
+ resnet_eps: float = 1e-6,
476
+ add_upsample: bool = True,
477
+ ):
478
+ super().__init__()
479
+ resnets = []
480
+
481
+ for i in range(num_layers):
482
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
483
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
484
+
485
+ resnets.append(
486
+ SpatioTemporalResBlock(
487
+ in_channels=resnet_in_channels + res_skip_channels,
488
+ out_channels=out_channels,
489
+ temb_channels=temb_channels,
490
+ eps=resnet_eps,
491
+ )
492
+ )
493
+
494
+ self.resnets = nn.ModuleList(resnets)
495
+
496
+ if add_upsample:
497
+ self.upsamplers = nn.ModuleList(
498
+ [Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]
499
+ )
500
+ else:
501
+ self.upsamplers = None
502
+
503
+ self.gradient_checkpointing = False
504
+ self.resolution_idx = resolution_idx
505
+
506
+ def forward(
507
+ self,
508
+ hidden_states: torch.FloatTensor,
509
+ res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
510
+ temb: Optional[torch.FloatTensor] = None,
511
+ upsample_size: Optional[int] = None,
512
+ image_only_indicator: Optional[torch.Tensor] = None,
513
+ position_ids: Optional[torch.Tensor] = None,
514
+ ) -> torch.FloatTensor:
515
+ for resnet in self.resnets:
516
+ # pop res hidden states
517
+ res_hidden_states = res_hidden_states_tuple[-1]
518
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
519
+
520
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
521
+
522
+ if self.training and self.gradient_checkpointing:
523
+
524
+ def create_custom_forward(module):
525
+ def custom_forward(*inputs):
526
+ return module(*inputs)
527
+
528
+ return custom_forward
529
+
530
+ if is_torch_version(">=", "1.11.0"):
531
+ hidden_states = torch.utils.checkpoint.checkpoint(
532
+ create_custom_forward(resnet),
533
+ hidden_states,
534
+ temb,
535
+ image_only_indicator,
536
+ use_reentrant=False,
537
+ )
538
+ else:
539
+ hidden_states = torch.utils.checkpoint.checkpoint(
540
+ create_custom_forward(resnet),
541
+ hidden_states,
542
+ temb,
543
+ image_only_indicator,
544
+ )
545
+ else:
546
+ hidden_states = resnet(
547
+ hidden_states,
548
+ temb,
549
+ image_only_indicator=image_only_indicator,
550
+ )
551
+
552
+ if self.upsamplers is not None:
553
+ for upsampler in self.upsamplers:
554
+ hidden_states = upsampler(hidden_states, upsample_size)
555
+
556
+ return hidden_states
557
+
558
+
559
+ class CrossAttnUpBlockSpatioTemporal(nn.Module):
560
+ def __init__(
561
+ self,
562
+ in_channels: int,
563
+ out_channels: int,
564
+ prev_output_channel: int,
565
+ temb_channels: int,
566
+ resolution_idx: Optional[int] = None,
567
+ num_layers: int = 1,
568
+ transformer_layers_per_block: Union[int, Tuple[int]] = 1,
569
+ resnet_eps: float = 1e-6,
570
+ num_attention_heads: int = 1,
571
+ cross_attention_dim: int = 1280,
572
+ add_upsample: bool = True,
573
+ ):
574
+ super().__init__()
575
+ resnets = []
576
+ attentions = []
577
+
578
+ self.has_cross_attention = True
579
+ self.num_attention_heads = num_attention_heads
580
+
581
+ if isinstance(transformer_layers_per_block, int):
582
+ transformer_layers_per_block = [transformer_layers_per_block] * num_layers
583
+
584
+ for i in range(num_layers):
585
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
586
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
587
+
588
+ resnets.append(
589
+ SpatioTemporalResBlock(
590
+ in_channels=resnet_in_channels + res_skip_channels,
591
+ out_channels=out_channels,
592
+ temb_channels=temb_channels,
593
+ eps=resnet_eps,
594
+ )
595
+ )
596
+ attentions.append(
597
+ TransformerSpatioTemporalModel(
598
+ num_attention_heads,
599
+ out_channels // num_attention_heads,
600
+ in_channels=out_channels,
601
+ num_layers=transformer_layers_per_block[i],
602
+ cross_attention_dim=cross_attention_dim,
603
+ )
604
+ )
605
+
606
+ self.attentions = nn.ModuleList(attentions)
607
+ self.resnets = nn.ModuleList(resnets)
608
+
609
+ if add_upsample:
610
+ self.upsamplers = nn.ModuleList(
611
+ [Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]
612
+ )
613
+ else:
614
+ self.upsamplers = None
615
+
616
+ self.gradient_checkpointing = False
617
+ self.resolution_idx = resolution_idx
618
+
619
+ def forward(
620
+ self,
621
+ hidden_states: torch.FloatTensor,
622
+ res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
623
+ temb: Optional[torch.FloatTensor] = None,
624
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
625
+ image_only_indicator: Optional[torch.Tensor] = None,
626
+ upsample_size: Optional[int] = None,
627
+ position_ids=None,
628
+ ) -> torch.FloatTensor:
629
+ for resnet, attn in zip(self.resnets, self.attentions):
630
+ # pop res hidden states
631
+ res_hidden_states = res_hidden_states_tuple[-1]
632
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
633
+
634
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
635
+
636
+ if self.training and self.gradient_checkpointing: # TODO
637
+
638
+ def create_custom_forward(module, return_dict=None):
639
+ def custom_forward(*inputs):
640
+ if return_dict is not None:
641
+ return module(*inputs, return_dict=return_dict)
642
+ else:
643
+ return module(*inputs)
644
+
645
+ return custom_forward
646
+
647
+ ckpt_kwargs: Dict[str, Any] = (
648
+ {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
649
+ )
650
+ hidden_states = torch.utils.checkpoint.checkpoint(
651
+ create_custom_forward(resnet),
652
+ hidden_states,
653
+ temb,
654
+ image_only_indicator,
655
+ **ckpt_kwargs,
656
+ )
657
+ hidden_states = attn(
658
+ hidden_states,
659
+ encoder_hidden_states=encoder_hidden_states,
660
+ image_only_indicator=image_only_indicator,
661
+ return_dict=False,
662
+ position_ids=position_ids,
663
+ )[0]
664
+ else:
665
+ hidden_states = resnet(
666
+ hidden_states,
667
+ temb,
668
+ image_only_indicator=image_only_indicator,
669
+ )
670
+ hidden_states = attn(
671
+ hidden_states,
672
+ encoder_hidden_states=encoder_hidden_states,
673
+ image_only_indicator=image_only_indicator,
674
+ return_dict=False,
675
+ position_ids=position_ids,
676
+ )[0]
677
+
678
+ if self.upsamplers is not None:
679
+ for upsampler in self.upsamplers:
680
+ hidden_states = upsampler(hidden_states, upsample_size)
681
+
682
+ return hidden_states
models/unets/unet_spatio_temporal_rope_condition.py ADDED
@@ -0,0 +1,546 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import Dict, Optional, Tuple, Union
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+
7
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
8
+ from diffusers.loaders import UNet2DConditionLoadersMixin
9
+ from diffusers.utils import BaseOutput, logging
10
+ from diffusers.models.attention_processor import (
11
+ CROSS_ATTENTION_PROCESSORS,
12
+ AttentionProcessor,
13
+ AttnProcessor,
14
+ )
15
+ from diffusers.models.embeddings import TimestepEmbedding, Timesteps
16
+ from diffusers.models.modeling_utils import ModelMixin
17
+ from .unet_3d_rope_blocks import (
18
+ UNetMidBlockSpatioTemporal,
19
+ get_down_block,
20
+ get_up_block,
21
+ )
22
+
23
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
24
+
25
+
26
+ @dataclass
27
+ class UNetSpatioTemporalRopeConditionOutput(BaseOutput):
28
+ """
29
+ The output of [`UNetSpatioTemporalConditionModel`].
30
+
31
+ Args:
32
+ sample (`torch.FloatTensor` of shape `(batch_size, num_frames, num_channels, height, width)`):
33
+ The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model.
34
+ """
35
+
36
+ sample: torch.FloatTensor = None
37
+
38
+
39
+ class UNetSpatioTemporalRopeConditionModel(
40
+ ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin
41
+ ):
42
+ r"""
43
+ A conditional Spatio-Temporal UNet model that takes a noisy video frames, conditional state, and a timestep and returns a sample
44
+ shaped output.
45
+
46
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
47
+ for all models (such as downloading or saving).
48
+
49
+ Parameters:
50
+ sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
51
+ Height and width of input/output sample.
52
+ in_channels (`int`, *optional*, defaults to 8): Number of channels in the input sample.
53
+ out_channels (`int`, *optional*, defaults to 4): Number of channels in the output.
54
+ down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlockSpatioTemporal", "CrossAttnDownBlockSpatioTemporal", "CrossAttnDownBlockSpatioTemporal", "DownBlockSpatioTemporal")`):
55
+ The tuple of downsample blocks to use.
56
+ up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlockSpatioTemporal", "CrossAttnUpBlockSpatioTemporal", "CrossAttnUpBlockSpatioTemporal", "CrossAttnUpBlockSpatioTemporal")`):
57
+ The tuple of upsample blocks to use.
58
+ block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
59
+ The tuple of output channels for each block.
60
+ addition_time_embed_dim: (`int`, defaults to 256):
61
+ Dimension to to encode the additional time ids.
62
+ projection_class_embeddings_input_dim (`int`, defaults to 768):
63
+ The dimension of the projection of encoded `added_time_ids`.
64
+ layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block.
65
+ cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280):
66
+ The dimension of the cross attention features.
67
+ transformer_layers_per_block (`int`, `Tuple[int]`, or `Tuple[Tuple]` , *optional*, defaults to 1):
68
+ The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
69
+ [`~models.unet_3d_blocks.CrossAttnDownBlockSpatioTemporal`], [`~models.unet_3d_blocks.CrossAttnUpBlockSpatioTemporal`],
70
+ [`~models.unet_3d_blocks.UNetMidBlockSpatioTemporal`].
71
+ num_attention_heads (`int`, `Tuple[int]`, defaults to `(5, 10, 10, 20)`):
72
+ The number of attention heads.
73
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
74
+ """
75
+
76
+ _supports_gradient_checkpointing = True
77
+
78
+ @register_to_config
79
+ def __init__(
80
+ self,
81
+ sample_size: Optional[int] = None,
82
+ in_channels: int = 8,
83
+ out_channels: int = 4,
84
+ down_block_types: Tuple[str] = (
85
+ "CrossAttnDownBlockSpatioTemporal",
86
+ "CrossAttnDownBlockSpatioTemporal",
87
+ "CrossAttnDownBlockSpatioTemporal",
88
+ "DownBlockSpatioTemporal",
89
+ ),
90
+ up_block_types: Tuple[str] = (
91
+ "UpBlockSpatioTemporal",
92
+ "CrossAttnUpBlockSpatioTemporal",
93
+ "CrossAttnUpBlockSpatioTemporal",
94
+ "CrossAttnUpBlockSpatioTemporal",
95
+ ),
96
+ block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
97
+ addition_time_embed_dim: int = 256,
98
+ projection_class_embeddings_input_dim: int = 768,
99
+ layers_per_block: Union[int, Tuple[int]] = 2,
100
+ cross_attention_dim: Union[int, Tuple[int]] = 1024,
101
+ transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1,
102
+ num_attention_heads: Union[int, Tuple[int]] = (5, 10, 20, 20),
103
+ num_frames: int = 25,
104
+ ):
105
+ super().__init__()
106
+
107
+ self.sample_size = sample_size
108
+
109
+ # Check inputs
110
+ if len(down_block_types) != len(up_block_types):
111
+ raise ValueError(
112
+ f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
113
+ )
114
+
115
+ if len(block_out_channels) != len(down_block_types):
116
+ raise ValueError(
117
+ f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
118
+ )
119
+
120
+ if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(
121
+ down_block_types
122
+ ):
123
+ raise ValueError(
124
+ f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
125
+ )
126
+
127
+ if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(
128
+ down_block_types
129
+ ):
130
+ raise ValueError(
131
+ f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}."
132
+ )
133
+
134
+ if not isinstance(layers_per_block, int) and len(layers_per_block) != len(
135
+ down_block_types
136
+ ):
137
+ raise ValueError(
138
+ f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}."
139
+ )
140
+
141
+ # input
142
+ self.conv_in = nn.Conv2d(
143
+ in_channels,
144
+ block_out_channels[0],
145
+ kernel_size=3,
146
+ padding=1,
147
+ )
148
+
149
+ # time
150
+ time_embed_dim = block_out_channels[0] * 4
151
+
152
+ self.time_proj = Timesteps(block_out_channels[0], True, downscale_freq_shift=0)
153
+ timestep_input_dim = block_out_channels[0]
154
+
155
+ self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
156
+
157
+ self.down_blocks = nn.ModuleList([])
158
+ self.up_blocks = nn.ModuleList([])
159
+
160
+ if isinstance(num_attention_heads, int):
161
+ num_attention_heads = (num_attention_heads,) * len(down_block_types)
162
+
163
+ if isinstance(cross_attention_dim, int):
164
+ cross_attention_dim = (cross_attention_dim,) * len(down_block_types)
165
+
166
+ if isinstance(layers_per_block, int):
167
+ layers_per_block = [layers_per_block] * len(down_block_types)
168
+
169
+ if isinstance(transformer_layers_per_block, int):
170
+ transformer_layers_per_block = [transformer_layers_per_block] * len(
171
+ down_block_types
172
+ )
173
+
174
+ blocks_time_embed_dim = time_embed_dim
175
+
176
+ # down
177
+ output_channel = block_out_channels[0]
178
+ for i, down_block_type in enumerate(down_block_types):
179
+ input_channel = output_channel
180
+ output_channel = block_out_channels[i]
181
+ is_final_block = i == len(block_out_channels) - 1
182
+
183
+ down_block = get_down_block(
184
+ down_block_type,
185
+ num_layers=layers_per_block[i],
186
+ transformer_layers_per_block=transformer_layers_per_block[i],
187
+ in_channels=input_channel,
188
+ out_channels=output_channel,
189
+ temb_channels=blocks_time_embed_dim,
190
+ add_downsample=not is_final_block,
191
+ resnet_eps=1e-5,
192
+ cross_attention_dim=cross_attention_dim[i],
193
+ num_attention_heads=num_attention_heads[i],
194
+ resnet_act_fn="silu",
195
+ )
196
+ self.down_blocks.append(down_block)
197
+
198
+ # mid
199
+ self.mid_block = UNetMidBlockSpatioTemporal(
200
+ block_out_channels[-1],
201
+ temb_channels=blocks_time_embed_dim,
202
+ transformer_layers_per_block=transformer_layers_per_block[-1],
203
+ cross_attention_dim=cross_attention_dim[-1],
204
+ num_attention_heads=num_attention_heads[-1],
205
+ )
206
+
207
+ # count how many layers upsample the images
208
+ self.num_upsamplers = 0
209
+
210
+ # up
211
+ reversed_block_out_channels = list(reversed(block_out_channels))
212
+ reversed_num_attention_heads = list(reversed(num_attention_heads))
213
+ reversed_layers_per_block = list(reversed(layers_per_block))
214
+ reversed_cross_attention_dim = list(reversed(cross_attention_dim))
215
+ reversed_transformer_layers_per_block = list(
216
+ reversed(transformer_layers_per_block)
217
+ )
218
+
219
+ output_channel = reversed_block_out_channels[0]
220
+ for i, up_block_type in enumerate(up_block_types):
221
+ is_final_block = i == len(block_out_channels) - 1
222
+
223
+ prev_output_channel = output_channel
224
+ output_channel = reversed_block_out_channels[i]
225
+ input_channel = reversed_block_out_channels[
226
+ min(i + 1, len(block_out_channels) - 1)
227
+ ]
228
+
229
+ # add upsample block for all BUT final layer
230
+ if not is_final_block:
231
+ add_upsample = True
232
+ self.num_upsamplers += 1
233
+ else:
234
+ add_upsample = False
235
+
236
+ up_block = get_up_block(
237
+ up_block_type,
238
+ num_layers=reversed_layers_per_block[i] + 1,
239
+ transformer_layers_per_block=reversed_transformer_layers_per_block[i],
240
+ in_channels=input_channel,
241
+ out_channels=output_channel,
242
+ prev_output_channel=prev_output_channel,
243
+ temb_channels=blocks_time_embed_dim,
244
+ add_upsample=add_upsample,
245
+ resnet_eps=1e-5,
246
+ resolution_idx=i,
247
+ cross_attention_dim=reversed_cross_attention_dim[i],
248
+ num_attention_heads=reversed_num_attention_heads[i],
249
+ resnet_act_fn="silu",
250
+ )
251
+ self.up_blocks.append(up_block)
252
+ prev_output_channel = output_channel
253
+
254
+ # out
255
+ self.conv_norm_out = nn.GroupNorm(
256
+ num_channels=block_out_channels[0], num_groups=32, eps=1e-5
257
+ )
258
+ self.conv_act = nn.SiLU()
259
+
260
+ self.conv_out = nn.Conv2d(
261
+ block_out_channels[0],
262
+ out_channels,
263
+ kernel_size=3,
264
+ padding=1,
265
+ )
266
+
267
+ @property
268
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
269
+ r"""
270
+ Returns:
271
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
272
+ indexed by its weight name.
273
+ """
274
+ # set recursively
275
+ processors = {}
276
+
277
+ def fn_recursive_add_processors(
278
+ name: str,
279
+ module: torch.nn.Module,
280
+ processors: Dict[str, AttentionProcessor],
281
+ ):
282
+ if hasattr(module, "get_processor"):
283
+ processors[f"{name}.processor"] = module.get_processor(
284
+ return_deprecated_lora=True
285
+ )
286
+
287
+ for sub_name, child in module.named_children():
288
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
289
+
290
+ return processors
291
+
292
+ for name, module in self.named_children():
293
+ fn_recursive_add_processors(name, module, processors)
294
+
295
+ return processors
296
+
297
+ def set_attn_processor(
298
+ self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]
299
+ ):
300
+ r"""
301
+ Sets the attention processor to use to compute attention.
302
+
303
+ Parameters:
304
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
305
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
306
+ for **all** `Attention` layers.
307
+
308
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
309
+ processor. This is strongly recommended when setting trainable attention processors.
310
+
311
+ """
312
+ count = len(self.attn_processors.keys())
313
+
314
+ if isinstance(processor, dict) and len(processor) != count:
315
+ raise ValueError(
316
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
317
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
318
+ )
319
+
320
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
321
+ if hasattr(module, "set_processor"):
322
+ if not isinstance(processor, dict):
323
+ module.set_processor(processor)
324
+ else:
325
+ module.set_processor(processor.pop(f"{name}.processor"))
326
+
327
+ for sub_name, child in module.named_children():
328
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
329
+
330
+ for name, module in self.named_children():
331
+ fn_recursive_attn_processor(name, module, processor)
332
+
333
+ def set_default_attn_processor(self):
334
+ """
335
+ Disables custom attention processors and sets the default attention implementation.
336
+ """
337
+ if all(
338
+ proc.__class__ in CROSS_ATTENTION_PROCESSORS
339
+ for proc in self.attn_processors.values()
340
+ ):
341
+ processor = AttnProcessor()
342
+ else:
343
+ raise ValueError(
344
+ f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
345
+ )
346
+
347
+ self.set_attn_processor(processor)
348
+
349
+ def _set_gradient_checkpointing(self, module, value=False):
350
+ if hasattr(module, "gradient_checkpointing"):
351
+ module.gradient_checkpointing = value
352
+
353
+ # Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.enable_forward_chunking
354
+ def enable_forward_chunking(
355
+ self, chunk_size: Optional[int] = None, dim: int = 0
356
+ ) -> None:
357
+ """
358
+ Sets the attention processor to use [feed forward
359
+ chunking](https://huggingface.co/blog/reformer#2-chunked-feed-forward-layers).
360
+
361
+ Parameters:
362
+ chunk_size (`int`, *optional*):
363
+ The chunk size of the feed-forward layers. If not specified, will run feed-forward layer individually
364
+ over each tensor of dim=`dim`.
365
+ dim (`int`, *optional*, defaults to `0`):
366
+ The dimension over which the feed-forward computation should be chunked. Choose between dim=0 (batch)
367
+ or dim=1 (sequence length).
368
+ """
369
+ if dim not in [0, 1]:
370
+ raise ValueError(f"Make sure to set `dim` to either 0 or 1, not {dim}")
371
+
372
+ # By default chunk size is 1
373
+ chunk_size = chunk_size or 1
374
+
375
+ def fn_recursive_feed_forward(
376
+ module: torch.nn.Module, chunk_size: int, dim: int
377
+ ):
378
+ if hasattr(module, "set_chunk_feed_forward"):
379
+ module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim)
380
+
381
+ for child in module.children():
382
+ fn_recursive_feed_forward(child, chunk_size, dim)
383
+
384
+ for module in self.children():
385
+ fn_recursive_feed_forward(module, chunk_size, dim)
386
+
387
+ def forward(
388
+ self,
389
+ sample: torch.FloatTensor,
390
+ timestep: Union[torch.Tensor, float, int],
391
+ encoder_hidden_states: torch.Tensor,
392
+ return_dict: bool = True,
393
+ position_ids=None,
394
+ ) -> Union[UNetSpatioTemporalRopeConditionOutput, Tuple]:
395
+ r"""
396
+ The [`UNetSpatioTemporalConditionModel`] forward method.
397
+
398
+ Args:
399
+ sample (`torch.FloatTensor`):
400
+ The noisy input tensor with the following shape `(batch, num_frames, channel, height, width)`.
401
+ timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
402
+ encoder_hidden_states (`torch.FloatTensor`):
403
+ The encoder hidden states with shape `(batch, sequence_length, cross_attention_dim)`.
404
+ return_dict (`bool`, *optional*, defaults to `True`):
405
+ Whether or not to return a [`~models.unet_slatio_temporal.UNetSpatioTemporalConditionOutput`] instead of a plain
406
+ tuple.
407
+ Returns:
408
+ [`~models.unet_slatio_temporal.UNetSpatioTemporalConditionOutput`] or `tuple`:
409
+ If `return_dict` is True, an [`~models.unet_slatio_temporal.UNetSpatioTemporalConditionOutput`] is returned, otherwise
410
+ a `tuple` is returned where the first element is the sample tensor.
411
+ """
412
+ default_overall_up_factor = 2**self.num_upsamplers
413
+
414
+ # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
415
+ forward_upsample_size = False
416
+ upsample_size = None
417
+
418
+ for dim in sample.shape[-2:]:
419
+ if dim % default_overall_up_factor != 0:
420
+ # Forward upsample size to force interpolation output size.
421
+ forward_upsample_size = True
422
+ break
423
+
424
+ # 1. time
425
+ timesteps = timestep
426
+ if not torch.is_tensor(timesteps):
427
+ # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
428
+ # This would be a good case for the `match` statement (Python 3.10+)
429
+ is_mps = sample.device.type == "mps"
430
+ if isinstance(timestep, float):
431
+ dtype = torch.float32 if is_mps else torch.float64
432
+ else:
433
+ dtype = torch.int32 if is_mps else torch.int64
434
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
435
+ elif len(timesteps.shape) == 0:
436
+ timesteps = timesteps[None].to(sample.device)
437
+
438
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
439
+ batch_size, num_frames = sample.shape[:2]
440
+ timesteps = timesteps.expand(batch_size)
441
+
442
+ t_emb = self.time_proj(timesteps)
443
+
444
+ # `Timesteps` does not contain any weights and will always return f32 tensors
445
+ # but time_embedding might actually be running in fp16. so we need to cast here.
446
+ # there might be better ways to encapsulate this.
447
+ t_emb = t_emb.to(dtype=sample.dtype)
448
+
449
+ emb = self.time_embedding(t_emb)
450
+
451
+ # Flatten the batch and frames dimensions
452
+ # sample: [batch, frames, channels, height, width] -> [batch * frames, channels, height, width]
453
+ sample = sample.flatten(0, 1)
454
+ # Repeat the embeddings num_video_frames times
455
+ # emb: [batch, channels] -> [batch * frames, channels]
456
+ emb = emb.repeat_interleave(num_frames, dim=0)
457
+ # encoder_hidden_states: [batch, 1, channels] -> [batch * frames, 1, channels]
458
+ encoder_hidden_states = encoder_hidden_states.repeat_interleave(
459
+ num_frames, dim=0
460
+ )
461
+
462
+ # 2. pre-process
463
+ sample = self.conv_in(sample)
464
+
465
+ image_only_indicator = torch.zeros(
466
+ batch_size, num_frames, dtype=sample.dtype, device=sample.device
467
+ )
468
+
469
+ down_block_res_samples = (sample,)
470
+ for downsample_block in self.down_blocks:
471
+ if (
472
+ hasattr(downsample_block, "has_cross_attention")
473
+ and downsample_block.has_cross_attention
474
+ ):
475
+ sample, res_samples = downsample_block(
476
+ hidden_states=sample,
477
+ temb=emb,
478
+ encoder_hidden_states=encoder_hidden_states,
479
+ image_only_indicator=image_only_indicator,
480
+ position_ids=position_ids,
481
+ )
482
+ else:
483
+ sample, res_samples = downsample_block(
484
+ hidden_states=sample,
485
+ temb=emb,
486
+ image_only_indicator=image_only_indicator,
487
+ position_ids=position_ids,
488
+ )
489
+
490
+ down_block_res_samples += res_samples
491
+
492
+ # 4. mid
493
+ sample = self.mid_block(
494
+ hidden_states=sample,
495
+ temb=emb,
496
+ encoder_hidden_states=encoder_hidden_states,
497
+ image_only_indicator=image_only_indicator,
498
+ position_ids=position_ids,
499
+ )
500
+
501
+ # 5. up
502
+ for i, upsample_block in enumerate(self.up_blocks):
503
+ is_final_block = i == len(self.up_blocks) - 1
504
+
505
+ res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
506
+ down_block_res_samples = down_block_res_samples[
507
+ : -len(upsample_block.resnets)
508
+ ]
509
+ if not is_final_block and forward_upsample_size:
510
+ upsample_size = down_block_res_samples[-1].shape[2:]
511
+
512
+ if (
513
+ hasattr(upsample_block, "has_cross_attention")
514
+ and upsample_block.has_cross_attention
515
+ ):
516
+ sample = upsample_block(
517
+ hidden_states=sample,
518
+ temb=emb,
519
+ res_hidden_states_tuple=res_samples,
520
+ encoder_hidden_states=encoder_hidden_states,
521
+ image_only_indicator=image_only_indicator,
522
+ upsample_size=upsample_size,
523
+ position_ids=position_ids,
524
+ )
525
+ else:
526
+ sample = upsample_block(
527
+ hidden_states=sample,
528
+ temb=emb,
529
+ res_hidden_states_tuple=res_samples,
530
+ image_only_indicator=image_only_indicator,
531
+ upsample_size=upsample_size,
532
+ position_ids=position_ids,
533
+ )
534
+
535
+ # 6. post-process
536
+ sample = self.conv_norm_out(sample)
537
+ sample = self.conv_act(sample)
538
+ sample = self.conv_out(sample)
539
+
540
+ # 7. Reshape back to original shape
541
+ sample = sample.reshape(batch_size, num_frames, *sample.shape[1:])
542
+
543
+ if not return_dict:
544
+ return (sample,)
545
+
546
+ return UNetSpatioTemporalRopeConditionOutput(sample=sample)
pipelines/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from .dav_pipeline import DAVPipeline
2
+
3
+ __all__ = {
4
+ "DAVPipeline": DAVPipeline,
5
+ }
pipelines/dav_pipeline.py ADDED
@@ -0,0 +1,246 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import tqdm
3
+ import numpy as np
4
+ from diffusers import DiffusionPipeline
5
+ from diffusers.utils import BaseOutput
6
+ import matplotlib
7
+
8
+
9
+ def colorize_depth(depth, cmap="Spectral"):
10
+ # colorize
11
+ cm = matplotlib.colormaps[cmap]
12
+ # (B, N, H, W, 3)
13
+ depth_colored = cm(depth, bytes=False)[..., 0:3] # value from 0 to 1
14
+ return depth_colored
15
+
16
+
17
+ class DAVOutput(BaseOutput):
18
+ r"""
19
+ Output class for zero-shot text-to-video pipeline.
20
+
21
+ Args:
22
+ frames (`[List[PIL.Image.Image]`, `np.ndarray`]):
23
+ List of denoised PIL images of length `batch_size` or NumPy array of shape `(batch_size, height, width,
24
+ num_channels)`.
25
+ """
26
+
27
+ disparity: np.ndarray
28
+ disparity_colored: np.ndarray
29
+ image: np.ndarray
30
+
31
+
32
+ class DAVPipeline(DiffusionPipeline):
33
+ def __init__(self, vae, unet, unet_interp, scheduler):
34
+ super().__init__()
35
+ self.register_modules(
36
+ vae=vae, unet=unet, unet_interp=unet_interp, scheduler=scheduler
37
+ )
38
+
39
+ def encode(self, input):
40
+ num_frames = input.shape[1]
41
+ input = input.flatten(0, 1)
42
+ latent = self.vae.encode(input.to(self.vae.dtype)).latent_dist.mode()
43
+ latent = latent * self.vae.config.scaling_factor
44
+ latent = latent.reshape(-1, num_frames, *latent.shape[1:])
45
+ return latent
46
+
47
+ def decode(self, latents, decode_chunk_size=16):
48
+ # [batch, frames, channels, height, width] -> [batch*frames, channels, height, width]
49
+ num_frames = latents.shape[1]
50
+ latents = latents.flatten(0, 1)
51
+ latents = latents / self.vae.config.scaling_factor
52
+
53
+ # decode decode_chunk_size frames at a time to avoid OOM
54
+ frames = []
55
+ for i in range(0, latents.shape[0], decode_chunk_size):
56
+ num_frames_in = latents[i : i + decode_chunk_size].shape[0]
57
+ frame = self.vae.decode(
58
+ latents[i : i + decode_chunk_size].to(self.vae.dtype),
59
+ num_frames=num_frames_in,
60
+ ).sample
61
+ frames.append(frame)
62
+ frames = torch.cat(frames, dim=0)
63
+
64
+ # [batch, frames, channels, height, width]
65
+ frames = frames.reshape(-1, num_frames, *frames.shape[1:])
66
+ return frames.to(torch.float32)
67
+
68
+ def single_infer(self, rgb, position_ids=None, num_inference_steps=None):
69
+ rgb_latent = self.encode(rgb)
70
+ noise_latent = torch.randn_like(rgb_latent)
71
+
72
+ self.scheduler.set_timesteps(num_inference_steps, device=rgb.device)
73
+ timesteps = self.scheduler.timesteps
74
+
75
+ image_embeddings = torch.zeros((noise_latent.shape[0], 1, 1024)).to(
76
+ noise_latent
77
+ )
78
+
79
+ for i, t in enumerate(timesteps):
80
+ latent_model_input = noise_latent
81
+
82
+ latent_model_input = torch.cat([latent_model_input, rgb_latent], dim=2)
83
+
84
+ # [batch_size, num_frame, 4, h, w]
85
+ model_output = self.unet(
86
+ latent_model_input,
87
+ t,
88
+ encoder_hidden_states=image_embeddings,
89
+ position_ids=position_ids,
90
+ ).sample
91
+
92
+ # compute the previous noisy sample x_t -> x_t-1
93
+ noise_latent = self.scheduler.step(
94
+ model_output, t, noise_latent
95
+ ).prev_sample
96
+
97
+ return noise_latent
98
+
99
+ def single_interp_infer(
100
+ self, rgb, masked_depth_latent, mask, num_inference_steps=None
101
+ ):
102
+ rgb_latent = self.encode(rgb)
103
+ noise_latent = torch.randn_like(rgb_latent)
104
+
105
+ self.scheduler.set_timesteps(num_inference_steps, device=rgb.device)
106
+ timesteps = self.scheduler.timesteps
107
+
108
+ image_embeddings = torch.zeros((noise_latent.shape[0], 1, 1024)).to(
109
+ noise_latent
110
+ )
111
+
112
+ for i, t in enumerate(timesteps):
113
+ latent_model_input = noise_latent
114
+
115
+ latent_model_input = torch.cat(
116
+ [latent_model_input, rgb_latent, masked_depth_latent, mask], dim=2
117
+ )
118
+
119
+ # [batch_size, num_frame, 4, h, w]
120
+ model_output = self.unet_interp(
121
+ latent_model_input, t, encoder_hidden_states=image_embeddings
122
+ ).sample
123
+
124
+ # compute the previous noisy sample x_t -> x_t-1
125
+ noise_latent = self.scheduler.step(
126
+ model_output, t, noise_latent
127
+ ).prev_sample
128
+
129
+ return noise_latent
130
+
131
+ def __call__(
132
+ self,
133
+ image,
134
+ num_frames,
135
+ num_overlap_frames,
136
+ num_interp_frames,
137
+ decode_chunk_size,
138
+ num_inference_steps,
139
+ ):
140
+ self.vae.to(dtype=torch.float16)
141
+
142
+ # (1, N, 3, H, W)
143
+ image = image.unsqueeze(0)
144
+ B, N = image.shape[:2]
145
+ rgb = image * 2 - 1 # [-1, 1]
146
+
147
+ if N <= num_frames or N <= num_interp_frames + 2 - num_overlap_frames:
148
+ depth_latent = self.single_infer(
149
+ rgb, num_inference_steps=num_inference_steps
150
+ )
151
+ else:
152
+ assert 2 <= num_overlap_frames <= (num_interp_frames + 2 + 1) // 2
153
+ assert num_frames % 2 == 0
154
+
155
+ key_frame_indices = []
156
+ for i in range(0, N, num_interp_frames + 2 - num_overlap_frames):
157
+ if (
158
+ i + num_interp_frames + 1 >= N
159
+ or len(key_frame_indices) >= num_frames
160
+ ):
161
+ break
162
+ key_frame_indices.append(i)
163
+ key_frame_indices.append(i + num_interp_frames + 1)
164
+
165
+ key_frame_indices = torch.tensor(key_frame_indices, device=rgb.device)
166
+
167
+ sorted_key_frame_indices, origin_indices = torch.sort(key_frame_indices)
168
+ key_rgb = rgb[:, sorted_key_frame_indices]
169
+ key_depth_latent = self.single_infer(
170
+ key_rgb,
171
+ sorted_key_frame_indices.unsqueeze(0).repeat(B, 1),
172
+ num_inference_steps=num_inference_steps,
173
+ )
174
+ key_depth_latent = key_depth_latent[:, origin_indices]
175
+
176
+ torch.cuda.empty_cache()
177
+
178
+ depth_latent = []
179
+ pre_latent = None
180
+ for i in tqdm.tqdm(range(0, len(key_frame_indices), 2)):
181
+ frame1 = key_depth_latent[:, i]
182
+ frame2 = key_depth_latent[:, i + 1]
183
+ masked_depth_latent = torch.zeros(
184
+ (B, num_interp_frames + 2, *key_depth_latent.shape[2:])
185
+ ).to(key_depth_latent)
186
+ masked_depth_latent[:, 0] = frame1
187
+ masked_depth_latent[:, -1] = frame2
188
+
189
+ mask = torch.zeros_like(masked_depth_latent)
190
+ mask[:, [0, -1]] = 1.0
191
+
192
+ latent = self.single_interp_infer(
193
+ rgb[:, key_frame_indices[i] : key_frame_indices[i + 1] + 1],
194
+ masked_depth_latent,
195
+ mask,
196
+ num_inference_steps=num_inference_steps,
197
+ )
198
+ latent = latent[:, 1:-1]
199
+
200
+ if pre_latent is not None:
201
+ overlap_a = pre_latent[
202
+ :, pre_latent.shape[1] - (num_overlap_frames - 2) :
203
+ ]
204
+ overlap_b = latent[:, : (num_overlap_frames - 2)]
205
+ ratio = (
206
+ torch.linspace(0, 1, num_overlap_frames - 2)
207
+ .to(overlap_a)
208
+ .view(1, -1, 1, 1, 1)
209
+ )
210
+ overlap = overlap_a * (1 - ratio) + overlap_b * ratio
211
+ pre_latent[:, pre_latent.shape[1] - (num_overlap_frames - 2) :] = (
212
+ overlap
213
+ )
214
+ depth_latent.append(pre_latent)
215
+
216
+ pre_latent = latent[:, (num_overlap_frames - 2) if i > 0 else 0 :]
217
+
218
+ torch.cuda.empty_cache()
219
+
220
+ depth_latent.append(pre_latent)
221
+ depth_latent = torch.cat(depth_latent, dim=1)
222
+
223
+ # dicard the first and last key frames
224
+ image = image[:, key_frame_indices[0] + 1 : key_frame_indices[-1]]
225
+ assert depth_latent.shape[1] == image.shape[1]
226
+
227
+ disparity = self.decode(depth_latent, decode_chunk_size=decode_chunk_size)
228
+ disparity = disparity.mean(dim=2, keepdim=False)
229
+ disparity = torch.clamp(disparity * 0.5 + 0.5, 0.0, 1.0)
230
+
231
+ # (N, H, W)
232
+ disparity = disparity.squeeze(0)
233
+ # (N, H, W, 3)
234
+ mid_d, max_d = disparity.min(), disparity.max()
235
+ disparity_colored = torch.clamp((max_d - disparity) / (max_d - mid_d), 0.0, 1.0)
236
+ disparity_colored = colorize_depth(disparity_colored.cpu().numpy())
237
+ disparity_colored = (disparity_colored * 255).astype(np.uint8)
238
+ image = image.squeeze(0).permute(0, 2, 3, 1).cpu().numpy()
239
+ image = (image * 255).astype(np.uint8)
240
+ disparity = disparity.cpu().numpy()
241
+
242
+ return DAVOutput(
243
+ disparity=disparity,
244
+ disparity_colored=disparity_colored,
245
+ image=image,
246
+ )
requirements.txt ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch==2.4.0
2
+ torchvision==0.19.0
3
+ diffusers==0.30.0
4
+ accelerate==0.31.0
5
+ transformers==4.43.2
6
+ huggingface-hub==0.24.2
7
+ opencv-python
8
+ tqdm
9
+ matplotlib
10
+ scipy
11
+ pillow
12
+ easydict
utils/img_utils.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ import cv2
3
+ import numpy as np
4
+ import os
5
+ import tempfile
6
+
7
+
8
+ def resize(img, size):
9
+ assert img.dtype == np.uint8
10
+ pil_image = Image.fromarray(img)
11
+ pil_image = pil_image.resize(size, Image.LANCZOS)
12
+ resized_img = np.array(pil_image)
13
+ return resized_img
14
+
15
+
16
+ def crop(img, start_h, start_w, crop_h, crop_w):
17
+ img_src = np.zeros((crop_h, crop_w, *img.shape[2:]), dtype=img.dtype)
18
+ hsize, wsize = crop_h, crop_w
19
+ dh, dw, sh, sw = start_h, start_w, 0, 0
20
+ if dh < 0:
21
+ sh = -dh
22
+ hsize += dh
23
+ dh = 0
24
+ if dh + hsize > img.shape[0]:
25
+ hsize = img.shape[0] - dh
26
+ if dw < 0:
27
+ sw = -dw
28
+ wsize += dw
29
+ dw = 0
30
+ if dw + wsize > img.shape[1]:
31
+ wsize = img.shape[1] - dw
32
+ img_src[sh : sh + hsize, sw : sw + wsize] = img[dh : dh + hsize, dw : dw + wsize]
33
+ return img_src
34
+
35
+
36
+ def imresize_max(img, size, min_side=False):
37
+ new_img = []
38
+ for i, _img in enumerate(img):
39
+ h, w = _img.shape[:2]
40
+ ori_size = min(h, w) if min_side else max(h, w)
41
+ _resize = min(size / ori_size, 1.0)
42
+ new_size = (int(w * _resize), int(h * _resize))
43
+ new_img.append(resize(_img, new_size))
44
+ return new_img
45
+
46
+
47
+ def imcrop_multi(img, multiple=32):
48
+ new_img = []
49
+ for i, _img in enumerate(img):
50
+ crop_size = (
51
+ _img.shape[0] // multiple * multiple,
52
+ _img.shape[1] // multiple * multiple,
53
+ )
54
+ start_h = int(0.5 * max(0, _img.shape[0] - crop_size[0]))
55
+ start_w = int(0.5 * max(0, _img.shape[1] - crop_size[1]))
56
+ _img_src = crop(_img, start_h, start_w, crop_size[0], crop_size[1])
57
+ new_img.append(_img_src)
58
+ return new_img
59
+
60
+
61
+ def read_video(video_path, max_frames=None):
62
+ cap = cv2.VideoCapture(video_path)
63
+ fps = cap.get(cv2.CAP_PROP_FPS)
64
+ frames = []
65
+ count = 0
66
+ while True:
67
+ ret, frame = cap.read()
68
+ if not ret:
69
+ break
70
+ frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
71
+ frames.append(frame)
72
+ count += 1
73
+ if max_frames is not None and count >= max_frames:
74
+ break
75
+ cap.release()
76
+ # (N, H, W, 3)
77
+ return frames, fps
78
+
79
+
80
+ def read_image(image_path):
81
+ frame = cv2.imread(image_path)
82
+ frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
83
+ # (1, H, W, 3)
84
+ return [frame]
85
+
86
+
87
+ def write_video(video_path, frames, fps):
88
+ tmp_dir = os.path.join(os.path.dirname(video_path), "tmp")
89
+ os.makedirs(tmp_dir, exist_ok=True)
90
+ for i, frame in enumerate(frames):
91
+ write_image(os.path.join(tmp_dir, f"{i:06d}.png"), frame)
92
+ # it will cause visual compression artifacts
93
+ ffmpeg_command = [
94
+ "ffmpeg",
95
+ "-f",
96
+ "image2",
97
+ "-framerate",
98
+ f"{fps}",
99
+ "-i",
100
+ os.path.join(tmp_dir, "%06d.png"),
101
+ "-b:v",
102
+ "5626k",
103
+ "-y",
104
+ video_path,
105
+ ]
106
+ os.system(" ".join(ffmpeg_command))
107
+ os.system(f"rm -rf {tmp_dir}")
108
+
109
+
110
+ def write_image(image_path, frame):
111
+ frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
112
+ cv2.imwrite(image_path, frame)