DamarJati commited on
Commit
a126ec1
1 Parent(s): 132db28

Upload 3 files

Browse files
src/flux/modules/autoencoder.py ADDED
@@ -0,0 +1,312 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+
3
+ import torch
4
+ from einops import rearrange
5
+ from torch import Tensor, nn
6
+
7
+
8
+ @dataclass
9
+ class AutoEncoderParams:
10
+ resolution: int
11
+ in_channels: int
12
+ ch: int
13
+ out_ch: int
14
+ ch_mult: list[int]
15
+ num_res_blocks: int
16
+ z_channels: int
17
+ scale_factor: float
18
+ shift_factor: float
19
+
20
+
21
+ def swish(x: Tensor) -> Tensor:
22
+ return x * torch.sigmoid(x)
23
+
24
+
25
+ class AttnBlock(nn.Module):
26
+ def __init__(self, in_channels: int):
27
+ super().__init__()
28
+ self.in_channels = in_channels
29
+
30
+ self.norm = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
31
+
32
+ self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1)
33
+ self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1)
34
+ self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1)
35
+ self.proj_out = nn.Conv2d(in_channels, in_channels, kernel_size=1)
36
+
37
+ def attention(self, h_: Tensor) -> Tensor:
38
+ h_ = self.norm(h_)
39
+ q = self.q(h_)
40
+ k = self.k(h_)
41
+ v = self.v(h_)
42
+
43
+ b, c, h, w = q.shape
44
+ q = rearrange(q, "b c h w -> b 1 (h w) c").contiguous()
45
+ k = rearrange(k, "b c h w -> b 1 (h w) c").contiguous()
46
+ v = rearrange(v, "b c h w -> b 1 (h w) c").contiguous()
47
+ h_ = nn.functional.scaled_dot_product_attention(q, k, v)
48
+
49
+ return rearrange(h_, "b 1 (h w) c -> b c h w", h=h, w=w, c=c, b=b)
50
+
51
+ def forward(self, x: Tensor) -> Tensor:
52
+ return x + self.proj_out(self.attention(x))
53
+
54
+
55
+ class ResnetBlock(nn.Module):
56
+ def __init__(self, in_channels: int, out_channels: int):
57
+ super().__init__()
58
+ self.in_channels = in_channels
59
+ out_channels = in_channels if out_channels is None else out_channels
60
+ self.out_channels = out_channels
61
+
62
+ self.norm1 = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
63
+ self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
64
+ self.norm2 = nn.GroupNorm(num_groups=32, num_channels=out_channels, eps=1e-6, affine=True)
65
+ self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
66
+ if self.in_channels != self.out_channels:
67
+ self.nin_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
68
+
69
+ def forward(self, x):
70
+ h = x
71
+ h = self.norm1(h)
72
+ h = swish(h)
73
+ h = self.conv1(h)
74
+
75
+ h = self.norm2(h)
76
+ h = swish(h)
77
+ h = self.conv2(h)
78
+
79
+ if self.in_channels != self.out_channels:
80
+ x = self.nin_shortcut(x)
81
+
82
+ return x + h
83
+
84
+
85
+ class Downsample(nn.Module):
86
+ def __init__(self, in_channels: int):
87
+ super().__init__()
88
+ # no asymmetric padding in torch conv, must do it ourselves
89
+ self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
90
+
91
+ def forward(self, x: Tensor):
92
+ pad = (0, 1, 0, 1)
93
+ x = nn.functional.pad(x, pad, mode="constant", value=0)
94
+ x = self.conv(x)
95
+ return x
96
+
97
+
98
+ class Upsample(nn.Module):
99
+ def __init__(self, in_channels: int):
100
+ super().__init__()
101
+ self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
102
+
103
+ def forward(self, x: Tensor):
104
+ x = nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
105
+ x = self.conv(x)
106
+ return x
107
+
108
+
109
+ class Encoder(nn.Module):
110
+ def __init__(
111
+ self,
112
+ resolution: int,
113
+ in_channels: int,
114
+ ch: int,
115
+ ch_mult: list[int],
116
+ num_res_blocks: int,
117
+ z_channels: int,
118
+ ):
119
+ super().__init__()
120
+ self.ch = ch
121
+ self.num_resolutions = len(ch_mult)
122
+ self.num_res_blocks = num_res_blocks
123
+ self.resolution = resolution
124
+ self.in_channels = in_channels
125
+ # downsampling
126
+ self.conv_in = nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1)
127
+
128
+ curr_res = resolution
129
+ in_ch_mult = (1,) + tuple(ch_mult)
130
+ self.in_ch_mult = in_ch_mult
131
+ self.down = nn.ModuleList()
132
+ block_in = self.ch
133
+ for i_level in range(self.num_resolutions):
134
+ block = nn.ModuleList()
135
+ attn = nn.ModuleList()
136
+ block_in = ch * in_ch_mult[i_level]
137
+ block_out = ch * ch_mult[i_level]
138
+ for _ in range(self.num_res_blocks):
139
+ block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
140
+ block_in = block_out
141
+ down = nn.Module()
142
+ down.block = block
143
+ down.attn = attn
144
+ if i_level != self.num_resolutions - 1:
145
+ down.downsample = Downsample(block_in)
146
+ curr_res = curr_res // 2
147
+ self.down.append(down)
148
+
149
+ # middle
150
+ self.mid = nn.Module()
151
+ self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in)
152
+ self.mid.attn_1 = AttnBlock(block_in)
153
+ self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in)
154
+
155
+ # end
156
+ self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True)
157
+ self.conv_out = nn.Conv2d(block_in, 2 * z_channels, kernel_size=3, stride=1, padding=1)
158
+
159
+ def forward(self, x: Tensor) -> Tensor:
160
+ # downsampling
161
+ hs = [self.conv_in(x)]
162
+ for i_level in range(self.num_resolutions):
163
+ for i_block in range(self.num_res_blocks):
164
+ h = self.down[i_level].block[i_block](hs[-1])
165
+ if len(self.down[i_level].attn) > 0:
166
+ h = self.down[i_level].attn[i_block](h)
167
+ hs.append(h)
168
+ if i_level != self.num_resolutions - 1:
169
+ hs.append(self.down[i_level].downsample(hs[-1]))
170
+
171
+ # middle
172
+ h = hs[-1]
173
+ h = self.mid.block_1(h)
174
+ h = self.mid.attn_1(h)
175
+ h = self.mid.block_2(h)
176
+ # end
177
+ h = self.norm_out(h)
178
+ h = swish(h)
179
+ h = self.conv_out(h)
180
+ return h
181
+
182
+
183
+ class Decoder(nn.Module):
184
+ def __init__(
185
+ self,
186
+ ch: int,
187
+ out_ch: int,
188
+ ch_mult: list[int],
189
+ num_res_blocks: int,
190
+ in_channels: int,
191
+ resolution: int,
192
+ z_channels: int,
193
+ ):
194
+ super().__init__()
195
+ self.ch = ch
196
+ self.num_resolutions = len(ch_mult)
197
+ self.num_res_blocks = num_res_blocks
198
+ self.resolution = resolution
199
+ self.in_channels = in_channels
200
+ self.ffactor = 2 ** (self.num_resolutions - 1)
201
+
202
+ # compute in_ch_mult, block_in and curr_res at lowest res
203
+ block_in = ch * ch_mult[self.num_resolutions - 1]
204
+ curr_res = resolution // 2 ** (self.num_resolutions - 1)
205
+ self.z_shape = (1, z_channels, curr_res, curr_res)
206
+
207
+ # z to block_in
208
+ self.conv_in = nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1)
209
+
210
+ # middle
211
+ self.mid = nn.Module()
212
+ self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in)
213
+ self.mid.attn_1 = AttnBlock(block_in)
214
+ self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in)
215
+
216
+ # upsampling
217
+ self.up = nn.ModuleList()
218
+ for i_level in reversed(range(self.num_resolutions)):
219
+ block = nn.ModuleList()
220
+ attn = nn.ModuleList()
221
+ block_out = ch * ch_mult[i_level]
222
+ for _ in range(self.num_res_blocks + 1):
223
+ block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
224
+ block_in = block_out
225
+ up = nn.Module()
226
+ up.block = block
227
+ up.attn = attn
228
+ if i_level != 0:
229
+ up.upsample = Upsample(block_in)
230
+ curr_res = curr_res * 2
231
+ self.up.insert(0, up) # prepend to get consistent order
232
+
233
+ # end
234
+ self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True)
235
+ self.conv_out = nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1)
236
+
237
+ def forward(self, z: Tensor) -> Tensor:
238
+ # z to block_in
239
+ h = self.conv_in(z)
240
+
241
+ # middle
242
+ h = self.mid.block_1(h)
243
+ h = self.mid.attn_1(h)
244
+ h = self.mid.block_2(h)
245
+
246
+ # upsampling
247
+ for i_level in reversed(range(self.num_resolutions)):
248
+ for i_block in range(self.num_res_blocks + 1):
249
+ h = self.up[i_level].block[i_block](h)
250
+ if len(self.up[i_level].attn) > 0:
251
+ h = self.up[i_level].attn[i_block](h)
252
+ if i_level != 0:
253
+ h = self.up[i_level].upsample(h)
254
+
255
+ # end
256
+ h = self.norm_out(h)
257
+ h = swish(h)
258
+ h = self.conv_out(h)
259
+ return h
260
+
261
+
262
+ class DiagonalGaussian(nn.Module):
263
+ def __init__(self, sample: bool = True, chunk_dim: int = 1):
264
+ super().__init__()
265
+ self.sample = sample
266
+ self.chunk_dim = chunk_dim
267
+
268
+ def forward(self, z: Tensor) -> Tensor:
269
+ mean, logvar = torch.chunk(z, 2, dim=self.chunk_dim)
270
+ if self.sample:
271
+ std = torch.exp(0.5 * logvar)
272
+ return mean + std * torch.randn_like(mean)
273
+ else:
274
+ return mean
275
+
276
+
277
+ class AutoEncoder(nn.Module):
278
+ def __init__(self, params: AutoEncoderParams):
279
+ super().__init__()
280
+ self.encoder = Encoder(
281
+ resolution=params.resolution,
282
+ in_channels=params.in_channels,
283
+ ch=params.ch,
284
+ ch_mult=params.ch_mult,
285
+ num_res_blocks=params.num_res_blocks,
286
+ z_channels=params.z_channels,
287
+ )
288
+ self.decoder = Decoder(
289
+ resolution=params.resolution,
290
+ in_channels=params.in_channels,
291
+ ch=params.ch,
292
+ out_ch=params.out_ch,
293
+ ch_mult=params.ch_mult,
294
+ num_res_blocks=params.num_res_blocks,
295
+ z_channels=params.z_channels,
296
+ )
297
+ self.reg = DiagonalGaussian()
298
+
299
+ self.scale_factor = params.scale_factor
300
+ self.shift_factor = params.shift_factor
301
+
302
+ def encode(self, x: Tensor) -> Tensor:
303
+ z = self.reg(self.encoder(x))
304
+ z = self.scale_factor * (z - self.shift_factor)
305
+ return z
306
+
307
+ def decode(self, z: Tensor) -> Tensor:
308
+ z = z / self.scale_factor + self.shift_factor
309
+ return self.decoder(z)
310
+
311
+ def forward(self, x: Tensor) -> Tensor:
312
+ return self.decode(self.encode(x))
src/flux/modules/conditioner.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import Tensor, nn
2
+ from transformers import (CLIPTextModel, CLIPTokenizer, T5EncoderModel,
3
+ T5Tokenizer)
4
+
5
+
6
+ class HFEmbedder(nn.Module):
7
+ def __init__(self, version: str, max_length: int, **hf_kwargs):
8
+ super().__init__()
9
+ self.is_clip = version.startswith("openai")
10
+ self.max_length = max_length
11
+ self.output_key = "pooler_output" if self.is_clip else "last_hidden_state"
12
+
13
+ if self.is_clip:
14
+ self.tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained(version, max_length=max_length)
15
+ self.hf_module: CLIPTextModel = CLIPTextModel.from_pretrained(version, **hf_kwargs)
16
+ else:
17
+ self.tokenizer: T5Tokenizer = T5Tokenizer.from_pretrained(version, max_length=max_length)
18
+ self.hf_module: T5EncoderModel = T5EncoderModel.from_pretrained(version, **hf_kwargs)
19
+
20
+ self.hf_module = self.hf_module.eval().requires_grad_(False)
21
+
22
+ def forward(self, text: list[str]) -> Tensor:
23
+ batch_encoding = self.tokenizer(
24
+ text,
25
+ truncation=True,
26
+ max_length=self.max_length,
27
+ return_length=False,
28
+ return_overflowing_tokens=False,
29
+ padding="max_length",
30
+ return_tensors="pt",
31
+ )
32
+
33
+ outputs = self.hf_module(
34
+ input_ids=batch_encoding["input_ids"].to(self.hf_module.device),
35
+ attention_mask=None,
36
+ output_hidden_states=False,
37
+ )
38
+ return outputs[self.output_key]
src/flux/modules/layers.py ADDED
@@ -0,0 +1,355 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from dataclasses import dataclass
3
+
4
+ import torch
5
+ from einops import rearrange
6
+ from torch import Tensor, nn
7
+
8
+ from ..math import attention, rope
9
+
10
+
11
+ class EmbedND(nn.Module):
12
+ def __init__(self, dim: int, theta: int, axes_dim: list[int]):
13
+ super().__init__()
14
+ self.dim = dim
15
+ self.theta = theta
16
+ self.axes_dim = axes_dim
17
+
18
+ def forward(self, ids: Tensor) -> Tensor:
19
+ n_axes = ids.shape[-1]
20
+ emb = torch.cat(
21
+ [rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)],
22
+ dim=-3,
23
+ )
24
+
25
+ return emb.unsqueeze(1)
26
+
27
+
28
+ def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 1000.0):
29
+ """
30
+ Create sinusoidal timestep embeddings.
31
+ :param t: a 1-D Tensor of N indices, one per batch element.
32
+ These may be fractional.
33
+ :param dim: the dimension of the output.
34
+ :param max_period: controls the minimum frequency of the embeddings.
35
+ :return: an (N, D) Tensor of positional embeddings.
36
+ """
37
+ t = time_factor * t
38
+ half = dim // 2
39
+ freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(
40
+ t.device
41
+ )
42
+
43
+ args = t[:, None].float() * freqs[None]
44
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
45
+ if dim % 2:
46
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
47
+ if torch.is_floating_point(t):
48
+ embedding = embedding.to(t)
49
+ return embedding
50
+
51
+
52
+ class MLPEmbedder(nn.Module):
53
+ def __init__(self, in_dim: int, hidden_dim: int):
54
+ super().__init__()
55
+ self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True)
56
+ self.silu = nn.SiLU()
57
+ self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True)
58
+
59
+ def forward(self, x: Tensor) -> Tensor:
60
+ return self.out_layer(self.silu(self.in_layer(x)))
61
+
62
+
63
+ class RMSNorm(torch.nn.Module):
64
+ def __init__(self, dim: int):
65
+ super().__init__()
66
+ self.scale = nn.Parameter(torch.ones(dim))
67
+
68
+ def forward(self, x: Tensor):
69
+ x_dtype = x.dtype
70
+ x = x.float()
71
+ rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + 1e-6)
72
+ return (x * rrms).to(dtype=x_dtype) * self.scale
73
+
74
+
75
+ class QKNorm(torch.nn.Module):
76
+ def __init__(self, dim: int):
77
+ super().__init__()
78
+ self.query_norm = RMSNorm(dim)
79
+ self.key_norm = RMSNorm(dim)
80
+
81
+ def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple[Tensor, Tensor]:
82
+ q = self.query_norm(q)
83
+ k = self.key_norm(k)
84
+ return q.to(v), k.to(v)
85
+
86
+ class LoRALinearLayer(nn.Module):
87
+ def __init__(self, in_features, out_features, rank=4, network_alpha=None, device=None, dtype=None):
88
+ super().__init__()
89
+
90
+ self.down = nn.Linear(in_features, rank, bias=False, device=device, dtype=dtype)
91
+ self.up = nn.Linear(rank, out_features, bias=False, device=device, dtype=dtype)
92
+ # This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script.
93
+ # See https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning
94
+ self.network_alpha = network_alpha
95
+ self.rank = rank
96
+
97
+ nn.init.normal_(self.down.weight, std=1 / rank)
98
+ nn.init.zeros_(self.up.weight)
99
+
100
+ def forward(self, hidden_states):
101
+ orig_dtype = hidden_states.dtype
102
+ dtype = self.down.weight.dtype
103
+
104
+ down_hidden_states = self.down(hidden_states.to(dtype))
105
+ up_hidden_states = self.up(down_hidden_states)
106
+
107
+ if self.network_alpha is not None:
108
+ up_hidden_states *= self.network_alpha / self.rank
109
+
110
+ return up_hidden_states.to(orig_dtype)
111
+
112
+ class FLuxSelfAttnProcessor:
113
+ def __call__(self, attn, x, pe, **attention_kwargs):
114
+ print('2' * 30)
115
+
116
+ qkv = attn.qkv(x)
117
+ q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
118
+ q, k = attn.norm(q, k, v)
119
+ x = attention(q, k, v, pe=pe)
120
+ x = attn.proj(x)
121
+ return x
122
+
123
+ class LoraFluxAttnProcessor(nn.Module):
124
+
125
+ def __init__(self, dim: int, rank=4, network_alpha=None, lora_weight=1):
126
+ super().__init__()
127
+ self.qkv_lora = LoRALinearLayer(dim, dim * 3, rank, network_alpha)
128
+ self.proj_lora = LoRALinearLayer(dim, dim, rank, network_alpha)
129
+ self.lora_weight = lora_weight
130
+
131
+
132
+ def __call__(self, attn, x, pe, **attention_kwargs):
133
+ qkv = attn.qkv(x) + self.qkv_lora(x) * self.lora_weight
134
+ q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
135
+ q, k = attn.norm(q, k, v)
136
+ x = attention(q, k, v, pe=pe)
137
+ x = attn.proj(x) + self.proj_lora(x) * self.lora_weight
138
+ print('1' * 30)
139
+ print(x.norm(), (self.proj_lora(x) * self.lora_weight).norm(), 'norm')
140
+ return x
141
+
142
+ class SelfAttention(nn.Module):
143
+ def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False):
144
+ super().__init__()
145
+ self.num_heads = num_heads
146
+ head_dim = dim // num_heads
147
+
148
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
149
+ self.norm = QKNorm(head_dim)
150
+ self.proj = nn.Linear(dim, dim)
151
+ def forward():
152
+ pass
153
+
154
+
155
+ @dataclass
156
+ class ModulationOut:
157
+ shift: Tensor
158
+ scale: Tensor
159
+ gate: Tensor
160
+
161
+
162
+ class Modulation(nn.Module):
163
+ def __init__(self, dim: int, double: bool):
164
+ super().__init__()
165
+ self.is_double = double
166
+ self.multiplier = 6 if double else 3
167
+ self.lin = nn.Linear(dim, self.multiplier * dim, bias=True)
168
+
169
+ def forward(self, vec: Tensor) -> tuple[ModulationOut, ModulationOut | None]:
170
+ out = self.lin(nn.functional.silu(vec))[:, None, :].chunk(self.multiplier, dim=-1)
171
+
172
+ return (
173
+ ModulationOut(*out[:3]),
174
+ ModulationOut(*out[3:]) if self.is_double else None,
175
+ )
176
+
177
+ class DoubleStreamBlockLoraProcessor(nn.Module):
178
+ def __init__(self, dim: int, rank=4, network_alpha=None, lora_weight=1):
179
+ super().__init__()
180
+ self.qkv_lora1 = LoRALinearLayer(dim, dim * 3, rank, network_alpha)
181
+ self.proj_lora1 = LoRALinearLayer(dim, dim, rank, network_alpha)
182
+ self.qkv_lora2 = LoRALinearLayer(dim, dim * 3, rank, network_alpha)
183
+ self.proj_lora2 = LoRALinearLayer(dim, dim, rank, network_alpha)
184
+ self.lora_weight = lora_weight
185
+
186
+ def forward(self, attn, img, txt, vec, pe, **attention_kwargs):
187
+ img_mod1, img_mod2 = attn.img_mod(vec)
188
+ txt_mod1, txt_mod2 = attn.txt_mod(vec)
189
+
190
+ # prepare image for attention
191
+ img_modulated = attn.img_norm1(img)
192
+ img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
193
+ img_qkv = attn.img_attn.qkv(img_modulated) + self.qkv_lora1(img_modulated) * self.lora_weight
194
+ img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads)
195
+ img_q, img_k = attn.img_attn.norm(img_q, img_k, img_v)
196
+
197
+ # prepare txt for attention
198
+ txt_modulated = attn.txt_norm1(txt)
199
+ txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
200
+ txt_qkv = attn.txt_attn.qkv(txt_modulated) + self.qkv_lora2(txt_modulated) * self.lora_weight
201
+ txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads)
202
+ txt_q, txt_k = attn.txt_attn.norm(txt_q, txt_k, txt_v)
203
+
204
+ # run actual attention
205
+ q = torch.cat((txt_q, img_q), dim=2)
206
+ k = torch.cat((txt_k, img_k), dim=2)
207
+ v = torch.cat((txt_v, img_v), dim=2)
208
+
209
+ attn1 = attention(q, k, v, pe=pe)
210
+ txt_attn, img_attn = attn1[:, : txt.shape[1]], attn1[:, txt.shape[1] :]
211
+
212
+ # calculate the img bloks
213
+ img = img + img_mod1.gate * attn.img_attn.proj(img_attn) + img_mod1.gate * self.proj_lora1(img_attn) * self.lora_weight
214
+ img = img + img_mod2.gate * attn.img_mlp((1 + img_mod2.scale) * attn.img_norm2(img) + img_mod2.shift)
215
+
216
+ # calculate the txt bloks
217
+ txt = txt + txt_mod1.gate * attn.txt_attn.proj(txt_attn) + txt_mod1.gate * self.proj_lora2(txt_attn) * self.lora_weight
218
+ txt = txt + txt_mod2.gate * attn.txt_mlp((1 + txt_mod2.scale) * attn.txt_norm2(txt) + txt_mod2.shift)
219
+ return img, txt
220
+
221
+ class DoubleStreamBlockProcessor:
222
+ def __call__(self, attn, img, txt, vec, pe, **attention_kwargs):
223
+ img_mod1, img_mod2 = attn.img_mod(vec)
224
+ txt_mod1, txt_mod2 = attn.txt_mod(vec)
225
+
226
+ # prepare image for attention
227
+ img_modulated = attn.img_norm1(img)
228
+ img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
229
+ img_qkv = attn.img_attn.qkv(img_modulated)
230
+ img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads)
231
+ img_q, img_k = attn.img_attn.norm(img_q, img_k, img_v)
232
+
233
+ # prepare txt for attention
234
+ txt_modulated = attn.txt_norm1(txt)
235
+ txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
236
+ txt_qkv = attn.txt_attn.qkv(txt_modulated)
237
+ txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads)
238
+ txt_q, txt_k = attn.txt_attn.norm(txt_q, txt_k, txt_v)
239
+
240
+ # run actual attention
241
+ q = torch.cat((txt_q, img_q), dim=2)
242
+ k = torch.cat((txt_k, img_k), dim=2)
243
+ v = torch.cat((txt_v, img_v), dim=2)
244
+
245
+ attn1 = attention(q, k, v, pe=pe)
246
+ txt_attn, img_attn = attn1[:, : txt.shape[1]], attn1[:, txt.shape[1] :]
247
+
248
+ # calculate the img bloks
249
+ img = img + img_mod1.gate * attn.img_attn.proj(img_attn)
250
+ img = img + img_mod2.gate * attn.img_mlp((1 + img_mod2.scale) * attn.img_norm2(img) + img_mod2.shift)
251
+
252
+ # calculate the txt bloks
253
+ txt = txt + txt_mod1.gate * attn.txt_attn.proj(txt_attn)
254
+ txt = txt + txt_mod2.gate * attn.txt_mlp((1 + txt_mod2.scale) * attn.txt_norm2(txt) + txt_mod2.shift)
255
+ return img, txt
256
+
257
+ class DoubleStreamBlock(nn.Module):
258
+ def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False):
259
+ super().__init__()
260
+ mlp_hidden_dim = int(hidden_size * mlp_ratio)
261
+ self.num_heads = num_heads
262
+ self.hidden_size = hidden_size
263
+ self.img_mod = Modulation(hidden_size, double=True)
264
+ self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
265
+ self.img_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias)
266
+
267
+ self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
268
+ self.img_mlp = nn.Sequential(
269
+ nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
270
+ nn.GELU(approximate="tanh"),
271
+ nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
272
+ )
273
+
274
+ self.txt_mod = Modulation(hidden_size, double=True)
275
+ self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
276
+ self.txt_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias)
277
+
278
+ self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
279
+ self.txt_mlp = nn.Sequential(
280
+ nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
281
+ nn.GELU(approximate="tanh"),
282
+ nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
283
+ )
284
+ processor = DoubleStreamBlockProcessor()
285
+ self.set_processor(processor)
286
+
287
+ def set_processor(self, processor) -> None:
288
+ self.processor = processor
289
+
290
+ def get_processor(self):
291
+ return self.processor
292
+
293
+ def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor) -> tuple[Tensor, Tensor]:
294
+ return self.processor(self, img, txt, vec, pe)
295
+
296
+ class SingleStreamBlock(nn.Module):
297
+ """
298
+ A DiT block with parallel linear layers as described in
299
+ https://arxiv.org/abs/2302.05442 and adapted modulation interface.
300
+ """
301
+
302
+ def __init__(
303
+ self,
304
+ hidden_size: int,
305
+ num_heads: int,
306
+ mlp_ratio: float = 4.0,
307
+ qk_scale: float | None = None,
308
+ ):
309
+ super().__init__()
310
+ self.hidden_dim = hidden_size
311
+ self.num_heads = num_heads
312
+ head_dim = hidden_size // num_heads
313
+ self.scale = qk_scale or head_dim**-0.5
314
+
315
+ self.mlp_hidden_dim = int(hidden_size * mlp_ratio)
316
+ # qkv and mlp_in
317
+ self.linear1 = nn.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim)
318
+ # proj and mlp_out
319
+ self.linear2 = nn.Linear(hidden_size + self.mlp_hidden_dim, hidden_size)
320
+
321
+ self.norm = QKNorm(head_dim)
322
+
323
+ self.hidden_size = hidden_size
324
+ self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
325
+
326
+ self.mlp_act = nn.GELU(approximate="tanh")
327
+ self.modulation = Modulation(hidden_size, double=False)
328
+
329
+ def forward(self, x: Tensor, vec: Tensor, pe: Tensor) -> Tensor:
330
+ mod, _ = self.modulation(vec)
331
+ x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift
332
+ qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
333
+
334
+ q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
335
+ q, k = self.norm(q, k, v)
336
+
337
+ # compute attention
338
+ attn = attention(q, k, v, pe=pe)
339
+ # compute activation in mlp stream, cat again and run second linear layer
340
+ output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
341
+ return x + mod.gate * output
342
+
343
+
344
+ class LastLayer(nn.Module):
345
+ def __init__(self, hidden_size: int, patch_size: int, out_channels: int):
346
+ super().__init__()
347
+ self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
348
+ self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
349
+ self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True))
350
+
351
+ def forward(self, x: Tensor, vec: Tensor) -> Tensor:
352
+ shift, scale = self.adaLN_modulation(vec).chunk(2, dim=1)
353
+ x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :]
354
+ x = self.linear(x)
355
+ return x