flux-schnell-cfg / merge.py
miike-ai's picture
Add files using upload-large-folder tool
eeffd4d verified
raw
history blame
2.44 kB
from diffusers import FluxTransformer2DModel
from huggingface_hub import snapshot_download
from accelerate import init_empty_weights
from diffusers.models.model_loading_utils import load_model_dict_into_meta
import safetensors.torch
import glob
import torch
from tqdm import tqdm
import os
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
# Initialize empty weights for memory efficiency
with init_empty_weights():
config = FluxTransformer2DModel.load_config("ostris/OpenFLUX.1", subfolder="transformer")
model = FluxTransformer2DModel.from_config(config)
# Download checkpoints with progress tracking
print("Downloading dev checkpoint...")
dev_ckpt = snapshot_download(repo_id="ostris/OpenFLUX.1", allow_patterns="transformer/*")
print("Downloading schnell checkpoint...")
schnell_ckpt = snapshot_download(repo_id="black-forest-labs/FLUX.1-schnell", allow_patterns="transformer/*")
# Get shard file paths
dev_shards = sorted(glob.glob(f"{dev_ckpt}/transformer/*.safetensors"))
schnell_shards = sorted(glob.glob(f"{schnell_ckpt}/transformer/*.safetensors"))
if not dev_shards or not schnell_shards:
raise ValueError("No shard files found. Ensure checkpoints are downloaded correctly.")
merged_state_dict = {}
guidance_state_dict = {}
# Merge weights with progress tracking
print("Merging weights...")
for i in tqdm(range(len(dev_shards)), desc="Merging shards", unit="shard"):
state_dict_dev_temp = safetensors.torch.load_file(dev_shards[i])
state_dict_schnell_temp = safetensors.torch.load_file(schnell_shards[i])
keys = list(state_dict_dev_temp.keys())
for k in keys:
if "guidance" not in k:
merged_state_dict[k] = (state_dict_dev_temp.pop(k) + state_dict_schnell_temp.pop(k)) / 2
else:
guidance_state_dict[k] = state_dict_dev_temp.pop(k)
if state_dict_dev_temp:
raise ValueError(f"Unprocessed keys in dev shard: {list(state_dict_dev_temp.keys())}")
if state_dict_schnell_temp:
raise ValueError(f"Unprocessed keys in schnell shard: {list(state_dict_schnell_temp.keys())}")
# Update with guidance weights
merged_state_dict.update(guidance_state_dict)
# Load merged weights into the model
print("Loading weights into the model...")
load_model_dict_into_meta(model, merged_state_dict)
# Save the model
print("Saving the merged model...")
model.to(torch.bfloat16).save_pretrained("merlin")
print("Merged model saved successfully at 'merlin'.")