Spaces:
Running
Running
import gradio as gr | |
from PIL import Image | |
from inference import generate_image | |
TASK_TO_INDEX = {"Task 1": 0, "Task 2": 1, "Task 3": 2, "Task 4": 3} | |
def update_reference_image(choice: int) -> tuple[str, int, str]: | |
""" | |
Update the reference image display based on radio button selection | |
Returns the image path, selected index, and corresponding heatmap | |
""" | |
image_path = f"imgs/pattern_{choice}.png" | |
heatmap_path = f"imgs/heatmap_{choice}.png" | |
return image_path, choice, heatmap_path | |
def process_coord_click(image_idx: int, evt: gr.SelectData) -> Image.Image: | |
""" | |
Process the click event on the coordinate selector | |
""" | |
x, y = evt.index[0], evt.index[1] | |
x, y = x / 1155, y / 1155 # Normalize the coordinates | |
return generate_image(image_idx, x, y) | |
with gr.Blocks( | |
css=""" | |
.radio-container { | |
width: 450px !important; | |
margin-left: auto !important; | |
margin-right: auto !important; | |
} | |
.coordinate-container { | |
width: 600px !important; | |
height: 600px !important; | |
} | |
.coordinate-container img { | |
width: 100% !important; | |
height: 100% !important; | |
object-fit: contain !important; | |
} | |
""" | |
) as demo: | |
gr.Markdown( | |
""" | |
# Interactive Image Generation | |
Select a task using the radio buttons, then click on the coordinate selector to generate a new image. | |
""" | |
) | |
with gr.Row(): | |
# Left column: Radio selection, reference image, and output | |
with gr.Column(scale=1): | |
# State variable to track selected image index | |
selected_idx = gr.State(value=0) | |
# Radio buttons with container class | |
with gr.Column(elem_classes="radio-container"): | |
task_select = gr.Radio( | |
choices=["Task 1", "Task 2", "Task 3", "Task 4"], | |
value="Task 1", | |
label="Select Task", | |
interactive=True, | |
) | |
# Reference image component that updates based on selection | |
reference_image = gr.Image( | |
value="imgs/pattern_0.png", | |
show_label=False, | |
interactive=False, | |
height=300, | |
width=450, | |
) | |
# Generated image output moved below reference image | |
output_image = gr.Image(label="Generated Output", height=300, width=450) | |
# Right column: Larger coordinate selector | |
with gr.Column(scale=1): | |
# Coordinate selector with container class for proper scaling | |
with gr.Column(elem_classes="coordinate-container"): | |
coord_selector = gr.Image( | |
value="imgs/heatmap_0.png", | |
label="Click to select (x, y) coordinates in the latent space", | |
show_label=True, | |
interactive=True, | |
container=True, | |
) | |
# Handle radio button selection | |
task_select.change( | |
fn=lambda x: update_reference_image(TASK_TO_INDEX[x]), | |
inputs=[task_select], | |
outputs=[reference_image, selected_idx, coord_selector], | |
) | |
# Handle coordinate selection | |
coord_selector.select( | |
process_coord_click, inputs=[selected_idx], outputs=output_image, trigger_mode="multiple" | |
) | |
demo.launch() | |