tight-inversion
Align with pulid demo
4d0ddc3
import json
import os
from dataclasses import dataclass
import torch
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file as load_sft
from flux.model import Flux, FluxParams
from flux.modules.autoencoder import AutoEncoder, AutoEncoderParams
from flux.modules.conditioner import HFEmbedder
@dataclass
class SamplingOptions:
prompt: str
width: int
height: int
num_steps: int
guidance: float
seed: int
@dataclass
class ModelSpec:
params: FluxParams
ae_params: AutoEncoderParams
ckpt_path: str
ae_path: str
repo_id: str
repo_flow: str
repo_ae: str
configs = {
"flux-dev": ModelSpec(
repo_id="black-forest-labs/FLUX.1-dev",
repo_flow="flux1-dev.safetensors",
repo_ae="ae.safetensors",
ckpt_path='models/flux1-dev.safetensors',
params=FluxParams(
in_channels=64,
vec_in_dim=768,
context_in_dim=4096,
hidden_size=3072,
mlp_ratio=4.0,
num_heads=24,
depth=19,
depth_single_blocks=38,
axes_dim=[16, 56, 56],
theta=10_000,
qkv_bias=True,
guidance_embed=True,
),
ae_path='models/ae.safetensors',
ae_params=AutoEncoderParams(
resolution=256,
in_channels=3,
ch=128,
out_ch=3,
ch_mult=[1, 2, 4, 4],
num_res_blocks=2,
z_channels=16,
scale_factor=0.3611,
shift_factor=0.1159,
),
),
"flux-schnell": ModelSpec(
repo_id="black-forest-labs/FLUX.1-schnell",
repo_flow="flux1-schnell.safetensors",
repo_ae="ae.safetensors",
ckpt_path=os.getenv("FLUX_SCHNELL"),
params=FluxParams(
in_channels=64,
vec_in_dim=768,
context_in_dim=4096,
hidden_size=3072,
mlp_ratio=4.0,
num_heads=24,
depth=19,
depth_single_blocks=38,
axes_dim=[16, 56, 56],
theta=10_000,
qkv_bias=True,
guidance_embed=False,
),
ae_path=os.getenv("AE"),
ae_params=AutoEncoderParams(
resolution=256,
in_channels=3,
ch=128,
out_ch=3,
ch_mult=[1, 2, 4, 4],
num_res_blocks=2,
z_channels=16,
scale_factor=0.3611,
shift_factor=0.1159,
),
),
}
def print_load_warning(missing: list[str], unexpected: list[str]) -> None:
if len(missing) > 0 and len(unexpected) > 0:
print(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing))
print("\n" + "-" * 79 + "\n")
print(f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected))
elif len(missing) > 0:
print(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing))
elif len(unexpected) > 0:
print(f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected))
def load_flow_model(name: str, device: str = "cuda", hf_download: bool = True):
# Loading Flux
print("Init model")
ckpt_path = configs[name].ckpt_path
if (
not os.path.exists(ckpt_path)
and configs[name].repo_id is not None
and configs[name].repo_flow is not None
and hf_download
):
ckpt_path = hf_hub_download(configs[name].repo_id, configs[name].repo_flow, local_dir='models')
with torch.device(device):
model = Flux(configs[name].params).to(torch.bfloat16)
if ckpt_path is not None:
print("Loading checkpoint")
# load_sft doesn't support torch.device
sd = load_sft(ckpt_path, device=str(device))
missing, unexpected = model.load_state_dict(sd, strict=False)
print_load_warning(missing, unexpected)
return model
def load_flow_model_quintized(
name: str,
device: str = "cuda",
hf_download: bool = True,
cache_path: str = None,
):
"""
Loads (or downloads) a FLUX-fp8 checkpoint, performs quantization once,
and caches the quantized model to disk. Future calls load from cache.
:param name: model name key in configs (e.g. "flux-dev-fp8")
:param device: Torch device string ("cuda" or "cpu")
:param hf_download: Whether to download from HF if local ckpt is missing
:param cache_path: Filepath for cached quantized model
:return: A quantized FLUX model on the specified device.
"""
if cache_path is None:
cache_path = os.path.join(os.path.expanduser("~"), ".cache/flux_dev_fp8_quantized_model.pth")
# 1) Check if we already have a cached, quantized model
if os.path.exists(cache_path):
print(f"Loading cached quantized model from '{cache_path}'...")
model = torch.load(cache_path, map_location=device)
return model.to(device)
# 2) If no cache, build and quantize for the first time.
print("No cached model found. Initializing + quantizing from scratch.")
# (A) Download or specify checkpoint paths
ckpt_path = "models/flux-dev-fp8.safetensors"
if not os.path.exists(ckpt_path) and hf_download:
print("Downloading model checkpoint from HF...")
ckpt_path = hf_hub_download("XLabs-AI/flux-dev-fp8", "flux-dev-fp8.safetensors")
print("Model downloaded to:", ckpt_path)
json_path = hf_hub_download("XLabs-AI/flux-dev-fp8", "flux_dev_quantization_map.json")
# (B) Build the unquantized model
print("Initializing model in bfloat16...")
model = Flux(configs[name].params).to(torch.bfloat16)
# (C) Load the unquantized weights
print("Loading unquantized checkpoint to CPU...")
sd = load_sft(ckpt_path, device="cpu") # CPU load
# (D) Load quantization map
with open(json_path, "r") as f:
quantization_map = json.load(f)
# (E) Quantize
print("Starting quantization process...")
from optimum.quanto import requantize
requantize(model, sd, quantization_map, device=device)
print("Quantization complete.")
# (F) Cache the fully quantized model to disk
print(f"Saving the quantized model to '{cache_path}'...")
torch.save(model, cache_path)
print("Model saved. Future runs will load from cache.")
return model.to(device)
def load_t5(device: str = "cuda", max_length: int = 512) -> HFEmbedder:
# max length 64, 128, 256 and 512 should work (if your sequence is short enough)
return HFEmbedder("xlabs-ai/xflux_text_encoders", max_length=max_length, torch_dtype=torch.bfloat16).to(device)
def load_clip(device: str = "cuda") -> HFEmbedder:
return HFEmbedder("openai/clip-vit-large-patch14", max_length=77, torch_dtype=torch.bfloat16).to(device)
def load_ae(name: str, device: str = "cuda", hf_download: bool = True) -> AutoEncoder:
ckpt_path = configs[name].ae_path
if (
not os.path.exists(ckpt_path)
and configs[name].repo_id is not None
and configs[name].repo_ae is not None
and hf_download
):
ckpt_path = hf_hub_download(configs[name].repo_id, configs[name].repo_ae, local_dir='models')
# Loading the autoencoder
print("Init AE")
with torch.device(device):
ae = AutoEncoder(configs[name].ae_params)
if ckpt_path is not None:
sd = load_sft(ckpt_path, device=str(device))
missing, unexpected = ae.load_state_dict(sd, strict=False)
print_load_warning(missing, unexpected)
return ae