Safetensors
File size: 6,291 Bytes
d9aea20
 
 
 
 
 
 
 
c2ecfb5
d9aea20
 
 
 
 
c2ecfb5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d9aea20
 
c2ecfb5
d9aea20
 
c2ecfb5
d9aea20
 
 
 
 
 
 
c2ecfb5
d9aea20
 
 
 
 
 
c2ecfb5
d9aea20
 
 
 
 
 
c2ecfb5
d9aea20
 
 
 
 
 
c2ecfb5
d9aea20
 
 
 
 
c2ecfb5
 
 
 
 
 
 
9dc5b0b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
604f17d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d9aea20
 
 
 
 
 
9dc5b0b
 
 
 
d9aea20
28dec30
 
 
d9aea20
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c2ecfb5
 
9dc5b0b
 
 
 
 
 
 
 
604f17d
 
d9aea20
c2ecfb5
d9aea20
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
import argparse
import uvicorn
from api import app


def parse_args():
    parser = argparse.ArgumentParser(description="Launch Flux API server")
    parser.add_argument(
        "-c",
        "--config-path",
        type=str,
        help="Path to the configuration file, if not provided, the model will be loaded from the command line arguments",
    )
    parser.add_argument(
        "-p",
        "--port",
        type=int,
        default=8088,
        help="Port to run the server on",
    )
    parser.add_argument(
        "-H",
        "--host",
        type=str,
        default="0.0.0.0",
        help="Host to run the server on",
    )
    parser.add_argument(
        "-f", "--flow-model-path", type=str, help="Path to the flow model"
    )
    parser.add_argument(
        "-t", "--text-enc-path", type=str, help="Path to the text encoder"
    )
    parser.add_argument(
        "-a", "--autoencoder-path", type=str, help="Path to the autoencoder"
    )
    parser.add_argument(
        "-m",
        "--model-version",
        type=str,
        choices=["flux-dev", "flux-schnell"],
        default="flux-dev",
        help="Choose model version",
    )
    parser.add_argument(
        "-F",
        "--flux-device",
        type=str,
        default="cuda:0",
        help="Device to run the flow model on",
    )
    parser.add_argument(
        "-T",
        "--text-enc-device",
        type=str,
        default="cuda:0",
        help="Device to run the text encoder on",
    )
    parser.add_argument(
        "-A",
        "--autoencoder-device",
        type=str,
        default="cuda:0",
        help="Device to run the autoencoder on",
    )
    parser.add_argument(
        "-q",
        "--num-to-quant",
        type=int,
        default=20,
        help="Number of linear layers in flow transformer (the 'unet') to quantize",
    )
    parser.add_argument(
        "-C",
        "--compile",
        action="store_true",
        default=False,
        help="Compile the flow model with extra optimizations",
    )
    parser.add_argument(
        "-qT",
        "--quant-text-enc",
        type=str,
        default="qfloat8",
        choices=["qint4", "qfloat8", "qint2", "qint8", "bf16"],
        help="Quantize the t5 text encoder to the given dtype, if bf16, will not quantize",
        dest="quant_text_enc",
    )
    parser.add_argument(
        "-qA",
        "--quant-ae",
        action="store_true",
        default=False,
        help="Quantize the autoencoder with float8 linear layers, otherwise will use bfloat16",
        dest="quant_ae",
    )
    parser.add_argument(
        "-OF",
        "--offload-flow",
        action="store_true",
        default=False,
        dest="offload_flow",
        help="Offload the flow model to the CPU when not being used to save memory",
    )
    parser.add_argument(
        "-OA",
        "--no-offload-ae",
        action="store_false",
        default=True,
        dest="offload_ae",
        help="Disable offloading the autoencoder to the CPU when not being used to increase e2e inference speed",
    )
    parser.add_argument(
        "-OT",
        "--no-offload-text-enc",
        action="store_false",
        default=True,
        dest="offload_text_enc",
        help="Disable offloading the text encoder to the CPU when not being used to increase e2e inference speed",
    )
    parser.add_argument(
        "-PF",
        "--prequantized-flow",
        action="store_true",
        default=False,
        dest="prequantized_flow",
        help="Load the flow model from a prequantized checkpoint "
        + "(requires loading the flow model, running a minimum of 24 steps, "
        + "and then saving the state_dict as a safetensors file), "
        + "which reduces the size of the checkpoint by about 50% & reduces startup time",
    )
    parser.add_argument(
        "-nqfm",
        "--no-quantize-flow-modulation",
        action="store_false",
        default=True,
        dest="quantize_modulation",
        help="Disable quantization of the modulation layers in the flow model, adds ~2GB vram usage for moderate precision improvements",
    )
    parser.add_argument(
        "-qfl",
        "--quantize-flow-embedder-layers",
        action="store_true",
        default=False,
        dest="quantize_flow_embedder_layers",
        help="Quantize the flow embedder layers in the flow model, saves ~512MB vram usage, but precision loss is very noticeable",
    )
    return parser.parse_args()


def main():
    args = parse_args()

    # lazy loading so cli returns fast instead of waiting for torch to load modules
    from flux_pipeline import FluxPipeline
    from util import load_config, ModelVersion

    if args.config_path:
        app.state.model = FluxPipeline.load_pipeline_from_config_path(
            args.config_path, flow_model_path=args.flow_model_path
        )
    else:
        model_version = (
            ModelVersion.flux_dev
            if args.model_version == "flux-dev"
            else ModelVersion.flux_schnell
        )
        config = load_config(
            model_version,
            flux_path=args.flow_model_path,
            flux_device=args.flux_device,
            ae_path=args.autoencoder_path,
            ae_device=args.autoencoder_device,
            text_enc_path=args.text_enc_path,
            text_enc_device=args.text_enc_device,
            flow_dtype="float16",
            text_enc_dtype="bfloat16",
            ae_dtype="bfloat16",
            num_to_quant=args.num_to_quant,
            compile_extras=args.compile,
            compile_blocks=args.compile,
            quant_text_enc=(
                None if args.quant_text_enc == "bf16" else args.quant_text_enc
            ),
            quant_ae=args.quant_ae,
            offload_flow=args.offload_flow,
            offload_ae=args.offload_ae,
            offload_text_enc=args.offload_text_enc,
            prequantized_flow=args.prequantized_flow,
            quantize_modulation=args.quantize_modulation,
            quantize_flow_embedder_layers=args.quantize_flow_embedder_layers,
        )
        app.state.model = FluxPipeline.load_pipeline_from_config(config)

    uvicorn.run(app, host=args.host, port=args.port)


if __name__ == "__main__":
    main()