lpn / app.py
clement-bonnet's picture
feat: left images in full
0dcdf8e
raw
history blame
2.75 kB
from functools import partial
import gradio as gr
from PIL import Image
from inference import generate_image
def process_coord_click(image_idx: int, evt: gr.SelectData) -> Image.Image:
"""
Process the click event on the coordinate selector
"""
x, y = evt.index[0], evt.index[1]
x, y = x / 400, y / 400
print(f"Clicked at coordinates: ({x:.3f}, {y:.3f})")
return generate_image(image_idx, x, y)
def process_image_select(evt: gr.SelectData, idx: int) -> tuple[int, str]:
"""
Process the reference image selection
Returns the selected image index and corresponding heatmap
"""
return idx, f"imgs/heatmap_{idx}.png"
with gr.Blocks() as demo:
gr.Markdown(
"""
# Interactive Image Generation
Click on a reference image to select it, then click on the coordinate selector to generate a new image.
"""
)
with gr.Row():
# Left column: Interactive reference images
with gr.Column(scale=1):
# State variable to track selected image index
selected_idx = gr.State(value=0)
# Two separate Image components for reference images
with gr.Column():
image_0 = gr.Image(
value="imgs/pattern_0.png",
label="Task 1",
show_label=False,
interactive=True,
height=300,
width=450,
)
image_1 = gr.Image(
value="imgs/pattern_1.png",
label="Task 2",
show_label=False,
interactive=True,
height=300,
width=450,
)
# Right column: Coordinate selector and output image
with gr.Column(scale=1):
# Coordinate selector with dynamic background
coord_selector = gr.Image(
value="imgs/heatmap_0.png", # Initial background
label="Click to select (x, y) coordinates",
show_label=True,
interactive=True,
height=400,
width=400,
)
# Generated image output
output_image = gr.Image(label="Generated Output", height=400, width=400)
# Handle image selection for each reference image
image_0.select(partial(process_image_select, idx=0), outputs=[selected_idx, coord_selector])
image_1.select(partial(process_image_select, idx=1), outputs=[selected_idx, coord_selector])
# Handle coordinate selection
coord_selector.select(
process_coord_click, inputs=[selected_idx], outputs=output_image, trigger_mode="multiple"
)
demo.launch()