|
import os |
|
import gc |
|
import glob |
|
from multiprocessing import Pool |
|
import time |
|
from tqdm import tqdm |
|
import torch |
|
from safetensors.torch import load_file |
|
from diffusers import FluxTransformer2DModel, FluxPipeline |
|
from huggingface_hub import snapshot_download |
|
from PIL import Image |
|
|
|
|
|
DEVICE = torch.device("cpu") |
|
|
|
USE_CPU_OFFLOAD = True |
|
DTYPE = torch.bfloat16 |
|
NUM_WORKERS = 1 |
|
SEED = 0 |
|
IMAGE_WIDTH = 880 |
|
IMAGE_HEIGHT = 656 |
|
|
|
PROMPTS = [ |
|
"a tiny astronaut hatching from an egg on the moon", |
|
|
|
'photo of a man on a beach holding a sign that says "Premature optimization is the root of all evil - test your shit!"' |
|
] |
|
STEP_COUNTS = [4, 8, 16, 32, 50] |
|
MERGE_RATIOS = [ |
|
|
|
(1, 0), (12, 1), (10, 1), (7, 1), (5.5, 1), (4, 1), (3.5, 1), (3, 1), (2.5, 1), (2, 1), (1.5, 1), (0, 1) |
|
] |
|
MERGE_LABELS = [ |
|
|
|
"Pure Schnell", "12:1", "10:1", "7:1", "5.5:1", "4:1", "3.5:1", "3:1", "2.5:1", "2:1", "1.5:1", "Pure Dev" |
|
] |
|
assert len(MERGE_RATIOS) == len(MERGE_LABELS) |
|
|
|
|
|
IMAGE_OUTPUT_DIR = "./outputs" |
|
MODEL_OUTPUT_DIR = "./merged_models" |
|
SAVE_MODELS = False |
|
os.makedirs(IMAGE_OUTPUT_DIR, exist_ok=True) |
|
|
|
|
|
|
|
def cleanup(): |
|
gc.collect() |
|
torch.cuda.empty_cache() |
|
|
|
|
|
|
|
start_time = time.time() |
|
|
|
|
|
def merge_models(dev_shards, schnell_shards, ratio): |
|
schnell_weight, dev_weight = ratio |
|
total_weight = schnell_weight + dev_weight |
|
|
|
merged_state_dict = {} |
|
guidance_state_dict = {} |
|
|
|
for i in tqdm(range(len(dev_shards)), "Processing shards...", dynamic_ncols=True): |
|
state_dict_dev = load_file(dev_shards[i]) |
|
state_dict_schnell = load_file(schnell_shards[i]) |
|
|
|
keys = list(state_dict_dev.keys()) |
|
for k in tqdm(keys, f"\tProcessing keys of shard {i}...", dynamic_ncols=True): |
|
if "guidance" not in k: |
|
merged_state_dict[k] = ( |
|
state_dict_schnell[k] * schnell_weight + |
|
state_dict_dev[k] * dev_weight |
|
) / total_weight |
|
else: |
|
guidance_state_dict[k] = state_dict_dev[k] |
|
|
|
merged_state_dict.update(guidance_state_dict) |
|
return merged_state_dict |
|
|
|
|
|
|
|
def create_merged_model(dev_ckpt, schnell_ckpt, ratio): |
|
config = FluxTransformer2DModel.load_config("black-forest-labs/FLUX.1-dev", subfolder="transformer") |
|
model = FluxTransformer2DModel.from_config(config) |
|
|
|
dev_shards = sorted(glob.glob(f"{dev_ckpt}/transformer/*.safetensors")) |
|
schnell_shards = sorted(glob.glob(f"{schnell_ckpt}/transformer/*.safetensors")) |
|
|
|
merged_state_dict = merge_models(dev_shards, schnell_shards, ratio) |
|
model.load_state_dict(merged_state_dict) |
|
del merged_state_dict |
|
cleanup() |
|
|
|
return model.to(DTYPE) |
|
|
|
|
|
def generate_image(pipeline, prompt, num_steps, output_path): |
|
if not os.path.exists(output_path): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
image = pipeline( |
|
prompt=prompt, |
|
guidance_scale=3.5, |
|
num_inference_steps=num_steps, |
|
height=IMAGE_HEIGHT, |
|
width=IMAGE_WIDTH, |
|
max_sequence_length=512, |
|
generator=torch.manual_seed(SEED), |
|
).images[0] |
|
image.save(output_path) |
|
else: |
|
print("Image already exists, skipping...") |
|
|
|
|
|
def process_model(ratio, label, dev_ckpt, schnell_ckpt): |
|
image_output_dir = os.path.join(IMAGE_OUTPUT_DIR, label.replace(":", "_")) |
|
os.makedirs(image_output_dir, exist_ok=True) |
|
existing_images = len([name for name in os.listdir(image_output_dir) if os.path.isfile(os.path.join(image_output_dir, name))]) |
|
if existing_images == len(PROMPTS) * len(STEP_COUNTS): |
|
print(f"\nModel {label} already complete, skipping...") |
|
return |
|
else: |
|
print(f"\nProcessing {label} model...") |
|
|
|
if ratio == (1, 0): |
|
model = FluxTransformer2DModel.from_pretrained(schnell_ckpt, subfolder="transformer", torch_dtype=DTYPE) |
|
elif ratio == (0, 1): |
|
model = FluxTransformer2DModel.save_pretrained().from_pretrained(dev_ckpt, subfolder="transformer", torch_dtype=DTYPE) |
|
else: |
|
model = create_merged_model(dev_ckpt, schnell_ckpt, ratio) |
|
|
|
if SAVE_MODELS: |
|
model_output_dir = os.path.join(MODEL_OUTPUT_DIR, label.replace(":", "_")) |
|
print(f"Saving model to {model_output_dir}...") |
|
model.save_pretrained(model_output_dir, max_shared_size="50GB", safe_serialization=True) |
|
|
|
pipeline = FluxPipeline.from_pretrained( |
|
dev_ckpt, |
|
transformer=model, |
|
torch_dtype=DTYPE, |
|
).to(DEVICE) |
|
if USE_CPU_OFFLOAD: |
|
pipeline.enable_sequential_cpu_offload() |
|
|
|
|
|
for prompt_idx, prompt in enumerate(PROMPTS): |
|
for step_count in STEP_COUNTS: |
|
output_path = os.path.join( |
|
image_output_dir, |
|
f"prompt{prompt_idx + 1}_steps{step_count}.png" |
|
) |
|
generate_image(pipeline, prompt, step_count, output_path) |
|
|
|
del pipeline |
|
del model |
|
cleanup() |
|
|
|
|
|
def main(): |
|
dev_ckpt = snapshot_download(repo_id="black-forest-labs/FLUX.1-dev", ignore_patterns=["flux1-dev.sft","flux1-dev.safetensors"], |
|
local_dir="./models/dev/") |
|
schnell_ckpt = snapshot_download(repo_id="black-forest-labs/FLUX.1-schnell", allow_patterns="transformer/*", |
|
local_dir="./models/schnell/") |
|
|
|
with Pool(NUM_WORKERS) as pool: |
|
results = [ |
|
pool.apply_async( |
|
process_model, |
|
(ratio, label, dev_ckpt, schnell_ckpt) |
|
) |
|
for ratio, label in zip(MERGE_RATIOS, MERGE_LABELS) |
|
] |
|
|
|
for result in tqdm(results): |
|
result.get() |
|
|
|
pool.close() |
|
pool.join() |
|
|
|
|
|
def create_image_grid(image_paths, output_path, padding=10): |
|
width = IMAGE_WIDTH // 2 |
|
height = IMAGE_HEIGHT // 2 |
|
images = [Image.open(path).resize((width, height)) for path in image_paths] |
|
|
|
grid_cols = len(MERGE_RATIOS) |
|
grid_rows = len(STEP_COUNTS) |
|
top_pad = 250 |
|
left_pad = 200 |
|
grid_width = (width * grid_cols) + (padding * (grid_cols + 1)) + left_pad |
|
grid_height = (height * grid_rows) + (padding * (grid_rows + 1)) + top_pad |
|
|
|
grid_image = Image.new('RGB', (grid_width, grid_height), color=(255, 255, 255)) |
|
|
|
for idx, img in enumerate(images): |
|
row = idx // grid_cols |
|
col = idx % grid_cols |
|
x_position = (col * width) + (padding * (col + 1)) + left_pad |
|
y_position = (row * height) + (padding * (row + 1)) + top_pad |
|
grid_image.paste(img, (x_position, y_position)) |
|
|
|
grid_image.save(output_path) |
|
|
|
|
|
|
|
main() |
|
|
|
|
|
print("Creating image comparison grid...") |
|
|
|
all_image_paths = [ |
|
os.path.join( |
|
IMAGE_OUTPUT_DIR, |
|
label.replace(":", "_"), |
|
f"prompt{prompt_idx + 1}_steps{step_count}.png" |
|
) |
|
for prompt_idx in range(len(PROMPTS)) |
|
for step_count in STEP_COUNTS |
|
for label in MERGE_LABELS |
|
] |
|
missing_images = [path for path in all_image_paths if not os.path.exists(path)] |
|
if missing_images: |
|
print(f"Warning: {len(missing_images)} images were not generated:") |
|
for path in missing_images[:5]: |
|
print(f" β’ {path}") |
|
if len(missing_images) > 5: |
|
print(f" (and {len(missing_images) - 5} more...)") |
|
|
|
|
|
for prompt_idx in range(len(PROMPTS)): |
|
prompt_images = [path for path in all_image_paths if f"prompt{prompt_idx + 1}" in path] |
|
grid_output_path = os.path.join(IMAGE_OUTPUT_DIR, f"grid_prompt{prompt_idx + 1}.png") |
|
create_image_grid(prompt_images, grid_output_path) |
|
|
|
|
|
end_time = time.time() |
|
total_time = end_time - start_time |
|
num_images = len(all_image_paths) |
|
|
|
print(f"\nProcessing complete!") |
|
print(f"Total time: {total_time:.2f} seconds") |
|
print(f"Total images generated: {num_images}") |
|
print(f"Average time per image: {total_time / num_images:.2f} seconds") |
|
print(f"Output directory: {IMAGE_OUTPUT_DIR}") |
|
|