lpn / app.py
clement-bonnet's picture
feat: add marker on the heatmap
4cac868
raw
history blame
4.89 kB
import gradio as gr
from PIL import Image, ImageDraw
from inference import generate_image
TASK_TO_INDEX = {"Task 1": 0, "Task 2": 1, "Task 3": 2, "Task 4": 3}
def create_marker_overlay(image_path: str, x: int, y: int) -> Image.Image:
"""
Creates an image with a marker at the specified coordinates
"""
# Load the base image
base_image = Image.open(image_path)
# Create a copy to draw on
marked_image = base_image.copy()
draw = ImageDraw.Draw(marked_image)
# Define marker properties
marker_size = 10
marker_color = "red"
# Draw marker
draw.line([x - marker_size, y, x + marker_size, y], fill=marker_color, width=2)
draw.line([x, y - marker_size, x, y + marker_size], fill=marker_color, width=2)
return marked_image
def update_reference_image(choice: int) -> tuple[str, int, str]:
"""
Update the reference image display based on radio button selection
Returns the image path, selected index, and corresponding heatmap
"""
image_path = f"imgs/pattern_{choice}.png"
heatmap_path = f"imgs/heatmap_{choice}.png"
return image_path, choice, heatmap_path
def process_coord_click(image_idx: int, evt: gr.SelectData) -> tuple[Image.Image, Image.Image]:
"""
Process the click event on the coordinate selector
Returns both the generated image and the marked coordinate selector
"""
x, y = evt.index[0], evt.index[1]
# Create normalized coordinates for generation
x_norm, y_norm = x / 1155, y / 1155
# Generate the output image
generated_image = generate_image(image_idx, x_norm, y_norm)
# Create marked coordinate selector
heatmap_path = f"imgs/heatmap_{image_idx}.png"
marked_selector = create_marker_overlay(heatmap_path, x, y)
return generated_image, marked_selector
with gr.Blocks(
css="""
.radio-container {
width: 450px !important;
margin-left: auto !important;
margin-right: auto !important;
}
.coordinate-container {
width: 600px !important;
height: 600px !important;
}
.coordinate-container img {
width: 100% !important;
height: 100% !important;
object-fit: contain !important;
}
"""
) as demo:
gr.Markdown(
"""
# Interactive Image Generation
Select a task using the radio buttons, then click on the coordinate selector to generate a new image.
"""
)
with gr.Row():
# Left column: Radio selection, reference image, and output
with gr.Column(scale=1):
# State variable to track selected image index
selected_idx = gr.State(value=0)
# Radio buttons with container class
with gr.Column(elem_classes="radio-container"):
task_select = gr.Radio(
choices=["Task 1", "Task 2", "Task 3", "Task 4"],
value="Task 1",
label="Select Task",
interactive=True,
)
# Reference image component that updates based on selection
reference_image = gr.Image(
value="imgs/pattern_0.png",
show_label=False,
interactive=False,
height=300,
width=450,
show_download_button=False,
show_fullscreen_button=False,
)
# Generated image output moved below reference image
output_image = gr.Image(
label="Generated Output",
height=300,
width=450,
show_download_button=False,
show_fullscreen_button=False,
interactive=False,
)
# Right column: Larger coordinate selector
with gr.Column(scale=1):
# Coordinate selector with container class for proper scaling
with gr.Column(elem_classes="coordinate-container"):
coord_selector = gr.Image(
value="imgs/heatmap_0.png",
label="Click to select (x, y) coordinates in the latent space",
show_label=True,
interactive=False,
sources=[],
container=True,
show_download_button=False,
show_fullscreen_button=False,
)
# Handle radio button selection
task_select.change(
fn=lambda x: update_reference_image(TASK_TO_INDEX[x]),
inputs=[task_select],
outputs=[reference_image, selected_idx, coord_selector],
)
# Handle coordinate selection - now updates both output image and coord_selector
coord_selector.select(
process_coord_click,
inputs=[selected_idx],
outputs=[output_image, coord_selector],
trigger_mode="multiple",
)
demo.launch()