Spaces:
Running
Running
from functools import partial | |
import gradio as gr | |
from PIL import Image | |
from inference import generate_image | |
def process_coord_click(image_idx: int, evt: gr.SelectData) -> Image.Image: | |
""" | |
Process the click event on the coordinate selector | |
""" | |
x, y = evt.index[0], evt.index[1] | |
x, y = x / 400, y / 400 | |
print(f"Clicked at coordinates: ({x:.3f}, {y:.3f})") | |
return generate_image(image_idx, x, y) | |
def process_image_select(evt: gr.SelectData, idx: int) -> tuple[int, str]: | |
""" | |
Process the reference image selection | |
Returns the selected image index and corresponding heatmap | |
""" | |
return idx, f"imgs/heatmap_{idx}.png" | |
with gr.Blocks() as demo: | |
gr.Markdown( | |
""" | |
# Interactive Image Generation | |
Click on a reference image to select it, then click on the coordinate selector to generate a new image. | |
""" | |
) | |
with gr.Row(): | |
# Left column: Interactive reference images | |
with gr.Column(scale=1): | |
# State variable to track selected image index | |
selected_idx = gr.State(value=0) | |
# Two separate Image components for reference images | |
with gr.Column(): | |
image_0 = gr.Image( | |
value="imgs/pattern_0.png", | |
label="Task 1", | |
show_label=False, | |
interactive=True, | |
height=300, | |
width=450, | |
) | |
image_1 = gr.Image( | |
value="imgs/pattern_1.png", | |
label="Task 2", | |
show_label=False, | |
interactive=True, | |
height=300, | |
width=450, | |
) | |
# Right column: Coordinate selector and output image | |
with gr.Column(scale=1): | |
# Coordinate selector with dynamic background | |
coord_selector = gr.Image( | |
value="imgs/heatmap_0.png", # Initial background | |
label="Click to select (x, y) coordinates", | |
show_label=True, | |
interactive=True, | |
height=400, | |
width=400, | |
) | |
# Generated image output | |
output_image = gr.Image(label="Generated Output", height=400, width=400) | |
# Handle image selection for each reference image | |
image_0.select(partial(process_image_select, idx=0), outputs=[selected_idx, coord_selector]) | |
image_1.select(partial(process_image_select, idx=1), outputs=[selected_idx, coord_selector]) | |
# Handle coordinate selection | |
coord_selector.select( | |
process_coord_click, inputs=[selected_idx], outputs=output_image, trigger_mode="multiple" | |
) | |
demo.launch() | |