|
import json |
|
from pathlib import Path |
|
from typing import Optional |
|
|
|
import torch |
|
from modules.autoencoder import AutoEncoder, AutoEncoderParams |
|
from modules.conditioner import HFEmbedder |
|
from modules.flux_model import Flux, FluxParams |
|
|
|
from safetensors.torch import load_file as load_sft |
|
from enum import StrEnum |
|
from pydantic import BaseModel, ConfigDict |
|
from loguru import logger |
|
|
|
|
|
class ModelVersion(StrEnum): |
|
flux_dev = "flux-dev" |
|
flux_schnell = "flux-schnell" |
|
|
|
|
|
class ModelSpec(BaseModel): |
|
version: ModelVersion |
|
params: FluxParams |
|
ae_params: AutoEncoderParams |
|
ckpt_path: str | None |
|
ae_path: str | None |
|
repo_id: str | None |
|
repo_flow: str | None |
|
repo_ae: str | None |
|
text_enc_max_length: int = 512 |
|
text_enc_path: str | None |
|
text_enc_device: str | torch.device | None = "cuda:0" |
|
ae_device: str | torch.device | None = "cuda:0" |
|
flux_device: str | torch.device | None = "cuda:0" |
|
flow_dtype: str = "float16" |
|
ae_dtype: str = "bfloat16" |
|
text_enc_dtype: str = "bfloat16" |
|
num_to_quant: Optional[int] = 20 |
|
|
|
model_config: ConfigDict = { |
|
"arbitrary_types_allowed": True, |
|
"use_enum_values": True, |
|
} |
|
|
|
|
|
def load_models(config: ModelSpec) -> tuple[Flux, AutoEncoder, HFEmbedder, HFEmbedder]: |
|
flow = load_flow_model(config) |
|
ae = load_autoencoder(config) |
|
clip, t5 = load_text_encoders(config) |
|
return flow, ae, clip, t5 |
|
|
|
|
|
def parse_device(device: str | torch.device | None) -> torch.device: |
|
if isinstance(device, str): |
|
return torch.device(device) |
|
elif isinstance(device, torch.device): |
|
return device |
|
else: |
|
return torch.device("cuda:0") |
|
|
|
|
|
def into_dtype(dtype: str) -> torch.dtype: |
|
if dtype == "float16": |
|
return torch.float16 |
|
elif dtype == "bfloat16": |
|
return torch.bfloat16 |
|
elif dtype == "float32": |
|
return torch.float32 |
|
else: |
|
raise ValueError(f"Invalid dtype: {dtype}") |
|
|
|
|
|
def into_device(device: str | torch.device | None) -> torch.device: |
|
if isinstance(device, str): |
|
return torch.device(device) |
|
elif isinstance(device, torch.device): |
|
return device |
|
elif isinstance(device, int): |
|
return torch.device(f"cuda:{device}") |
|
else: |
|
return torch.device("cuda:0") |
|
|
|
|
|
def load_config( |
|
name: ModelVersion = ModelVersion.flux_dev, |
|
flux_path: str | None = None, |
|
ae_path: str | None = None, |
|
text_enc_path: str | None = None, |
|
text_enc_device: str | torch.device | None = None, |
|
ae_device: str | torch.device | None = None, |
|
flux_device: str | torch.device | None = None, |
|
flow_dtype: str = "float16", |
|
ae_dtype: str = "bfloat16", |
|
text_enc_dtype: str = "bfloat16", |
|
num_to_quant: Optional[int] = 20, |
|
): |
|
text_enc_device = str(parse_device(text_enc_device)) |
|
ae_device = str(parse_device(ae_device)) |
|
flux_device = str(parse_device(flux_device)) |
|
return ModelSpec( |
|
version=name, |
|
repo_id=( |
|
"black-forest-labs/FLUX.1-dev" |
|
if name == ModelVersion.flux_dev |
|
else "black-forest-labs/FLUX.1-schnell" |
|
), |
|
repo_flow=( |
|
"flux1-dev.sft" if name == ModelVersion.flux_dev else "flux1-schnell.sft" |
|
), |
|
repo_ae="ae.sft", |
|
ckpt_path=flux_path, |
|
params=FluxParams( |
|
in_channels=64, |
|
vec_in_dim=768, |
|
context_in_dim=4096, |
|
hidden_size=3072, |
|
mlp_ratio=4.0, |
|
num_heads=24, |
|
depth=19, |
|
depth_single_blocks=38, |
|
axes_dim=[16, 56, 56], |
|
theta=10_000, |
|
qkv_bias=True, |
|
guidance_embed=True, |
|
), |
|
ae_path=ae_path, |
|
ae_params=AutoEncoderParams( |
|
resolution=256, |
|
in_channels=3, |
|
ch=128, |
|
out_ch=3, |
|
ch_mult=[1, 2, 4, 4], |
|
num_res_blocks=2, |
|
z_channels=16, |
|
scale_factor=0.3611, |
|
shift_factor=0.1159, |
|
), |
|
text_enc_path=text_enc_path, |
|
text_enc_device=text_enc_device, |
|
ae_device=ae_device, |
|
flux_device=flux_device, |
|
flow_dtype=flow_dtype, |
|
ae_dtype=ae_dtype, |
|
text_enc_dtype=text_enc_dtype, |
|
text_enc_max_length=512 if name == ModelVersion.flux_dev else 256, |
|
num_to_quant=num_to_quant, |
|
) |
|
|
|
|
|
def load_config_from_path(path: str) -> ModelSpec: |
|
path_path = Path(path) |
|
if not path_path.exists(): |
|
raise ValueError(f"Path {path} does not exist") |
|
if not path_path.is_file(): |
|
raise ValueError(f"Path {path} is not a file") |
|
return ModelSpec(**json.loads(path_path.read_text())) |
|
|
|
|
|
def print_load_warning(missing: list[str], unexpected: list[str]) -> None: |
|
if len(missing) > 0 and len(unexpected) > 0: |
|
logger.warning(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing)) |
|
logger.warning("\n" + "-" * 79 + "\n") |
|
logger.warning( |
|
f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected) |
|
) |
|
elif len(missing) > 0: |
|
logger.warning(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing)) |
|
elif len(unexpected) > 0: |
|
logger.warning( |
|
f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected) |
|
) |
|
|
|
|
|
def load_flow_model(config: ModelSpec) -> Flux: |
|
ckpt_path = config.ckpt_path |
|
|
|
with torch.device("meta"): |
|
model = Flux(config.params, dtype=into_dtype(config.flow_dtype)).type( |
|
into_dtype(config.flow_dtype) |
|
) |
|
|
|
if ckpt_path is not None: |
|
|
|
sd = load_sft(ckpt_path, device="cpu") |
|
missing, unexpected = model.load_state_dict(sd, strict=False, assign=True) |
|
print_load_warning(missing, unexpected) |
|
return model |
|
|
|
|
|
def load_text_encoders(config: ModelSpec) -> tuple[HFEmbedder, HFEmbedder]: |
|
clip = HFEmbedder( |
|
"openai/clip-vit-large-patch14", |
|
max_length=77, |
|
torch_dtype=into_dtype(config.text_enc_dtype), |
|
device=into_device(config.text_enc_device), |
|
) |
|
t5 = HFEmbedder( |
|
config.text_enc_path, |
|
max_length=config.text_enc_max_length, |
|
torch_dtype=into_dtype(config.text_enc_dtype), |
|
device=into_device(config.text_enc_device).index or 0, |
|
) |
|
return clip, t5 |
|
|
|
|
|
def load_autoencoder(config: ModelSpec) -> AutoEncoder: |
|
ckpt_path = config.ae_path |
|
with torch.device("meta" if ckpt_path is not None else config.ae_device): |
|
ae = AutoEncoder(config.ae_params) |
|
|
|
if ckpt_path is not None: |
|
sd = load_sft(ckpt_path, device=str(config.ae_device)) |
|
missing, unexpected = ae.load_state_dict(sd, strict=False, assign=True) |
|
print_load_warning(missing, unexpected) |
|
return ae |
|
|
|
|
|
class LoadedModels(BaseModel): |
|
flow: Flux |
|
ae: AutoEncoder |
|
clip: HFEmbedder |
|
t5: HFEmbedder |
|
config: ModelSpec |
|
|
|
model_config = { |
|
"arbitrary_types_allowed": True, |
|
"use_enum_values": True, |
|
} |
|
|
|
|
|
def load_models_from_config_path( |
|
path: str, |
|
) -> LoadedModels: |
|
config = load_config_from_path(path) |
|
clip, t5 = load_text_encoders(config) |
|
return LoadedModels( |
|
flow=load_flow_model(config), |
|
ae=load_autoencoder(config), |
|
clip=clip, |
|
t5=t5, |
|
config=config, |
|
) |
|
|
|
|
|
def load_models_from_config(config: ModelSpec) -> LoadedModels: |
|
clip, t5 = load_text_encoders(config) |
|
return LoadedModels( |
|
flow=load_flow_model(config), |
|
ae=load_autoencoder(config), |
|
clip=clip, |
|
t5=t5, |
|
config=config, |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
p = "/big/generator-ui/flux-testing/flux/model-dir/flux1-dev.sft" |
|
ae_p = "/big/generator-ui/flux-testing/flux/model-dir/ae.sft" |
|
|
|
config = load_config( |
|
ModelVersion.flux_dev, |
|
flux_path=p, |
|
ae_path=ae_p, |
|
text_enc_path="city96/t5-v1_1-xxl-encoder-bf16", |
|
text_enc_device="cuda:0", |
|
ae_device="cuda:0", |
|
flux_device="cuda:0", |
|
flow_dtype="float16", |
|
ae_dtype="bfloat16", |
|
text_enc_dtype="bfloat16", |
|
num_to_quant=20, |
|
) |
|
with open("configs/config-dev-cuda0.json", "w") as f: |
|
json.dump(config.model_dump(), f, indent=2) |
|
print(config) |
|
|