clement-bonnet commited on
Commit
d6614dc
·
verified ·
1 Parent(s): 6b54fb2

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +31 -0
app.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import jax
3
+ import numpy as np
4
+ # Import your JAX model here
5
+
6
+ def predict(x: float, y: float):
7
+ """
8
+ Replace this with your actual JAX model inference
9
+ """
10
+ # This is just a placeholder - replace with your model
11
+ return {
12
+ "prediction": float(np.sin(x * 3.14) * np.cos(y * 3.14)),
13
+ "x": x,
14
+ "y": y
15
+ }
16
+
17
+ # Create Gradio interface with API endpoint
18
+ demo = gr.Interface(
19
+ fn=predict,
20
+ inputs=[
21
+ gr.Number(label="X coordinate", minimum=0, maximum=1),
22
+ gr.Number(label="Y coordinate", minimum=0, maximum=1)
23
+ ],
24
+ outputs=gr.JSON(),
25
+ title="Research Visualization ML Model",
26
+ description="Click on the square to generate predictions",
27
+ allow_flagging="never"
28
+ )
29
+
30
+ # Enable API endpoint
31
+ demo.launch()