FLUX.1-Merges / merge_compare.py
HaileyStorm's picture
Upload 2 files
d4b964f verified
raw
history blame
12.5 kB
import os
import gc
import glob
from multiprocessing import Pool
import time
from tqdm import tqdm
import torch
from safetensors.torch import load_file
from diffusers import FluxTransformer2DModel, FluxPipeline
from huggingface_hub import snapshot_download
from PIL import Image
# Configuration
DEVICE = torch.device("cpu")
# If True, uses pipeline.enable_sequential_cpu_offload(). Make sure device is CPU.
USE_CPU_OFFLOAD = True
DTYPE = torch.bfloat16
NUM_WORKERS = 1
SEED = 0
IMAGE_WIDTH = 880 # 688
IMAGE_HEIGHT = 656 # 512
PROMPTS = [
"a tiny astronaut hatching from an egg on the moon",
#"photo of a female cyberpunk hacker, plugged in and hacking, far future, neon lights"
'photo of a man on a beach holding a sign that says "Premature optimization is the root of all evil - test your shit!"'
]
STEP_COUNTS = [4, 8, 16, 32, 50]
MERGE_RATIOS = [
# (1, 0), (4, 1), (3, 1), (2, 1), (1, 1), (1, 2), (1, 3), (1, 4), (0, 1)
(1, 0), (12, 1), (10, 1), (7, 1), (5.5, 1), (4, 1), (3.5, 1), (3, 1), (2.5, 1), (2, 1), (1.5, 1), (0, 1)
]
MERGE_LABELS = [
# "Pure Schnell", "4:1", "3:1", "2:1", "1:1 Merge", "1:2", "1:3", "1:4", "Pure Dev"
"Pure Schnell", "12:1", "10:1", "7:1", "5.5:1", "4:1", "3.5:1", "3:1", "2.5:1", "2:1", "1.5:1", "Pure Dev"
]
assert len(MERGE_RATIOS) == len(MERGE_LABELS)
# Output directories
IMAGE_OUTPUT_DIR = "./outputs"
MODEL_OUTPUT_DIR = "./merged_models"
SAVE_MODELS = False
os.makedirs(IMAGE_OUTPUT_DIR, exist_ok=True)
# Utility function for cleanup
def cleanup():
gc.collect()
torch.cuda.empty_cache()
# Start timing
start_time = time.time()
def merge_models(dev_shards, schnell_shards, ratio):
schnell_weight, dev_weight = ratio
total_weight = schnell_weight + dev_weight
merged_state_dict = {}
guidance_state_dict = {}
for i in tqdm(range(len(dev_shards)), "Processing shards...", dynamic_ncols=True):
state_dict_dev = load_file(dev_shards[i])
state_dict_schnell = load_file(schnell_shards[i])
keys = list(state_dict_dev.keys())
for k in tqdm(keys, f"\tProcessing keys of shard {i}...", dynamic_ncols=True):
if "guidance" not in k:
merged_state_dict[k] = (
state_dict_schnell[k] * schnell_weight +
state_dict_dev[k] * dev_weight
) / total_weight
else:
guidance_state_dict[k] = state_dict_dev[k]
merged_state_dict.update(guidance_state_dict)
return merged_state_dict
# Function to create merged model
def create_merged_model(dev_ckpt, schnell_ckpt, ratio):
config = FluxTransformer2DModel.load_config("black-forest-labs/FLUX.1-dev", subfolder="transformer")
model = FluxTransformer2DModel.from_config(config)
dev_shards = sorted(glob.glob(f"{dev_ckpt}/transformer/*.safetensors"))
schnell_shards = sorted(glob.glob(f"{schnell_ckpt}/transformer/*.safetensors"))
merged_state_dict = merge_models(dev_shards, schnell_shards, ratio)
model.load_state_dict(merged_state_dict)
del merged_state_dict
cleanup()
return model.to(DTYPE)
def generate_image(pipeline, prompt, num_steps, output_path):
if not os.path.exists(output_path):
# Params:
# prompt – The prompt or prompts to guide the image generation. If not defined, one has to pass prompt_embeds. instead.
# prompt_2 – The prompt or prompts to be sent to tokenizer_2 and text_encoder_2. If not defined, prompt is will be used instead
# height – The height in pixels of the generated image. This is set to 1024 by default for the best results.
# width – The width in pixels of the generated image. This is set to 1024 by default for the best results.
# num_inference_steps – The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference.
# timesteps – Custom timesteps to use for the denoising process with schedulers which support a timesteps argument in their set_timesteps method. If not defined, the default behavior when num_inference_steps is passed will be used. Must be in descending order.
# guidance_scale – Guidance scale as defined in [Classifier-Free Diffusion Guidance](https:// arxiv. org/ abs/ 2207.12598 ). guidance_scale is defined as w of equation 2. of [Imagen Paper](https:// arxiv. org/ pdf/ 2205.11487.pdf ). Guidance scale is enabled by setting guidance_scale > 1. Higher guidance scale encourages to generate images that are closely linked to the text prompt, usually at the expense of lower image quality.
# num_images_per_prompt – The number of images to generate per prompt.
# generator – One or a list of [torch generator(s)](https:// pytorch. org/ docs/ stable/ generated/ torch. Generator. html ) to make generation deterministic.
# latents – Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will ge generated by sampling using the supplied random generator.
# prompt_embeds – Pre-generated text embeddings. Can be used to easily tweak text inputs, e. g. prompt weighting. If not provided, text embeddings will be generated from prompt input argument.
# pooled_prompt_embeds – Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, e. g. prompt weighting. If not provided, pooled text embeddings will be generated from prompt input argument.
# output_type – The output format of the generate image. Choose between [PIL](https:// pillow. readthedocs. io/ en/ stable/ ): PIL. Image. Image or np. array.
# return_dict – Whether or not to return a [~pipelines. flux. FluxPipelineOutput] instead of a plain tuple.
# joint_attention_kwargs – A kwargs dictionary that if specified is passed along to the AttentionProcessor as defined under self. processor in [diffusers. models. attention_processor](https:// github. com/ huggingface/ diffusers/ blob/ main/ src/ diffusers/ models/ attention_processor. py ).
# callback_on_step_end – A function that calls at the end of each denoising steps during the inference. The function is called with the following arguments: callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict). callback_kwargs will include a list of all tensors as specified by callback_on_step_end_tensor_inputs.
# callback_on_step_end_tensor_inputs – The list of tensor inputs for the callback_on_step_end function. The tensors specified in the list will be passed as callback_kwargs argument. You will only be able to include variables listed in the ._callback_tensor_inputs attribute of your pipeline class.
# max_sequence_length – Maximum sequence length to use with the prompt.
# Returns:
# [~pipelines. flux. FluxPipelineOutput] if return_dict is True, otherwise a tuple. When returning a tuple, the first element is a list with the generated images.
image = pipeline(
prompt=prompt,
guidance_scale=3.5,
num_inference_steps=num_steps,
height=IMAGE_HEIGHT,
width=IMAGE_WIDTH,
max_sequence_length=512,
generator=torch.manual_seed(SEED),
).images[0]
image.save(output_path)
else:
print("Image already exists, skipping...")
def process_model(ratio, label, dev_ckpt, schnell_ckpt):
image_output_dir = os.path.join(IMAGE_OUTPUT_DIR, label.replace(":", "_"))
os.makedirs(image_output_dir, exist_ok=True)
existing_images = len([name for name in os.listdir(image_output_dir) if os.path.isfile(os.path.join(image_output_dir, name))])
if existing_images == len(PROMPTS) * len(STEP_COUNTS):
print(f"\nModel {label} already complete, skipping...")
return
else:
print(f"\nProcessing {label} model...")
if ratio == (1, 0): # Pure Schnell
model = FluxTransformer2DModel.from_pretrained(schnell_ckpt, subfolder="transformer", torch_dtype=DTYPE)
elif ratio == (0, 1): # Pure Dev
model = FluxTransformer2DModel.save_pretrained().from_pretrained(dev_ckpt, subfolder="transformer", torch_dtype=DTYPE)
else:
model = create_merged_model(dev_ckpt, schnell_ckpt, ratio)
if SAVE_MODELS:
model_output_dir = os.path.join(MODEL_OUTPUT_DIR, label.replace(":", "_"))
print(f"Saving model to {model_output_dir}...")
model.save_pretrained(model_output_dir, max_shared_size="50GB", safe_serialization=True)
pipeline = FluxPipeline.from_pretrained(
dev_ckpt,
transformer=model,
torch_dtype=DTYPE,
).to(DEVICE)
if USE_CPU_OFFLOAD:
pipeline.enable_sequential_cpu_offload()
#pipeline.enable_xformers_memory_efficient_attention()
for prompt_idx, prompt in enumerate(PROMPTS):
for step_count in STEP_COUNTS:
output_path = os.path.join(
image_output_dir,
f"prompt{prompt_idx + 1}_steps{step_count}.png"
)
generate_image(pipeline, prompt, step_count, output_path)
del pipeline
del model
cleanup()
def main():
dev_ckpt = snapshot_download(repo_id="black-forest-labs/FLUX.1-dev", ignore_patterns=["flux1-dev.sft","flux1-dev.safetensors"],
local_dir="./models/dev/")
schnell_ckpt = snapshot_download(repo_id="black-forest-labs/FLUX.1-schnell", allow_patterns="transformer/*",
local_dir="./models/schnell/")
with Pool(NUM_WORKERS) as pool:
results = [
pool.apply_async(
process_model,
(ratio, label, dev_ckpt, schnell_ckpt)
)
for ratio, label in zip(MERGE_RATIOS, MERGE_LABELS)
]
for result in tqdm(results):
result.get() # This will block until the result is ready
pool.close()
pool.join()
def create_image_grid(image_paths, output_path, padding=10):
width = IMAGE_WIDTH // 2
height = IMAGE_HEIGHT // 2
images = [Image.open(path).resize((width, height)) for path in image_paths]
grid_cols = len(MERGE_RATIOS)
grid_rows = len(STEP_COUNTS)
top_pad = 250
left_pad = 200
grid_width = (width * grid_cols) + (padding * (grid_cols + 1)) + left_pad
grid_height = (height * grid_rows) + (padding * (grid_rows + 1)) + top_pad
grid_image = Image.new('RGB', (grid_width, grid_height), color=(255, 255, 255))
for idx, img in enumerate(images):
row = idx // grid_cols
col = idx % grid_cols
x_position = (col * width) + (padding * (col + 1)) + left_pad
y_position = (row * height) + (padding * (row + 1)) + top_pad
grid_image.paste(img, (x_position, y_position))
grid_image.save(output_path)
# Run the main process
main()
# Create the image grids
print("Creating image comparison grid...")
# Reconstruct the image paths
all_image_paths = [
os.path.join(
IMAGE_OUTPUT_DIR,
label.replace(":", "_"),
f"prompt{prompt_idx + 1}_steps{step_count}.png"
)
for prompt_idx in range(len(PROMPTS))
for step_count in STEP_COUNTS
for label in MERGE_LABELS
]
missing_images = [path for path in all_image_paths if not os.path.exists(path)]
if missing_images:
print(f"Warning: {len(missing_images)} images were not generated:")
for path in missing_images[:5]: # Show first 5
print(f" β€’ {path}")
if len(missing_images) > 5:
print(f" (and {len(missing_images) - 5} more...)")
# Create grid images
for prompt_idx in range(len(PROMPTS)):
prompt_images = [path for path in all_image_paths if f"prompt{prompt_idx + 1}" in path]
grid_output_path = os.path.join(IMAGE_OUTPUT_DIR, f"grid_prompt{prompt_idx + 1}.png")
create_image_grid(prompt_images, grid_output_path)
# Final report
end_time = time.time()
total_time = end_time - start_time
num_images = len(all_image_paths)
print(f"\nProcessing complete!")
print(f"Total time: {total_time:.2f} seconds")
print(f"Total images generated: {num_images}")
print(f"Average time per image: {total_time / num_images:.2f} seconds")
print(f"Output directory: {IMAGE_OUTPUT_DIR}")