from diffusers import FluxTransformer2DModel from huggingface_hub import snapshot_download from accelerate import init_empty_weights from diffusers.models.model_loading_utils import load_model_dict_into_meta import safetensors.torch import glob import torch from tqdm import tqdm import os os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1" # Initialize empty weights for memory efficiency with init_empty_weights(): config = FluxTransformer2DModel.load_config("ostris/OpenFLUX.1", subfolder="transformer") model = FluxTransformer2DModel.from_config(config) # Download checkpoints with progress tracking print("Downloading dev checkpoint...") dev_ckpt = snapshot_download(repo_id="ostris/OpenFLUX.1", allow_patterns="transformer/*") print("Downloading schnell checkpoint...") schnell_ckpt = snapshot_download(repo_id="black-forest-labs/FLUX.1-schnell", allow_patterns="transformer/*") # Get shard file paths dev_shards = sorted(glob.glob(f"{dev_ckpt}/transformer/*.safetensors")) schnell_shards = sorted(glob.glob(f"{schnell_ckpt}/transformer/*.safetensors")) if not dev_shards or not schnell_shards: raise ValueError("No shard files found. Ensure checkpoints are downloaded correctly.") merged_state_dict = {} guidance_state_dict = {} # Merge weights with progress tracking print("Merging weights...") for i in tqdm(range(len(dev_shards)), desc="Merging shards", unit="shard"): state_dict_dev_temp = safetensors.torch.load_file(dev_shards[i]) state_dict_schnell_temp = safetensors.torch.load_file(schnell_shards[i]) keys = list(state_dict_dev_temp.keys()) for k in keys: if "guidance" not in k: merged_state_dict[k] = (state_dict_dev_temp.pop(k) + state_dict_schnell_temp.pop(k)) / 2 else: guidance_state_dict[k] = state_dict_dev_temp.pop(k) if state_dict_dev_temp: raise ValueError(f"Unprocessed keys in dev shard: {list(state_dict_dev_temp.keys())}") if state_dict_schnell_temp: raise ValueError(f"Unprocessed keys in schnell shard: {list(state_dict_schnell_temp.keys())}") # Update with guidance weights merged_state_dict.update(guidance_state_dict) # Load merged weights into the model print("Loading weights into the model...") load_model_dict_into_meta(model, merged_state_dict) # Save the model print("Saving the merged model...") model.to(torch.bfloat16).save_pretrained("merlin") print("Merged model saved successfully at 'merlin'.")