import spaces
import gradio as gr
import torch
from PIL import Image
import random
import numpy as np
import torch
import os
import json
from datetime import datetime
from pipeline_rf import RectifiedFlowPipeline
# Load the Stable Diffusion Inpainting model
pipe = RectifiedFlowPipeline.from_pretrained("XCLIU/2_rectified_flow_from_sd_1_5", torch_dtype=torch.float32)
pipe.to("cuda") # Comment this line if GPU is not available
# Function to process the image
@spaces.GPU(duration=20)
def process_image(
image_layers, prompt, seed, randomize_seed, num_inference_steps,
max_steps, learning_rate, optimization_steps, inverseproblem, mask_input
):
image_with_mask = {
"image": image_layers["background"],
"mask": image_layers["layers"][0] if mask_input is None else mask_input
}
# Set seed
if randomize_seed or seed is None:
seed = random.randint(0, 2**32 - 1)
generator = torch.Generator("cuda").manual_seed(int(seed))
# Unpack image and mask
if image_with_mask is None:
return None, f"❌ Please upload an image and create a mask."
image = image_with_mask["image"]
mask = image_with_mask["mask"]
if image is None or mask is None:
return None, f"❌ Please ensure both image and mask are provided."
# Convert images to RGB
image = image.convert("RGB")
mask = mask.split()[-1] # Convert mask to grayscale
if not prompt:
return None, f"❌ Please provide a prompt for inpainting."
with torch.autocast("cuda"):
# Placeholder for using advanced parameters in the future
# Adjust parameters according to advanced settings if applicable
result = pipe(
prompt=prompt,
negative_prompt="",
input_image=image.resize((512, 512)),
mask_image=mask.resize((512, 512)),
num_inference_steps=num_inference_steps,
guidance_scale=0.0,
generator=generator,
save_masked_image=True,
output_path="test.png",
learning_rate=learning_rate,
max_steps=max_steps,
optimization_steps=optimization_steps,
inverseproblem=inverseproblem
).images[0]
return result, f"✅ Inpainting completed with seed {seed}."
# Design the Gradio interface
with gr.Blocks() as demo:
gr.Markdown(
"""
"""
)
gr.Markdown("
🍲 FlowChef 🍲
")
gr.Markdown("Inversion/Gradient/Training-free Steering of InstaFlow (SDv1.5) for Inpainting (Inverse Problem)
")
gr.Markdown("Project Page | Paper
(Steering Rectified Flow Models in the Vector Field for Controlled Image Generation)")
gr.Markdown("💡 We recommend going through our tutorial introduction before getting started!
")
gr.Markdown("⚡ For better performance, check out our demo on Flux!
")
# Store current state
current_input_image = None
current_mask = None
current_output_image = None
current_params = {}
# Images at the top
with gr.Row():
with gr.Column():
image_input = gr.ImageMask(
# source="upload",
# tool="sketch",
type="pil",
label="Input Image and Mask",
image_mode="RGBA",
height=512,
width=512,
)
with gr.Column():
output_image = gr.Image(label="Output Image")
# All options below
with gr.Column():
prompt = gr.Textbox(
label="Prompt",
placeholder="Describe what should appear in the masked area..."
)
with gr.Row():
seed = gr.Number(label="Seed (Optional)", value=None)
randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
num_inference_steps = gr.Slider(
label="Inference Steps", minimum=50, maximum=200, value=100
)
# Advanced settings in an accordion
with gr.Accordion("Advanced Settings", open=False):
max_steps = gr.Slider(label="Max Steps", minimum=50, maximum=200, value=200)
learning_rate = gr.Slider(label="Learning Rate", minimum=0.01, maximum=0.5, value=0.02)
optimization_steps = gr.Slider(label="Optimization Steps", minimum=1, maximum=10, value=1)
inverseproblem = gr.Checkbox(label="Apply mask on pixel space", value=False, info="Enables inverse problem formulation for inpainting by masking the RGB image itself. Hence, to avoid artifacts we increase the mask size manually during inference.")
mask_input = gr.Image(
type="pil",
label="Optional Mask",
image_mode="RGBA",
)
with gr.Row():
run_button = gr.Button("Run", variant="primary")
# save_button = gr.Button("Save Data", variant="secondary")
# def update_visibility(selected_mode):
# if selected_mode == "Inpainting":
# return gr.update(visible=True), gr.update(visible=False)
# else:
# return gr.update(visible=True), gr.update(visible=True)
# mode.change(
# update_visibility,
# inputs=mode,
# outputs=[prompt, edit_prompt],
# )
def run_and_update_status(
image_with_mask, prompt, seed, randomize_seed, num_inference_steps,
max_steps, learning_rate, optimization_steps, inverseproblem, mask_input
):
result_image, result_status = process_image(
image_with_mask, prompt, seed, randomize_seed, num_inference_steps,
max_steps, learning_rate, optimization_steps, inverseproblem, mask_input
)
# Store current state
global current_input_image, current_mask, current_output_image, current_params
current_input_image = image_with_mask["background"] if image_with_mask else None
current_mask = mask_input if mask_input is not None else (image_with_mask["layers"][0] if image_with_mask else None)
current_output_image = result_image
current_params = {
"prompt": prompt,
"seed": seed,
"randomize_seed": randomize_seed,
"num_inference_steps": num_inference_steps,
"max_steps": max_steps,
"learning_rate": learning_rate,
"optimization_steps": optimization_steps,
"inverseproblem": inverseproblem,
}
return result_image
def save_data():
if not os.path.exists("saved_results"):
os.makedirs("saved_results")
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
save_dir = os.path.join("saved_results", timestamp)
os.makedirs(save_dir)
# Save images
if current_input_image:
current_input_image.save(os.path.join(save_dir, "input.png"))
if current_mask:
current_mask.save(os.path.join(save_dir, "mask.png"))
if current_output_image:
current_output_image.save(os.path.join(save_dir, "output.png"))
# Save parameters
with open(os.path.join(save_dir, "parameters.json"), "w") as f:
json.dump(current_params, f, indent=4)
return f"✅ Data saved in {save_dir}"
run_button.click(
fn=run_and_update_status,
inputs=[
image_input,
prompt,
seed,
randomize_seed,
num_inference_steps,
max_steps,
learning_rate,
optimization_steps,
inverseproblem,
mask_input
],
outputs=output_image,
)
# save_button.click(fn=save_data)
gr.Markdown(
""
)
def load_example_image_with_mask(image_path):
# Load the image
image = Image.open(image_path)
# Create an empty mask of the same size
mask = Image.new('L', image.size, 0)
return {"background": image, "layers": [mask], "composite": image}
examples_dir = "assets"
volcano_dict = load_example_image_with_mask(os.path.join(examples_dir, "vulcano.jpg"))
dog_dict = load_example_image_with_mask(os.path.join(examples_dir, "dog.webp"))
gr.Examples(
examples=[
[
"./saved_results/20241129_210517/input.png", # image with mask
"./saved_results/20241129_210517/mask.png",
"./saved_results/20241129_210517/output.png",
"a cat", # prompt
0, # seed
True, # randomize_seed
200, # num_inference_steps
200, # max_steps
0.1, # learning_rate
1, # optimization_steps
False,
],
[
"./saved_results/20241129_211124/input.png", # image with mask
"./saved_results/20241129_211124/mask.png",
"./saved_results/20241129_211124/output.png",
" ", # prompt
0, # seed
True, # randomize_seed
200, # num_inference_steps
200, # max_steps
0.1, # learning_rate
5, # optimization_steps
False,
],
[
"./saved_results/20241129_212001/input.png", # image with mask
"./saved_results/20241129_212001/mask.png",
"./saved_results/20241129_212001/output.png",
" ", # prompt
52, # seed
False, # randomize_seed
200, # num_inference_steps
200, # max_steps
0.02, # learning_rate
10, # optimization_steps
True,
],
[
"./saved_results/20241129_212052/input.png", # image with mask
"./saved_results/20241129_212052/mask.png",
"./saved_results/20241129_212052/output.png",
" ", # prompt
52, # seed
False, # randomize_seed
200, # num_inference_steps
200, # max_steps
0.02, # learning_rate
10, # optimization_steps
True,
],
[
"./saved_results/20241129_212155/input.png", # image with mask
"./saved_results/20241129_212155/mask.png",
"./saved_results/20241129_212155/output.png",
" ", # prompt
52, # seed
False, # randomize_seed
200, # num_inference_steps
200, # max_steps
0.02, # learning_rate
10, # optimization_steps
True,
],
],
inputs=[
image_input,
mask_input,
output_image,
prompt,
seed,
randomize_seed,
num_inference_steps,
max_steps,
learning_rate,
optimization_steps,
inverseproblem
],
# outputs=[output_image],
# fn=run_and_update_status,
# cache_examples=True,
)
demo.launch()