import argparse from pathlib import Path from typing import Dict import safetensors.torch import torch import json import shutil def load_text_encoder(index_path: Path) -> Dict: with open(index_path, "r") as f: index: Dict = json.load(f) loaded_tensors = {} for part_file in set(index.get("weight_map", {}).values()): tensors = safetensors.torch.load_file( index_path.parent / part_file, device="cpu" ) for tensor_name in tensors: loaded_tensors[tensor_name] = tensors[tensor_name] return loaded_tensors def convert_unet(unet: Dict, add_prefix=True) -> Dict: if add_prefix: return {"model.diffusion_model." + key: value for key, value in unet.items()} return unet def convert_vae(vae_path: Path, add_prefix=True) -> Dict: state_dict = torch.load(vae_path / "autoencoder.pth", weights_only=True) stats_path = vae_path / "per_channel_statistics.json" if stats_path.exists(): with open(stats_path, "r") as f: data = json.load(f) transposed_data = list(zip(*data["data"])) data_dict = { f"{'vae.' if add_prefix else ''}per_channel_statistics.{col}": torch.tensor( vals ) for col, vals in zip(data["columns"], transposed_data) } else: data_dict = {} result = { ("vae." if add_prefix else "") + key: value for key, value in state_dict.items() } result.update(data_dict) return result def convert_encoder(encoder: Dict) -> Dict: return { "text_encoders.t5xxl.transformer." + key: value for key, value in encoder.items() } def save_config(config_src: str, config_dst: str): shutil.copy(config_src, config_dst) def load_vae_config(vae_path: Path) -> str: config_path = vae_path / "config.json" if not config_path.exists(): raise FileNotFoundError(f"VAE config file {config_path} not found.") return str(config_path) def main( unet_path: str, vae_path: str, out_path: str, mode: str, unet_config_path: str = None, scheduler_config_path: str = None, ) -> None: unet = convert_unet( torch.load(unet_path, weights_only=True), add_prefix=(mode == "single") ) # Load VAE from directory and config vae = convert_vae(Path(vae_path), add_prefix=(mode == "single")) vae_config_path = load_vae_config(Path(vae_path)) if mode == "single": result = {**unet, **vae} safetensors.torch.save_file(result, out_path) elif mode == "separate": # Create directories for unet, vae, and scheduler unet_dir = Path(out_path) / "unet" vae_dir = Path(out_path) / "vae" scheduler_dir = Path(out_path) / "scheduler" unet_dir.mkdir(parents=True, exist_ok=True) vae_dir.mkdir(parents=True, exist_ok=True) scheduler_dir.mkdir(parents=True, exist_ok=True) # Save unet and vae safetensors with the name diffusion_pytorch_model.safetensors safetensors.torch.save_file( unet, unet_dir / "unet_diffusion_pytorch_model.safetensors" ) safetensors.torch.save_file( vae, vae_dir / "vae_diffusion_pytorch_model.safetensors" ) # Save config files for unet, vae, and scheduler if unet_config_path: save_config(unet_config_path, unet_dir / "config.json") if vae_config_path: save_config(vae_config_path, vae_dir / "config.json") if scheduler_config_path: save_config(scheduler_config_path, scheduler_dir / "scheduler_config.json") if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--unet_path", "-u", type=str, default="unet/ema-002.pt") parser.add_argument("--vae_path", "-v", type=str, default="vae/") parser.add_argument("--out_path", "-o", type=str, default="xora.safetensors") parser.add_argument( "--mode", "-m", type=str, choices=["single", "separate"], default="single", help="Choose 'single' for the original behavior, 'separate' to save unet and vae separately.", ) parser.add_argument( "--unet_config_path", type=str, help="Path to the UNet config file (for separate mode)", ) parser.add_argument( "--scheduler_config_path", type=str, help="Path to the Scheduler config file (for separate mode)", ) args = parser.parse_args() main(**args.__dict__)