File size: 2,746 Bytes
0dcdf8e
 
d6614dc
29b5baf
d6614dc
9ef0f25
29b5baf
999b913
0dcdf8e
9ef0f25
 
 
 
1f14f97
9ef0f25
 
999b913
 
0dcdf8e
1f14f97
0dcdf8e
 
1f14f97
0dcdf8e
1f14f97
 
9ef0f25
 
 
 
0dcdf8e
9ef0f25
29b5baf
d6614dc
9ef0f25
0dcdf8e
9ef0f25
0dcdf8e
 
9ef0f25
0dcdf8e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9ef0f25
1f14f97
 
 
9ef0f25
1f14f97
9ef0f25
 
 
 
 
 
 
1f14f97
0dcdf8e
9ef0f25
0dcdf8e
 
 
 
 
 
 
 
999b913
9ef0f25
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
from functools import partial

import gradio as gr
from PIL import Image

from inference import generate_image


def process_coord_click(image_idx: int, evt: gr.SelectData) -> Image.Image:
    """
    Process the click event on the coordinate selector
    """
    x, y = evt.index[0], evt.index[1]
    x, y = x / 400, y / 400
    print(f"Clicked at coordinates: ({x:.3f}, {y:.3f})")
    return generate_image(image_idx, x, y)


def process_image_select(evt: gr.SelectData, idx: int) -> tuple[int, str]:
    """
    Process the reference image selection
    Returns the selected image index and corresponding heatmap
    """
    return idx, f"imgs/heatmap_{idx}.png"


with gr.Blocks() as demo:
    gr.Markdown(
        """
    # Interactive Image Generation
    Click on a reference image to select it, then click on the coordinate selector to generate a new image.
    """
    )

    with gr.Row():
        # Left column: Interactive reference images
        with gr.Column(scale=1):
            # State variable to track selected image index
            selected_idx = gr.State(value=0)

            # Two separate Image components for reference images
            with gr.Column():
                image_0 = gr.Image(
                    value="imgs/pattern_0.png",
                    label="Task 1",
                    show_label=False,
                    interactive=True,
                    height=300,
                    width=450,
                )
                image_1 = gr.Image(
                    value="imgs/pattern_1.png",
                    label="Task 2",
                    show_label=False,
                    interactive=True,
                    height=300,
                    width=450,
                )

        # Right column: Coordinate selector and output image
        with gr.Column(scale=1):
            # Coordinate selector with dynamic background
            coord_selector = gr.Image(
                value="imgs/heatmap_0.png",  # Initial background
                label="Click to select (x, y) coordinates",
                show_label=True,
                interactive=True,
                height=400,
                width=400,
            )

            # Generated image output
            output_image = gr.Image(label="Generated Output", height=400, width=400)

    # Handle image selection for each reference image
    image_0.select(partial(process_image_select, idx=0), outputs=[selected_idx, coord_selector])
    image_1.select(partial(process_image_select, idx=1), outputs=[selected_idx, coord_selector])

    # Handle coordinate selection
    coord_selector.select(
        process_coord_click, inputs=[selected_idx], outputs=output_image, trigger_mode="multiple"
    )

demo.launch()