jimmycarter commited on
Commit
0745964
·
verified ·
1 Parent(s): 26f2e08

Upload transformer.py

Browse files
Files changed (1) hide show
  1. transformer/transformer.py +757 -0
transformer/transformer.py ADDED
@@ -0,0 +1,757 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Stability AI, The HuggingFace Team, The InstantX Team, and Terminus Research Group. All rights reserved.
2
+ #
3
+ # Originally licensed under the Apache License, Version 2.0 (the "License");
4
+ # Updated to "Affero GENERAL PUBLIC LICENSE Version 3, 19 November 2007" via extensive updates to attn_mask usage.
5
+
6
+ from typing import Any, Dict, List, Optional, Union
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+
12
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
13
+ from diffusers.loaders import FromOriginalModelMixin, PeftAdapterMixin
14
+ from diffusers.models.attention import FeedForward
15
+ from diffusers.models.attention_processor import (
16
+ Attention,
17
+ apply_rope,
18
+ )
19
+ from diffusers.models.modeling_utils import ModelMixin
20
+ from diffusers.models.normalization import (
21
+ AdaLayerNormContinuous,
22
+ AdaLayerNormZero,
23
+ AdaLayerNormZeroSingle,
24
+ )
25
+ from diffusers.utils import (
26
+ USE_PEFT_BACKEND,
27
+ is_torch_version,
28
+ logging,
29
+ scale_lora_layers,
30
+ unscale_lora_layers,
31
+ )
32
+ from diffusers.utils.torch_utils import maybe_allow_in_graph
33
+ from diffusers.models.embeddings import (
34
+ CombinedTimestepGuidanceTextProjEmbeddings,
35
+ CombinedTimestepTextProjEmbeddings,
36
+ )
37
+ from diffusers.models.modeling_outputs import Transformer2DModelOutput
38
+
39
+
40
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
41
+
42
+
43
+ class FluxSingleAttnProcessor2_0:
44
+ r"""
45
+ Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
46
+ """
47
+
48
+ def __init__(self):
49
+ if not hasattr(F, "scaled_dot_product_attention"):
50
+ raise ImportError(
51
+ "AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
52
+ )
53
+
54
+ def __call__(
55
+ self,
56
+ attn: Attention,
57
+ hidden_states: torch.Tensor,
58
+ encoder_hidden_states: Optional[torch.Tensor] = None,
59
+ attention_mask: Optional[torch.FloatTensor] = None,
60
+ image_rotary_emb: Optional[torch.Tensor] = None,
61
+ ) -> torch.Tensor:
62
+ input_ndim = hidden_states.ndim
63
+
64
+ if input_ndim == 4:
65
+ batch_size, channel, height, width = hidden_states.shape
66
+ hidden_states = hidden_states.view(
67
+ batch_size, channel, height * width
68
+ ).transpose(1, 2)
69
+
70
+ batch_size, _, _ = hidden_states.shape
71
+ query = attn.to_q(hidden_states)
72
+ key = attn.to_k(hidden_states)
73
+ value = attn.to_v(hidden_states)
74
+
75
+ inner_dim = key.shape[-1]
76
+ head_dim = inner_dim // attn.heads
77
+
78
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
79
+
80
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
81
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
82
+
83
+ if attn.norm_q is not None:
84
+ query = attn.norm_q(query)
85
+ if attn.norm_k is not None:
86
+ key = attn.norm_k(key)
87
+
88
+ # Apply RoPE if needed
89
+ if image_rotary_emb is not None:
90
+ # YiYi to-do: update uising apply_rotary_emb
91
+ # from ..embeddings import apply_rotary_emb
92
+ # query = apply_rotary_emb(query, image_rotary_emb)
93
+ # key = apply_rotary_emb(key, image_rotary_emb)
94
+ query, key = apply_rope(query, key, image_rotary_emb)
95
+
96
+ if attention_mask is not None:
97
+ attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
98
+ attention_mask = (attention_mask > 0).bool()
99
+ attention_mask = attention_mask.to(
100
+ device=hidden_states.device, dtype=hidden_states.dtype
101
+ )
102
+
103
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
104
+ # TODO: add support for attn.scale when we move to Torch 2.1
105
+ hidden_states = F.scaled_dot_product_attention(
106
+ query,
107
+ key,
108
+ value,
109
+ dropout_p=0.0,
110
+ is_causal=False,
111
+ attn_mask=attention_mask,
112
+ )
113
+
114
+ hidden_states = hidden_states.transpose(1, 2).reshape(
115
+ batch_size, -1, attn.heads * head_dim
116
+ )
117
+ hidden_states = hidden_states.to(query.dtype)
118
+
119
+ if input_ndim == 4:
120
+ hidden_states = hidden_states.transpose(-1, -2).reshape(
121
+ batch_size, channel, height, width
122
+ )
123
+
124
+ return hidden_states
125
+
126
+
127
+ class FluxAttnProcessor2_0:
128
+ """Attention processor used typically in processing the SD3-like self-attention projections."""
129
+
130
+ def __init__(self):
131
+ if not hasattr(F, "scaled_dot_product_attention"):
132
+ raise ImportError(
133
+ "FluxAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
134
+ )
135
+
136
+ def __call__(
137
+ self,
138
+ attn: Attention,
139
+ hidden_states: torch.FloatTensor,
140
+ encoder_hidden_states: torch.FloatTensor = None,
141
+ attention_mask: Optional[torch.FloatTensor] = None,
142
+ image_rotary_emb: Optional[torch.Tensor] = None,
143
+ ) -> torch.FloatTensor:
144
+ input_ndim = hidden_states.ndim
145
+ if input_ndim == 4:
146
+ batch_size, channel, height, width = hidden_states.shape
147
+ hidden_states = hidden_states.view(
148
+ batch_size, channel, height * width
149
+ ).transpose(1, 2)
150
+ context_input_ndim = encoder_hidden_states.ndim
151
+ if context_input_ndim == 4:
152
+ batch_size, channel, height, width = encoder_hidden_states.shape
153
+ encoder_hidden_states = encoder_hidden_states.view(
154
+ batch_size, channel, height * width
155
+ ).transpose(1, 2)
156
+
157
+ batch_size = encoder_hidden_states.shape[0]
158
+
159
+ # `sample` projections.
160
+ query = attn.to_q(hidden_states)
161
+ key = attn.to_k(hidden_states)
162
+ value = attn.to_v(hidden_states)
163
+
164
+ inner_dim = key.shape[-1]
165
+ head_dim = inner_dim // attn.heads
166
+
167
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
168
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
169
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
170
+
171
+ if attn.norm_q is not None:
172
+ query = attn.norm_q(query)
173
+ if attn.norm_k is not None:
174
+ key = attn.norm_k(key)
175
+
176
+ # `context` projections.
177
+ encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
178
+ encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
179
+ encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
180
+
181
+ encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
182
+ batch_size, -1, attn.heads, head_dim
183
+ ).transpose(1, 2)
184
+ encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
185
+ batch_size, -1, attn.heads, head_dim
186
+ ).transpose(1, 2)
187
+ encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
188
+ batch_size, -1, attn.heads, head_dim
189
+ ).transpose(1, 2)
190
+
191
+ if attn.norm_added_q is not None:
192
+ encoder_hidden_states_query_proj = attn.norm_added_q(
193
+ encoder_hidden_states_query_proj
194
+ )
195
+ if attn.norm_added_k is not None:
196
+ encoder_hidden_states_key_proj = attn.norm_added_k(
197
+ encoder_hidden_states_key_proj
198
+ )
199
+
200
+ # attention
201
+ query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
202
+ key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
203
+ value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
204
+
205
+ if image_rotary_emb is not None:
206
+ # YiYi to-do: update uising apply_rotary_emb
207
+ # from ..embeddings import apply_rotary_emb
208
+ # query = apply_rotary_emb(query, image_rotary_emb)
209
+ # key = apply_rotary_emb(key, image_rotary_emb)
210
+ query, key = apply_rope(query, key, image_rotary_emb)
211
+
212
+ if attention_mask is not None:
213
+ attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
214
+ attention_mask = (attention_mask > 0).bool()
215
+ attention_mask = attention_mask.to(
216
+ device=hidden_states.device, dtype=hidden_states.dtype
217
+ )
218
+
219
+ hidden_states = F.scaled_dot_product_attention(
220
+ query,
221
+ key,
222
+ value,
223
+ dropout_p=0.0,
224
+ is_causal=False,
225
+ attn_mask=attention_mask,
226
+ )
227
+ hidden_states = hidden_states.transpose(1, 2).reshape(
228
+ batch_size, -1, attn.heads * head_dim
229
+ )
230
+ hidden_states = hidden_states.to(query.dtype)
231
+
232
+ encoder_hidden_states, hidden_states = (
233
+ hidden_states[:, : encoder_hidden_states.shape[1]],
234
+ hidden_states[:, encoder_hidden_states.shape[1] :],
235
+ )
236
+
237
+ # linear proj
238
+ hidden_states = attn.to_out[0](hidden_states)
239
+ # dropout
240
+ hidden_states = attn.to_out[1](hidden_states)
241
+ encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
242
+
243
+ if input_ndim == 4:
244
+ hidden_states = hidden_states.transpose(-1, -2).reshape(
245
+ batch_size, channel, height, width
246
+ )
247
+ if context_input_ndim == 4:
248
+ encoder_hidden_states = encoder_hidden_states.transpose(-1, -2).reshape(
249
+ batch_size, channel, height, width
250
+ )
251
+
252
+ return hidden_states, encoder_hidden_states
253
+
254
+
255
+ # YiYi to-do: refactor rope related functions/classes
256
+ def rope(pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor:
257
+ assert dim % 2 == 0, "The dimension must be even."
258
+
259
+ scale = (
260
+ torch.arange(
261
+ 0,
262
+ dim,
263
+ 2,
264
+ dtype=torch.float64, # torch.float32 if torch.backends.mps.is_available() else
265
+ device=pos.device,
266
+ )
267
+ / dim
268
+ )
269
+ omega = 1.0 / (theta**scale)
270
+
271
+ batch_size, seq_length = pos.shape
272
+ out = torch.einsum("...n,d->...nd", pos, omega)
273
+ cos_out = torch.cos(out)
274
+ sin_out = torch.sin(out)
275
+
276
+ stacked_out = torch.stack([cos_out, -sin_out, sin_out, cos_out], dim=-1)
277
+ out = stacked_out.view(batch_size, -1, dim // 2, 2, 2)
278
+ return out.float()
279
+
280
+
281
+ # YiYi to-do: refactor rope related functions/classes
282
+ class EmbedND(nn.Module):
283
+ def __init__(self, dim: int, theta: int, axes_dim: List[int]):
284
+ super().__init__()
285
+ self.dim = dim
286
+ self.theta = theta
287
+ self.axes_dim = axes_dim
288
+
289
+ def forward(self, ids: torch.Tensor) -> torch.Tensor:
290
+ n_axes = ids.shape[-1]
291
+ emb = torch.cat(
292
+ [rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)],
293
+ dim=-3,
294
+ )
295
+
296
+ return emb.unsqueeze(1)
297
+
298
+
299
+ def expand_flux_attention_mask(
300
+ hidden_states: torch.Tensor,
301
+ attn_mask: torch.Tensor,
302
+ ) -> torch.Tensor:
303
+ """
304
+ Expand a mask so that the image is included.
305
+ """
306
+ bsz = attn_mask.shape[0]
307
+ assert bsz == hidden_states.shape[0]
308
+ residual_seq_len = hidden_states.shape[1]
309
+ mask_seq_len = attn_mask.shape[1]
310
+
311
+ expanded_mask = torch.ones(bsz, residual_seq_len)
312
+ expanded_mask[:, :mask_seq_len] = attn_mask
313
+
314
+ return expanded_mask
315
+
316
+
317
+ @maybe_allow_in_graph
318
+ class FluxSingleTransformerBlock(nn.Module):
319
+ r"""
320
+ A Transformer block following the MMDiT architecture, introduced in Stable Diffusion 3.
321
+
322
+ Reference: https://arxiv.org/abs/2403.03206
323
+
324
+ Parameters:
325
+ dim (`int`): The number of channels in the input and output.
326
+ num_attention_heads (`int`): The number of heads to use for multi-head attention.
327
+ attention_head_dim (`int`): The number of channels in each head.
328
+ context_pre_only (`bool`): Boolean to determine if we should add some blocks associated with the
329
+ processing of `context` conditions.
330
+ """
331
+
332
+ def __init__(self, dim, num_attention_heads, attention_head_dim, mlp_ratio=4.0):
333
+ super().__init__()
334
+ self.mlp_hidden_dim = int(dim * mlp_ratio)
335
+
336
+ self.norm = AdaLayerNormZeroSingle(dim)
337
+ self.proj_mlp = nn.Linear(dim, self.mlp_hidden_dim)
338
+ self.act_mlp = nn.GELU(approximate="tanh")
339
+ self.proj_out = nn.Linear(dim + self.mlp_hidden_dim, dim)
340
+
341
+ processor = FluxSingleAttnProcessor2_0()
342
+ self.attn = Attention(
343
+ query_dim=dim,
344
+ cross_attention_dim=None,
345
+ dim_head=attention_head_dim,
346
+ heads=num_attention_heads,
347
+ out_dim=dim,
348
+ bias=True,
349
+ processor=processor,
350
+ qk_norm="rms_norm",
351
+ eps=1e-6,
352
+ pre_only=True,
353
+ )
354
+
355
+ def forward(
356
+ self,
357
+ hidden_states: torch.FloatTensor,
358
+ temb: torch.FloatTensor,
359
+ image_rotary_emb=None,
360
+ attention_mask: Optional[torch.Tensor] = None,
361
+ ):
362
+ residual = hidden_states
363
+ norm_hidden_states, gate = self.norm(hidden_states, emb=temb)
364
+ mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
365
+
366
+ if attention_mask is not None:
367
+ attention_mask = expand_flux_attention_mask(
368
+ hidden_states,
369
+ attention_mask,
370
+ )
371
+
372
+ attn_output = self.attn(
373
+ hidden_states=norm_hidden_states,
374
+ image_rotary_emb=image_rotary_emb,
375
+ attention_mask=attention_mask,
376
+ )
377
+
378
+ hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2)
379
+ gate = gate.unsqueeze(1)
380
+ hidden_states = gate * self.proj_out(hidden_states)
381
+ hidden_states = residual + hidden_states
382
+
383
+ return hidden_states
384
+
385
+
386
+ @maybe_allow_in_graph
387
+ class FluxTransformerBlock(nn.Module):
388
+ r"""
389
+ A Transformer block following the MMDiT architecture, introduced in Stable Diffusion 3.
390
+
391
+ Reference: https://arxiv.org/abs/2403.03206
392
+
393
+ Parameters:
394
+ dim (`int`): The number of channels in the input and output.
395
+ num_attention_heads (`int`): The number of heads to use for multi-head attention.
396
+ attention_head_dim (`int`): The number of channels in each head.
397
+ context_pre_only (`bool`): Boolean to determine if we should add some blocks associated with the
398
+ processing of `context` conditions.
399
+ """
400
+
401
+ def __init__(
402
+ self, dim, num_attention_heads, attention_head_dim, qk_norm="rms_norm", eps=1e-6
403
+ ):
404
+ super().__init__()
405
+
406
+ self.norm1 = AdaLayerNormZero(dim)
407
+
408
+ self.norm1_context = AdaLayerNormZero(dim)
409
+
410
+ if hasattr(F, "scaled_dot_product_attention"):
411
+ processor = FluxAttnProcessor2_0()
412
+ else:
413
+ raise ValueError(
414
+ "The current PyTorch version does not support the `scaled_dot_product_attention` function."
415
+ )
416
+ self.attn = Attention(
417
+ query_dim=dim,
418
+ cross_attention_dim=None,
419
+ added_kv_proj_dim=dim,
420
+ dim_head=attention_head_dim,
421
+ heads=num_attention_heads,
422
+ out_dim=dim,
423
+ context_pre_only=False,
424
+ bias=True,
425
+ processor=processor,
426
+ qk_norm=qk_norm,
427
+ eps=eps,
428
+ )
429
+
430
+ self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
431
+ self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
432
+
433
+ self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
434
+ self.ff_context = FeedForward(
435
+ dim=dim, dim_out=dim, activation_fn="gelu-approximate"
436
+ )
437
+
438
+ # let chunk size default to None
439
+ self._chunk_size = None
440
+ self._chunk_dim = 0
441
+
442
+ def forward(
443
+ self,
444
+ hidden_states: torch.FloatTensor,
445
+ encoder_hidden_states: torch.FloatTensor,
446
+ temb: torch.FloatTensor,
447
+ image_rotary_emb=None,
448
+ attention_mask: Optional[torch.Tensor] = None,
449
+ ):
450
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
451
+ hidden_states, emb=temb
452
+ )
453
+
454
+ norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = (
455
+ self.norm1_context(encoder_hidden_states, emb=temb)
456
+ )
457
+
458
+ if attention_mask is not None:
459
+ attention_mask = expand_flux_attention_mask(
460
+ torch.cat([encoder_hidden_states, hidden_states], dim=1),
461
+ attention_mask,
462
+ )
463
+
464
+ # Attention.
465
+ attn_output, context_attn_output = self.attn(
466
+ hidden_states=norm_hidden_states,
467
+ encoder_hidden_states=norm_encoder_hidden_states,
468
+ image_rotary_emb=image_rotary_emb,
469
+ attention_mask=attention_mask,
470
+ )
471
+
472
+ # Process attention outputs for the `hidden_states`.
473
+ attn_output = gate_msa.unsqueeze(1) * attn_output
474
+ hidden_states = hidden_states + attn_output
475
+
476
+ norm_hidden_states = self.norm2(hidden_states)
477
+ norm_hidden_states = (
478
+ norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
479
+ )
480
+
481
+ ff_output = self.ff(norm_hidden_states)
482
+ ff_output = gate_mlp.unsqueeze(1) * ff_output
483
+
484
+ hidden_states = hidden_states + ff_output
485
+
486
+ # Process attention outputs for the `encoder_hidden_states`.
487
+
488
+ context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output
489
+ encoder_hidden_states = encoder_hidden_states + context_attn_output
490
+
491
+ norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
492
+ norm_encoder_hidden_states = (
493
+ norm_encoder_hidden_states * (1 + c_scale_mlp[:, None])
494
+ + c_shift_mlp[:, None]
495
+ )
496
+
497
+ context_ff_output = self.ff_context(norm_encoder_hidden_states)
498
+ encoder_hidden_states = (
499
+ encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output
500
+ )
501
+
502
+ return encoder_hidden_states, hidden_states
503
+
504
+
505
+ class FluxTransformer2DModelWithMasking(
506
+ ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin
507
+ ):
508
+ """
509
+ The Transformer model introduced in Flux.
510
+
511
+ Reference: https://blackforestlabs.ai/announcing-black-forest-labs/
512
+
513
+ Parameters:
514
+ patch_size (`int`): Patch size to turn the input data into small patches.
515
+ in_channels (`int`, *optional*, defaults to 16): The number of channels in the input.
516
+ num_layers (`int`, *optional*, defaults to 18): The number of layers of MMDiT blocks to use.
517
+ num_single_layers (`int`, *optional*, defaults to 18): The number of layers of single DiT blocks to use.
518
+ attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head.
519
+ num_attention_heads (`int`, *optional*, defaults to 18): The number of heads to use for multi-head attention.
520
+ joint_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
521
+ pooled_projection_dim (`int`): Number of dimensions to use when projecting the `pooled_projections`.
522
+ guidance_embeds (`bool`, defaults to False): Whether to use guidance embeddings.
523
+ """
524
+
525
+ _supports_gradient_checkpointing = True
526
+
527
+ @register_to_config
528
+ def __init__(
529
+ self,
530
+ patch_size: int = 1,
531
+ in_channels: int = 64,
532
+ num_layers: int = 19,
533
+ num_single_layers: int = 38,
534
+ attention_head_dim: int = 128,
535
+ num_attention_heads: int = 24,
536
+ joint_attention_dim: int = 4096,
537
+ pooled_projection_dim: int = 768,
538
+ guidance_embeds: bool = False,
539
+ axes_dims_rope: List[int] = [16, 56, 56],
540
+ ):
541
+ super().__init__()
542
+ self.out_channels = in_channels
543
+ self.inner_dim = (
544
+ self.config.num_attention_heads * self.config.attention_head_dim
545
+ )
546
+
547
+ self.pos_embed = EmbedND(
548
+ dim=self.inner_dim, theta=10000, axes_dim=axes_dims_rope
549
+ )
550
+ text_time_guidance_cls = (
551
+ CombinedTimestepGuidanceTextProjEmbeddings
552
+ if guidance_embeds
553
+ else CombinedTimestepTextProjEmbeddings
554
+ )
555
+ self.time_text_embed = text_time_guidance_cls(
556
+ embedding_dim=self.inner_dim,
557
+ pooled_projection_dim=self.config.pooled_projection_dim,
558
+ )
559
+
560
+ self.context_embedder = nn.Linear(
561
+ self.config.joint_attention_dim, self.inner_dim
562
+ )
563
+ self.x_embedder = torch.nn.Linear(self.config.in_channels, self.inner_dim)
564
+
565
+ self.transformer_blocks = nn.ModuleList(
566
+ [
567
+ FluxTransformerBlock(
568
+ dim=self.inner_dim,
569
+ num_attention_heads=self.config.num_attention_heads,
570
+ attention_head_dim=self.config.attention_head_dim,
571
+ )
572
+ for i in range(self.config.num_layers)
573
+ ]
574
+ )
575
+
576
+ self.single_transformer_blocks = nn.ModuleList(
577
+ [
578
+ FluxSingleTransformerBlock(
579
+ dim=self.inner_dim,
580
+ num_attention_heads=self.config.num_attention_heads,
581
+ attention_head_dim=self.config.attention_head_dim,
582
+ )
583
+ for i in range(self.config.num_single_layers)
584
+ ]
585
+ )
586
+
587
+ self.norm_out = AdaLayerNormContinuous(
588
+ self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6
589
+ )
590
+ self.proj_out = nn.Linear(
591
+ self.inner_dim, patch_size * patch_size * self.out_channels, bias=True
592
+ )
593
+
594
+ self.gradient_checkpointing = False
595
+
596
+ def _set_gradient_checkpointing(self, module, value=False):
597
+ if hasattr(module, "gradient_checkpointing"):
598
+ module.gradient_checkpointing = value
599
+
600
+ def forward(
601
+ self,
602
+ hidden_states: torch.Tensor,
603
+ encoder_hidden_states: torch.Tensor = None,
604
+ pooled_projections: torch.Tensor = None,
605
+ timestep: torch.LongTensor = None,
606
+ img_ids: torch.Tensor = None,
607
+ txt_ids: torch.Tensor = None,
608
+ guidance: torch.Tensor = None,
609
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
610
+ return_dict: bool = True,
611
+ attention_mask: Optional[torch.Tensor] = None,
612
+ ) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
613
+ """
614
+ The [`FluxTransformer2DModelWithMasking`] forward method.
615
+
616
+ Args:
617
+ hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`):
618
+ Input `hidden_states`.
619
+ encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`):
620
+ Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
621
+ pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected
622
+ from the embeddings of input conditions.
623
+ timestep ( `torch.LongTensor`):
624
+ Used to indicate denoising step.
625
+ block_controlnet_hidden_states: (`list` of `torch.Tensor`):
626
+ A list of tensors that if specified are added to the residuals of transformer blocks.
627
+ joint_attention_kwargs (`dict`, *optional*):
628
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
629
+ `self.processor` in
630
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
631
+ return_dict (`bool`, *optional*, defaults to `True`):
632
+ Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
633
+ tuple.
634
+
635
+ Returns:
636
+ If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
637
+ `tuple` where the first element is the sample tensor.
638
+ """
639
+ if joint_attention_kwargs is not None:
640
+ joint_attention_kwargs = joint_attention_kwargs.copy()
641
+ lora_scale = joint_attention_kwargs.pop("scale", 1.0)
642
+ else:
643
+ lora_scale = 1.0
644
+
645
+ if USE_PEFT_BACKEND:
646
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
647
+ scale_lora_layers(self, lora_scale)
648
+ else:
649
+ if (
650
+ joint_attention_kwargs is not None
651
+ and joint_attention_kwargs.get("scale", None) is not None
652
+ ):
653
+ logger.warning(
654
+ "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
655
+ )
656
+ hidden_states = self.x_embedder(hidden_states)
657
+
658
+ timestep = timestep.to(hidden_states.dtype) * 1000
659
+ if guidance is not None:
660
+ guidance = guidance.to(hidden_states.dtype) * 1000
661
+ else:
662
+ guidance = None
663
+ temb = (
664
+ self.time_text_embed(timestep, pooled_projections)
665
+ if guidance is None
666
+ else self.time_text_embed(timestep, guidance, pooled_projections)
667
+ )
668
+ encoder_hidden_states = self.context_embedder(encoder_hidden_states)
669
+
670
+ ids = torch.cat((txt_ids, img_ids), dim=1)
671
+ image_rotary_emb = self.pos_embed(ids)
672
+
673
+ for index_block, block in enumerate(self.transformer_blocks):
674
+ if self.training and self.gradient_checkpointing:
675
+
676
+ def create_custom_forward(module, return_dict=None):
677
+ def custom_forward(*inputs):
678
+ if return_dict is not None:
679
+ return module(*inputs, return_dict=return_dict)
680
+ else:
681
+ return module(*inputs)
682
+
683
+ return custom_forward
684
+
685
+ ckpt_kwargs: Dict[str, Any] = (
686
+ {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
687
+ )
688
+ encoder_hidden_states, hidden_states = (
689
+ torch.utils.checkpoint.checkpoint(
690
+ create_custom_forward(block),
691
+ hidden_states,
692
+ encoder_hidden_states,
693
+ temb,
694
+ image_rotary_emb,
695
+ attention_mask,
696
+ **ckpt_kwargs,
697
+ )
698
+ )
699
+
700
+ else:
701
+ encoder_hidden_states, hidden_states = block(
702
+ hidden_states=hidden_states,
703
+ encoder_hidden_states=encoder_hidden_states,
704
+ temb=temb,
705
+ image_rotary_emb=image_rotary_emb,
706
+ attention_mask=attention_mask,
707
+ )
708
+
709
+ # Flux places the text tokens in front of the image tokens in the
710
+ # sequence.
711
+ hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
712
+
713
+ for index_block, block in enumerate(self.single_transformer_blocks):
714
+ if self.training and self.gradient_checkpointing:
715
+
716
+ def create_custom_forward(module, return_dict=None):
717
+ def custom_forward(*inputs):
718
+ if return_dict is not None:
719
+ return module(*inputs, return_dict=return_dict)
720
+ else:
721
+ return module(*inputs)
722
+
723
+ return custom_forward
724
+
725
+ ckpt_kwargs: Dict[str, Any] = (
726
+ {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
727
+ )
728
+ hidden_states = torch.utils.checkpoint.checkpoint(
729
+ create_custom_forward(block),
730
+ hidden_states,
731
+ temb,
732
+ image_rotary_emb,
733
+ attention_mask,
734
+ **ckpt_kwargs,
735
+ )
736
+
737
+ else:
738
+ hidden_states = block(
739
+ hidden_states=hidden_states,
740
+ temb=temb,
741
+ image_rotary_emb=image_rotary_emb,
742
+ attention_mask=attention_mask,
743
+ )
744
+
745
+ hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...]
746
+
747
+ hidden_states = self.norm_out(hidden_states, temb)
748
+ output = self.proj_out(hidden_states)
749
+
750
+ if USE_PEFT_BACKEND:
751
+ # remove `lora_scale` from each PEFT layer
752
+ unscale_lora_layers(self, lora_scale)
753
+
754
+ if not return_dict:
755
+ return (output,)
756
+
757
+ return Transformer2DModelOutput(sample=output)