import os import json from safetensors import safe_open from safetensors.torch import save_file def merge_safetensors(input_dir, output_file, config_file): # Dictionary to store all tensors merged_tensors = {} # Load config with open(config_file, 'r') as f: config = json.load(f) # Prepare metadata metadata = { "format": "pt", "total_size": "", #str(total_size), # Notice we stringify this! "_diffusers_version": config.get("_diffusers_version", ""), "_class_name": config.get("_class_name", ""), # Add other fields at this level } total_size = 0 # Iterate through all files in the input directory for filename in os.listdir(input_dir): if filename.endswith('.safetensors'): file_path = os.path.join(input_dir, filename) # Load tensors and metadata from each file with safe_open(file_path, framework="pt", device="cpu") as f: file_metadata = f.metadata() if file_metadata and "__metadata__" in file_metadata: total_size += int(file_metadata["__metadata__"].get("total_size", 0)) for key in f.keys(): tensor = f.get_tensor(key) merged_tensors[key] = tensor # Add total size to metadata metadata["total_size"] = str(total_size) # Save the merged tensors to a single file with metadata save_file(merged_tensors, output_file, metadata) input_directory = './10_1' output_file = './10_1/flux1-merge-S10_D1.safetensors' config_file = './10_1/config.json' merge_safetensors(input_directory, output_file, config_file)