clement-bonnet commited on
Commit
433f4b7
·
1 Parent(s): 808cfce

feat: 4 tasks

Browse files
Files changed (2) hide show
  1. app.py +23 -11
  2. inference.py +1 -1
app.py CHANGED
@@ -4,6 +4,9 @@ from PIL import Image
4
  from inference import generate_image
5
 
6
 
 
 
 
7
  def update_reference_image(choice: int) -> tuple[str, int, str]:
8
  """
9
  Update the reference image display based on radio button selection
@@ -24,7 +27,15 @@ def process_coord_click(image_idx: int, evt: gr.SelectData) -> Image.Image:
24
  return generate_image(image_idx, x, y)
25
 
26
 
27
- with gr.Blocks() as demo:
 
 
 
 
 
 
 
 
28
  gr.Markdown(
29
  """
30
  # Interactive Image Generation
@@ -38,19 +49,20 @@ with gr.Blocks() as demo:
38
  # State variable to track selected image index
39
  selected_idx = gr.State(value=0)
40
 
41
- # Radio buttons for task selection
42
- task_select = gr.Radio(
43
- choices=["Task 1", "Task 2"],
44
- value="Task 1",
45
- label="Select Task",
46
- interactive=True,
47
- )
 
48
 
49
  # Single reference image component that updates based on selection
50
  reference_image = gr.Image(
51
  value="imgs/pattern_0.png",
52
  show_label=False,
53
- interactive=False, # No need for click interaction now
54
  height=300,
55
  width=450,
56
  )
@@ -60,7 +72,7 @@ with gr.Blocks() as demo:
60
  # Coordinate selector with dynamic background
61
  coord_selector = gr.Image(
62
  value="imgs/heatmap_0.png",
63
- label="Click to select (x, y) coordinates in the latent space",
64
  show_label=True,
65
  interactive=True,
66
  height=400,
@@ -72,7 +84,7 @@ with gr.Blocks() as demo:
72
 
73
  # Handle radio button selection
74
  task_select.change(
75
- fn=lambda x: update_reference_image(0 if x == "Task 1" else 1),
76
  inputs=[task_select],
77
  outputs=[reference_image, selected_idx, coord_selector],
78
  )
 
4
  from inference import generate_image
5
 
6
 
7
+ TASK_TO_INDEX = {"Task 1": 0, "Task 2": 1, "Task 3": 2, "Task 4": 3}
8
+
9
+
10
  def update_reference_image(choice: int) -> tuple[str, int, str]:
11
  """
12
  Update the reference image display based on radio button selection
 
27
  return generate_image(image_idx, x, y)
28
 
29
 
30
+ with gr.Blocks(
31
+ css="""
32
+ .radio-container {
33
+ width: 450px !important;
34
+ margin-left: auto !important;
35
+ margin-right: auto !important;
36
+ }
37
+ """
38
+ ) as demo:
39
  gr.Markdown(
40
  """
41
  # Interactive Image Generation
 
49
  # State variable to track selected image index
50
  selected_idx = gr.State(value=0)
51
 
52
+ # Radio buttons with container class
53
+ with gr.Column(elem_classes="radio-container"):
54
+ task_select = gr.Radio(
55
+ choices=["Task 1", "Task 2", "Task 3", "Task 4"],
56
+ value="Task 1",
57
+ label="Select Task",
58
+ interactive=True,
59
+ )
60
 
61
  # Single reference image component that updates based on selection
62
  reference_image = gr.Image(
63
  value="imgs/pattern_0.png",
64
  show_label=False,
65
+ interactive=False,
66
  height=300,
67
  width=450,
68
  )
 
72
  # Coordinate selector with dynamic background
73
  coord_selector = gr.Image(
74
  value="imgs/heatmap_0.png",
75
+ label="Click to select (x, y) coordinates",
76
  show_label=True,
77
  interactive=True,
78
  height=400,
 
84
 
85
  # Handle radio button selection
86
  task_select.change(
87
+ fn=lambda x: update_reference_image(TASK_TO_INDEX[x]),
88
  inputs=[task_select],
89
  outputs=[reference_image, selected_idx, coord_selector],
90
  )
inference.py CHANGED
@@ -23,7 +23,7 @@ from utils import patch_target, ax_to_pil
23
 
24
 
25
  checkpoint_name = "quiet-thunder-789--checkpoint:v0"
26
- BLUE_LOCATION_INPUTS = {0: 13, 1: 9}
27
 
28
  local_dir = snapshot_download(repo_id="clement-bonnet/lpn-2d", allow_patterns=f"{checkpoint_name}/*")
29
  with open(f"{local_dir}/{checkpoint_name}/config.yaml", "r") as f:
 
23
 
24
 
25
  checkpoint_name = "quiet-thunder-789--checkpoint:v0"
26
+ BLUE_LOCATION_INPUTS = {0: 13, 1: 9, 2: 9, 3: 6}
27
 
28
  local_dir = snapshot_download(repo_id="clement-bonnet/lpn-2d", allow_patterns=f"{checkpoint_name}/*")
29
  with open(f"{local_dir}/{checkpoint_name}/config.yaml", "r") as f: