Spaces:
Runtime error
Runtime error
import os | |
import gradio as gr | |
from background_replacer import replace_background | |
developer_mode = os.getenv('DEV_MODE', False) | |
DEFAULT_POSITIVE_PROMPT = "on the pavement, poolside, idyllic infinity pool, Hawaiian hilltops, commercial product photography" | |
DEFAULT_NEGATIVE_PROMPT = "" | |
EXAMPLES = [] | |
INTRO = "Test." | |
def generate( | |
image, | |
positive_prompt, | |
negative_prompt, | |
seed, | |
depth_map_feather_threshold, | |
depth_map_dilation_iterations, | |
depth_map_blur_radius, | |
progress=gr.Progress(track_tqdm=True) | |
): | |
if image is None: | |
return [None, None, None, None] | |
options = { | |
'seed': seed, | |
'depth_map_feather_threshold': depth_map_feather_threshold, | |
'depth_map_dilation_iterations': depth_map_dilation_iterations, | |
'depth_map_blur_radius': depth_map_blur_radius, | |
} | |
return replace_background(image, positive_prompt, negative_prompt, options) | |
custom_css = """ | |
#image-upload { | |
flex-grow: 1; | |
} | |
#params .tabs { | |
display: flex; | |
flex-direction: column; | |
flex-grow: 1; | |
} | |
#params .tabitem[style="display: block;"] { | |
flex-grow: 1; | |
display: flex !important; | |
} | |
#params .gap { | |
flex-grow: 1; | |
} | |
#params .form { | |
flex-grow: 1 !important; | |
} | |
#params .form > :last-child{ | |
flex-grow: 1; | |
} | |
.md ol, .md ul { | |
margin-left: 1rem; | |
} | |
.md img { | |
margin-bottom: 1rem; | |
} | |
""" | |
with gr.Blocks(css=custom_css) as iface: | |
gr.Markdown(INTRO) | |
with gr.Row(): | |
with gr.Column(): | |
image_upload = gr.Image( | |
label="Product image", | |
type="pil", | |
elem_id="image-upload" | |
) | |
caption = gr.Label( | |
label="Caption", | |
visible=developer_mode | |
) | |
with gr.Column(elem_id="params"): | |
with gr.Tab('Prompts'): | |
positive_prompt = gr.Textbox( | |
label="Positive Prompt: describe what you'd like to see", | |
lines=3, | |
value=DEFAULT_POSITIVE_PROMPT | |
) | |
negative_prompt = gr.Textbox( | |
label="Negative Prompt: describe what you want to avoid", | |
lines=3, | |
value=DEFAULT_NEGATIVE_PROMPT | |
) | |
if developer_mode: | |
with gr.Tab('Options'): | |
seed = gr.Number( | |
label="Seed", | |
precision=0, | |
value=0, | |
elem_id="seed", | |
visible=developer_mode | |
) | |
depth_map_feather_threshold = gr.Slider( | |
label="Depth map feather threshold", | |
value=128, | |
minimum=0, | |
maximum=255, | |
visible=developer_mode | |
) | |
depth_map_dilation_iterations = gr.Number( | |
label="Depth map dilation iterations", | |
precision=0, | |
value=10, | |
minimum=0, | |
visible=developer_mode | |
) | |
depth_map_blur_radius = gr.Number( | |
label="Depth map blur radius", | |
precision=0, | |
value=10, | |
minimum=0, | |
visible=developer_mode | |
) | |
else: | |
seed = gr.Number(value=-1, visible=False) | |
depth_map_feather_threshold = gr.Slider( | |
value=128, visible=False) | |
depth_map_dilation_iterations = gr.Number( | |
precision=0, value=10, visible=False) | |
depth_map_blur_radius = gr.Number( | |
precision=0, value=10, visible=False) | |
gen_button = gr.Button(value="Generate!", variant="primary") | |
with gr.Tab('Results'): | |
results = gr.Gallery( | |
show_label=False, | |
object_fit="contain", | |
columns=4 | |
) | |
if developer_mode: | |
with gr.Tab('Generated'): | |
generated = gr.Gallery( | |
show_label=False, | |
object_fit="contain", | |
columns=4 | |
) | |
with gr.Tab('Pre-processing'): | |
pre_processing = gr.Gallery( | |
show_label=False, | |
object_fit="contain", | |
columns=4 | |
) | |
else: | |
generated = gr.Gallery(visible=False) | |
pre_processing = gr.Gallery(visible=False) | |
gr.Examples( | |
examples=EXAMPLES, | |
inputs=[image_upload, positive_prompt, negative_prompt], | |
) | |
gen_button.click( | |
fn=generate, | |
inputs=[ | |
image_upload, | |
positive_prompt, | |
negative_prompt, | |
seed, | |
depth_map_feather_threshold, | |
depth_map_dilation_iterations, | |
depth_map_blur_radius | |
], | |
outputs=[ | |
results, | |
generated, | |
pre_processing, | |
caption | |
], | |
) | |
iface.queue(max_size=10, api_open=False).launch(show_api=False) | |