Spaces:
Runtime error
Runtime error
# This demo needs to be run from the repo folder. | |
# python demo/fake_gan/run.py | |
import os | |
import random | |
import gradio as gr | |
import itertools | |
from PIL import Image, ImageFont, ImageDraw | |
import DirectedDiffusion | |
# prompt | |
# boundingbox | |
# prompt indices for region | |
# number of trailing attention | |
# number of DD steps | |
# gaussian coefficient | |
# seed | |
EXAMPLES = [ | |
[ | |
"A painting of a tiger, on the wall in the living room", | |
"0.2,0.6,0.0,0.5", | |
"1,5", | |
5, | |
15, | |
1.0, | |
2094889, | |
], | |
[ | |
"a dog diving into a pool in sunny day", | |
"0.0,0.5,0.0,0.5", | |
"1,2", | |
10, | |
20, | |
5.0, | |
2483964026826, | |
], | |
[ | |
"A red cube above a blue sphere", | |
"0.4,0.7,0.0,0.5 0.4,0.7,0.5,1.0", | |
"2,3 6,7", | |
10, | |
20, | |
1.0, | |
1213698, | |
], | |
[ | |
"The sun shining on a house", | |
"0.0,0.5,0.0,0.5", | |
"1,2", | |
10, | |
20, | |
1.0, | |
2483964026826, | |
], | |
[ | |
"a diver swimming through a school of fish", | |
"0.5,1.0,0.0,0.5", | |
"1,2", | |
10, | |
10, | |
5.0, | |
2483964026826, | |
], | |
[ | |
"A stone castle surrounded by lakes and trees", | |
"0.3,0.7,0.0,1.0", | |
"1,2,3", | |
10, | |
5, | |
1.0, | |
2483964026826, | |
], | |
[ | |
"A dog hiding behind the chair", | |
"0.5,0.9,0.0,1.0", | |
"1,2", | |
10, | |
5, | |
2.5, | |
248396402123, | |
], | |
[ | |
"A dog sitting next to a mirror", | |
"0.0,0.5,0.0,1.0 0.5,1.0,0.0,1.0", | |
"1,2 6,7", | |
20, | |
5, | |
1.0, | |
24839640268232521, | |
], | |
] | |
model_bundle = DirectedDiffusion.AttnEditorUtils.load_all_models( | |
model_path_diffusion="CompVis/stable-diffusion-v1-4" | |
) | |
# model_bundle = DirectedDiffusion.AttnEditorUtils.load_all_models( | |
# model_path_diffusion="../DirectedDiffusion/assets/models/stable-diffusion-v1-4" | |
# ) | |
ALL_OUTPUT = {} | |
def directed_diffusion( | |
in_prompt, | |
in_bb, | |
in_token_ids, | |
in_slider_trailings, | |
in_slider_ddsteps, | |
in_slider_gcoef, | |
in_seed, | |
is_draw_bbox, | |
): | |
str_arg_to_val = lambda arg, f: [ | |
[f(b) for b in a.split(",")] for a in arg.split(" ") | |
] | |
roi = str_arg_to_val(in_bb, float) | |
attn_editor_bundle = { | |
"edit_index": str_arg_to_val(in_token_ids, int), | |
"roi": roi, | |
"num_trailing_attn": [in_slider_trailings] * len(roi), | |
"num_affected_steps": in_slider_ddsteps, | |
"noise_scale": [in_slider_gcoef] * len(roi), | |
} | |
img = DirectedDiffusion.Diffusion.stablediffusion( | |
model_bundle, | |
attn_editor_bundle=attn_editor_bundle, | |
guidance_scale=7.5, | |
prompt=in_prompt, | |
steps=50, | |
seed=in_seed, | |
is_save_attn=False, | |
is_save_recons=False, | |
) | |
if is_draw_bbox and in_slider_ddsteps > 0: | |
for r in roi: | |
x0, y0, x1, y1 = [int(r_ * 512) for r_ in r] | |
image_editable = ImageDraw.Draw(img) | |
image_editable.rectangle( | |
xy=[x0, x1, y0, y1], outline=(255, 0, 0, 255), width=5 | |
) | |
return img | |
def run_it( | |
in_prompt, | |
in_bb, | |
in_token_ids, | |
in_slider_trailings, | |
in_slider_ddsteps, | |
in_slider_gcoef, | |
in_seed, | |
is_draw_bbox, | |
is_grid_search, | |
progress=gr.Progress(), | |
): | |
global ALL_OUTPUT | |
num_affected_steps = [in_slider_ddsteps] | |
noise_scale = [in_slider_gcoef] | |
num_trailing_attn = [in_slider_trailings] | |
if is_grid_search: | |
num_affected_steps = [5, 10] | |
noise_scale = [1.0, 1.5, 2.5] | |
num_trailing_attn = [10, 20, 30, 40] | |
param_list = [num_affected_steps, noise_scale, num_trailing_attn] | |
param_list = list(itertools.product(*param_list)) | |
results = [] | |
progress(0, desc="Starting...") | |
for i, element in enumerate(progress.tqdm(param_list)): | |
print("=========== Arguments ============") | |
print("Prompt:", in_prompt) | |
print("BoundingBox:", in_bb) | |
print("Token indices:", in_token_ids) | |
print("Num Trialings:", element[2]) | |
print("Num DD steps:", element[0]) | |
print("Gaussian coef:", element[1]) | |
print("Seed:", in_seed) | |
print("===================================") | |
img = directed_diffusion( | |
in_prompt=in_prompt, | |
in_bb=in_bb, | |
in_token_ids=in_token_ids, | |
in_slider_trailings=element[2], | |
in_slider_ddsteps=element[0], | |
in_slider_gcoef=element[1], | |
in_seed=in_seed, | |
is_draw_bbox=is_draw_bbox, | |
) | |
results.append( | |
( | |
img, | |
"#Trailing:{},#DDSteps:{},GaussianCoef:{}".format( | |
element[2], element[0], element[1] | |
), | |
) | |
) | |
return results | |
with gr.Blocks() as demo: | |
gr.Markdown( | |
""" | |
### Directed Diffusion: Direct Control of Object Placement through Attention Guidance | |
**\*Wan-Duo Kurt Ma, \^J. P. Lewis, \^\*W. Bastiaan Kleijn, \^Thomas Leung** | |
*\*Victoria University of Wellington, \^Google Research* | |
Let's pin the object in the prompt as you wish! | |
For more information, please checkout our project page ([link](https://hohonu-vicml.github.io/DirectedDiffusion.Page/)), repository ([link](https://github.com/hohonu-vicml/DirectedDiffusion)), and the paper ([link](https://arxiv.org/abs/2302.13153)) | |
""" | |
) | |
with gr.Row(variant="panel"): | |
with gr.Column(variant="compact"): | |
in_prompt = gr.Textbox( | |
label="Enter your prompt", | |
show_label=False, | |
max_lines=1, | |
placeholder="Enter your prompt", | |
).style( | |
container=False, | |
) | |
with gr.Row(variant="compact"): | |
in_bb = gr.Textbox( | |
label="Bounding box", | |
show_label=True, | |
max_lines=1, | |
placeholder="e.g., 0.1,0.5,0.3,0.6", | |
) | |
in_token_ids = gr.Textbox( | |
label="Token indices", | |
show_label=True, | |
max_lines=1, | |
placeholder="e.g., 1,2,3", | |
) | |
in_seed = gr.Number( | |
value=2483964026821236, label="Random seed", interactive=True | |
) | |
with gr.Row(variant="compact"): | |
is_grid_search = gr.Checkbox( | |
value=False, | |
label="Grid search? (If checked then sliders are ignored)", | |
) | |
is_draw_bbox = gr.Checkbox( | |
value=True, | |
label="To draw the bounding box?", | |
) | |
with gr.Row(variant="compact"): | |
in_slider_trailings = gr.Slider( | |
minimum=0, maximum=30, value=10, step=1, label="#trailings" | |
) | |
in_slider_ddsteps = gr.Slider( | |
minimum=0, maximum=30, value=10, step=1, label="#DDSteps" | |
) | |
in_slider_gcoef = gr.Slider( | |
minimum=0, maximum=10, value=1.0, step=0.1, label="GaussianCoef" | |
) | |
with gr.Row(variant="compact"): | |
btn_run = gr.Button("Generate image").style(full_width=True) | |
#btn_clean = gr.Button("Clean Gallery").style(full_width=True) | |
gr.Markdown( | |
""" Note: | |
1) Please click one of the examples below for quick setup. | |
2) if #DDsteps==0, it means the SD process runs without DD. | |
3) The bounding box is the tuple of four scalars representing the fractional boundary of an image: left,right,top,bottom | |
4) The token indices are the word positions in the prompt associated with the edited region, 1-indexed. | |
""" | |
) | |
with gr.Column(variant="compact"): | |
gallery = gr.Gallery( | |
label="Generated images", show_label=False, elem_id="gallery" | |
).style(grid=[2], height="auto") | |
args = [ | |
in_prompt, | |
in_bb, | |
in_token_ids, | |
in_slider_trailings, | |
in_slider_ddsteps, | |
in_slider_gcoef, | |
in_seed, | |
is_draw_bbox, | |
is_grid_search, | |
] | |
btn_run.click(run_it, inputs=args, outputs=gallery) | |
#btn_clean.click(clean_gallery, outputs=gallery) | |
examples = gr.Examples( | |
examples=EXAMPLES, | |
inputs=args, | |
) | |
if __name__ == "__main__": | |
demo.queue().launch() | |