RAG-Diffusion / app.py
znchen
Update app.py
eb81c16
import os
import cv2
import gradio as gr
import numpy as np
import random
import base64
import requests
import json
import time
import spaces
from gradio_box_promptable_image import BoxPromptableImage
from gen_box_func import generate_parameters, visualize
import torch
from RAG_pipeline_flux import RAG_FluxPipeline
MAX_SEED = 999999
pipe = RAG_FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16)
pipe = pipe.to("cuda")
global run_nums
def update_run_num():
with open("assets/run_num.txt", "r+") as f:
run_num = int(f.read().strip()) + 1
f.seek(0)
f.write(str(run_num))
return run_num
# init
run_num = update_run_num()
def read_run_num():
with open("assets/run_num.txt", "r+") as f:
run_num = int(f.read().strip())
return run_num
def get_box_inputs(prompts):
# if isinstance(prompts, str):
# prompts = json.loads(prompts)
if prompts=="layout1":
prompts=[[0.05*1024, 0.05*1024, 2.0, (0.05+0.40)*1024, (0.05+0.9)*1024, 3.0], [0.5*1024, 0.05*1024, 2.0, (0.5+0.45)*1024, (0.05+0.9)*1024, 3.0]]
elif prompts=="layout2":
prompts=[[20.0, 425.0, 2.0, 551.0, 1008.0, 3.0], [615.0, 84.0, 2.0, 1000.0, 389.0, 3.0]]
elif prompts=="layout3":
prompts=[[0.2*1024, 0.1*1024, 2.0, (0.2+0.6)*1024, (0.1+0.4)*1024, 3.0],[0.2*1024, 0.6*1024, 2.0, (0.2+0.6)*1024, (0.6+0.35)*1024, 3.0]]
elif prompts=="layout4":
prompts=[[9.0, 153.0, 2.0, 343.0, 959.0, 3.0], [376.0, 145.0, 2.0, 692.0, 959.0, 3.0], [715.0, 143.0, 2.0, 1015.0, 956.0, 3.0]]
box_inputs = []
for prompt in prompts:
# print("prompt",prompt)
if prompt[2] == 2.0 and prompt[5] == 3.0:
box_inputs.append((prompt[0], prompt[1], prompt[3], prompt[4]))
return box_inputs
@spaces.GPU
def rag_gen(
# box_prompt_image,
box_point,
box_image,
prompt,
coarse_prompt,
detailed_prompt,
HB_replace,
SR_delta,
num_inference_steps,
guidance_scale,
seed,
randomize_seed,
):
points, image = box_point, box_image
print("points", points)
box_inputs = get_box_inputs(points)
# prompt_img_height, prompt_img_width, _ = image.shape
prompt_img_height, prompt_img_width = 1024,1024
# GREEN = (36, 255, 12)
HB_prompt_list = coarse_prompt.split("BREAK")
HB_m_offset_list, HB_n_offset_list, HB_m_scale_list, HB_n_scale_list, SR_hw_split_ratio = generate_parameters(box_inputs, prompt_img_width, prompt_img_height)
image = visualize(HB_m_offset_list, HB_n_offset_list, HB_m_scale_list, HB_n_scale_list, SR_hw_split_ratio, prompt_img_width, prompt_img_height)
if randomize_seed:
seed = random.randint(0, MAX_SEED)
else:
seed = seed % MAX_SEED
SR_prompt = detailed_prompt
rag_image = pipe(
SR_delta=SR_delta,
SR_hw_split_ratio=SR_hw_split_ratio,
SR_prompt=SR_prompt,
HB_prompt_list=HB_prompt_list,
HB_m_offset_list=HB_m_offset_list,
HB_n_offset_list=HB_n_offset_list,
HB_m_scale_list=HB_m_scale_list,
HB_n_scale_list=HB_n_scale_list,
HB_replace=HB_replace,
seed=seed,
prompt=prompt,
height=1024,
width=1024,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
).images[0]
global run_num
run_num = update_run_num()
# return image, rag_image, seed, f"<span style='font-size: 16px; font-weight: bold; color: red; display: block; text-align: center;'>Total inference runs: {run_num}</span>"
# return rag_image, seed, f"<span style='font-size: 16px; font-weight: bold; color: red; display: block; text-align: center;'>Total inference runs: {run_num}</span>"
return rag_image, seed
example_path = os.path.join(os.path.dirname(__file__), 'assets')
css="""
#col-left {
margin: 0 auto;
max-width: 400px;
}
#col-right {
margin: 0 auto;
max-width: 600px;
}
#col-showcase {
margin: 0 auto;
max-width: 1100px;
}
#button {
color: blue;
}
#custom-label {
color: purple;
font-size: 16px;
font-weight: bold;
}
"""
assets_root_path = os.path.join(os.path.dirname(__file__), 'assets')
def load_description(fp):
with open(fp, 'r', encoding='utf-8') as f:
content = f.read()
return content
with gr.Blocks(css=css) as demo:
gr.HTML(load_description("assets/title.md"))
# run_nums_box = gr.Markdown(
# value=f"<span style='font-size: 16px; font-weight: bold; color: red; display: block; text-align: center;'>Total inference runs: {run_num}</span>"
# )
with gr.Row():
with gr.Column(elem_id="col-left"):
gr.HTML("""
<div style="display: flex; justify-content: center; align-items: center; text-align: center; font-size: 20px;">
<div>
</div>
<div>
Step 1. Choose
<span style="color: purple; font-weight: bold;">layout example</span>
</div>
</div>
""")
prompt = gr.Textbox(
label="Prompt",
placeholder="Enter your prompt",
lines=2
)
coarse_prompt = gr.Textbox(
label="Regional Fundamental Prompt(BREAK is a delimiter).",
placeholder="Enter your prompt",
lines=2
)
detailed_prompt = gr.Textbox(
label="Regional Highly Descriptive Prompt(BREAK is a delimiter).",
placeholder="Enter your prompt",
lines=2
)
with gr.Column(elem_id="col-left"):
default_image_path = "assets/images_template.png"
box_image = gr.Image(
show_label=False,
interactive=False,
label="Layout",
value=default_image_path)
# box_prompt_image = BoxPromptableImage(
# show_label=False,
# interactive=False,
# label="Layout",
# value={"image": default_image_path})
# box_prompt_image = gr.Image(label="Layout", show_label=True)
gr.HTML("""
<div style="display: flex; justify-content: center; align-items: center; text-align: center; font-size: 16px;">
<div style="display: flex; justify-content: center; align-items: center; text-align: center; font-size: 12px;">
<strong>
<span style="color: gray; font-weight: bold;">Tip: You can get a more ideal picture by adjusting HB_replace and SR_delta</span>
</strong>
</div>
</div>
""")
with gr.Column(elem_id="col-right"):
gr.HTML("""
<div style="display: flex; justify-content: center; align-items: center; text-align: center; font-size: 20px;">
<div>
Step 2. Press “Run” to get results
</div>
</div>
<div style="display: flex; justify-content: center; align-items: center; text-align: center; font-size: 10px;">
<div>
Errors may be displayed due to insufficient computing power
</div>
</div>
""")
# layout = gr.Image(label="Layout", show_label=True)
result = gr.Image(label="Result", show_label=True)
with gr.Accordion("Advanced Settings", open=False):
with gr.Row():
seed = gr.Slider(
label="Seed",
minimum=0,
maximum=MAX_SEED,
step=1,
value=0,
)
randomize_seed = gr.Checkbox(label="Random seed", value=True)
with gr.Row():
HB_replace = gr.Slider(
label="HB_replace(The times of hard binding. More can make the position control more precise, but may lead to obvious boundaries.)",
minimum=0,
maximum=8,
step=1,
value=2,
)
with gr.Row():
SR_delta = gr.Slider(
label="SR_delta(The fusion strength of image latent and regional-aware local latent. This is a flexible parameter, you can try 0.25, 0.5, 0.75, 1.0.)",
minimum=0.0,
maximum=1,
step=0.1,
value=1,
)
with gr.Row():
guidance_scale = gr.Slider(
label="Guidance Scale",
minimum=1,
maximum=15,
step=0.1,
value=3.5,
)
num_inference_steps = gr.Slider(
label="Number of inference steps",
minimum=1,
maximum=50,
step=1,
value=20,
)
with gr.Row():
button = gr.Button("Run", elem_id="button")
box_point = gr.Textbox(visible=False)
gr.on(
triggers=[
button.click,
],
fn=rag_gen,
# fn=lambda *args: rag_gen(*args,[]),
inputs=[
box_point,
box_image,
prompt,
coarse_prompt,
detailed_prompt,
HB_replace,
SR_delta,
num_inference_steps,
guidance_scale, seed,
randomize_seed,
],
# outputs=[layout, result, seed, run_nums_box],
# outputs=[result, seed, run_nums_box],
outputs=[result, seed],
api_name="run",
)
with gr.Column():
gr.HTML('<div id="custom-label">Layout Example ⬇️</div>')
gr.Examples(
# label="Layout Example (For more complex layouts, please run our code directly.)",
examples=[
[
"layout1",
"assets/case1.png",
"a man is holding a bag, a man is talking on a cell phone.", # prompt
"A man holding a bag. BREAK a man holding a cell phone to his ear.", # coarse_prompt
"A man holding a bag, gripping it firmly, with a casual yet purposeful stance. BREAK a man, engaged in conversation, holding a cell phone to his ear.", # detailed_prompt
3, # HB_replace
1.0, # SR_delta
20, # num_inference_steps
3.5, # guidance_scale
1234, # seed
False, # randomize_seed
],
[
"layout2",
"assets/case2.png",
"A woman looking at the moon", # prompt
"a woman BREAK a moon", # coarse_prompt
"A woman, standing gracefully, her gaze fixed on the sky with a sense of wonder. BREAK The moon, luminous and full, casting a soft glow across the tranquil night.", # detailed_prompt
3, # HB_replace
0.8, # SR_delta
20, # num_inference_steps
3.5, # guidance_scale
1233, # seed
False, # randomize_seed
],
[
"layout3",
"assets/case3.png",
"a turtle on the bottom of a phone", # prompt
"Phone BREAK Turtle", # coarse_prompt
"The phone, placed above the turtle, potentially with its screen or back visible, its sleek design prominent. BREAK The turtle, below the phone, with its shell textured and detailed, eyes slightly protruding as it looks upward.", # detailed_prompt
2, # HB_replace
0.8, # SR_delta
20, # num_inference_steps
3.5, # guidance_scale
1233, # seed
False, # randomize_seed
],
[
"layout4",
"assets/case4.png",
"From left to right, a blonde ponytail Europe girl in white shirt, a brown curly hair African girl in blue shirt printed with a bird, an Asian young man with black short hair in suit are walking in the campus happily.", # prompt
"A blonde ponytail European girl in a white shirt BREAK A brown curly hair African girl in a blue shirt printed with a bird BREAK An Asian young man with black short hair in a suit", # coarse_prompt
"A blonde ponytail European girl in a crisp white shirt, walking with a light smile. Her ponytail swings slightly as she enjoys the lively atmosphere of the campus. BREAK A brown curly hair African girl, her vibrant blue shirt adorned with a bird print. Her joyful expression matches her energetic stride as her curls bounce lightly in the breeze. BREAK An Asian young man in a sharp suit, his black short hair neatly styled, walking confidently alongside the two girls. His suit contrasts with the casual campus environment, adding an air of professionalism to the scene.", # detailed_prompt
2, # HB_replace
1.0, # SR_delta
20, # num_inference_steps
3.5, # guidance_scale
1234, # seed
False, # randomize_seed
],
],
inputs=[
box_point,
box_image,
prompt,
coarse_prompt,
detailed_prompt,
HB_replace,
SR_delta,
num_inference_steps,
guidance_scale,
seed,
randomize_seed
],
outputs=None,
fn=rag_gen,
cache_examples=False,
)
if __name__ == "__main__":
demo.queue(max_size=20).launch(share=True, server_port=7860)