File size: 6,291 Bytes
d9aea20 c2ecfb5 d9aea20 c2ecfb5 d9aea20 c2ecfb5 d9aea20 c2ecfb5 d9aea20 c2ecfb5 d9aea20 c2ecfb5 d9aea20 c2ecfb5 d9aea20 c2ecfb5 d9aea20 c2ecfb5 9dc5b0b 604f17d d9aea20 9dc5b0b d9aea20 28dec30 d9aea20 c2ecfb5 9dc5b0b 604f17d d9aea20 c2ecfb5 d9aea20 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 |
import argparse
import uvicorn
from api import app
def parse_args():
parser = argparse.ArgumentParser(description="Launch Flux API server")
parser.add_argument(
"-c",
"--config-path",
type=str,
help="Path to the configuration file, if not provided, the model will be loaded from the command line arguments",
)
parser.add_argument(
"-p",
"--port",
type=int,
default=8088,
help="Port to run the server on",
)
parser.add_argument(
"-H",
"--host",
type=str,
default="0.0.0.0",
help="Host to run the server on",
)
parser.add_argument(
"-f", "--flow-model-path", type=str, help="Path to the flow model"
)
parser.add_argument(
"-t", "--text-enc-path", type=str, help="Path to the text encoder"
)
parser.add_argument(
"-a", "--autoencoder-path", type=str, help="Path to the autoencoder"
)
parser.add_argument(
"-m",
"--model-version",
type=str,
choices=["flux-dev", "flux-schnell"],
default="flux-dev",
help="Choose model version",
)
parser.add_argument(
"-F",
"--flux-device",
type=str,
default="cuda:0",
help="Device to run the flow model on",
)
parser.add_argument(
"-T",
"--text-enc-device",
type=str,
default="cuda:0",
help="Device to run the text encoder on",
)
parser.add_argument(
"-A",
"--autoencoder-device",
type=str,
default="cuda:0",
help="Device to run the autoencoder on",
)
parser.add_argument(
"-q",
"--num-to-quant",
type=int,
default=20,
help="Number of linear layers in flow transformer (the 'unet') to quantize",
)
parser.add_argument(
"-C",
"--compile",
action="store_true",
default=False,
help="Compile the flow model with extra optimizations",
)
parser.add_argument(
"-qT",
"--quant-text-enc",
type=str,
default="qfloat8",
choices=["qint4", "qfloat8", "qint2", "qint8", "bf16"],
help="Quantize the t5 text encoder to the given dtype, if bf16, will not quantize",
dest="quant_text_enc",
)
parser.add_argument(
"-qA",
"--quant-ae",
action="store_true",
default=False,
help="Quantize the autoencoder with float8 linear layers, otherwise will use bfloat16",
dest="quant_ae",
)
parser.add_argument(
"-OF",
"--offload-flow",
action="store_true",
default=False,
dest="offload_flow",
help="Offload the flow model to the CPU when not being used to save memory",
)
parser.add_argument(
"-OA",
"--no-offload-ae",
action="store_false",
default=True,
dest="offload_ae",
help="Disable offloading the autoencoder to the CPU when not being used to increase e2e inference speed",
)
parser.add_argument(
"-OT",
"--no-offload-text-enc",
action="store_false",
default=True,
dest="offload_text_enc",
help="Disable offloading the text encoder to the CPU when not being used to increase e2e inference speed",
)
parser.add_argument(
"-PF",
"--prequantized-flow",
action="store_true",
default=False,
dest="prequantized_flow",
help="Load the flow model from a prequantized checkpoint "
+ "(requires loading the flow model, running a minimum of 24 steps, "
+ "and then saving the state_dict as a safetensors file), "
+ "which reduces the size of the checkpoint by about 50% & reduces startup time",
)
parser.add_argument(
"-nqfm",
"--no-quantize-flow-modulation",
action="store_false",
default=True,
dest="quantize_modulation",
help="Disable quantization of the modulation layers in the flow model, adds ~2GB vram usage for moderate precision improvements",
)
parser.add_argument(
"-qfl",
"--quantize-flow-embedder-layers",
action="store_true",
default=False,
dest="quantize_flow_embedder_layers",
help="Quantize the flow embedder layers in the flow model, saves ~512MB vram usage, but precision loss is very noticeable",
)
return parser.parse_args()
def main():
args = parse_args()
# lazy loading so cli returns fast instead of waiting for torch to load modules
from flux_pipeline import FluxPipeline
from util import load_config, ModelVersion
if args.config_path:
app.state.model = FluxPipeline.load_pipeline_from_config_path(
args.config_path, flow_model_path=args.flow_model_path
)
else:
model_version = (
ModelVersion.flux_dev
if args.model_version == "flux-dev"
else ModelVersion.flux_schnell
)
config = load_config(
model_version,
flux_path=args.flow_model_path,
flux_device=args.flux_device,
ae_path=args.autoencoder_path,
ae_device=args.autoencoder_device,
text_enc_path=args.text_enc_path,
text_enc_device=args.text_enc_device,
flow_dtype="float16",
text_enc_dtype="bfloat16",
ae_dtype="bfloat16",
num_to_quant=args.num_to_quant,
compile_extras=args.compile,
compile_blocks=args.compile,
quant_text_enc=(
None if args.quant_text_enc == "bf16" else args.quant_text_enc
),
quant_ae=args.quant_ae,
offload_flow=args.offload_flow,
offload_ae=args.offload_ae,
offload_text_enc=args.offload_text_enc,
prequantized_flow=args.prequantized_flow,
quantize_modulation=args.quantize_modulation,
quantize_flow_embedder_layers=args.quantize_flow_embedder_layers,
)
app.state.model = FluxPipeline.load_pipeline_from_config(config)
uvicorn.run(app, host=args.host, port=args.port)
if __name__ == "__main__":
main()
|