twodgirl commited on
Commit
56e5a97
1 Parent(s): a4814de

Create patched transformer.

Browse files
Files changed (1) hide show
  1. flux_model.py +272 -0
flux_model.py ADDED
@@ -0,0 +1,272 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Black Forest Labs, The HuggingFace Team and The InstantX Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
16
+ from diffusers.models.activations import FP32SiLU, get_activation
17
+ from diffusers.models.embeddings import Timesteps, PixArtAlphaTextProjection
18
+ from diffusers.models.modeling_utils import ModelMixin
19
+ from diffusers.models.transformers.transformer_flux import AdaLayerNormContinuous, CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings, EmbedND, FluxSingleTransformerBlock, FluxTransformerBlock
20
+ from diffusers.models.modeling_outputs import Transformer2DModelOutput
21
+ from diffusers.utils import logging
22
+ import torch
23
+ from torch import nn
24
+
25
+ logger = logging.get_logger(__name__)
26
+
27
+ class FluxTransformer2DModel(ModelMixin, ConfigMixin):
28
+ """
29
+ The Transformer model introduced in Flux.
30
+
31
+ Reference: https://blackforestlabs.ai/announcing-black-forest-labs/
32
+
33
+ Parameters:
34
+ patch_size (`int`): Patch size to turn the input data into small patches.
35
+ in_channels (`int`, *optional*, defaults to 16): The number of channels in the input.
36
+ num_layers (`int`, *optional*, defaults to 18): The number of layers of MMDiT blocks to use.
37
+ num_single_layers (`int`, *optional*, defaults to 18): The number of layers of single DiT blocks to use.
38
+ attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head.
39
+ num_attention_heads (`int`, *optional*, defaults to 18): The number of heads to use for multi-head attention.
40
+ joint_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
41
+ pooled_projection_dim (`int`): Number of dimensions to use when projecting the `pooled_projections`.
42
+ guidance_embeds (`bool`, defaults to False): Whether to use guidance embeddings.
43
+ """
44
+
45
+ _supports_gradient_checkpointing = True
46
+ _no_split_modules = ["FluxTransformerBlock", "FluxSingleTransformerBlock"]
47
+
48
+ @register_to_config
49
+ def __init__(
50
+ self,
51
+ patch_size: int = 1,
52
+ in_channels: int = 64,
53
+ num_layers: int = 19,
54
+ num_single_layers: int = 38,
55
+ attention_head_dim: int = 128,
56
+ num_attention_heads: int = 24,
57
+ joint_attention_dim: int = 4096,
58
+ pooled_projection_dim: int = 768,
59
+ guidance_embeds: bool = False,
60
+ axes_dims_rope=(16, 56, 56),
61
+ device=None
62
+ ):
63
+ super().__init__()
64
+ self.out_channels = in_channels
65
+ self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim
66
+
67
+ # self.pos_embed = FluxPosEmbed(theta=10000, axes_dim=axes_dims_rope)
68
+ self.pos_embed = EmbedND(dim=self.inner_dim, theta=10000, axes_dim=axes_dims_rope).to(device)
69
+
70
+ text_time_guidance_cls = (
71
+ CombinedTimestepGuidanceTextProjEmbeddings if guidance_embeds else CombinedTimestepTextProjEmbeddings
72
+ )
73
+ self.time_text_embed = text_time_guidance_cls(
74
+ embedding_dim=self.inner_dim, pooled_projection_dim=self.config.pooled_projection_dim
75
+ ).to(device)
76
+
77
+ self.context_embedder = nn.Linear(self.config.joint_attention_dim, self.inner_dim).to(device)
78
+ self.x_embedder = nn.Linear(self.config.in_channels, self.inner_dim).to(device)
79
+
80
+ self.transformer_blocks = nn.ModuleList(
81
+ [
82
+ FluxTransformerBlock(
83
+ dim=self.inner_dim,
84
+ num_attention_heads=self.config.num_attention_heads,
85
+ attention_head_dim=self.config.attention_head_dim,
86
+ ).to(device)
87
+ for i in range(self.config.num_layers)
88
+ ]
89
+ )
90
+
91
+ self.single_transformer_blocks = nn.ModuleList(
92
+ [
93
+ FluxSingleTransformerBlock(
94
+ dim=self.inner_dim,
95
+ num_attention_heads=self.config.num_attention_heads,
96
+ attention_head_dim=self.config.attention_head_dim,
97
+ ).to(device)
98
+ for i in range(self.config.num_single_layers)
99
+ ]
100
+ )
101
+
102
+ self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6).to(device)
103
+ self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True).to(device)
104
+
105
+ self.pul_id = None
106
+ self.pul_id_weight = 1.0
107
+
108
+ self.gradient_checkpointing = False
109
+
110
+ @property
111
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
112
+ def attn_processors(self):
113
+ r"""
114
+ Returns:
115
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
116
+ indexed by its weight name.
117
+ """
118
+ # set recursively
119
+ processors = {}
120
+
121
+ def fn_recursive_add_processors(name: str, module: nn.Module, processors):
122
+ if hasattr(module, "get_processor"):
123
+ processors[f"{name}.processor"] = module.get_processor()
124
+
125
+ for sub_name, child in module.named_children():
126
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
127
+
128
+ return processors
129
+
130
+ for name, module in self.named_children():
131
+ fn_recursive_add_processors(name, module, processors)
132
+
133
+ return processors
134
+
135
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
136
+ def set_attn_processor(self, processor):
137
+ r"""
138
+ Sets the attention processor to use to compute attention.
139
+
140
+ Parameters:
141
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
142
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
143
+ for **all** `Attention` layers.
144
+
145
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
146
+ processor. This is strongly recommended when setting trainable attention processors.
147
+
148
+ """
149
+ count = len(self.attn_processors.keys())
150
+
151
+ if isinstance(processor, dict) and len(processor) != count:
152
+ raise ValueError(
153
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
154
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
155
+ )
156
+
157
+ def fn_recursive_attn_processor(name: str, module: nn.Module, processor):
158
+ if hasattr(module, "set_processor"):
159
+ if not isinstance(processor, dict):
160
+ module.set_processor(processor)
161
+ else:
162
+ module.set_processor(processor.pop(f"{name}.processor"))
163
+
164
+ for sub_name, child in module.named_children():
165
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
166
+
167
+ for name, module in self.named_children():
168
+ fn_recursive_attn_processor(name, module, processor)
169
+
170
+ def _set_gradient_checkpointing(self, module, value=False):
171
+ if hasattr(module, "gradient_checkpointing"):
172
+ module.gradient_checkpointing = value
173
+
174
+ def forward(
175
+ self,
176
+ hidden_states: torch.Tensor,
177
+ encoder_hidden_states: torch.Tensor = None,
178
+ pooled_projections: torch.Tensor = None,
179
+ timestep: torch.LongTensor = None,
180
+ img_ids: torch.Tensor = None,
181
+ txt_ids: torch.Tensor = None,
182
+ guidance: torch.Tensor = None,
183
+ joint_attention_kwargs = None,
184
+ controlnet_block_samples=None,
185
+ controlnet_single_block_samples=None,
186
+ return_dict: bool = True
187
+ ):
188
+ """
189
+ The [`FluxTransformer2DModel`] forward method.
190
+
191
+ Args:
192
+ hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`):
193
+ Input `hidden_states`.
194
+ encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`):
195
+ Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
196
+ pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected
197
+ from the embeddings of input conditions.
198
+ timestep ( `torch.LongTensor`):
199
+ Used to indicate denoising step.
200
+ block_controlnet_hidden_states: (`list` of `torch.Tensor`):
201
+ A list of tensors that if specified are added to the residuals of transformer blocks.
202
+ joint_attention_kwargs (`dict`, *optional*):
203
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
204
+ `self.processor` in
205
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
206
+ return_dict (`bool`, *optional*, defaults to `True`):
207
+ Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
208
+ tuple.
209
+
210
+ Returns:
211
+ If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
212
+ `tuple` where the first element is the sample tensor.
213
+ """
214
+ hidden_states = self.x_embedder(hidden_states)
215
+
216
+ timestep = timestep.to(hidden_states.dtype) * 1000
217
+ if guidance is not None:
218
+ guidance = guidance.to(hidden_states.dtype) * 1000
219
+ else:
220
+ guidance = None
221
+ temb = (
222
+ self.time_text_embed(timestep, pooled_projections)
223
+ if guidance is None
224
+ else self.time_text_embed(timestep, guidance, pooled_projections)
225
+ )
226
+ encoder_hidden_states = self.context_embedder(encoder_hidden_states)
227
+
228
+ ###
229
+ # Modified by huggingface/twodgirl.
230
+ # Code from diffusers and PuLID.
231
+
232
+ ids = torch.cat((txt_ids, img_ids), dim=1)
233
+ image_rotary_emb = self.pos_embed(ids)
234
+ ca_index = 0
235
+
236
+ for index_block, block in enumerate(self.transformer_blocks):
237
+ encoder_hidden_states, hidden_states = block(
238
+ hidden_states=hidden_states,
239
+ encoder_hidden_states=encoder_hidden_states,
240
+ temb=temb,
241
+ image_rotary_emb=image_rotary_emb,
242
+ )
243
+
244
+ if index_block % self.pulid_double_interval == 0 and self.pul_id is not None:
245
+ weighted = self.pul_id_weight * self.pulid_ca[ca_index](self.pul_id, hidden_states.to(self.pul_id.dtype))
246
+ hidden_states = hidden_states + weighted.to(hidden_states.dtype)
247
+ ca_index += 1
248
+
249
+ hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
250
+
251
+ for index_block, block in enumerate(self.single_transformer_blocks):
252
+ hidden_states = block(
253
+ hidden_states=hidden_states,
254
+ temb=temb,
255
+ image_rotary_emb=image_rotary_emb,
256
+ )
257
+ if index_block % self.pulid_single_interval == 0 and self.pul_id is not None:
258
+ encoder_hidden_states, real_ = hidden_states[:, :encoder_hidden_states.shape[1], ...], hidden_states[:, encoder_hidden_states.shape[1]:, ...]
259
+ weighted = self.pul_id_weight * self.pulid_ca[ca_index](self.pul_id, real_.to(self.pul_id.dtype))
260
+ real_ = real_ + weighted.to(real_.dtype)
261
+ hidden_states = torch.cat([encoder_hidden_states, real_], dim=1)
262
+ ca_index += 1
263
+
264
+ hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...]
265
+
266
+ hidden_states = self.norm_out(hidden_states, temb)
267
+ output = self.proj_out(hidden_states)
268
+
269
+ if not return_dict:
270
+ return (output,)
271
+
272
+ return Transformer2DModelOutput(sample=output)