Lin
Bug fixes
63f6c0c
raw
history blame
5.32 kB
import os
import gradio as gr
from background_replacer import replace_background
developer_mode = os.getenv('DEV_MODE', False)
DEFAULT_POSITIVE_PROMPT = "on the pavement, poolside, idyllic infinity pool, Hawaiian hilltops, commercial product photography"
DEFAULT_NEGATIVE_PROMPT = ""
EXAMPLES = []
INTRO = "Test."
def generate(
image,
positive_prompt,
negative_prompt,
seed,
depth_map_feather_threshold,
depth_map_dilation_iterations,
depth_map_blur_radius,
progress=gr.Progress(track_tqdm=True)
):
if image is None:
return [None, None, None, None]
options = {
'seed': seed,
'depth_map_feather_threshold': depth_map_feather_threshold,
'depth_map_dilation_iterations': depth_map_dilation_iterations,
'depth_map_blur_radius': depth_map_blur_radius,
}
return replace_background(image, positive_prompt, negative_prompt, options)
custom_css = """
#image-upload {
flex-grow: 1;
}
#params .tabs {
display: flex;
flex-direction: column;
flex-grow: 1;
}
#params .tabitem[style="display: block;"] {
flex-grow: 1;
display: flex !important;
}
#params .gap {
flex-grow: 1;
}
#params .form {
flex-grow: 1 !important;
}
#params .form > :last-child{
flex-grow: 1;
}
.md ol, .md ul {
margin-left: 1rem;
}
.md img {
margin-bottom: 1rem;
}
"""
with gr.Blocks(css=custom_css) as iface:
gr.Markdown(INTRO)
with gr.Row():
with gr.Column():
image_upload = gr.Image(
label="Product image",
type="pil",
elem_id="image-upload"
)
caption = gr.Label(
label="Caption",
visible=developer_mode
)
with gr.Column(elem_id="params"):
with gr.Tab('Prompts'):
positive_prompt = gr.Textbox(
label="Positive Prompt: describe what you'd like to see",
lines=3,
value=DEFAULT_POSITIVE_PROMPT
)
negative_prompt = gr.Textbox(
label="Negative Prompt: describe what you want to avoid",
lines=3,
value=DEFAULT_NEGATIVE_PROMPT
)
if developer_mode:
with gr.Tab('Options'):
seed = gr.Number(
label="Seed",
precision=0,
value=0,
elem_id="seed",
visible=developer_mode
)
depth_map_feather_threshold = gr.Slider(
label="Depth map feather threshold",
value=128,
minimum=0,
maximum=255,
visible=developer_mode
)
depth_map_dilation_iterations = gr.Number(
label="Depth map dilation iterations",
precision=0,
value=10,
minimum=0,
visible=developer_mode
)
depth_map_blur_radius = gr.Number(
label="Depth map blur radius",
precision=0,
value=10,
minimum=0,
visible=developer_mode
)
else:
seed = gr.Number(value=-1, visible=False)
depth_map_feather_threshold = gr.Slider(
value=128, visible=False)
depth_map_dilation_iterations = gr.Number(
precision=0, value=10, visible=False)
depth_map_blur_radius = gr.Number(
precision=0, value=10, visible=False)
gen_button = gr.Button(value="Generate!", variant="primary")
with gr.Tab('Results'):
results = gr.Gallery(
show_label=False,
object_fit="contain",
columns=4
)
if developer_mode:
with gr.Tab('Generated'):
generated = gr.Gallery(
show_label=False,
object_fit="contain",
columns=4
)
with gr.Tab('Pre-processing'):
pre_processing = gr.Gallery(
show_label=False,
object_fit="contain",
columns=4
)
else:
generated = gr.Gallery(visible=False)
pre_processing = gr.Gallery(visible=False)
gr.Examples(
examples=EXAMPLES,
inputs=[image_upload, positive_prompt, negative_prompt],
)
gen_button.click(
fn=generate,
inputs=[
image_upload,
positive_prompt,
negative_prompt,
seed,
depth_map_feather_threshold,
depth_map_dilation_iterations,
depth_map_blur_radius
],
outputs=[
results,
generated,
pre_processing,
caption
],
)
iface.queue(max_size=10, api_open=False).launch(show_api=False)