|
from diffusers import FluxTransformer2DModel |
|
from huggingface_hub import snapshot_download |
|
from accelerate import init_empty_weights |
|
from diffusers.models.model_loading_utils import load_model_dict_into_meta |
|
import safetensors.torch |
|
import glob |
|
import torch |
|
from tqdm import tqdm |
|
import os |
|
|
|
|
|
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1" |
|
|
|
|
|
with init_empty_weights(): |
|
config = FluxTransformer2DModel.load_config("ostris/OpenFLUX.1", subfolder="transformer") |
|
model = FluxTransformer2DModel.from_config(config) |
|
|
|
|
|
print("Downloading dev checkpoint...") |
|
dev_ckpt = snapshot_download(repo_id="ostris/OpenFLUX.1", allow_patterns="transformer/*") |
|
print("Downloading schnell checkpoint...") |
|
schnell_ckpt = snapshot_download(repo_id="black-forest-labs/FLUX.1-schnell", allow_patterns="transformer/*") |
|
|
|
|
|
dev_shards = sorted(glob.glob(f"{dev_ckpt}/transformer/*.safetensors")) |
|
schnell_shards = sorted(glob.glob(f"{schnell_ckpt}/transformer/*.safetensors")) |
|
|
|
if not dev_shards or not schnell_shards: |
|
raise ValueError("No shard files found. Ensure checkpoints are downloaded correctly.") |
|
|
|
merged_state_dict = {} |
|
guidance_state_dict = {} |
|
|
|
|
|
print("Merging weights...") |
|
for i in tqdm(range(len(dev_shards)), desc="Merging shards", unit="shard"): |
|
state_dict_dev_temp = safetensors.torch.load_file(dev_shards[i]) |
|
state_dict_schnell_temp = safetensors.torch.load_file(schnell_shards[i]) |
|
|
|
keys = list(state_dict_dev_temp.keys()) |
|
for k in keys: |
|
if "guidance" not in k: |
|
merged_state_dict[k] = (state_dict_dev_temp.pop(k) + state_dict_schnell_temp.pop(k)) / 2 |
|
else: |
|
guidance_state_dict[k] = state_dict_dev_temp.pop(k) |
|
|
|
if state_dict_dev_temp: |
|
raise ValueError(f"Unprocessed keys in dev shard: {list(state_dict_dev_temp.keys())}") |
|
if state_dict_schnell_temp: |
|
raise ValueError(f"Unprocessed keys in schnell shard: {list(state_dict_schnell_temp.keys())}") |
|
|
|
|
|
merged_state_dict.update(guidance_state_dict) |
|
|
|
|
|
print("Loading weights into the model...") |
|
load_model_dict_into_meta(model, merged_state_dict) |
|
|
|
|
|
print("Saving the merged model...") |
|
model.to(torch.bfloat16).save_pretrained("merlin") |
|
print("Merged model saved successfully at 'merlin'.") |
|
|