File size: 3,372 Bytes
d6614dc
29b5baf
d6614dc
9ef0f25
29b5baf
999b913
433f4b7
 
 
808cfce
 
 
 
 
 
 
 
 
 
0dcdf8e
9ef0f25
 
 
 
f6ee8cd
9ef0f25
999b913
 
433f4b7
 
f6ee8cd
433f4b7
 
 
 
f6ee8cd
 
 
 
 
 
 
 
 
433f4b7
 
9ef0f25
 
 
808cfce
9ef0f25
29b5baf
d6614dc
9ef0f25
f6ee8cd
9ef0f25
0dcdf8e
 
9ef0f25
433f4b7
 
 
 
 
 
 
 
808cfce
f6ee8cd
808cfce
 
 
433f4b7
808cfce
 
 
9ef0f25
f6ee8cd
 
9ef0f25
f6ee8cd
 
 
 
 
 
 
 
 
 
 
9ef0f25
808cfce
 
433f4b7
808cfce
 
 
0dcdf8e
 
 
 
 
999b913
9ef0f25
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import gradio as gr
from PIL import Image

from inference import generate_image


TASK_TO_INDEX = {"Task 1": 0, "Task 2": 1, "Task 3": 2, "Task 4": 3}


def update_reference_image(choice: int) -> tuple[str, int, str]:
    """
    Update the reference image display based on radio button selection
    Returns the image path, selected index, and corresponding heatmap
    """
    image_path = f"imgs/pattern_{choice}.png"
    heatmap_path = f"imgs/heatmap_{choice}.png"
    return image_path, choice, heatmap_path


def process_coord_click(image_idx: int, evt: gr.SelectData) -> Image.Image:
    """
    Process the click event on the coordinate selector
    """
    x, y = evt.index[0], evt.index[1]
    x, y = x / 1155, y / 1155  # Normalize the coordinates
    return generate_image(image_idx, x, y)


with gr.Blocks(
    css="""
    .radio-container {
        width: 450px !important;
        margin-left: auto !important;
        margin-right: auto !important;
    }
    .coordinate-container {
        width: 600px !important;
        height: 600px !important;
    }
    .coordinate-container img {
        width: 100% !important;
        height: 100% !important;
        object-fit: contain !important;
    }
"""
) as demo:
    gr.Markdown(
        """
    # Interactive Image Generation
    Select a task using the radio buttons, then click on the coordinate selector to generate a new image.
    """
    )

    with gr.Row():
        # Left column: Radio selection, reference image, and output
        with gr.Column(scale=1):
            # State variable to track selected image index
            selected_idx = gr.State(value=0)

            # Radio buttons with container class
            with gr.Column(elem_classes="radio-container"):
                task_select = gr.Radio(
                    choices=["Task 1", "Task 2", "Task 3", "Task 4"],
                    value="Task 1",
                    label="Select Task",
                    interactive=True,
                )

            # Reference image component that updates based on selection
            reference_image = gr.Image(
                value="imgs/pattern_0.png",
                show_label=False,
                interactive=False,
                height=300,
                width=450,
            )

            # Generated image output moved below reference image
            output_image = gr.Image(label="Generated Output", height=300, width=450)

        # Right column: Larger coordinate selector
        with gr.Column(scale=1):
            # Coordinate selector with container class for proper scaling
            with gr.Column(elem_classes="coordinate-container"):
                coord_selector = gr.Image(
                    value="imgs/heatmap_0.png",
                    label="Click to select (x, y) coordinates in the latent space",
                    show_label=True,
                    interactive=True,
                    container=True,
                )

    # Handle radio button selection
    task_select.change(
        fn=lambda x: update_reference_image(TASK_TO_INDEX[x]),
        inputs=[task_select],
        outputs=[reference_image, selected_idx, coord_selector],
    )

    # Handle coordinate selection
    coord_selector.select(
        process_coord_click, inputs=[selected_idx], outputs=output_image, trigger_mode="multiple"
    )

demo.launch()