lpn / app.py
clement-bonnet's picture
feat: draw marker before generating image
a006f1b
raw
history blame
6.18 kB
import gradio as gr
from PIL import Image, ImageDraw
from inference import generate_image
TASK_TO_INDEX = {"Task 1": 0, "Task 2": 1, "Task 3": 2, "Task 4": 3}
def create_marker_overlay(image_path: str, x: int, y: int) -> Image.Image:
"""
Creates an image with a marker at the specified coordinates
"""
base_image = Image.open(image_path)
marked_image = base_image.copy()
draw = ImageDraw.Draw(marked_image)
marker_size = 10
marker_color = "red"
draw.line([x - marker_size, y, x + marker_size, y], fill=marker_color, width=2)
draw.line([x, y - marker_size, x, y + marker_size], fill=marker_color, width=2)
return marked_image
def update_reference_image(choice: int) -> tuple[str, int, str]:
"""
Update the reference image display based on radio button selection
Returns the image path, selected index, and corresponding heatmap
"""
image_path = f"imgs/pattern_{choice}.png"
heatmap_path = f"imgs/heatmap_{choice}.png"
return image_path, choice, heatmap_path
def update_marker(image_idx: int, evt: gr.SelectData) -> tuple[Image.Image, tuple[int, int]]:
"""
Update the coordinate selector with the marker
Returns the marked image and the coordinates for the next function
"""
x, y = evt.index[0], evt.index[1]
heatmap_path = f"imgs/heatmap_{image_idx}.png"
return create_marker_overlay(heatmap_path, x, y), (x, y)
def generate_output_image(image_idx: int, coords: tuple[int, int]) -> Image.Image:
"""
Generate the output image based on the selected coordinates
"""
x, y = coords
x_norm, y_norm = x / 1155, y / 1155
return generate_image(image_idx, x_norm, y_norm)
with gr.Blocks(
css="""
.radio-container {
width: 450px !important;
margin-left: auto !important;
margin-right: auto !important;
}
.coordinate-container {
width: 600px !important;
height: 600px !important;
}
.coordinate-container img {
width: 100% !important;
height: 100% !important;
object-fit: contain !important;
}
.documentation {
margin-top: 2rem !important;
padding: 1rem !important;
background-color: #f8f9fa !important;
border-radius: 8px !important;
}
"""
) as demo:
gr.Markdown(
"""
# Interactive Image Generation
Select a task using the radio buttons, then click on the coordinate selector to generate a new image.
"""
)
with gr.Row():
# Left column
with gr.Column(scale=1):
selected_idx = gr.State(value=0)
coords = gr.State() # Add state for coordinates
with gr.Column(elem_classes="radio-container"):
task_select = gr.Radio(
choices=["Task 1", "Task 2", "Task 3", "Task 4"],
value="Task 1",
label="Select Task",
interactive=True,
)
gr.Markdown("### Reference Pattern")
reference_image = gr.Image(
value="imgs/pattern_0.png",
show_label=False,
interactive=False,
height=300,
width=450,
show_download_button=False,
show_fullscreen_button=False,
)
gr.Markdown("### Generated Output")
output_image = gr.Image(
show_label=False,
height=300,
width=450,
show_download_button=False,
show_fullscreen_button=False,
interactive=False,
)
# Right column
with gr.Column(scale=1):
gr.Markdown("### Coordinate Selector")
gr.Markdown("Click anywhere in the image below to select (x, y) coordinates in the latent space")
with gr.Column(elem_classes="coordinate-container"):
coord_selector = gr.Image(
value="imgs/heatmap_0.png",
show_label=False,
interactive=False,
sources=[],
container=True,
show_download_button=False,
show_fullscreen_button=False,
)
# Documentation section
with gr.Column(elem_classes="documentation"):
gr.Markdown(
"""
## Method Documentation
### How It Works
This interactive demo showcases our novel image generation method that uses coordinate-based control. The process works as follows:
1. **Task Selection**: Choose one of four different pattern generation tasks
2. **Reference Pattern**: View the target pattern for the selected task
3. **Coordinate Selection**: Click anywhere in the heatmap to specify where in the latent space you want to generate from
4. **Generation**: The model generates a new image based on your selected coordinates
### Sample Results
Here are some example outputs from our method:
![LPN Diagram](imgs/lpn_diagram.png)
### Technical Details
Our approach uses a novel coordinate-conditioning mechanism that allows precise control over the generated patterns. The heatmap visualization shows the distribution of pattern characteristics across the latent space.
For more information, please refer to our [paper](https://arxiv.org/pdf/2411.08706) or GitHub [repository](https://github.com/clement-bonnet/lpn).
"""
)
# Event handlers
task_select.change(
fn=lambda x: update_reference_image(TASK_TO_INDEX[x]),
inputs=[task_select],
outputs=[reference_image, selected_idx, coord_selector],
)
# Split the coordinate selection into two events with state passing
coord_selector.select(
fn=update_marker,
inputs=[selected_idx],
outputs=[coord_selector, coords],
trigger_mode="multiple",
).then(
fn=generate_output_image,
inputs=[selected_idx, coords],
outputs=output_image,
)
demo.launch()