clement-bonnet commited on
Commit
0dcdf8e
·
1 Parent(s): 1f14f97

feat: left images in full

Browse files
Files changed (1) hide show
  1. app.py +38 -28
app.py CHANGED
@@ -1,10 +1,12 @@
 
 
1
  import gradio as gr
2
  from PIL import Image
3
 
4
  from inference import generate_image
5
 
6
 
7
- def process_click(image_idx: int, evt: gr.SelectData) -> Image.Image:
8
  """
9
  Process the click event on the coordinate selector
10
  """
@@ -14,43 +16,46 @@ def process_click(image_idx: int, evt: gr.SelectData) -> Image.Image:
14
  return generate_image(image_idx, x, y)
15
 
16
 
17
- def update_background(image_idx: int) -> Image.Image:
18
  """
19
- Update the coordinate selector background based on selected image
 
20
  """
21
- return f"imgs/heatmap_{image_idx}.png"
22
 
23
 
24
  with gr.Blocks() as demo:
25
  gr.Markdown(
26
  """
27
  # Interactive Image Generation
28
- Choose a reference image and click on the coordinate selector to generate a new image.
29
  """
30
  )
31
 
32
  with gr.Row():
33
- # Left column: Reference images only
34
  with gr.Column(scale=1):
35
- # Radio buttons for image selection
36
- image_idx = gr.Radio(
37
- choices=list(range(2)),
38
- value=0,
39
- label="Select Reference Image",
40
- type="index",
41
- )
42
 
43
- # Display reference images
44
- gallery = gr.Gallery(
45
- value=[
46
- "imgs/pattern_0.png",
47
- "imgs/pattern_1.png",
48
- ],
49
- columns=1,
50
- rows=2,
51
- min_width=600,
52
- label="Reference Images",
53
- )
 
 
 
 
 
 
 
54
 
55
  # Right column: Coordinate selector and output image
56
  with gr.Column(scale=1):
@@ -65,10 +70,15 @@ with gr.Blocks() as demo:
65
  )
66
 
67
  # Generated image output
68
- output_image = gr.Image(label="Generated Image", height=400, width=400)
69
 
70
- # Handle click events and background updates
71
- coord_selector.select(process_click, inputs=[image_idx], outputs=output_image, trigger_mode="multiple")
72
- image_idx.change(update_background, inputs=[image_idx], outputs=[coord_selector])
 
 
 
 
 
73
 
74
  demo.launch()
 
1
+ from functools import partial
2
+
3
  import gradio as gr
4
  from PIL import Image
5
 
6
  from inference import generate_image
7
 
8
 
9
+ def process_coord_click(image_idx: int, evt: gr.SelectData) -> Image.Image:
10
  """
11
  Process the click event on the coordinate selector
12
  """
 
16
  return generate_image(image_idx, x, y)
17
 
18
 
19
+ def process_image_select(evt: gr.SelectData, idx: int) -> tuple[int, str]:
20
  """
21
+ Process the reference image selection
22
+ Returns the selected image index and corresponding heatmap
23
  """
24
+ return idx, f"imgs/heatmap_{idx}.png"
25
 
26
 
27
  with gr.Blocks() as demo:
28
  gr.Markdown(
29
  """
30
  # Interactive Image Generation
31
+ Click on a reference image to select it, then click on the coordinate selector to generate a new image.
32
  """
33
  )
34
 
35
  with gr.Row():
36
+ # Left column: Interactive reference images
37
  with gr.Column(scale=1):
38
+ # State variable to track selected image index
39
+ selected_idx = gr.State(value=0)
 
 
 
 
 
40
 
41
+ # Two separate Image components for reference images
42
+ with gr.Column():
43
+ image_0 = gr.Image(
44
+ value="imgs/pattern_0.png",
45
+ label="Task 1",
46
+ show_label=False,
47
+ interactive=True,
48
+ height=300,
49
+ width=450,
50
+ )
51
+ image_1 = gr.Image(
52
+ value="imgs/pattern_1.png",
53
+ label="Task 2",
54
+ show_label=False,
55
+ interactive=True,
56
+ height=300,
57
+ width=450,
58
+ )
59
 
60
  # Right column: Coordinate selector and output image
61
  with gr.Column(scale=1):
 
70
  )
71
 
72
  # Generated image output
73
+ output_image = gr.Image(label="Generated Output", height=400, width=400)
74
 
75
+ # Handle image selection for each reference image
76
+ image_0.select(partial(process_image_select, idx=0), outputs=[selected_idx, coord_selector])
77
+ image_1.select(partial(process_image_select, idx=1), outputs=[selected_idx, coord_selector])
78
+
79
+ # Handle coordinate selection
80
+ coord_selector.select(
81
+ process_coord_click, inputs=[selected_idx], outputs=output_image, trigger_mode="multiple"
82
+ )
83
 
84
  demo.launch()