Safetensors
aredden commited on
Commit
d9aea20
·
0 Parent(s):

initial commit

Browse files
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ __pycache__
README.md ADDED
@@ -0,0 +1,267 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Sure, here's a draft for your README:
2
+
3
+ ````markdown
4
+ # Flux FP16 Accumulate Model Implementation with FastAPI
5
+
6
+ This repository contains an implementation of the Flux model, along with an API that allows you to generate images based on text prompts. The API can be run via command-line arguments.
7
+
8
+ ## Table of Contents
9
+
10
+ - [Installation](#installation)
11
+ - [Usage](#usage)
12
+ - [Configuration](#configuration)
13
+ - [API Endpoints](#api-endpoints)
14
+ - [Examples](#examples)
15
+ - [License](#license)
16
+
17
+ ## Installation
18
+
19
+ To install the required dependencies, run:
20
+
21
+ ```bash
22
+ pip install -r requirements.txt
23
+ ```
24
+ ````
25
+
26
+ ## Usage
27
+
28
+ You can run the API server using the following command:
29
+
30
+ ```bash
31
+ python main.py --config-path <path_to_config> --port <port_number> --host <host_address>
32
+ ```
33
+
34
+ ### Command-Line Arguments
35
+
36
+ - `--config-path`: Path to the configuration file. If not provided, the model will be loaded from the command line arguments.
37
+ - `--port`: Port to run the server on (default: 8088).
38
+ - `--host`: Host to run the server on (default: 0.0.0.0).
39
+ - `--flow-model-path`: Path to the flow model.
40
+ - `--text-enc-path`: Path to the text encoder.
41
+ - `--autoencoder-path`: Path to the autoencoder.
42
+ - `--model-version`: Choose model version (`flux-dev` or `flux-schnell`).
43
+ - `--flux-device`: Device to run the flow model on (default: cuda:0).
44
+ - `--text-enc-device`: Device to run the text encoder on (default: cuda:0).
45
+ - `--autoencoder-device`: Device to run the autoencoder on (default: cuda:0).
46
+ - `--num-to-quant`: Number of linear layers in the flow transformer to quantize (default: 20).
47
+
48
+ ## Configuration
49
+
50
+ The configuration files are located in the `configs` directory. You can specify different configurations for different model versions and devices.
51
+
52
+ Example configuration file (`configs/config-dev.json`):
53
+
54
+ ```json
55
+ {
56
+ "version": "flux-dev",
57
+ "params": {
58
+ "in_channels": 64,
59
+ "vec_in_dim": 768,
60
+ "context_in_dim": 4096,
61
+ "hidden_size": 3072,
62
+ "mlp_ratio": 4.0,
63
+ "num_heads": 24,
64
+ "depth": 19,
65
+ "depth_single_blocks": 38,
66
+ "axes_dim": [16, 56, 56],
67
+ "theta": 10000,
68
+ "qkv_bias": true,
69
+ "guidance_embed": true
70
+ },
71
+ "ae_params": {
72
+ "resolution": 256,
73
+ "in_channels": 3,
74
+ "ch": 128,
75
+ "out_ch": 3,
76
+ "ch_mult": [1, 2, 4, 4],
77
+ "num_res_blocks": 2,
78
+ "z_channels": 16,
79
+ "scale_factor": 0.3611,
80
+ "shift_factor": 0.1159
81
+ },
82
+ "ckpt_path": "/path/to/your/flux1-dev.sft",
83
+ "ae_path": "/path/to/your/ae.sft",
84
+ "repo_id": "black-forest-labs/FLUX.1-dev",
85
+ "repo_flow": "flux1-dev.sft",
86
+ "repo_ae": "ae.sft",
87
+ "text_enc_max_length": 512,
88
+ "text_enc_path": "path/to/your/t5-v1_1-xxl-encoder-bf16", // or "city96/t5-v1_1-xxl-encoder-bf16" for a simple to download version
89
+ "text_enc_device": "cuda:1",
90
+ "ae_device": "cuda:1",
91
+ "flux_device": "cuda:0",
92
+ "flow_dtype": "float16",
93
+ "ae_dtype": "bfloat16",
94
+ "text_enc_dtype": "bfloat16",
95
+ "num_to_quant": 20
96
+ }
97
+ ```
98
+
99
+ ## API Endpoints
100
+
101
+ ### Generate Image
102
+
103
+ - **URL**: `/generate`
104
+ - **Method**: `POST`
105
+ - **Request Body**:
106
+
107
+ - `prompt` (str): The text prompt for image generation.
108
+ - `width` (int, optional): The width of the generated image (default: 720).
109
+ - `height` (int, optional): The height of the generated image (default: 1024).
110
+ - `num_steps` (int, optional): The number of steps for the generation process (default: 24).
111
+ - `guidance` (float, optional): The guidance scale for the generation process (default: 3.5).
112
+ - `seed` (int, optional): The seed for random number generation.
113
+
114
+ - **Response**: A JPEG image stream.
115
+
116
+ ## Examples
117
+
118
+ ### Running the Server
119
+
120
+ ```bash
121
+ python main.py --config-path configs/config-dev.json --port 8088 --host 0.0.0.0
122
+ ```
123
+
124
+ OR, if you need more granular control over the server, you can run the server with something like this:
125
+
126
+ ```bash
127
+ python main.py --port 8088 --host 0.0.0.0 \
128
+ --flow-model-path /path/to/your/flux1-dev.sft \
129
+ --text-enc-path /path/to/your/t5-v1_1-xxl-encoder-bf16 \
130
+ --autoencoder-path /path/to/your/ae.sft \
131
+ --model-version flux-dev \
132
+ --flux-device cuda:0 \
133
+ --text-enc-device cuda:1 \
134
+ --autoencoder-device cuda:1 \
135
+ --num-to-quant 20
136
+ ```
137
+
138
+ ### Generating an Image
139
+
140
+ Send a POST request to `http://<host>:<port>/generate` with the following JSON body:
141
+
142
+ ```json
143
+ {
144
+ "prompt": "a beautiful asian woman in traditional clothing with golden hairpin and blue eyes, wearing a red kimono with dragon patterns",
145
+ "width": 1024,
146
+ "height": 1024,
147
+ "num_steps": 24,
148
+ "guidance": 3.0,
149
+ "seed": 13456
150
+ }
151
+ ```
152
+
153
+ For an example of how to generate from a python client using the FastAPI server:
154
+
155
+ ```py
156
+ import requests
157
+ import io
158
+
159
+ prompt = "a beautiful asian woman in traditional clothing with golden hairpin and blue eyes, wearing a red kimono with dragon patterns"
160
+ res = requests.post(
161
+ "http://localhost:8088/generate",
162
+ json={
163
+ "width": 1024,
164
+ "height": 720,
165
+ "num_steps": 20,
166
+ "guidance": 4,
167
+ "prompt": prompt,
168
+ },
169
+ stream=True,
170
+ )
171
+
172
+ with open(f"output.jpg", "wb") as f:
173
+ f.write(io.BytesIO(res.content).read())
174
+
175
+ ```
176
+
177
+ ## License
178
+
179
+ This project is licensed under the MIT License.
180
+
181
+ ````
182
+
183
+ ## References
184
+
185
+ - Code for loading the pipeline from the configuration path:
186
+
187
+ ```200:310:flux_impl.py
188
+ @torch.inference_mode()
189
+ def load_pipeline_from_config(config: ModelSpec) -> Model:
190
+ models = load_models_from_config(config)
191
+ config = models.config
192
+ num_quanted = 0
193
+ max_quanted = config.num_to_quant
194
+ flux_device = into_device(config.flux_device)
195
+ ae_device = into_device(config.ae_device)
196
+ clip_device = into_device(config.text_enc_device)
197
+ t5_device = into_device(config.text_enc_device)
198
+ flux_dtype = into_dtype(config.flow_dtype)
199
+ device_index = flux_device.index or 0
200
+ flow_model = models.flow.requires_grad_(False).eval().type(flux_dtype)
201
+ for block in flow_model.single_blocks:
202
+ block.cuda(flux_device)
203
+ if num_quanted < max_quanted:
204
+ num_quanted = quant_module(
205
+ block.linear1, num_quanted, device_index=device_index
206
+ )
207
+
208
+ for block in flow_model.double_blocks:
209
+ block.cuda(flux_device)
210
+ if num_quanted < max_quanted:
211
+ num_quanted = full_quant(
212
+ block, max_quanted, num_quanted, device_index=device_index
213
+ )
214
+
215
+ to_gpu_extras = [
216
+ "vector_in",
217
+ "img_in",
218
+ "txt_in",
219
+ "time_in",
220
+ "guidance_in",
221
+ "final_layer",
222
+ "pe_embedder",
223
+ ]
224
+ for extra in to_gpu_extras:
225
+ getattr(flow_model, extra).cuda(flux_device).type(flux_dtype)
226
+ ````
227
+
228
+ - Code for the main entry point:
229
+
230
+ ```59:85:main.py
231
+ def main():
232
+ args = parse_args()
233
+
234
+ if args.config_path:
235
+ app.state.model = load_pipeline_from_config_path(args.config_path)
236
+ else:
237
+ model_version = (
238
+ ModelVersion.flux_dev
239
+ if args.model_version == "flux-dev"
240
+ else ModelVersion.flux_schnell
241
+ )
242
+ config = load_config(
243
+ model_version,
244
+ flux_path=args.flow_model_path,
245
+ flux_device=args.flux_device,
246
+ ae_path=args.autoencoder_path,
247
+ ae_device=args.autoencoder_device,
248
+ text_enc_path=args.text_enc_path,
249
+ text_enc_device=args.text_enc_device,
250
+ flow_dtype="float16",
251
+ text_enc_dtype="bfloat16",
252
+ ae_dtype="bfloat16",
253
+ num_to_quant=args.num_to_quant,
254
+ )
255
+ app.state.model = load_pipeline_from_config(config)
256
+
257
+ uvicorn.run(app, host=args.host, port=args.port)
258
+ ```
259
+
260
+ - Code for the API endpoint:
261
+
262
+ ```22:25:api.py
263
+ @app.post("/generate")
264
+ def generate(args: GenerateArgs):
265
+ result = app.state.model.generate(**args.model_dump())
266
+ return StreamingResponse(result, media_type="image/jpeg")
267
+ ```
api.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ import numpy as np
4
+ from fastapi import FastAPI
5
+ from fastapi.responses import StreamingResponse
6
+ from pydantic import BaseModel, Field
7
+
8
+ app = FastAPI()
9
+
10
+
11
+ class GenerateArgs(BaseModel):
12
+ prompt: str
13
+ width: Optional[int] = Field(default=720)
14
+ height: Optional[int] = Field(default=1024)
15
+ num_steps: Optional[int] = Field(default=24)
16
+ guidance: Optional[float] = Field(default=3.5)
17
+ seed: Optional[int] = Field(
18
+ default_factory=lambda: np.random.randint(0, 2**32 - 1), gt=0, lt=2**32 - 1
19
+ )
20
+
21
+
22
+ @app.post("/generate")
23
+ def generate(args: GenerateArgs):
24
+ result = app.state.model.generate(**args.model_dump())
25
+ return StreamingResponse(result, media_type="image/jpeg")
configs/config-dev-cuda0.json ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "version": "flux-dev",
3
+ "params": {
4
+ "in_channels": 64,
5
+ "vec_in_dim": 768,
6
+ "context_in_dim": 4096,
7
+ "hidden_size": 3072,
8
+ "mlp_ratio": 4.0,
9
+ "num_heads": 24,
10
+ "depth": 19,
11
+ "depth_single_blocks": 38,
12
+ "axes_dim": [
13
+ 16,
14
+ 56,
15
+ 56
16
+ ],
17
+ "theta": 10000,
18
+ "qkv_bias": true,
19
+ "guidance_embed": true
20
+ },
21
+ "ae_params": {
22
+ "resolution": 256,
23
+ "in_channels": 3,
24
+ "ch": 128,
25
+ "out_ch": 3,
26
+ "ch_mult": [
27
+ 1,
28
+ 2,
29
+ 4,
30
+ 4
31
+ ],
32
+ "num_res_blocks": 2,
33
+ "z_channels": 16,
34
+ "scale_factor": 0.3611,
35
+ "shift_factor": 0.1159
36
+ },
37
+ "ckpt_path": "/big/generator-ui/flux-testing/flux/model-dir/flux1-dev.sft",
38
+ "ae_path": "/big/generator-ui/flux-testing/flux/model-dir/ae.sft",
39
+ "repo_id": "black-forest-labs/FLUX.1-dev",
40
+ "repo_flow": "flux1-dev.sft",
41
+ "repo_ae": "ae.sft",
42
+ "text_enc_max_length": 512,
43
+ "text_enc_path": "city96/t5-v1_1-xxl-encoder-bf16",
44
+ "text_enc_device": "cuda:0",
45
+ "ae_device": "cuda:0",
46
+ "flux_device": "cuda:0",
47
+ "flow_dtype": "float16",
48
+ "ae_dtype": "bfloat16",
49
+ "text_enc_dtype": "bfloat16",
50
+ "num_to_quant": 20
51
+ }
configs/config-dev.json ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "version": "flux-dev",
3
+ "params": {
4
+ "in_channels": 64,
5
+ "vec_in_dim": 768,
6
+ "context_in_dim": 4096,
7
+ "hidden_size": 3072,
8
+ "mlp_ratio": 4.0,
9
+ "num_heads": 24,
10
+ "depth": 19,
11
+ "depth_single_blocks": 38,
12
+ "axes_dim": [
13
+ 16,
14
+ 56,
15
+ 56
16
+ ],
17
+ "theta": 10000,
18
+ "qkv_bias": true,
19
+ "guidance_embed": true
20
+ },
21
+ "ae_params": {
22
+ "resolution": 256,
23
+ "in_channels": 3,
24
+ "ch": 128,
25
+ "out_ch": 3,
26
+ "ch_mult": [
27
+ 1,
28
+ 2,
29
+ 4,
30
+ 4
31
+ ],
32
+ "num_res_blocks": 2,
33
+ "z_channels": 16,
34
+ "scale_factor": 0.3611,
35
+ "shift_factor": 0.1159
36
+ },
37
+ "ckpt_path": "/big/generator-ui/flux-testing/flux/model-dir/flux1-dev.sft",
38
+ "ae_path": "/big/generator-ui/flux-testing/flux/model-dir/ae.sft",
39
+ "repo_id": "black-forest-labs/FLUX.1-dev",
40
+ "repo_flow": "flux1-dev.sft",
41
+ "repo_ae": "ae.sft",
42
+ "text_enc_max_length": 512,
43
+ "text_enc_path": "city96/t5-v1_1-xxl-encoder-bf16",
44
+ "text_enc_device": "cuda:1",
45
+ "ae_device": "cuda:1",
46
+ "flux_device": "cuda:0",
47
+ "flow_dtype": "float16",
48
+ "ae_dtype": "bfloat16",
49
+ "text_enc_dtype": "bfloat16",
50
+ "num_to_quant": 20
51
+ }
configs/config-schnell-cuda0.json ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "version": "flux-schnell",
3
+ "params": {
4
+ "in_channels": 64,
5
+ "vec_in_dim": 768,
6
+ "context_in_dim": 4096,
7
+ "hidden_size": 3072,
8
+ "mlp_ratio": 4.0,
9
+ "num_heads": 24,
10
+ "depth": 19,
11
+ "depth_single_blocks": 38,
12
+ "axes_dim": [
13
+ 16,
14
+ 56,
15
+ 56
16
+ ],
17
+ "theta": 10000,
18
+ "qkv_bias": true,
19
+ "guidance_embed": true
20
+ },
21
+ "ae_params": {
22
+ "resolution": 256,
23
+ "in_channels": 3,
24
+ "ch": 128,
25
+ "out_ch": 3,
26
+ "ch_mult": [
27
+ 1,
28
+ 2,
29
+ 4,
30
+ 4
31
+ ],
32
+ "num_res_blocks": 2,
33
+ "z_channels": 16,
34
+ "scale_factor": 0.3611,
35
+ "shift_factor": 0.1159
36
+ },
37
+ "ckpt_path": "/big/generator-ui/flux-testing/flux/model-dir-schnell/flux1-schnell.sft",
38
+ "ae_path": "/big/generator-ui/flux-testing/flux/model-dir-schnell/ae.sft",
39
+ "repo_id": "black-forest-labs/FLUX.1-schnell",
40
+ "repo_flow": "flux1-schnell.sft",
41
+ "repo_ae": "ae.sft",
42
+ "text_enc_max_length": 256,
43
+ "text_enc_path": "city96/t5-v1_1-xxl-encoder-bf16",
44
+ "text_enc_device": "cuda:0",
45
+ "ae_device": "cuda:0",
46
+ "flux_device": "cuda:0",
47
+ "flow_dtype": "float16",
48
+ "ae_dtype": "bfloat16",
49
+ "text_enc_dtype": "bfloat16",
50
+ "num_to_quant": 20
51
+ }
configs/config-schnell.json ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "version": "flux-schnell",
3
+ "params": {
4
+ "in_channels": 64,
5
+ "vec_in_dim": 768,
6
+ "context_in_dim": 4096,
7
+ "hidden_size": 3072,
8
+ "mlp_ratio": 4.0,
9
+ "num_heads": 24,
10
+ "depth": 19,
11
+ "depth_single_blocks": 38,
12
+ "axes_dim": [
13
+ 16,
14
+ 56,
15
+ 56
16
+ ],
17
+ "theta": 10000,
18
+ "qkv_bias": true,
19
+ "guidance_embed": true
20
+ },
21
+ "ae_params": {
22
+ "resolution": 256,
23
+ "in_channels": 3,
24
+ "ch": 128,
25
+ "out_ch": 3,
26
+ "ch_mult": [
27
+ 1,
28
+ 2,
29
+ 4,
30
+ 4
31
+ ],
32
+ "num_res_blocks": 2,
33
+ "z_channels": 16,
34
+ "scale_factor": 0.3611,
35
+ "shift_factor": 0.1159
36
+ },
37
+ "ckpt_path": "/big/generator-ui/flux-testing/flux/model-dir-schnell/flux1-schnell.sft",
38
+ "ae_path": "/big/generator-ui/flux-testing/flux/model-dir-schnell/ae.sft",
39
+ "repo_id": "black-forest-labs/FLUX.1-schnell",
40
+ "repo_flow": "flux1-schnell.sft",
41
+ "repo_ae": "ae.sft",
42
+ "text_enc_max_length": 256,
43
+ "text_enc_path": "city96/t5-v1_1-xxl-encoder-bf16",
44
+ "text_enc_device": "cuda:1",
45
+ "ae_device": "cuda:1",
46
+ "flux_device": "cuda:0",
47
+ "flow_dtype": "float16",
48
+ "ae_dtype": "bfloat16",
49
+ "text_enc_dtype": "bfloat16",
50
+ "num_to_quant": 20
51
+ }
cublas_linear.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import Literal, Optional
3
+
4
+ import torch
5
+ from torch.nn import functional as F
6
+
7
+ from cublas_ops_ext import _simt_hgemv
8
+ from cublas_ops_ext import cublas_hgemm_axbT as _cublas_hgemm_axbT
9
+ from cublas_ops_ext import cublas_hgemm_batched_simple as _cublas_hgemm_batched_simple
10
+ from cublas_ops_ext import (
11
+ cublaslt_hgemm_batched_simple as _cublaslt_hgemm_batched_simple,
12
+ )
13
+ from cublas_ops_ext import cublaslt_hgemm_simple as _cublaslt_hgemm_simple
14
+ from torch import Tensor, nn
15
+
16
+ global has_moved
17
+ has_moved = {idx: False for idx in range(torch.cuda.device_count())}
18
+
19
+
20
+ class StaticState:
21
+ workspace = {
22
+ idx: torch.empty((1024 * 1024 * 8,), dtype=torch.uint8)
23
+ for idx in range(torch.cuda.device_count())
24
+ }
25
+ workspace_size = workspace[0].nelement()
26
+ bias_g = {
27
+ idx: torch.tensor([], dtype=torch.float16)
28
+ for idx in range(torch.cuda.device_count())
29
+ }
30
+
31
+ @classmethod
32
+ def get(cls, __name: str, device: torch.device) -> torch.Any:
33
+ global has_moved
34
+ idx = device.index if device.index is not None else 0
35
+ if not has_moved[idx]:
36
+ cls.workspace[idx] = cls.workspace[idx].cuda(idx)
37
+ cls.bias_g[idx] = cls.bias_g[idx].cuda(idx)
38
+ has_moved[idx] = True
39
+ if "bias" in __name:
40
+ return cls.bias_g[idx]
41
+ if "workspace" in __name:
42
+ return cls.workspace[idx]
43
+ if "workspace_size" in __name:
44
+ return cls.workspace_size
45
+
46
+
47
+ @torch.no_grad()
48
+ def hgemv_simt(vec: torch.HalfTensor, mat: torch.HalfTensor, block_dim_x: int = 32):
49
+ prev_dims = vec.shape[:-1]
50
+ out = _simt_hgemv(mat, vec.view(-1, 1), block_dim_x=block_dim_x).view(
51
+ *prev_dims, -1
52
+ )
53
+ return out
54
+
55
+
56
+ @torch.no_grad()
57
+ def cublas_half_matmul_batched_simple(a: torch.Tensor, b: torch.Tensor):
58
+ out = _cublas_hgemm_batched_simple(a, b)
59
+ return out
60
+
61
+
62
+ @torch.no_grad()
63
+ def cublas_half_matmul_simple(a: torch.Tensor, b: torch.Tensor):
64
+ out = _cublas_hgemm_axbT(b, a)
65
+ return out
66
+
67
+
68
+ @torch.no_grad()
69
+ def cublaslt_fused_half_matmul_simple(
70
+ a: torch.Tensor,
71
+ b: torch.Tensor,
72
+ bias: Optional[torch.Tensor] = None,
73
+ epilogue_str: Optional[Literal["NONE", "RELU", "GELU"]] = "NONE",
74
+ ):
75
+ if bias is None:
76
+ bias = StaticState.get("bias", a.device)
77
+ out = _cublaslt_hgemm_simple(
78
+ a, b, bias, epilogue_str, StaticState.get("workspace", a.device)
79
+ )
80
+ return out
81
+
82
+
83
+ @torch.no_grad()
84
+ def cublaslt_fused_half_matmul_batched_simple(
85
+ a: torch.Tensor,
86
+ b: torch.Tensor,
87
+ bias: Optional[torch.Tensor] = None,
88
+ epilogue_str: Optional[Literal["NONE", "RELU", "GELU"]] = "NONE",
89
+ ):
90
+ if bias is None:
91
+ bias = StaticState.get("bias", a.device)
92
+ out = _cublaslt_hgemm_batched_simple(
93
+ a, b, bias, epilogue_str, StaticState.get("workspace", a.device)
94
+ )
95
+ return out
96
+
97
+
98
+ class CublasLinear(nn.Linear):
99
+ def __init__(
100
+ self,
101
+ in_features,
102
+ out_features,
103
+ bias=True,
104
+ device=None,
105
+ dtype=torch.float16,
106
+ epilogue_str="NONE",
107
+ ):
108
+ super().__init__(
109
+ in_features, out_features, bias=bias, device=device, dtype=dtype
110
+ )
111
+ self._epilogue_str = epilogue_str
112
+ self.has_bias = bias
113
+ self.has_checked_weight = False
114
+
115
+ def forward(self, x: Tensor) -> Tensor:
116
+ if not self.has_checked_weight:
117
+ if not self.weight.dtype == torch.float16:
118
+ self.to(dtype=torch.float16)
119
+ self.has_checked_weight = True
120
+ out_dtype = x.dtype
121
+ needs_convert = out_dtype != torch.float16
122
+ if needs_convert:
123
+ x = x.type(torch.float16)
124
+
125
+ use_cublasLt = self.has_bias or self._epilogue_str != "NONE"
126
+ if x.ndim == 1:
127
+ x = x.unsqueeze(0)
128
+ if math.prod(x.shape) == x.shape[-1]:
129
+ out = F.linear(x, self.weight, bias=self.bias)
130
+ if self._epilogue_str == "RELU":
131
+ return F.relu(out)
132
+ elif self._epilogue_str == "GELU":
133
+ return F.gelu(out)
134
+ if needs_convert:
135
+ return out.type(out_dtype)
136
+ return out
137
+ if use_cublasLt:
138
+ leading_dims = x.shape[:-1]
139
+ x = x.reshape(-1, x.shape[-1])
140
+ out = cublaslt_fused_half_matmul_simple(
141
+ x, self.weight, bias=self.bias.data, epilogue_str=self._epilogue_str
142
+ )
143
+ if needs_convert:
144
+ return out.view(*leading_dims, out.shape[-1]).type(out_dtype)
145
+ return out.view(*leading_dims, out.shape[-1])
146
+ else:
147
+ leading_dims = x.shape[:-1]
148
+ x = x.reshape(-1, x.shape[-1])
149
+ out = cublas_half_matmul_simple(x, self.weight)
150
+ if needs_convert:
151
+ return out.view(*leading_dims, out.shape[-1]).type(out_dtype)
152
+ return out.view(*leading_dims, out.shape[-1])
flux_impl.py ADDED
@@ -0,0 +1,272 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ from typing import List
3
+
4
+ import torch
5
+ from torch import nn
6
+
7
+ torch.backends.cuda.matmul.allow_tf32 = True
8
+ torch.backends.cudnn.allow_tf32 = True
9
+ torch.backends.cudnn.benchmark = True
10
+ torch.backends.cudnn.benchmark_limit = 20
11
+ torch.set_float32_matmul_precision("high")
12
+ from torch._dynamo import config
13
+ from torch._inductor import config as ind_config
14
+
15
+ config.cache_size_limit = 10000000000
16
+ ind_config.force_fuse_int_mm_with_mul = True
17
+
18
+ from loguru import logger
19
+ from torchao.quantization.quant_api import int8_weight_only, quantize_
20
+
21
+ from cublas_linear import CublasLinear as F16Linear
22
+ from modules.flux_model import RMSNorm
23
+ from sampling import denoise, get_noise, get_schedule, prepare, unpack
24
+ from turbojpeg_imgs import TurboImage
25
+ from util import (
26
+ ModelSpec,
27
+ into_device,
28
+ into_dtype,
29
+ load_config_from_path,
30
+ load_models_from_config,
31
+ )
32
+
33
+
34
+ class Model:
35
+ def __init__(
36
+ self,
37
+ name,
38
+ offload=False,
39
+ clip=None,
40
+ t5=None,
41
+ model=None,
42
+ ae=None,
43
+ dtype=torch.bfloat16,
44
+ verbose=False,
45
+ flux_device="cuda:0",
46
+ ae_device="cuda:1",
47
+ clip_device="cuda:1",
48
+ t5_device="cuda:1",
49
+ ):
50
+
51
+ self.name = name
52
+ self.device_flux = (
53
+ flux_device
54
+ if isinstance(flux_device, torch.device)
55
+ else torch.device(flux_device)
56
+ )
57
+ self.device_ae = (
58
+ ae_device
59
+ if isinstance(ae_device, torch.device)
60
+ else torch.device(ae_device)
61
+ )
62
+ self.device_clip = (
63
+ clip_device
64
+ if isinstance(clip_device, torch.device)
65
+ else torch.device(clip_device)
66
+ )
67
+ self.device_t5 = (
68
+ t5_device
69
+ if isinstance(t5_device, torch.device)
70
+ else torch.device(t5_device)
71
+ )
72
+ self.dtype = dtype
73
+ self.offload = offload
74
+ self.clip = clip
75
+ self.t5 = t5
76
+ self.model = model
77
+ self.ae = ae
78
+ self.rng = torch.Generator(device="cpu")
79
+ self.turbojpeg = TurboImage()
80
+ self.verbose = verbose
81
+
82
+ @torch.inference_mode()
83
+ def generate(
84
+ self,
85
+ prompt,
86
+ width=720,
87
+ height=1023,
88
+ num_steps=24,
89
+ guidance=3.5,
90
+ seed=None,
91
+ ):
92
+ if num_steps is None:
93
+ num_steps = 4 if self.name == "flux-schnell" else 50
94
+
95
+ # allow for packing and conversion to latent space
96
+ height = 16 * (height // 16)
97
+ width = 16 * (width // 16)
98
+
99
+ if seed is None:
100
+ seed = self.rng.seed()
101
+ logger.info(f"Generating with:\nSeed: {seed}\nPrompt: {prompt}")
102
+
103
+ x = get_noise(
104
+ 1,
105
+ height,
106
+ width,
107
+ device=self.device_t5,
108
+ dtype=torch.bfloat16,
109
+ seed=seed,
110
+ )
111
+ inp = prepare(self.t5, self.clip, x, prompt=prompt)
112
+ timesteps = get_schedule(
113
+ num_steps, inp["img"].shape[1], shift=(self.name != "flux-schnell")
114
+ )
115
+ for k in inp:
116
+ inp[k] = inp[k].to(self.device_flux).type(self.dtype)
117
+
118
+ # denoise initial noise
119
+ x = denoise(
120
+ self.model,
121
+ **inp,
122
+ timesteps=timesteps,
123
+ guidance=guidance,
124
+ dtype=self.dtype,
125
+ device=self.device_flux,
126
+ )
127
+ inp.clear()
128
+ timesteps.clear()
129
+ torch.cuda.empty_cache()
130
+ x = x.to(self.device_ae)
131
+
132
+ # decode latents to pixel space
133
+ x = unpack(x.float(), height, width)
134
+ with torch.autocast(
135
+ device_type=self.device_ae.type, dtype=torch.bfloat16, cache_enabled=False
136
+ ):
137
+ x = self.ae.decode(x)
138
+
139
+ # bring into PIL format and save
140
+ x = x.clamp(-1, 1)
141
+ num_images = x.shape[0]
142
+ images: List[torch.Tensor] = []
143
+ for i in range(num_images):
144
+ x = x[i].permute(1, 2, 0).add(1.0).mul(127.5).type(torch.uint8).contiguous()
145
+ images.append(x)
146
+ if len(images) == 1:
147
+ im = images[0]
148
+ else:
149
+ im = torch.vstack(images)
150
+
151
+ im = self.turbojpeg.encode_torch(im, quality=95)
152
+ images.clear()
153
+ return io.BytesIO(im)
154
+
155
+
156
+ def quant_module(module, running_sum_quants=0, device_index=0):
157
+ if isinstance(module, nn.Linear) and not isinstance(module, F16Linear):
158
+ module.cuda(device_index)
159
+ module.compile()
160
+ quantize_(module, int8_weight_only())
161
+ running_sum_quants += 1
162
+ elif isinstance(module, F16Linear):
163
+ module.cuda(device_index)
164
+ elif isinstance(module, nn.Conv2d):
165
+ module.cuda(device_index)
166
+ elif isinstance(module, nn.Embedding):
167
+ module.cuda(device_index)
168
+ elif isinstance(module, nn.ConvTranspose2d):
169
+ module.cuda(device_index)
170
+ elif isinstance(module, nn.Conv1d):
171
+ module.cuda(device_index)
172
+ elif isinstance(module, nn.Conv3d):
173
+ module.cuda(device_index)
174
+ elif isinstance(module, nn.ConvTranspose3d):
175
+ module.cuda(device_index)
176
+ elif isinstance(module, nn.RMSNorm):
177
+ module.cuda(device_index)
178
+ elif isinstance(module, RMSNorm):
179
+ module.cuda(device_index)
180
+ elif isinstance(module, nn.LayerNorm):
181
+ module.cuda(device_index)
182
+ return running_sum_quants
183
+
184
+
185
+ def full_quant(model, max_quants=24, current_quants=0, device_index=0):
186
+ for module in model.modules():
187
+ if current_quants < max_quants:
188
+ current_quants = quant_module(
189
+ module, current_quants, device_index=device_index
190
+ )
191
+ return current_quants
192
+
193
+
194
+ @torch.inference_mode()
195
+ def load_pipeline_from_config_path(path: str) -> Model:
196
+ config = load_config_from_path(path)
197
+ return load_pipeline_from_config(config)
198
+
199
+
200
+ @torch.inference_mode()
201
+ def load_pipeline_from_config(config: ModelSpec) -> Model:
202
+ models = load_models_from_config(config)
203
+ config = models.config
204
+ num_quanted = 0
205
+ max_quanted = config.num_to_quant
206
+ flux_device = into_device(config.flux_device)
207
+ ae_device = into_device(config.ae_device)
208
+ clip_device = into_device(config.text_enc_device)
209
+ t5_device = into_device(config.text_enc_device)
210
+ flux_dtype = into_dtype(config.flow_dtype)
211
+ device_index = flux_device.index or 0
212
+ flow_model = models.flow.requires_grad_(False).eval().type(flux_dtype)
213
+ for block in flow_model.single_blocks:
214
+ block.cuda(flux_device)
215
+ if num_quanted < max_quanted:
216
+ num_quanted = quant_module(
217
+ block.linear1, num_quanted, device_index=device_index
218
+ )
219
+
220
+ for block in flow_model.double_blocks:
221
+ block.cuda(flux_device)
222
+ if num_quanted < max_quanted:
223
+ num_quanted = full_quant(
224
+ block, max_quanted, num_quanted, device_index=device_index
225
+ )
226
+
227
+ to_gpu_extras = [
228
+ "vector_in",
229
+ "img_in",
230
+ "txt_in",
231
+ "time_in",
232
+ "guidance_in",
233
+ "final_layer",
234
+ "pe_embedder",
235
+ ]
236
+ for extra in to_gpu_extras:
237
+ getattr(flow_model, extra).cuda(flux_device).type(flux_dtype)
238
+ return Model(
239
+ name=config.version,
240
+ clip=models.clip,
241
+ t5=models.t5,
242
+ model=flow_model,
243
+ ae=models.ae,
244
+ dtype=flux_dtype,
245
+ verbose=False,
246
+ flux_device=flux_device,
247
+ ae_device=ae_device,
248
+ clip_device=clip_device,
249
+ t5_device=t5_device,
250
+ )
251
+
252
+
253
+ if __name__ == "__main__":
254
+ pipe = load_pipeline_from_config_path("config-dev.json")
255
+ o = pipe.generate(
256
+ prompt="a beautiful asian woman in traditional clothing with golden hairpin and blue eyes, wearing a red kimono with dragon patterns",
257
+ height=1024,
258
+ width=1024,
259
+ seed=13456,
260
+ num_steps=24,
261
+ guidance=3.0,
262
+ )
263
+ open("out.jpg", "wb").write(o.read())
264
+ o = pipe.generate(
265
+ prompt="a beautiful asian woman in traditional clothing with golden hairpin and blue eyes, wearing a red kimono with dragon patterns",
266
+ height=1024,
267
+ width=1024,
268
+ seed=7,
269
+ num_steps=24,
270
+ guidance=3.0,
271
+ )
272
+ open("out2.jpg", "wb").write(o.read())
main.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import uvicorn
3
+ from api import app
4
+ from flux_impl import load_pipeline_from_config, load_pipeline_from_config_path
5
+ from util import load_config, ModelVersion
6
+
7
+
8
+ def parse_args():
9
+ parser = argparse.ArgumentParser(description="Launch Flux API server")
10
+ parser.add_argument(
11
+ "--config-path",
12
+ type=str,
13
+ help="Path to the configuration file, if not provided, the model will be loaded from the command line arguments",
14
+ )
15
+ parser.add_argument(
16
+ "--port", type=int, default=8088, help="Port to run the server on"
17
+ )
18
+ parser.add_argument(
19
+ "--host", type=str, default="0.0.0.0", help="Host to run the server on"
20
+ )
21
+ parser.add_argument("--flow-model-path", type=str, help="Path to the flow model")
22
+ parser.add_argument("--text-enc-path", type=str, help="Path to the text encoder")
23
+ parser.add_argument("--autoencoder-path", type=str, help="Path to the autoencoder")
24
+ parser.add_argument(
25
+ "--model-version",
26
+ type=str,
27
+ choices=["flux-dev", "flux-schnell"],
28
+ default="flux-dev",
29
+ help="Choose model version",
30
+ )
31
+ parser.add_argument(
32
+ "--flux-device",
33
+ type=str,
34
+ default="cuda:0",
35
+ help="Device to run the flow model on",
36
+ )
37
+ parser.add_argument(
38
+ "--text-enc-device",
39
+ type=str,
40
+ default="cuda:0",
41
+ help="Device to run the text encoder on",
42
+ )
43
+ parser.add_argument(
44
+ "--autoencoder-device",
45
+ type=str,
46
+ default="cuda:0",
47
+ help="Device to run the autoencoder on",
48
+ )
49
+ parser.add_argument(
50
+ "--num-to-quant",
51
+ type=int,
52
+ default=20,
53
+ help="Number of linear layers in flow transformer (the 'unet') to quantize",
54
+ )
55
+
56
+ return parser.parse_args()
57
+
58
+
59
+ def main():
60
+ args = parse_args()
61
+
62
+ if args.config_path:
63
+ app.state.model = load_pipeline_from_config_path(args.config_path)
64
+ else:
65
+ model_version = (
66
+ ModelVersion.flux_dev
67
+ if args.model_version == "flux-dev"
68
+ else ModelVersion.flux_schnell
69
+ )
70
+ config = load_config(
71
+ model_version,
72
+ flux_path=args.flow_model_path,
73
+ flux_device=args.flux_device,
74
+ ae_path=args.autoencoder_path,
75
+ ae_device=args.autoencoder_device,
76
+ text_enc_path=args.text_enc_path,
77
+ text_enc_device=args.text_enc_device,
78
+ flow_dtype="float16",
79
+ text_enc_dtype="bfloat16",
80
+ ae_dtype="bfloat16",
81
+ num_to_quant=args.num_to_quant,
82
+ )
83
+ app.state.model = load_pipeline_from_config(config)
84
+
85
+ uvicorn.run(app, host=args.host, port=args.port)
86
+
87
+
88
+ if __name__ == "__main__":
89
+ main()
modules/autoencoder.py ADDED
@@ -0,0 +1,336 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from einops import rearrange
3
+ from torch import Tensor, nn
4
+ from pydantic import BaseModel
5
+
6
+
7
+ class AutoEncoderParams(BaseModel):
8
+ resolution: int
9
+ in_channels: int
10
+ ch: int
11
+ out_ch: int
12
+ ch_mult: list[int]
13
+ num_res_blocks: int
14
+ z_channels: int
15
+ scale_factor: float
16
+ shift_factor: float
17
+
18
+
19
+ def swish(x: Tensor) -> Tensor:
20
+ return x * torch.sigmoid(x)
21
+
22
+
23
+ class AttnBlock(nn.Module):
24
+ def __init__(self, in_channels: int):
25
+ super().__init__()
26
+ self.in_channels = in_channels
27
+
28
+ self.norm = nn.GroupNorm(
29
+ num_groups=32, num_channels=in_channels, eps=1e-6, affine=True
30
+ )
31
+
32
+ self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1)
33
+ self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1)
34
+ self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1)
35
+ self.proj_out = nn.Conv2d(in_channels, in_channels, kernel_size=1)
36
+
37
+ def attention(self, h_: Tensor) -> Tensor:
38
+ h_ = self.norm(h_)
39
+ q = self.q(h_)
40
+ k = self.k(h_)
41
+ v = self.v(h_)
42
+
43
+ b, c, h, w = q.shape
44
+ q = rearrange(q, "b c h w -> b 1 (h w) c").contiguous()
45
+ k = rearrange(k, "b c h w -> b 1 (h w) c").contiguous()
46
+ v = rearrange(v, "b c h w -> b 1 (h w) c").contiguous()
47
+ h_ = nn.functional.scaled_dot_product_attention(q, k, v)
48
+
49
+ return rearrange(h_, "b 1 (h w) c -> b c h w", h=h, w=w, c=c, b=b)
50
+
51
+ def forward(self, x: Tensor) -> Tensor:
52
+ return x + self.proj_out(self.attention(x))
53
+
54
+
55
+ class ResnetBlock(nn.Module):
56
+ def __init__(self, in_channels: int, out_channels: int):
57
+ super().__init__()
58
+ self.in_channels = in_channels
59
+ out_channels = in_channels if out_channels is None else out_channels
60
+ self.out_channels = out_channels
61
+
62
+ self.norm1 = nn.GroupNorm(
63
+ num_groups=32, num_channels=in_channels, eps=1e-6, affine=True
64
+ )
65
+ self.conv1 = nn.Conv2d(
66
+ in_channels, out_channels, kernel_size=3, stride=1, padding=1
67
+ )
68
+ self.norm2 = nn.GroupNorm(
69
+ num_groups=32, num_channels=out_channels, eps=1e-6, affine=True
70
+ )
71
+ self.conv2 = nn.Conv2d(
72
+ out_channels, out_channels, kernel_size=3, stride=1, padding=1
73
+ )
74
+ if self.in_channels != self.out_channels:
75
+ self.nin_shortcut = nn.Conv2d(
76
+ in_channels, out_channels, kernel_size=1, stride=1, padding=0
77
+ )
78
+
79
+ def forward(self, x):
80
+ h = x
81
+ h = self.norm1(h)
82
+ h = swish(h)
83
+ h = self.conv1(h)
84
+
85
+ h = self.norm2(h)
86
+ h = swish(h)
87
+ h = self.conv2(h)
88
+
89
+ if self.in_channels != self.out_channels:
90
+ x = self.nin_shortcut(x)
91
+
92
+ return x + h
93
+
94
+
95
+ class Downsample(nn.Module):
96
+ def __init__(self, in_channels: int):
97
+ super().__init__()
98
+ # no asymmetric padding in torch conv, must do it ourselves
99
+ self.conv = nn.Conv2d(
100
+ in_channels, in_channels, kernel_size=3, stride=2, padding=0
101
+ )
102
+
103
+ def forward(self, x: Tensor):
104
+ pad = (0, 1, 0, 1)
105
+ x = nn.functional.pad(x, pad, mode="constant", value=0)
106
+ x = self.conv(x)
107
+ return x
108
+
109
+
110
+ class Upsample(nn.Module):
111
+ def __init__(self, in_channels: int):
112
+ super().__init__()
113
+ self.conv = nn.Conv2d(
114
+ in_channels, in_channels, kernel_size=3, stride=1, padding=1
115
+ )
116
+
117
+ def forward(self, x: Tensor):
118
+ x = nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
119
+ x = self.conv(x)
120
+ return x
121
+
122
+
123
+ class Encoder(nn.Module):
124
+ def __init__(
125
+ self,
126
+ resolution: int,
127
+ in_channels: int,
128
+ ch: int,
129
+ ch_mult: list[int],
130
+ num_res_blocks: int,
131
+ z_channels: int,
132
+ ):
133
+ super().__init__()
134
+ self.ch = ch
135
+ self.num_resolutions = len(ch_mult)
136
+ self.num_res_blocks = num_res_blocks
137
+ self.resolution = resolution
138
+ self.in_channels = in_channels
139
+ # downsampling
140
+ self.conv_in = nn.Conv2d(
141
+ in_channels, self.ch, kernel_size=3, stride=1, padding=1
142
+ )
143
+
144
+ curr_res = resolution
145
+ in_ch_mult = (1,) + tuple(ch_mult)
146
+ self.in_ch_mult = in_ch_mult
147
+ self.down = nn.ModuleList()
148
+ block_in = self.ch
149
+ for i_level in range(self.num_resolutions):
150
+ block = nn.ModuleList()
151
+ attn = nn.ModuleList()
152
+ block_in = ch * in_ch_mult[i_level]
153
+ block_out = ch * ch_mult[i_level]
154
+ for _ in range(self.num_res_blocks):
155
+ block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
156
+ block_in = block_out
157
+ down = nn.Module()
158
+ down.block = block
159
+ down.attn = attn
160
+ if i_level != self.num_resolutions - 1:
161
+ down.downsample = Downsample(block_in)
162
+ curr_res = curr_res // 2
163
+ self.down.append(down)
164
+
165
+ # middle
166
+ self.mid = nn.Module()
167
+ self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in)
168
+ self.mid.attn_1 = AttnBlock(block_in)
169
+ self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in)
170
+
171
+ # end
172
+ self.norm_out = nn.GroupNorm(
173
+ num_groups=32, num_channels=block_in, eps=1e-6, affine=True
174
+ )
175
+ self.conv_out = nn.Conv2d(
176
+ block_in, 2 * z_channels, kernel_size=3, stride=1, padding=1
177
+ )
178
+
179
+ def forward(self, x: Tensor) -> Tensor:
180
+ # downsampling
181
+ hs = [self.conv_in(x)]
182
+ for i_level in range(self.num_resolutions):
183
+ for i_block in range(self.num_res_blocks):
184
+ h = self.down[i_level].block[i_block](hs[-1])
185
+ if len(self.down[i_level].attn) > 0:
186
+ h = self.down[i_level].attn[i_block](h)
187
+ hs.append(h)
188
+ if i_level != self.num_resolutions - 1:
189
+ hs.append(self.down[i_level].downsample(hs[-1]))
190
+
191
+ # middle
192
+ h = hs[-1]
193
+ h = self.mid.block_1(h)
194
+ h = self.mid.attn_1(h)
195
+ h = self.mid.block_2(h)
196
+ # end
197
+ h = self.norm_out(h)
198
+ h = swish(h)
199
+ h = self.conv_out(h)
200
+ return h
201
+
202
+
203
+ class Decoder(nn.Module):
204
+ def __init__(
205
+ self,
206
+ ch: int,
207
+ out_ch: int,
208
+ ch_mult: list[int],
209
+ num_res_blocks: int,
210
+ in_channels: int,
211
+ resolution: int,
212
+ z_channels: int,
213
+ ):
214
+ super().__init__()
215
+ self.ch = ch
216
+ self.num_resolutions = len(ch_mult)
217
+ self.num_res_blocks = num_res_blocks
218
+ self.resolution = resolution
219
+ self.in_channels = in_channels
220
+ self.ffactor = 2 ** (self.num_resolutions - 1)
221
+
222
+ # compute in_ch_mult, block_in and curr_res at lowest res
223
+ block_in = ch * ch_mult[self.num_resolutions - 1]
224
+ curr_res = resolution // 2 ** (self.num_resolutions - 1)
225
+ self.z_shape = (1, z_channels, curr_res, curr_res)
226
+
227
+ # z to block_in
228
+ self.conv_in = nn.Conv2d(
229
+ z_channels, block_in, kernel_size=3, stride=1, padding=1
230
+ )
231
+
232
+ # middle
233
+ self.mid = nn.Module()
234
+ self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in)
235
+ self.mid.attn_1 = AttnBlock(block_in)
236
+ self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in)
237
+
238
+ # upsampling
239
+ self.up = nn.ModuleList()
240
+ for i_level in reversed(range(self.num_resolutions)):
241
+ block = nn.ModuleList()
242
+ attn = nn.ModuleList()
243
+ block_out = ch * ch_mult[i_level]
244
+ for _ in range(self.num_res_blocks + 1):
245
+ block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
246
+ block_in = block_out
247
+ up = nn.Module()
248
+ up.block = block
249
+ up.attn = attn
250
+ if i_level != 0:
251
+ up.upsample = Upsample(block_in)
252
+ curr_res = curr_res * 2
253
+ self.up.insert(0, up) # prepend to get consistent order
254
+
255
+ # end
256
+ self.norm_out = nn.GroupNorm(
257
+ num_groups=32, num_channels=block_in, eps=1e-6, affine=True
258
+ )
259
+ self.conv_out = nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1)
260
+
261
+ def forward(self, z: Tensor) -> Tensor:
262
+ # z to block_in
263
+ h = self.conv_in(z)
264
+
265
+ # middle
266
+ h = self.mid.block_1(h)
267
+ h = self.mid.attn_1(h)
268
+ h = self.mid.block_2(h)
269
+
270
+ # upsampling
271
+ for i_level in reversed(range(self.num_resolutions)):
272
+ for i_block in range(self.num_res_blocks + 1):
273
+ h = self.up[i_level].block[i_block](h)
274
+ if len(self.up[i_level].attn) > 0:
275
+ h = self.up[i_level].attn[i_block](h)
276
+ if i_level != 0:
277
+ h = self.up[i_level].upsample(h)
278
+
279
+ # end
280
+ h = self.norm_out(h)
281
+ h = swish(h)
282
+ h = self.conv_out(h)
283
+ return h
284
+
285
+
286
+ class DiagonalGaussian(nn.Module):
287
+ def __init__(self, sample: bool = True, chunk_dim: int = 1):
288
+ super().__init__()
289
+ self.sample = sample
290
+ self.chunk_dim = chunk_dim
291
+
292
+ def forward(self, z: Tensor) -> Tensor:
293
+ mean, logvar = torch.chunk(z, 2, dim=self.chunk_dim)
294
+ if self.sample:
295
+ std = torch.exp(0.5 * logvar)
296
+ return mean + std * torch.randn_like(mean)
297
+ else:
298
+ return mean
299
+
300
+
301
+ class AutoEncoder(nn.Module):
302
+ def __init__(self, params: AutoEncoderParams):
303
+ super().__init__()
304
+ self.encoder = Encoder(
305
+ resolution=params.resolution,
306
+ in_channels=params.in_channels,
307
+ ch=params.ch,
308
+ ch_mult=params.ch_mult,
309
+ num_res_blocks=params.num_res_blocks,
310
+ z_channels=params.z_channels,
311
+ )
312
+ self.decoder = Decoder(
313
+ resolution=params.resolution,
314
+ in_channels=params.in_channels,
315
+ ch=params.ch,
316
+ out_ch=params.out_ch,
317
+ ch_mult=params.ch_mult,
318
+ num_res_blocks=params.num_res_blocks,
319
+ z_channels=params.z_channels,
320
+ )
321
+ self.reg = DiagonalGaussian()
322
+
323
+ self.scale_factor = params.scale_factor
324
+ self.shift_factor = params.shift_factor
325
+
326
+ def encode(self, x: Tensor) -> Tensor:
327
+ z = self.reg(self.encoder(x))
328
+ z = self.scale_factor * (z - self.shift_factor)
329
+ return z
330
+
331
+ def decode(self, z: Tensor) -> Tensor:
332
+ z = z / self.scale_factor + self.shift_factor
333
+ return self.decoder(z)
334
+
335
+ def forward(self, x: Tensor) -> Tensor:
336
+ return self.decode(self.encode(x))
modules/conditioner.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import Tensor, nn
2
+ import torch
3
+ from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5Tokenizer
4
+
5
+ from transformers.utils.quantization_config import BitsAndBytesConfig
6
+
7
+
8
+ class HFEmbedder(nn.Module):
9
+ def __init__(
10
+ self, version: str, max_length: int, device: torch.device | int, **hf_kwargs
11
+ ):
12
+ super().__init__()
13
+ self.is_clip = version.startswith("openai")
14
+ self.max_length = max_length
15
+ self.output_key = "pooler_output" if self.is_clip else "last_hidden_state"
16
+
17
+ if self.is_clip:
18
+ self.tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained(
19
+ version, max_length=max_length
20
+ )
21
+ self.hf_module: CLIPTextModel = CLIPTextModel.from_pretrained(
22
+ version, **hf_kwargs
23
+ )
24
+ self.hf_module = self.hf_module.eval().requires_grad_(False).to(device)
25
+ else:
26
+ self.tokenizer: T5Tokenizer = T5Tokenizer.from_pretrained(
27
+ version, max_length=max_length
28
+ )
29
+ self.hf_module: T5EncoderModel = T5EncoderModel.from_pretrained(
30
+ version,
31
+ **hf_kwargs,
32
+ device_map={"": device},
33
+ quantization_config=BitsAndBytesConfig(
34
+ load_in_4bit=True,
35
+ ),
36
+ )
37
+
38
+ def forward(self, text: list[str]) -> Tensor:
39
+ batch_encoding = self.tokenizer(
40
+ text,
41
+ truncation=True,
42
+ max_length=self.max_length,
43
+ return_length=False,
44
+ return_overflowing_tokens=False,
45
+ padding="max_length",
46
+ return_tensors="pt",
47
+ )
48
+ outputs = self.hf_module(
49
+ input_ids=batch_encoding["input_ids"].to(self.hf_module.device),
50
+ attention_mask=None,
51
+ output_hidden_states=False,
52
+ )
53
+ return outputs[self.output_key]
modules/flux_model.py ADDED
@@ -0,0 +1,492 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ torch.backends.cuda.matmul.allow_tf32 = True
4
+ torch.backends.cudnn.allow_tf32 = True
5
+ torch.backends.cudnn.benchmark = True
6
+ torch.backends.cudnn.benchmark_limit = 20
7
+ torch.set_float32_matmul_precision("high")
8
+ import math
9
+ from dataclasses import dataclass
10
+
11
+ from cublas_linear import CublasLinear as F16Linear
12
+ from einops.layers.torch import Rearrange
13
+ from torch import Tensor, nn
14
+ from torch._dynamo import config
15
+ from torch._inductor import config as ind_config
16
+ from xformers.ops import memory_efficient_attention
17
+ from pydantic import BaseModel
18
+
19
+ config.cache_size_limit = 10000000000
20
+ ind_config.force_fuse_int_mm_with_mul = True
21
+
22
+
23
+ class FluxParams(BaseModel):
24
+ in_channels: int
25
+ vec_in_dim: int
26
+ context_in_dim: int
27
+ hidden_size: int
28
+ mlp_ratio: float
29
+ num_heads: int
30
+ depth: int
31
+ depth_single_blocks: int
32
+ axes_dim: list[int]
33
+ theta: int
34
+ qkv_bias: bool
35
+ guidance_embed: bool
36
+
37
+
38
+ @torch.compile(mode="reduce-overhead")
39
+ def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor) -> Tensor:
40
+ q, k = apply_rope(q, k, pe)
41
+ x = memory_efficient_attention(
42
+ q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)
43
+ )
44
+ x = x.reshape(*x.shape[:-2], -1)
45
+ return x
46
+
47
+
48
+ @torch.compile(mode="reduce-overhead")
49
+ def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
50
+ scale = torch.arange(0, dim, 2, dtype=torch.float32, device=pos.device) / dim
51
+ omega = 1.0 / (theta**scale)
52
+ out = torch.einsum("...n,d->...nd", pos, omega)
53
+ out = torch.stack(
54
+ [torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1
55
+ )
56
+ out = out.reshape(*out.shape[:-1], 2, 2)
57
+ return out
58
+
59
+
60
+ def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor) -> tuple[Tensor, Tensor]:
61
+ xq_ = xq.reshape(*xq.shape[:-1], -1, 1, 2)
62
+ xk_ = xk.reshape(*xk.shape[:-1], -1, 1, 2)
63
+ xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
64
+ xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
65
+ return xq_out.reshape(*xq.shape), xk_out.reshape(*xk.shape)
66
+
67
+
68
+ class EmbedND(nn.Module):
69
+ def __init__(
70
+ self,
71
+ dim: int,
72
+ theta: int,
73
+ axes_dim: list[int],
74
+ dtype: torch.dtype = torch.bfloat16,
75
+ ):
76
+ super().__init__()
77
+ self.dim = dim
78
+ self.theta = theta
79
+ self.axes_dim = axes_dim
80
+ self.dtype = dtype
81
+
82
+ def forward(self, ids: Tensor) -> Tensor:
83
+ n_axes = ids.shape[-1]
84
+ emb = torch.cat(
85
+ [
86
+ rope(ids[..., i], self.axes_dim[i], self.theta).type(self.dtype)
87
+ for i in range(n_axes)
88
+ ],
89
+ dim=-3,
90
+ )
91
+
92
+ return emb.unsqueeze(1)
93
+
94
+
95
+ def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 1000.0):
96
+ """
97
+ Create sinusoidal timestep embeddings.
98
+ :param t: a 1-D Tensor of N indices, one per batch element.
99
+ These may be fractional.
100
+ :param dim: the dimension of the output.
101
+ :param max_period: controls the minimum frequency of the embeddings.
102
+ :return: an (N, D) Tensor of positional embeddings.
103
+ """
104
+ t = time_factor * t
105
+ half = dim // 2
106
+ freqs = torch.exp(
107
+ -math.log(max_period)
108
+ * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device)
109
+ / half
110
+ )
111
+
112
+ args = t[:, None].float() * freqs[None]
113
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
114
+ if dim % 2:
115
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
116
+ return embedding
117
+
118
+
119
+ class MLPEmbedder(nn.Module):
120
+ def __init__(self, in_dim: int, hidden_dim: int):
121
+ super().__init__()
122
+ self.in_layer = F16Linear(in_dim, hidden_dim, bias=True)
123
+ self.silu = nn.SiLU()
124
+ self.out_layer = F16Linear(hidden_dim, hidden_dim, bias=True)
125
+
126
+ def forward(self, x: Tensor) -> Tensor:
127
+ return self.out_layer(self.silu(self.in_layer(x)))
128
+
129
+
130
+ @torch.compile(mode="reduce-overhead", dynamic=True)
131
+ def calculation(
132
+ x,
133
+ ):
134
+ rrms = torch.rsqrt(torch.mean(x.pow(2), dim=-1, keepdim=True) + 1e-6)
135
+ x = x * rrms
136
+ return x
137
+
138
+
139
+ class RMSNorm(torch.nn.Module):
140
+ def __init__(self, dim: int):
141
+ super().__init__()
142
+ self.scale = nn.Parameter(torch.ones(dim))
143
+
144
+ def forward(self, x: Tensor):
145
+ return calculation(x) * self.scale
146
+
147
+
148
+ class QKNorm(torch.nn.Module):
149
+ def __init__(self, dim: int):
150
+ super().__init__()
151
+ self.query_norm = RMSNorm(dim)
152
+ self.key_norm = RMSNorm(dim)
153
+
154
+ def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple[Tensor, Tensor]:
155
+ q = self.query_norm(q)
156
+ k = self.key_norm(k)
157
+ return q, k
158
+
159
+
160
+ class SelfAttention(nn.Module):
161
+ def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False):
162
+ super().__init__()
163
+ self.num_heads = num_heads
164
+ head_dim = dim // num_heads
165
+
166
+ self.qkv = F16Linear(dim, dim * 3, bias=qkv_bias)
167
+ self.norm = QKNorm(head_dim)
168
+ self.proj = F16Linear(dim, dim)
169
+ self.rearrange = Rearrange("B L (K H D) -> K B H L D", K=3, H=num_heads)
170
+
171
+ def forward(self, x: Tensor, pe: Tensor) -> Tensor:
172
+ qkv = self.qkv(x)
173
+ q, k, v = self.rearrange(qkv)
174
+ q, k = self.norm(q, k, v)
175
+ x = attention(q, k, v, pe=pe)
176
+ x = self.proj(x)
177
+ return x
178
+
179
+
180
+ @dataclass
181
+ class ModulationOut:
182
+ shift: Tensor
183
+ scale: Tensor
184
+ gate: Tensor
185
+
186
+
187
+ class Modulation(nn.Module):
188
+ def __init__(self, dim: int, double: bool):
189
+ super().__init__()
190
+ self.is_double = double
191
+ self.multiplier = 6 if double else 3
192
+ self.lin = nn.Linear(dim, self.multiplier * dim, bias=True)
193
+ self.act = nn.SiLU()
194
+
195
+ def forward(self, vec: Tensor) -> tuple[ModulationOut, ModulationOut | None]:
196
+ out = self.lin(self.act(vec))[:, None, :].chunk(self.multiplier, dim=-1)
197
+
198
+ return (
199
+ ModulationOut(*out[:3]),
200
+ ModulationOut(*out[3:]) if self.is_double else None,
201
+ )
202
+
203
+
204
+ class DoubleStreamBlock(nn.Module):
205
+ def __init__(
206
+ self,
207
+ hidden_size: int,
208
+ num_heads: int,
209
+ mlp_ratio: float,
210
+ qkv_bias: bool = False,
211
+ dtype: torch.dtype = torch.bfloat16,
212
+ idx: int = 0,
213
+ ):
214
+ super().__init__()
215
+ self.dtype = dtype
216
+
217
+ mlp_hidden_dim = int(hidden_size * mlp_ratio)
218
+ self.num_heads = num_heads
219
+ self.hidden_size = hidden_size
220
+ self.img_mod = Modulation(hidden_size, double=True)
221
+ self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
222
+ self.img_attn = SelfAttention(
223
+ dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias
224
+ )
225
+
226
+ self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
227
+ self.img_mlp = nn.Sequential(
228
+ F16Linear(hidden_size, mlp_hidden_dim, bias=True),
229
+ nn.GELU(approximate="tanh"),
230
+ F16Linear(mlp_hidden_dim, hidden_size, bias=True),
231
+ )
232
+
233
+ self.txt_mod = Modulation(hidden_size, double=True)
234
+ self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
235
+ self.txt_attn = SelfAttention(
236
+ dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias
237
+ )
238
+
239
+ self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
240
+ self.txt_mlp = nn.Sequential(
241
+ (F16Linear(hidden_size, mlp_hidden_dim, bias=True)),
242
+ nn.GELU(approximate="tanh"),
243
+ (F16Linear(mlp_hidden_dim, hidden_size, bias=True)),
244
+ )
245
+ self.rearrange_for_norm = Rearrange(
246
+ "B L (K H D) -> K B H L D", K=3, H=num_heads
247
+ )
248
+
249
+ def forward(
250
+ self,
251
+ img: Tensor,
252
+ txt: Tensor,
253
+ vec: Tensor,
254
+ pe: Tensor,
255
+ ) -> tuple[Tensor, Tensor]:
256
+ img_mod1, img_mod2 = self.img_mod(vec)
257
+ txt_mod1, txt_mod2 = self.txt_mod(vec)
258
+
259
+ # prepare image for attention
260
+ img_modulated = self.img_norm1(img)
261
+ img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
262
+ img_qkv = self.img_attn.qkv(img_modulated)
263
+ img_q, img_k, img_v = self.rearrange_for_norm(img_qkv)
264
+ img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)
265
+
266
+ # prepare txt for attention
267
+ txt_modulated = self.txt_norm1(txt)
268
+ txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
269
+ txt_qkv = self.txt_attn.qkv(txt_modulated)
270
+ txt_q, txt_k, txt_v = self.rearrange_for_norm(txt_qkv)
271
+ txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
272
+
273
+ q = torch.cat((txt_q, img_q), dim=2)
274
+ k = torch.cat((txt_k, img_k), dim=2)
275
+ v = torch.cat((txt_v, img_v), dim=2)
276
+
277
+ attn = attention(q, k, v, pe=pe)
278
+ txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :]
279
+ # calculate the img bloks
280
+ img = img + img_mod1.gate * self.img_attn.proj(img_attn)
281
+ img = img + img_mod2.gate * self.img_mlp(
282
+ (1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift
283
+ ).clamp(min=-384, max=384)
284
+
285
+ # calculate the txt bloks
286
+ txt = txt + txt_mod1.gate * self.txt_attn.proj(txt_attn)
287
+ txt = txt + txt_mod2.gate * self.txt_mlp(
288
+ (1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift
289
+ ).clamp(min=-384, max=384)
290
+
291
+ return img, txt
292
+
293
+
294
+ class SingleStreamBlock(nn.Module):
295
+ """
296
+ A DiT block with parallel linear layers as described in
297
+ https://arxiv.org/abs/2302.05442 and adapted modulation interface.
298
+ """
299
+
300
+ def __init__(
301
+ self,
302
+ hidden_size: int,
303
+ num_heads: int,
304
+ mlp_ratio: float = 4.0,
305
+ qk_scale: float | None = None,
306
+ dtype: torch.dtype = torch.bfloat16,
307
+ ):
308
+ super().__init__()
309
+ self.dtype = dtype
310
+ self.hidden_dim = hidden_size
311
+ self.num_heads = num_heads
312
+ head_dim = hidden_size // num_heads
313
+ self.scale = qk_scale or head_dim**-0.5
314
+
315
+ self.mlp_hidden_dim = int(hidden_size * mlp_ratio)
316
+ # qkv and mlp_in
317
+ self.linear1 = nn.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim)
318
+ # proj and mlp_out
319
+ self.linear2 = F16Linear(hidden_size + self.mlp_hidden_dim, hidden_size)
320
+
321
+ self.norm = QKNorm(head_dim)
322
+
323
+ self.hidden_size = hidden_size
324
+ self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
325
+
326
+ self.mlp_act = nn.GELU(approximate="tanh")
327
+ self.modulation = Modulation(hidden_size, double=False)
328
+ self.rearrange_for_norm = Rearrange(
329
+ "B L (K H D) -> K B H L D", K=3, H=num_heads
330
+ )
331
+
332
+ def forward(self, x: Tensor, vec: Tensor, pe: Tensor) -> Tensor:
333
+ mod = self.modulation(vec)[0]
334
+ pre_norm = self.pre_norm(x)
335
+ x_mod = (1 + mod.scale) * pre_norm + mod.shift
336
+ qkv, mlp = torch.split(
337
+ self.linear1(x_mod),
338
+ [3 * self.hidden_size, self.mlp_hidden_dim],
339
+ dim=-1,
340
+ )
341
+ q, k, v = self.rearrange_for_norm(qkv)
342
+ q, k = self.norm(q, k, v)
343
+ attn = attention(q, k, v, pe=pe)
344
+ output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2)).clamp(
345
+ min=-384, max=384
346
+ )
347
+ return x + mod.gate * output
348
+
349
+
350
+ class LastLayer(nn.Module):
351
+ def __init__(self, hidden_size: int, patch_size: int, out_channels: int):
352
+ super().__init__()
353
+ self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
354
+ self.linear = nn.Linear(
355
+ hidden_size, patch_size * patch_size * out_channels, bias=True
356
+ )
357
+ self.adaLN_modulation = nn.Sequential(
358
+ nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True)
359
+ )
360
+
361
+ def forward(self, x: Tensor, vec: Tensor) -> Tensor:
362
+ shift, scale = self.adaLN_modulation(vec).chunk(2, dim=1)
363
+ x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :]
364
+ x = self.linear(x)
365
+ return x
366
+
367
+
368
+ class Flux(nn.Module):
369
+ """
370
+ Transformer model for flow matching on sequences.
371
+ """
372
+
373
+ def __init__(self, params: FluxParams, dtype: torch.dtype = torch.bfloat16):
374
+ super().__init__()
375
+
376
+ self.dtype = dtype
377
+ self.params = params
378
+ self.in_channels = params.in_channels
379
+ self.out_channels = self.in_channels
380
+ if params.hidden_size % params.num_heads != 0:
381
+ raise ValueError(
382
+ f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}"
383
+ )
384
+ pe_dim = params.hidden_size // params.num_heads
385
+ if sum(params.axes_dim) != pe_dim:
386
+ raise ValueError(
387
+ f"Got {params.axes_dim} but expected positional dim {pe_dim}"
388
+ )
389
+ self.hidden_size = params.hidden_size
390
+ self.num_heads = params.num_heads
391
+ self.pe_embedder = EmbedND(
392
+ dim=pe_dim,
393
+ theta=params.theta,
394
+ axes_dim=params.axes_dim,
395
+ dtype=self.dtype,
396
+ )
397
+ self.img_in = F16Linear(self.in_channels, self.hidden_size, bias=True)
398
+ self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
399
+ self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size)
400
+ self.guidance_in = (
401
+ MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
402
+ if params.guidance_embed
403
+ else nn.Identity()
404
+ )
405
+ self.txt_in = F16Linear(params.context_in_dim, self.hidden_size)
406
+
407
+ self.double_blocks = nn.ModuleList(
408
+ [
409
+ DoubleStreamBlock(
410
+ self.hidden_size,
411
+ self.num_heads,
412
+ mlp_ratio=params.mlp_ratio,
413
+ qkv_bias=params.qkv_bias,
414
+ dtype=self.dtype,
415
+ idx=idx,
416
+ )
417
+ for idx in range(params.depth)
418
+ ]
419
+ )
420
+
421
+ self.single_blocks = nn.ModuleList(
422
+ [
423
+ SingleStreamBlock(
424
+ self.hidden_size,
425
+ self.num_heads,
426
+ mlp_ratio=params.mlp_ratio,
427
+ dtype=self.dtype,
428
+ )
429
+ for _ in range(params.depth_single_blocks)
430
+ ]
431
+ )
432
+
433
+ self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels)
434
+
435
+ def forward(
436
+ self,
437
+ img: Tensor,
438
+ img_ids: Tensor,
439
+ txt: Tensor,
440
+ txt_ids: Tensor,
441
+ timesteps: Tensor,
442
+ y: Tensor,
443
+ guidance: Tensor | None = None,
444
+ ) -> Tensor:
445
+ if img.ndim != 3 or txt.ndim != 3:
446
+ raise ValueError("Input img and txt tensors must have 3 dimensions.")
447
+
448
+ # running on sequences img
449
+ img = self.img_in(img)
450
+ vec = self.time_in(timestep_embedding(timesteps, 256).type(self.dtype))
451
+
452
+ if self.params.guidance_embed:
453
+ if guidance is None:
454
+ raise ValueError(
455
+ "Didn't get guidance strength for guidance distilled model."
456
+ )
457
+ vec = vec + self.guidance_in(
458
+ timestep_embedding(guidance, 256).type(self.dtype)
459
+ )
460
+ vec = vec + self.vector_in(y)
461
+
462
+ txt = self.txt_in(txt)
463
+
464
+ ids = torch.cat((txt_ids, img_ids), dim=1)
465
+ pe = self.pe_embedder(ids)
466
+
467
+ for i, block in enumerate(self.double_blocks):
468
+ img, txt = block(img=img, txt=txt, vec=vec, pe=pe)
469
+
470
+ img = torch.cat((txt, img), 1)
471
+ for block in self.single_blocks:
472
+ img = block(img, vec=vec, pe=pe)
473
+
474
+ img = img[:, txt.shape[1] :, ...]
475
+ img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
476
+ return img
477
+
478
+ @classmethod
479
+ def from_safetensors(
480
+ self,
481
+ model_path: str,
482
+ model_params: FluxParams,
483
+ dtype: torch.dtype = torch.bfloat16,
484
+ device: torch.device = torch.device(
485
+ "cuda" if torch.cuda.is_available() else "cpu"
486
+ ),
487
+ ):
488
+
489
+ model = Flux(params=model_params, dtype=dtype)
490
+ model.load_state_dict(model_path.state_dict())
491
+ model.to(device)
492
+ return model
requirements.txt ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ git+https://github.com/aredden/torch-cublas-hgemm.git@master
2
+ git+https://github.com/pytorch/ao.git@main
3
+ einops
4
+ PyTurboJPEG
5
+ pydantic
6
+ fastapi
7
+ bitsandbytes
8
+ xformers
9
+ loguru
10
+ transformers
11
+ tokenizers
12
+ sentencepiece
sampling.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import Callable
3
+
4
+ import torch
5
+ from einops import rearrange, repeat
6
+ from torch import Tensor
7
+
8
+ from modules.flux_model import Flux
9
+ from modules.conditioner import HFEmbedder
10
+
11
+
12
+ @torch.inference_mode()
13
+ def get_noise(
14
+ num_samples: int,
15
+ height: int,
16
+ width: int,
17
+ device: torch.device,
18
+ dtype: torch.dtype,
19
+ seed: int,
20
+ ):
21
+ return torch.randn(
22
+ num_samples,
23
+ 16,
24
+ # allow for packing
25
+ 2 * math.ceil(height / 16),
26
+ 2 * math.ceil(width / 16),
27
+ device=device,
28
+ dtype=dtype,
29
+ generator=torch.Generator(device=device).manual_seed(seed),
30
+ )
31
+
32
+
33
+ @torch.inference_mode()
34
+ def prepare(
35
+ t5: HFEmbedder, clip: HFEmbedder, img: Tensor, prompt: str | list[str]
36
+ ) -> dict[str, Tensor]:
37
+ bs, c, h, w = img.shape
38
+ if bs == 1 and not isinstance(prompt, str):
39
+ bs = len(prompt)
40
+
41
+ img = rearrange(img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
42
+ if img.shape[0] == 1 and bs > 1:
43
+ img = repeat(img, "1 ... -> bs ...", bs=bs)
44
+
45
+ img_ids = torch.zeros(h // 2, w // 2, 3)
46
+ img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None]
47
+ img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :]
48
+ img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
49
+
50
+ if isinstance(prompt, str):
51
+ prompt = [prompt]
52
+ txt = t5(prompt)
53
+ if txt.shape[0] == 1 and bs > 1:
54
+ txt = repeat(txt, "1 ... -> bs ...", bs=bs)
55
+ txt_ids = torch.zeros(bs, txt.shape[1], 3)
56
+
57
+ vec = clip(prompt)
58
+ if vec.shape[0] == 1 and bs > 1:
59
+ vec = repeat(vec, "1 ... -> bs ...", bs=bs)
60
+
61
+ return {
62
+ "img": img,
63
+ "img_ids": img_ids.to(img.device),
64
+ "txt": txt.to(img.device),
65
+ "txt_ids": txt_ids.to(img.device),
66
+ "vec": vec.to(img.device),
67
+ }
68
+
69
+
70
+ def time_shift(mu: float, sigma: float, t: Tensor):
71
+ return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
72
+
73
+
74
+ def get_lin_function(
75
+ x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15
76
+ ) -> Callable[[float], float]:
77
+ m = (y2 - y1) / (x2 - x1)
78
+ b = y1 - m * x1
79
+ return lambda x: m * x + b
80
+
81
+
82
+ def get_schedule(
83
+ num_steps: int,
84
+ image_seq_len: int,
85
+ base_shift: float = 0.5,
86
+ max_shift: float = 1.15,
87
+ shift: bool = True,
88
+ ) -> list[float]:
89
+ # extra step for zero
90
+ timesteps = torch.linspace(1, 0, num_steps + 1)
91
+
92
+ # shifting the schedule to favor high timesteps for higher signal images
93
+ if shift:
94
+ # eastimate mu based on linear estimation between two points
95
+ mu = get_lin_function(y1=base_shift, y2=max_shift)(image_seq_len)
96
+ timesteps = time_shift(mu, 1.0, timesteps)
97
+
98
+ return timesteps.tolist()
99
+
100
+
101
+ @torch.inference_mode()
102
+ def denoise(
103
+ model: Flux,
104
+ # model input
105
+ img: Tensor,
106
+ img_ids: Tensor,
107
+ txt: Tensor,
108
+ txt_ids: Tensor,
109
+ vec: Tensor,
110
+ # sampling parameters
111
+ timesteps: list[float],
112
+ guidance: float = 4.0,
113
+ dtype: torch.dtype = torch.bfloat16,
114
+ device: torch.device = torch.device("cuda:0"),
115
+ ):
116
+ from tqdm import tqdm
117
+
118
+ # this is ignored for schnell
119
+ img = img.to(device=device, dtype=dtype)
120
+ img_ids = img_ids.to(device=device, dtype=dtype)
121
+ txt = txt.to(device=device, dtype=dtype)
122
+ txt_ids = txt_ids.to(device=device, dtype=dtype)
123
+ vec = vec.to(device=device, dtype=dtype)
124
+ guidance_vec = torch.full((img.shape[0],), guidance, device=device, dtype=dtype)
125
+ for t_curr, t_prev in tqdm(
126
+ zip(timesteps[:-1], timesteps[1:]), total=len(timesteps) - 1
127
+ ):
128
+ t_vec = torch.full((img.shape[0],), t_curr, dtype=dtype, device=device)
129
+ pred = model(
130
+ img=img,
131
+ img_ids=img_ids,
132
+ txt=txt,
133
+ txt_ids=txt_ids,
134
+ y=vec,
135
+ timesteps=t_vec,
136
+ guidance=guidance_vec,
137
+ )
138
+
139
+ img = img + (t_prev - t_curr) * pred
140
+
141
+ return img
142
+
143
+
144
+ def unpack(x: Tensor, height: int, width: int) -> Tensor:
145
+ return rearrange(
146
+ x,
147
+ "b (h w) (c ph pw) -> b c (h ph) (w pw)",
148
+ h=math.ceil(height / 16),
149
+ w=math.ceil(width / 16),
150
+ ph=2,
151
+ pw=2,
152
+ )
turbojpeg_imgs.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ from turbojpeg import (
4
+ TurboJPEG,
5
+ TJPF_GRAY,
6
+ TJFLAG_PROGRESSIVE,
7
+ TJFLAG_FASTUPSAMPLE,
8
+ TJFLAG_FASTDCT,
9
+ TJPF_RGB,
10
+ TJPF_BGR,
11
+ TJSAMP_GRAY,
12
+ TJSAMP_411,
13
+ TJSAMP_420,
14
+ TJSAMP_422,
15
+ TJSAMP_444,
16
+ TJSAMP_440,
17
+ TJSAMP_441,
18
+ )
19
+
20
+
21
+ class Subsampling:
22
+ S411 = TJSAMP_411
23
+ S420 = TJSAMP_420
24
+ S422 = TJSAMP_422
25
+ S444 = TJSAMP_444
26
+ S440 = TJSAMP_440
27
+ S441 = TJSAMP_441
28
+ GRAY = TJSAMP_GRAY
29
+
30
+
31
+ class Flags:
32
+ PROGRESSIVE = TJFLAG_PROGRESSIVE
33
+ FASTUPSAMPLE = TJFLAG_FASTUPSAMPLE
34
+ FASTDCT = TJFLAG_FASTDCT
35
+
36
+
37
+ class PixelFormat:
38
+ GRAY = TJPF_GRAY
39
+ RGB = TJPF_RGB
40
+ BGR = TJPF_BGR
41
+
42
+
43
+ class TurboImage:
44
+ def __init__(self):
45
+ self.tj = TurboJPEG()
46
+ self.flags = Flags.PROGRESSIVE
47
+
48
+ self.subsampling_gray = Subsampling.GRAY
49
+ self.pixel_format_gray = PixelFormat.GRAY
50
+ self.subsampling_rgb = Subsampling.S420
51
+ self.pixel_format_rgb = PixelFormat.RGB
52
+
53
+ def set_subsampling_gray(self, subsampling):
54
+ self.subsampling_gray = subsampling
55
+
56
+ def set_subsampling_rgb(self, subsampling):
57
+ self.subsampling_rgb = subsampling
58
+
59
+ def set_pixel_format_gray(self, pixel_format):
60
+ self.pixel_format_gray = pixel_format
61
+
62
+ def set_pixel_format_rgb(self, pixel_format):
63
+ self.pixel_format_rgb = pixel_format
64
+
65
+ def set_flags(self, flags):
66
+ self.flags = flags
67
+
68
+ def encode(
69
+ self,
70
+ img,
71
+ subsampling,
72
+ pixel_format,
73
+ quality=90,
74
+ ):
75
+ return self.tj.encode(
76
+ img,
77
+ quality=quality,
78
+ flags=self.flags,
79
+ pixel_format=pixel_format,
80
+ jpeg_subsample=subsampling,
81
+ )
82
+
83
+ @torch.inference_mode()
84
+ def encode_torch(self, img: torch.Tensor, quality=90):
85
+ if img.ndim == 2:
86
+ subsampling = self.subsampling_gray
87
+ pixel_format = self.pixel_format_gray
88
+ img = img.clamp(0, 255).cpu().contiguous().numpy().astype(np.uint8)
89
+ elif img.ndim == 3:
90
+ subsampling = self.subsampling_rgb
91
+ pixel_format = self.pixel_format_rgb
92
+ if img.shape[0] == 3:
93
+ img = (
94
+ img.permute(1, 2, 0)
95
+ .clamp(0, 255)
96
+ .cpu()
97
+ .contiguous()
98
+ .numpy()
99
+ .astype(np.uint8)
100
+ )
101
+ elif img.shape[2] == 3:
102
+ img = img.clamp(0, 255).cpu().contiguous().numpy().astype(np.uint8)
103
+ else:
104
+ raise ValueError(f"Unsupported image shape: {img.shape}")
105
+ else:
106
+ raise ValueError(f"Unsupported image num dims: {img.ndim}")
107
+
108
+ return self.encode(
109
+ img,
110
+ quality=quality,
111
+ subsampling=subsampling,
112
+ pixel_format=pixel_format,
113
+ )
114
+
115
+ def encode_numpy(self, img: np.ndarray, quality=90):
116
+ if img.ndim == 2:
117
+ subsampling = self.subsampling_gray
118
+ pixel_format = self.pixel_format_gray
119
+ elif img.ndim == 3:
120
+ if img.shape[0] == 3:
121
+ img = np.ascontiguousarray(img.transpose(1, 2, 0))
122
+ elif img.shape[2] == 3:
123
+ img = np.ascontiguousarray(img)
124
+ else:
125
+ raise ValueError(f"Unsupported image shape: {img.shape}")
126
+ subsampling = self.subsampling_rgb
127
+ pixel_format = self.pixel_format_rgb
128
+ else:
129
+ raise ValueError(f"Unsupported image num dims: {img.ndim}")
130
+
131
+ img = img.clip(0, 255).astype(np.uint8)
132
+ return self.encode(
133
+ img, quality=quality, subsampling=subsampling, pixel_format=pixel_format
134
+ )
util.py ADDED
@@ -0,0 +1,275 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from pathlib import Path
3
+ from typing import Optional
4
+
5
+ import torch
6
+ from modules.autoencoder import AutoEncoder, AutoEncoderParams
7
+ from modules.conditioner import HFEmbedder
8
+ from modules.flux_model import Flux, FluxParams
9
+
10
+ from safetensors.torch import load_file as load_sft
11
+ from enum import StrEnum
12
+ from pydantic import BaseModel, ConfigDict
13
+ from loguru import logger
14
+
15
+
16
+ class ModelVersion(StrEnum):
17
+ flux_dev = "flux-dev"
18
+ flux_schnell = "flux-schnell"
19
+
20
+
21
+ class ModelSpec(BaseModel):
22
+ version: ModelVersion
23
+ params: FluxParams
24
+ ae_params: AutoEncoderParams
25
+ ckpt_path: str | None
26
+ ae_path: str | None
27
+ repo_id: str | None
28
+ repo_flow: str | None
29
+ repo_ae: str | None
30
+ text_enc_max_length: int = 512
31
+ text_enc_path: str | None
32
+ text_enc_device: str | torch.device | None = "cuda:0"
33
+ ae_device: str | torch.device | None = "cuda:0"
34
+ flux_device: str | torch.device | None = "cuda:0"
35
+ flow_dtype: str = "float16"
36
+ ae_dtype: str = "bfloat16"
37
+ text_enc_dtype: str = "bfloat16"
38
+ num_to_quant: Optional[int] = 20
39
+
40
+ model_config: ConfigDict = {
41
+ "arbitrary_types_allowed": True,
42
+ "use_enum_values": True,
43
+ }
44
+
45
+
46
+ def load_models(config: ModelSpec) -> tuple[Flux, AutoEncoder, HFEmbedder, HFEmbedder]:
47
+ flow = load_flow_model(config)
48
+ ae = load_autoencoder(config)
49
+ clip, t5 = load_text_encoders(config)
50
+ return flow, ae, clip, t5
51
+
52
+
53
+ def parse_device(device: str | torch.device | None) -> torch.device:
54
+ if isinstance(device, str):
55
+ return torch.device(device)
56
+ elif isinstance(device, torch.device):
57
+ return device
58
+ else:
59
+ return torch.device("cuda:0")
60
+
61
+
62
+ def into_dtype(dtype: str) -> torch.dtype:
63
+ if dtype == "float16":
64
+ return torch.float16
65
+ elif dtype == "bfloat16":
66
+ return torch.bfloat16
67
+ elif dtype == "float32":
68
+ return torch.float32
69
+ else:
70
+ raise ValueError(f"Invalid dtype: {dtype}")
71
+
72
+
73
+ def into_device(device: str | torch.device | None) -> torch.device:
74
+ if isinstance(device, str):
75
+ return torch.device(device)
76
+ elif isinstance(device, torch.device):
77
+ return device
78
+ elif isinstance(device, int):
79
+ return torch.device(f"cuda:{device}")
80
+ else:
81
+ return torch.device("cuda:0")
82
+
83
+
84
+ def load_config(
85
+ name: ModelVersion = ModelVersion.flux_dev,
86
+ flux_path: str | None = None,
87
+ ae_path: str | None = None,
88
+ text_enc_path: str | None = None,
89
+ text_enc_device: str | torch.device | None = None,
90
+ ae_device: str | torch.device | None = None,
91
+ flux_device: str | torch.device | None = None,
92
+ flow_dtype: str = "float16",
93
+ ae_dtype: str = "bfloat16",
94
+ text_enc_dtype: str = "bfloat16",
95
+ num_to_quant: Optional[int] = 20,
96
+ ):
97
+ text_enc_device = str(parse_device(text_enc_device))
98
+ ae_device = str(parse_device(ae_device))
99
+ flux_device = str(parse_device(flux_device))
100
+ return ModelSpec(
101
+ version=name,
102
+ repo_id=(
103
+ "black-forest-labs/FLUX.1-dev"
104
+ if name == ModelVersion.flux_dev
105
+ else "black-forest-labs/FLUX.1-schnell"
106
+ ),
107
+ repo_flow=(
108
+ "flux1-dev.sft" if name == ModelVersion.flux_dev else "flux1-schnell.sft"
109
+ ),
110
+ repo_ae="ae.sft",
111
+ ckpt_path=flux_path,
112
+ params=FluxParams(
113
+ in_channels=64,
114
+ vec_in_dim=768,
115
+ context_in_dim=4096,
116
+ hidden_size=3072,
117
+ mlp_ratio=4.0,
118
+ num_heads=24,
119
+ depth=19,
120
+ depth_single_blocks=38,
121
+ axes_dim=[16, 56, 56],
122
+ theta=10_000,
123
+ qkv_bias=True,
124
+ guidance_embed=True,
125
+ ),
126
+ ae_path=ae_path,
127
+ ae_params=AutoEncoderParams(
128
+ resolution=256,
129
+ in_channels=3,
130
+ ch=128,
131
+ out_ch=3,
132
+ ch_mult=[1, 2, 4, 4],
133
+ num_res_blocks=2,
134
+ z_channels=16,
135
+ scale_factor=0.3611,
136
+ shift_factor=0.1159,
137
+ ),
138
+ text_enc_path=text_enc_path,
139
+ text_enc_device=text_enc_device,
140
+ ae_device=ae_device,
141
+ flux_device=flux_device,
142
+ flow_dtype=flow_dtype,
143
+ ae_dtype=ae_dtype,
144
+ text_enc_dtype=text_enc_dtype,
145
+ text_enc_max_length=512 if name == ModelVersion.flux_dev else 256,
146
+ num_to_quant=num_to_quant,
147
+ )
148
+
149
+
150
+ def load_config_from_path(path: str) -> ModelSpec:
151
+ path_path = Path(path)
152
+ if not path_path.exists():
153
+ raise ValueError(f"Path {path} does not exist")
154
+ if not path_path.is_file():
155
+ raise ValueError(f"Path {path} is not a file")
156
+ return ModelSpec(**json.loads(path_path.read_text()))
157
+
158
+
159
+ def print_load_warning(missing: list[str], unexpected: list[str]) -> None:
160
+ if len(missing) > 0 and len(unexpected) > 0:
161
+ logger.warning(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing))
162
+ logger.warning("\n" + "-" * 79 + "\n")
163
+ logger.warning(
164
+ f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected)
165
+ )
166
+ elif len(missing) > 0:
167
+ logger.warning(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing))
168
+ elif len(unexpected) > 0:
169
+ logger.warning(
170
+ f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected)
171
+ )
172
+
173
+
174
+ def load_flow_model(config: ModelSpec) -> Flux:
175
+ ckpt_path = config.ckpt_path
176
+
177
+ with torch.device("meta"):
178
+ model = Flux(config.params, dtype=into_dtype(config.flow_dtype)).type(
179
+ into_dtype(config.flow_dtype)
180
+ )
181
+
182
+ if ckpt_path is not None:
183
+ # load_sft doesn't support torch.device
184
+ sd = load_sft(ckpt_path, device="cpu")
185
+ missing, unexpected = model.load_state_dict(sd, strict=False, assign=True)
186
+ print_load_warning(missing, unexpected)
187
+ return model
188
+
189
+
190
+ def load_text_encoders(config: ModelSpec) -> tuple[HFEmbedder, HFEmbedder]:
191
+ clip = HFEmbedder(
192
+ "openai/clip-vit-large-patch14",
193
+ max_length=77,
194
+ torch_dtype=into_dtype(config.text_enc_dtype),
195
+ device=into_device(config.text_enc_device),
196
+ )
197
+ t5 = HFEmbedder(
198
+ config.text_enc_path,
199
+ max_length=config.text_enc_max_length,
200
+ torch_dtype=into_dtype(config.text_enc_dtype),
201
+ device=into_device(config.text_enc_device).index or 0,
202
+ )
203
+ return clip, t5
204
+
205
+
206
+ def load_autoencoder(config: ModelSpec) -> AutoEncoder:
207
+ ckpt_path = config.ae_path
208
+ with torch.device("meta" if ckpt_path is not None else config.ae_device):
209
+ ae = AutoEncoder(config.ae_params)
210
+
211
+ if ckpt_path is not None:
212
+ sd = load_sft(ckpt_path, device=str(config.ae_device))
213
+ missing, unexpected = ae.load_state_dict(sd, strict=False, assign=True)
214
+ print_load_warning(missing, unexpected)
215
+ return ae
216
+
217
+
218
+ class LoadedModels(BaseModel):
219
+ flow: Flux
220
+ ae: AutoEncoder
221
+ clip: HFEmbedder
222
+ t5: HFEmbedder
223
+ config: ModelSpec
224
+
225
+ model_config = {
226
+ "arbitrary_types_allowed": True,
227
+ "use_enum_values": True,
228
+ }
229
+
230
+
231
+ def load_models_from_config_path(
232
+ path: str,
233
+ ) -> LoadedModels:
234
+ config = load_config_from_path(path)
235
+ clip, t5 = load_text_encoders(config)
236
+ return LoadedModels(
237
+ flow=load_flow_model(config),
238
+ ae=load_autoencoder(config),
239
+ clip=clip,
240
+ t5=t5,
241
+ config=config,
242
+ )
243
+
244
+
245
+ def load_models_from_config(config: ModelSpec) -> LoadedModels:
246
+ clip, t5 = load_text_encoders(config)
247
+ return LoadedModels(
248
+ flow=load_flow_model(config),
249
+ ae=load_autoencoder(config),
250
+ clip=clip,
251
+ t5=t5,
252
+ config=config,
253
+ )
254
+
255
+
256
+ if __name__ == "__main__":
257
+ p = "/big/generator-ui/flux-testing/flux/model-dir/flux1-dev.sft"
258
+ ae_p = "/big/generator-ui/flux-testing/flux/model-dir/ae.sft"
259
+
260
+ config = load_config(
261
+ ModelVersion.flux_dev,
262
+ flux_path=p,
263
+ ae_path=ae_p,
264
+ text_enc_path="city96/t5-v1_1-xxl-encoder-bf16",
265
+ text_enc_device="cuda:0",
266
+ ae_device="cuda:0",
267
+ flux_device="cuda:0",
268
+ flow_dtype="float16",
269
+ ae_dtype="bfloat16",
270
+ text_enc_dtype="bfloat16",
271
+ num_to_quant=20,
272
+ )
273
+ with open("configs/config-dev-cuda0.json", "w") as f:
274
+ json.dump(config.model_dump(), f, indent=2)
275
+ print(config)