import os import gc import glob from multiprocessing import Pool import time from tqdm import tqdm import torch from safetensors.torch import load_file from diffusers import FluxTransformer2DModel, FluxPipeline from huggingface_hub import snapshot_download from PIL import Image # Configuration DEVICE = torch.device("cpu") # If True, uses pipeline.enable_sequential_cpu_offload(). Make sure device is CPU. USE_CPU_OFFLOAD = True DTYPE = torch.bfloat16 NUM_WORKERS = 1 SEED = 0 IMAGE_WIDTH = 880 # 688 IMAGE_HEIGHT = 656 # 512 PROMPTS = [ "a tiny astronaut hatching from an egg on the moon", #"photo of a female cyberpunk hacker, plugged in and hacking, far future, neon lights" 'photo of a man on a beach holding a sign that says "Premature optimization is the root of all evil - test your shit!"' ] STEP_COUNTS = [4, 8, 16, 32, 50] MERGE_RATIOS = [ # (1, 0), (4, 1), (3, 1), (2, 1), (1, 1), (1, 2), (1, 3), (1, 4), (0, 1) (1, 0), (12, 1), (10, 1), (7, 1), (5.5, 1), (4, 1), (3.5, 1), (3, 1), (2.5, 1), (2, 1), (1.5, 1), (0, 1) ] MERGE_LABELS = [ # "Pure Schnell", "4:1", "3:1", "2:1", "1:1 Merge", "1:2", "1:3", "1:4", "Pure Dev" "Pure Schnell", "12:1", "10:1", "7:1", "5.5:1", "4:1", "3.5:1", "3:1", "2.5:1", "2:1", "1.5:1", "Pure Dev" ] assert len(MERGE_RATIOS) == len(MERGE_LABELS) # Output directories IMAGE_OUTPUT_DIR = "./outputs" MODEL_OUTPUT_DIR = "./merged_models" SAVE_MODELS = False os.makedirs(IMAGE_OUTPUT_DIR, exist_ok=True) # Utility function for cleanup def cleanup(): gc.collect() torch.cuda.empty_cache() # Start timing start_time = time.time() def merge_models(dev_shards, schnell_shards, ratio): schnell_weight, dev_weight = ratio total_weight = schnell_weight + dev_weight merged_state_dict = {} guidance_state_dict = {} for i in tqdm(range(len(dev_shards)), "Processing shards...", dynamic_ncols=True): state_dict_dev = load_file(dev_shards[i]) state_dict_schnell = load_file(schnell_shards[i]) keys = list(state_dict_dev.keys()) for k in tqdm(keys, f"\tProcessing keys of shard {i}...", dynamic_ncols=True): if "guidance" not in k: merged_state_dict[k] = ( state_dict_schnell[k] * schnell_weight + state_dict_dev[k] * dev_weight ) / total_weight else: guidance_state_dict[k] = state_dict_dev[k] merged_state_dict.update(guidance_state_dict) return merged_state_dict # Function to create merged model def create_merged_model(dev_ckpt, schnell_ckpt, ratio): config = FluxTransformer2DModel.load_config("black-forest-labs/FLUX.1-dev", subfolder="transformer") model = FluxTransformer2DModel.from_config(config) dev_shards = sorted(glob.glob(f"{dev_ckpt}/transformer/*.safetensors")) schnell_shards = sorted(glob.glob(f"{schnell_ckpt}/transformer/*.safetensors")) merged_state_dict = merge_models(dev_shards, schnell_shards, ratio) model.load_state_dict(merged_state_dict) del merged_state_dict cleanup() return model.to(DTYPE) def generate_image(pipeline, prompt, num_steps, output_path): if not os.path.exists(output_path): # Params: # prompt – The prompt or prompts to guide the image generation. If not defined, one has to pass prompt_embeds. instead. # prompt_2 – The prompt or prompts to be sent to tokenizer_2 and text_encoder_2. If not defined, prompt is will be used instead # height – The height in pixels of the generated image. This is set to 1024 by default for the best results. # width – The width in pixels of the generated image. This is set to 1024 by default for the best results. # num_inference_steps – The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. # timesteps – Custom timesteps to use for the denoising process with schedulers which support a timesteps argument in their set_timesteps method. If not defined, the default behavior when num_inference_steps is passed will be used. Must be in descending order. # guidance_scale – Guidance scale as defined in [Classifier-Free Diffusion Guidance](https:// arxiv. org/ abs/ 2207.12598 ). guidance_scale is defined as w of equation 2. of [Imagen Paper](https:// arxiv. org/ pdf/ 2205.11487.pdf ). Guidance scale is enabled by setting guidance_scale > 1. Higher guidance scale encourages to generate images that are closely linked to the text prompt, usually at the expense of lower image quality. # num_images_per_prompt – The number of images to generate per prompt. # generator – One or a list of [torch generator(s)](https:// pytorch. org/ docs/ stable/ generated/ torch. Generator. html ) to make generation deterministic. # latents – Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will ge generated by sampling using the supplied random generator. # prompt_embeds – Pre-generated text embeddings. Can be used to easily tweak text inputs, e. g. prompt weighting. If not provided, text embeddings will be generated from prompt input argument. # pooled_prompt_embeds – Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, e. g. prompt weighting. If not provided, pooled text embeddings will be generated from prompt input argument. # output_type – The output format of the generate image. Choose between [PIL](https:// pillow. readthedocs. io/ en/ stable/ ): PIL. Image. Image or np. array. # return_dict – Whether or not to return a [~pipelines. flux. FluxPipelineOutput] instead of a plain tuple. # joint_attention_kwargs – A kwargs dictionary that if specified is passed along to the AttentionProcessor as defined under self. processor in [diffusers. models. attention_processor](https:// github. com/ huggingface/ diffusers/ blob/ main/ src/ diffusers/ models/ attention_processor. py ). # callback_on_step_end – A function that calls at the end of each denoising steps during the inference. The function is called with the following arguments: callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict). callback_kwargs will include a list of all tensors as specified by callback_on_step_end_tensor_inputs. # callback_on_step_end_tensor_inputs – The list of tensor inputs for the callback_on_step_end function. The tensors specified in the list will be passed as callback_kwargs argument. You will only be able to include variables listed in the ._callback_tensor_inputs attribute of your pipeline class. # max_sequence_length – Maximum sequence length to use with the prompt. # Returns: # [~pipelines. flux. FluxPipelineOutput] if return_dict is True, otherwise a tuple. When returning a tuple, the first element is a list with the generated images. image = pipeline( prompt=prompt, guidance_scale=3.5, num_inference_steps=num_steps, height=IMAGE_HEIGHT, width=IMAGE_WIDTH, max_sequence_length=512, generator=torch.manual_seed(SEED), ).images[0] image.save(output_path) else: print("Image already exists, skipping...") def process_model(ratio, label, dev_ckpt, schnell_ckpt): image_output_dir = os.path.join(IMAGE_OUTPUT_DIR, label.replace(":", "_")) os.makedirs(image_output_dir, exist_ok=True) existing_images = len([name for name in os.listdir(image_output_dir) if os.path.isfile(os.path.join(image_output_dir, name))]) if existing_images == len(PROMPTS) * len(STEP_COUNTS): print(f"\nModel {label} already complete, skipping...") return else: print(f"\nProcessing {label} model...") if ratio == (1, 0): # Pure Schnell model = FluxTransformer2DModel.from_pretrained(schnell_ckpt, subfolder="transformer", torch_dtype=DTYPE) elif ratio == (0, 1): # Pure Dev model = FluxTransformer2DModel.save_pretrained().from_pretrained(dev_ckpt, subfolder="transformer", torch_dtype=DTYPE) else: model = create_merged_model(dev_ckpt, schnell_ckpt, ratio) if SAVE_MODELS: model_output_dir = os.path.join(MODEL_OUTPUT_DIR, label.replace(":", "_")) print(f"Saving model to {model_output_dir}...") model.save_pretrained(model_output_dir, max_shared_size="50GB", safe_serialization=True) pipeline = FluxPipeline.from_pretrained( dev_ckpt, transformer=model, torch_dtype=DTYPE, ).to(DEVICE) if USE_CPU_OFFLOAD: pipeline.enable_sequential_cpu_offload() #pipeline.enable_xformers_memory_efficient_attention() for prompt_idx, prompt in enumerate(PROMPTS): for step_count in STEP_COUNTS: output_path = os.path.join( image_output_dir, f"prompt{prompt_idx + 1}_steps{step_count}.png" ) generate_image(pipeline, prompt, step_count, output_path) del pipeline del model cleanup() def main(): dev_ckpt = snapshot_download(repo_id="black-forest-labs/FLUX.1-dev", ignore_patterns=["flux1-dev.sft","flux1-dev.safetensors"], local_dir="./models/dev/") schnell_ckpt = snapshot_download(repo_id="black-forest-labs/FLUX.1-schnell", allow_patterns="transformer/*", local_dir="./models/schnell/") with Pool(NUM_WORKERS) as pool: results = [ pool.apply_async( process_model, (ratio, label, dev_ckpt, schnell_ckpt) ) for ratio, label in zip(MERGE_RATIOS, MERGE_LABELS) ] for result in tqdm(results): result.get() # This will block until the result is ready pool.close() pool.join() def create_image_grid(image_paths, output_path, padding=10): width = IMAGE_WIDTH // 2 height = IMAGE_HEIGHT // 2 images = [Image.open(path).resize((width, height)) for path in image_paths] grid_cols = len(MERGE_RATIOS) grid_rows = len(STEP_COUNTS) top_pad = 250 left_pad = 200 grid_width = (width * grid_cols) + (padding * (grid_cols + 1)) + left_pad grid_height = (height * grid_rows) + (padding * (grid_rows + 1)) + top_pad grid_image = Image.new('RGB', (grid_width, grid_height), color=(255, 255, 255)) for idx, img in enumerate(images): row = idx // grid_cols col = idx % grid_cols x_position = (col * width) + (padding * (col + 1)) + left_pad y_position = (row * height) + (padding * (row + 1)) + top_pad grid_image.paste(img, (x_position, y_position)) grid_image.save(output_path) # Run the main process main() # Create the image grids print("Creating image comparison grid...") # Reconstruct the image paths all_image_paths = [ os.path.join( IMAGE_OUTPUT_DIR, label.replace(":", "_"), f"prompt{prompt_idx + 1}_steps{step_count}.png" ) for prompt_idx in range(len(PROMPTS)) for step_count in STEP_COUNTS for label in MERGE_LABELS ] missing_images = [path for path in all_image_paths if not os.path.exists(path)] if missing_images: print(f"Warning: {len(missing_images)} images were not generated:") for path in missing_images[:5]: # Show first 5 print(f" • {path}") if len(missing_images) > 5: print(f" (and {len(missing_images) - 5} more...)") # Create grid images for prompt_idx in range(len(PROMPTS)): prompt_images = [path for path in all_image_paths if f"prompt{prompt_idx + 1}" in path] grid_output_path = os.path.join(IMAGE_OUTPUT_DIR, f"grid_prompt{prompt_idx + 1}.png") create_image_grid(prompt_images, grid_output_path) # Final report end_time = time.time() total_time = end_time - start_time num_images = len(all_image_paths) print(f"\nProcessing complete!") print(f"Total time: {total_time:.2f} seconds") print(f"Total images generated: {num_images}") print(f"Average time per image: {total_time / num_images:.2f} seconds") print(f"Output directory: {IMAGE_OUTPUT_DIR}")