ReNO / app.py
lucaeyring
updated description
416974a
raw
history blame
18 kB
import torch
import gc
import gradio as gr
from main import setup, execute_task
from arguments import parse_args
import os
import shutil
import glob
import time
import threading
import argparse
def list_iter_images(save_dir):
# Specify only PNG images
image_extension = 'png'
# Create a list to store the image file paths
image_paths = []
# Use glob to find all PNG image files
all_images = glob.glob(os.path.join(save_dir, f'*.{image_extension}'))
# Filter out 'best_image.png'
image_paths = [img for img in all_images if os.path.basename(img) != 'best_image.png']
return image_paths
def clean_dir(save_dir):
# Check if the directory exists
if os.path.exists(save_dir):
# Check if the directory contains any files
if len(os.listdir(save_dir)) > 0:
# If it contains files, delete all files in the directory
for filename in os.listdir(save_dir):
file_path = os.path.join(save_dir, filename)
try:
if os.path.isfile(file_path) or os.path.islink(file_path):
os.unlink(file_path) # Remove file or symbolic link
elif os.path.isdir(file_path):
shutil.rmtree(file_path) # Remove directory and its contents
except Exception as e:
print(f"Failed to delete {file_path}. Reason: {e}")
print(f"All files in {save_dir} have been deleted.")
else:
print(f"{save_dir} exists but is empty.")
else:
print(f"{save_dir} does not exist.")
def start_over(gallery_state):
torch.cuda.empty_cache() # Free up cached memory
gc.collect()
if gallery_state is not None:
gallery_state = None
return gallery_state, None, None, gr.update(visible=False)
def setup_model(loaded_model_setup, prompt, model, seed, num_iterations, enable_hps, hps_w, enable_imagereward, imgrw_w, enable_pickscore, pcks_w, enable_clip, clip_w, learning_rate, progress=gr.Progress(track_tqdm=True)):
gr.Info(f"Loading {model} model ...")
if prompt is None or prompt == "":
raise gr.Error("You forgot to provide a prompt !")
print(f"LOADED_MODEL SETUP: {loaded_model_setup}")
"""Clear CUDA memory before starting the training."""
torch.cuda.empty_cache() # Free up cached memory
gc.collect()
# Set up arguments
args = parse_args()
args.task = "single"
args.prompt = prompt
args.model = model
args.seed = seed
args.n_iters = num_iterations
args.lr = learning_rate
args.cache_dir = "./HF_model_cache"
args.save_dir = "./outputs"
args.save_all_images = True
if enable_hps is True:
args.enable_hps = True
args.hps_weighting = hps_w
else:
args.enable_hps = False
if enable_imagereward is True:
args.enable_imagereward = True
args.imagereward_weighting = imgrw_w
else:
args.enable_imagereward = False
if enable_pickscore is True:
args.enable_pickscore = True
args.pickscore_weighting = pcks_w
else:
args.enable_pickscore = False
if enable_clip is True:
args.enable_clip = True
args.clip_weighting = clip_w
else:
args.enable_clip = False
if model == "flux":
args.cpu_offloading = True
args.enable_multi_apply = True
args.multi_step_model = "flux"
if model == "hyper-sd":
args.cpu_offloading = True
# Check if args are the same as the loaded_model_setup except for the prompt
if loaded_model_setup and hasattr(loaded_model_setup[0], '__dict__'):
previous_args = loaded_model_setup[0]
# Exclude 'prompt' from comparison
new_args_dict = {k: v for k, v in args.__dict__.items() if k != 'prompt'}
prev_args_dict = {k: v for k, v in previous_args.__dict__.items() if k != 'prompt'}
if new_args_dict == prev_args_dict:
# If the arguments (excluding prompt) are the same, reuse the loaded setup
print(f"Arguments (excluding prompt) are the same, reusing loaded setup for {model} model.")
# Update the prompt in the loaded_model_setup
loaded_model_setup[0].prompt = prompt
yield f"{model} model already loaded with the same configuration.", loaded_model_setup
# Attempt to set up the model
try:
# If other args differ, proceed with the setup
args, trainer, device, dtype, shape, enable_grad, settings, pipe = setup(args, loaded_model_setup)
new_loaded_setup = [args, trainer, device, dtype, shape, enable_grad, settings, pipe]
yield f"{model} model loaded successfully!", new_loaded_setup
except Exception as e:
print(f"Failed to load {model} model: {e}.")
yield f"Failed to load {model} model: {e}. You can try again, as it usually finally loads on the second try :)", None
def generate_image(setup_args, num_iterations):
torch.cuda.empty_cache() # Free up cached memory
gc.collect()
gr.Info(f"Executing iterations task ...")
args = setup_args[0]
trainer = setup_args[1]
device = setup_args[2]
dtype = setup_args[3]
shape = setup_args[4]
enable_grad = setup_args[5]
settings = setup_args[6]
print(f"SETTINGS: {settings}")
pipe = setup_args[7]
save_dir = f"{args.save_dir}/{args.task}/{settings}/{args.prompt[:150]}"
clean_dir(save_dir)
try:
torch.cuda.empty_cache() # Free up cached memory
gc.collect()
steps_completed = []
result_container = {"best_image": None, "total_init_rewards": None, "total_best_rewards": None}
error_status = {"error_occurred": False} # Shared dictionary to track error status
thread_status = {"running": False} # Track whether a thread is already running
def progress_callback(step):
# Limit redundant prints by checking the step number
if not steps_completed or step > steps_completed[-1]:
steps_completed.append(step)
print(f"Progress: Step {step} completed.")
def run_main():
thread_status["running"] = True # Mark thread as running
try:
execute_task(
args, trainer, device, dtype, shape, enable_grad, settings, pipe, progress_callback
)
except torch.cuda.OutOfMemoryError as e:
print(f"CUDA Out of Memory Error: {e}")
error_status["error_occurred"] = True
except RuntimeError as e:
if 'out of memory' in str(e):
print(f"Runtime Error: {e}")
error_status["error_occurred"] = True
else:
raise
finally:
thread_status["running"] = False # Mark thread as completed
if not thread_status["running"]: # Ensure no other thread is running
main_thread = threading.Thread(target=run_main)
main_thread.start()
last_step_yielded = 0
while main_thread.is_alive() and not error_status["error_occurred"]:
# Check if new steps have been completed
if steps_completed and steps_completed[-1] > last_step_yielded:
last_step_yielded = steps_completed[-1]
png_number = last_step_yielded - 1
# Get the image for this step
image_path = os.path.join(save_dir, f"{png_number}.png")
if os.path.exists(image_path):
yield (image_path, f"Iteration {last_step_yielded}/{num_iterations} - Image saved", None)
else:
yield (None, f"Iteration {last_step_yielded}/{num_iterations} - Image not found", None)
else:
time.sleep(0.1) # Sleep to prevent busy waiting
if error_status["error_occurred"]:
torch.cuda.empty_cache() # Free up cached memory
gc.collect()
yield (None, "CUDA out of memory. Please reduce your batch size or image resolution.", None)
else:
main_thread.join() # Ensure thread completion
final_image_path = os.path.join(save_dir, "best_image.png")
if os.path.exists(final_image_path):
iter_images = list_iter_images(save_dir)
torch.cuda.empty_cache() # Free up cached memory
gc.collect()
time.sleep(0.5)
yield (final_image_path, f"Final image saved at {final_image_path}", iter_images)
else:
torch.cuda.empty_cache() # Free up cached memory
gc.collect()
yield (None, "Image generation completed, but no final image was found.", None)
torch.cuda.empty_cache() # Free up cached memory
gc.collect()
except torch.cuda.OutOfMemoryError as e:
print(f"Global CUDA Out of Memory Error: {e}")
yield (None, f"{e}", None)
except RuntimeError as e:
if 'out of memory' in str(e):
print(f"Runtime Error: {e}")
yield (None, f"{e}", None)
else:
yield (None, f"An error occurred: {str(e)}", None)
except Exception as e:
print(f"Unexpected Error: {e}")
yield (None, f"An unexpected error occurred: {str(e)}", None)
def show_gallery_output(gallery_state):
if gallery_state is not None:
return gr.update(value=gallery_state, visible=True)
else:
return gr.update(value=None, visible=False)
def combined_function(gallery_state, loaded_model_setup, prompt, chosen_model, seed, n_iter, enable_hps, hps_w, enable_imagereward, imgrw_w, enable_pickscore, pcks_w, enable_clip, clip_w, learning_rate, progress=gr.Progress(track_tqdm=True)):
# Step 1: Start Over
gallery_state, output_image, status, iter_gallery_update = start_over(gallery_state)
model_status = "" # No model status yet
yield gallery_state, output_image, status, iter_gallery_update, loaded_model_setup, model_status
# Step 2: Setup the model
model_status, new_loaded_model_setup = None, None
for model_status, new_loaded_model_setup in setup_model(
loaded_model_setup, prompt, chosen_model, seed, n_iter, enable_hps, hps_w,
enable_imagereward, imgrw_w, enable_pickscore, pcks_w, enable_clip, clip_w, learning_rate):
yield gallery_state, output_image, status, iter_gallery_update, new_loaded_model_setup, model_status
# Step 3: Generate the image
output_image, status, gallery_state_update = None, None, None
for output_image, status, gallery_state_update in generate_image(new_loaded_model_setup, n_iter):
yield gallery_state_update, output_image, status, iter_gallery_update, new_loaded_model_setup, model_status
# Step 4: Show the gallery
iter_gallery_update = show_gallery_output(gallery_state_update)
yield gallery_state_update, output_image, status, iter_gallery_update, new_loaded_model_setup, model_status
# Create Gradio interface
title="# ReNO: Enhancing One-step Text-to-Image Models through Reward-based Noise Optimization"
description = "Enter a prompt to generate an image using ReNO. The method enhances text-to-image generation by optimizing \
the initial noise using reward models as detailed in the paper. The demo uses a lower learning rate (2.5) compared to the paper's default (5.0) \
for smoother trajectories - if you are looking for more dramatic changes, you can increase this value. You can also \
adjust the reward weights to e.g. prioritize either prompt following (increase ImageReward) or aesthetic quality \
(increase HPS/PickScore) based on your preferences.\n\nThe first time you load this demo, it will take a bit \
to download and initialize the required model. Once loaded, each optimization run takes about 25-60 seconds."
css="""
#model-status-id{
height: 126px;
}
#model-status-id .progress-text{
font-size: 10px!important;
}
#model-status-id .progress-level-inner{
font-size: 8px!important;
}
"""
with gr.Blocks(css=css, analytics_enabled=False) as demo:
loaded_model_setup = gr.State()
gallery_state = gr.State()
with gr.Column():
gr.Markdown(title)
gr.Markdown(description)
gr.HTML("""
<div style="display:flex;column-gap:4px;">
<a href='https://github.com/ExplainableML/ReNO'>
<img src='https://img.shields.io/badge/GitHub-Repo-blue'>
</a>
<a href='https://arxiv.org/abs/2406.04312v1'>
<img src='https://img.shields.io/badge/Paper-Arxiv-red'>
</a>
</div>
""")
with gr.Row():
with gr.Column():
prompt = gr.Textbox(label="Prompt")
with gr.Row():
chosen_model = gr.Dropdown(["sd-turbo", "sdxl-turbo", "pixart", "hyper-sd", "flux"], label="Model", value="sdxl-turbo")
seed = gr.Number(label="seed", value=0)
model_status = gr.Textbox(label="model status", visible=True, elem_id="model-status-id")
with gr.Row():
n_iter = gr.Slider(minimum=10, maximum=100, step=10, value=50, label="Number of Iterations")
learning_rate = gr.Slider(minimum=0.1, maximum=10.0, step=0.1, value=2.5, label="Learning Rate")
with gr.Accordion("Advanced Settings", open=True):
with gr.Column():
with gr.Row():
enable_hps = gr.Checkbox(label="HPS ON", value=True, scale=1)
hps_w = gr.Slider(label="HPS weight", step=0.1, minimum=0.0, maximum=10.0, value=5.0, interactive=False, scale=3)
with gr.Row():
enable_imagereward = gr.Checkbox(label="ImageReward ON", value=True, scale=1)
imgrw_w = gr.Slider(label="ImageReward weight", step=0.1, minimum=0, maximum=5.0, value=1.0, interactive=False, scale=3)
with gr.Row():
enable_pickscore = gr.Checkbox(label="PickScore ON", value=True, scale=1)
pcks_w = gr.Slider(label="PickScore weight", step=0.01, minimum=0, maximum=0.5, value=0.05, interactive=False, scale=3)
with gr.Row():
enable_clip = gr.Checkbox(label="CLIP ON", value=True, scale=1)
clip_w = gr.Slider(label="CLIP weight", step=0.01, minimum=0, maximum=0.1, value=0.01, interactive=False, scale=3)
submit_btn = gr.Button("Submit")
gr.Examples(
examples = [
"A red dog and a green cat",
"A blue scooter is parked near a curb in front of a green vintage car",
"A curious, orange fox and a fluffy, white rabbit, playing together in a lush, green meadow filled with yellow dandelions",
"An orange chair to the right of a black airplane"
"A toaster riding a bike",
"A brain riding a rocketship towards the moon",
],
inputs = [prompt]
)
with gr.Column():
output_image = gr.Image(type="filepath", label="Best Generated Image")
status = gr.Textbox(label="Status")
iter_gallery = gr.Gallery(label="Iterations", columns=4, visible=False)
def allow_weighting(weight_type):
if weight_type is True:
return gr.update(interactive=True)
else:
return gr.update(interactive=False)
enable_hps.change(
fn = allow_weighting,
inputs = [enable_hps],
outputs = [hps_w],
queue = False
)
enable_imagereward.change(
fn = allow_weighting,
inputs = [enable_imagereward],
outputs = [imgrw_w],
queue = False
)
enable_pickscore.change(
fn = allow_weighting,
inputs = [enable_pickscore],
outputs = [pcks_w],
queue = False
)
enable_clip.change(
fn = allow_weighting,
inputs = [enable_clip],
outputs = [clip_w],
queue = False
)
submit_btn.click(
fn = combined_function,
inputs = [
gallery_state, loaded_model_setup, prompt, chosen_model, seed, n_iter,
enable_hps, hps_w, enable_imagereward, imgrw_w, enable_pickscore,
pcks_w, enable_clip, clip_w, learning_rate
],
outputs = [
gallery_state, output_image, status, iter_gallery, loaded_model_setup, model_status # Ensure `model_status` is included in the outputs
]
)
"""
submit_btn.click(
fn = start_over,
inputs =[gallery_state],
outputs = [gallery_state, output_image, status, iter_gallery]
).then(
fn = setup_model,
inputs = [loaded_model_setup, prompt, chosen_model, seed, n_iter, enable_hps, hps_w, enable_imagereward, imgrw_w, enable_pickscore, pcks_w, enable_clip, clip_w, learning_rate],
outputs = [model_status, loaded_model_setup] # Load the new setup into the state
).then(
fn = generate_image,
inputs = [loaded_model_setup, n_iter],
outputs = [output_image, status, gallery_state]
).then(
fn = show_gallery_output,
inputs = [gallery_state],
outputs = iter_gallery
)
"""
# Launch the app
demo.queue().launch(show_error=True, show_api=False)