clement-bonnet commited on
Commit
29b5baf
·
verified ·
1 Parent(s): 583dd45

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +113 -22
app.py CHANGED
@@ -1,31 +1,122 @@
 
1
  import gradio as gr
2
- import jax
3
  import numpy as np
4
- # Import your JAX model here
5
 
6
- def predict(x: float, y: float):
 
7
  """
8
- Replace this with your actual JAX model inference
9
  """
10
  # This is just a placeholder - replace with your model
11
- return {
12
- "prediction": float(np.sin(x * 3.14) * np.cos(y * 3.14)),
13
- "x": x,
14
- "y": y
15
- }
 
 
16
 
17
- # Create Gradio interface with API endpoint
18
- demo = gr.Interface(
19
- fn=predict,
20
- inputs=[
21
- gr.Number(label="X coordinate", minimum=0, maximum=1),
22
- gr.Number(label="Y coordinate", minimum=0, maximum=1)
23
- ],
24
- outputs=gr.JSON(),
25
- title="Research Visualization ML Model",
26
- description="Click on the square to generate predictions",
27
- allow_flagging="never"
28
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
 
30
- # Enable API endpoint
31
  demo.launch()
 
1
+ import os
2
  import gradio as gr
 
3
  import numpy as np
4
+ from PIL import Image
5
 
6
+ # Placeholder for your actual model
7
+ def generate_image(image_idx: int, x: float, y: float) -> Image.Image:
8
  """
9
+ Replace this with your actual model inference
10
  """
11
  # This is just a placeholder - replace with your model
12
+ # Creating a simple gradient image as example output
13
+ width, height = 256, 256
14
+ gradient = np.zeros((height, width, 3), dtype=np.uint8)
15
+ gradient[:, :, 0] = np.linspace(0, 255 * x, width)
16
+ gradient[:, :, 1] = np.linspace(0, 255 * y, height)[:, np.newaxis]
17
+ gradient[:, :, 2] = image_idx * 30 # vary blue channel based on selected image
18
+ return Image.fromarray(gradient)
19
 
20
+ def process_click(image_idx: int, evt: gr.SelectData) -> Image.Image:
21
+ """
22
+ Process the click event on the coordinate selector
23
+ """
24
+ # Extract coordinates from click event
25
+ x, y = evt.index[0], evt.index[1]
26
+ # Normalize coordinates to [0, 1]
27
+ x, y = x/100, y/100
28
+ # Generate image using the model
29
+ return generate_image(image_idx, x, y)
30
+
31
+ with gr.Blocks() as demo:
32
+ gr.Markdown("""
33
+ # Interactive Image Generation
34
+ Choose a reference image and click on the coordinate selector to generate a new image.
35
+ """)
36
+
37
+ with gr.Row():
38
+ # Left column: Reference images and coordinate selector
39
+ with gr.Column(scale=1):
40
+ # Radio buttons for image selection
41
+ image_idx = gr.Radio(
42
+ choices=[i for i in range(4)], # Replace with your actual number of images
43
+ value=0,
44
+ label="Select Reference Image",
45
+ type="index"
46
+ )
47
+
48
+ # Display reference images
49
+ gallery = gr.Gallery(
50
+ value=[
51
+ "image_0.jpg",
52
+ "image_0.jpg",
53
+ "image_0.jpg",
54
+ "image_0.jpg",
55
+ ],
56
+ columns=2,
57
+ rows=2,
58
+ height=300,
59
+ label="Reference Images"
60
+ )
61
+
62
+ # Coordinate selector (displayed as heatmap for click interaction)
63
+ coord_selector = gr.Plot(
64
+ value=None,
65
+ label="Click to select (x, y) coordinates"
66
+ )
67
+
68
+ # Initialize the coordinate selector
69
+ def create_selector():
70
+ import plotly.graph_objects as go
71
+ fig = go.Figure()
72
+
73
+ # Add a square shape
74
+ fig.add_trace(go.Scatter(
75
+ x=[0, 100, 100, 0, 0],
76
+ y=[0, 0, 100, 100, 0],
77
+ mode='lines',
78
+ line=dict(color='black'),
79
+ showlegend=False
80
+ ))
81
+
82
+ # Update layout
83
+ fig.update_layout(
84
+ width=300,
85
+ height=300,
86
+ margin=dict(l=0, r=0, t=0, b=0),
87
+ xaxis=dict(
88
+ range=[-5, 105],
89
+ showgrid=False,
90
+ zeroline=False,
91
+ visible=False
92
+ ),
93
+ yaxis=dict(
94
+ range=[-5, 105],
95
+ showgrid=False,
96
+ zeroline=False,
97
+ visible=False,
98
+ scaleanchor='x'
99
+ ),
100
+ plot_bgcolor='white'
101
+ )
102
+ return fig
103
+
104
+ # Initialize the coordinate selector
105
+ coord_selector.value = create_selector()
106
+
107
+ # Right column: Generated image
108
+ with gr.Column(scale=1):
109
+ output_image = gr.Image(
110
+ label="Generated Output",
111
+ height=300
112
+ )
113
+
114
+ # Handle click events
115
+ coord_selector.select(
116
+ process_click,
117
+ inputs=[image_idx],
118
+ outputs=output_image
119
+ )
120
 
121
+ # Launch the app
122
  demo.launch()