SunderAli17 commited on
Commit
3fa6c11
·
verified ·
1 Parent(s): 4ff3ac9

Create model.py

Browse files
Files changed (1) hide show
  1. flux/model.py +135 -0
flux/model.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+
3
+ import torch
4
+ from torch import Tensor, nn
5
+
6
+ from flux.modules.layers import (
7
+ DoubleStreamBlock,
8
+ EmbedND,
9
+ LastLayer,
10
+ MLPEmbedder,
11
+ SingleStreamBlock,
12
+ timestep_embedding,
13
+ )
14
+
15
+
16
+ @dataclass
17
+ class FluxParams:
18
+ in_channels: int
19
+ vec_in_dim: int
20
+ context_in_dim: int
21
+ hidden_size: int
22
+ mlp_ratio: float
23
+ num_heads: int
24
+ depth: int
25
+ depth_single_blocks: int
26
+ axes_dim: list[int]
27
+ theta: int
28
+ qkv_bias: bool
29
+ guidance_embed: bool
30
+
31
+
32
+ class Flux(nn.Module):
33
+ """
34
+ Transformer model for flow matching on sequences.
35
+ """
36
+
37
+ def __init__(self, params: FluxParams):
38
+ super().__init__()
39
+
40
+ self.params = params
41
+ self.in_channels = params.in_channels
42
+ self.out_channels = self.in_channels
43
+ if params.hidden_size % params.num_heads != 0:
44
+ raise ValueError(
45
+ f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}"
46
+ )
47
+ pe_dim = params.hidden_size // params.num_heads
48
+ if sum(params.axes_dim) != pe_dim:
49
+ raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}")
50
+ self.hidden_size = params.hidden_size
51
+ self.num_heads = params.num_heads
52
+ self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim)
53
+ self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True)
54
+ self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
55
+ self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size)
56
+ self.guidance_in = (
57
+ MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) if params.guidance_embed else nn.Identity()
58
+ )
59
+ self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size)
60
+
61
+ self.double_blocks = nn.ModuleList(
62
+ [
63
+ DoubleStreamBlock(
64
+ self.hidden_size,
65
+ self.num_heads,
66
+ mlp_ratio=params.mlp_ratio,
67
+ qkv_bias=params.qkv_bias,
68
+ )
69
+ for _ in range(params.depth)
70
+ ]
71
+ )
72
+
73
+ self.single_blocks = nn.ModuleList(
74
+ [
75
+ SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio)
76
+ for _ in range(params.depth_single_blocks)
77
+ ]
78
+ )
79
+
80
+ self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels)
81
+
82
+ self.pulid_ca = None
83
+ self.pulid_double_interval = 2
84
+ self.pulid_single_interval = 4
85
+
86
+ def forward(
87
+ self,
88
+ img: Tensor,
89
+ img_ids: Tensor,
90
+ txt: Tensor,
91
+ txt_ids: Tensor,
92
+ timesteps: Tensor,
93
+ y: Tensor,
94
+ guidance: Tensor = None,
95
+ id: Tensor = None,
96
+ id_weight: float = 1.0,
97
+ ) -> Tensor:
98
+ if img.ndim != 3 or txt.ndim != 3:
99
+ raise ValueError("Input img and txt tensors must have 3 dimensions.")
100
+
101
+ # running on sequences img
102
+ img = self.img_in(img)
103
+ vec = self.time_in(timestep_embedding(timesteps, 256))
104
+ if self.params.guidance_embed:
105
+ if guidance is None:
106
+ raise ValueError("Didn't get guidance strength for guidance distilled model.")
107
+ vec = vec + self.guidance_in(timestep_embedding(guidance, 256))
108
+ vec = vec + self.vector_in(y)
109
+ txt = self.txt_in(txt)
110
+
111
+ ids = torch.cat((txt_ids, img_ids), dim=1)
112
+ pe = self.pe_embedder(ids)
113
+
114
+ ca_idx = 0
115
+ for i, block in enumerate(self.double_blocks):
116
+ img, txt = block(img=img, txt=txt, vec=vec, pe=pe)
117
+
118
+ if i % self.pulid_double_interval == 0 and id is not None:
119
+ img = img + id_weight * self.pulid_ca[ca_idx](id, img)
120
+ ca_idx += 1
121
+
122
+ img = torch.cat((txt, img), 1)
123
+ for i, block in enumerate(self.single_blocks):
124
+ x = block(img, vec=vec, pe=pe)
125
+ real_img, txt = x[:, txt.shape[1]:, ...], x[:, :txt.shape[1], ...]
126
+
127
+ if i % self.pulid_single_interval == 0 and id is not None:
128
+ real_img = real_img + id_weight * self.pulid_ca[ca_idx](id, real_img)
129
+ ca_idx += 1
130
+
131
+ img = torch.cat((txt, real_img), 1)
132
+ img = img[:, txt.shape[1] :, ...]
133
+
134
+ img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
135
+ return img