Spaces:
Running
Running
File size: 781 Bytes
d6614dc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 |
import gradio as gr
import jax
import numpy as np
# Import your JAX model here
def predict(x: float, y: float):
"""
Replace this with your actual JAX model inference
"""
# This is just a placeholder - replace with your model
return {
"prediction": float(np.sin(x * 3.14) * np.cos(y * 3.14)),
"x": x,
"y": y
}
# Create Gradio interface with API endpoint
demo = gr.Interface(
fn=predict,
inputs=[
gr.Number(label="X coordinate", minimum=0, maximum=1),
gr.Number(label="Y coordinate", minimum=0, maximum=1)
],
outputs=gr.JSON(),
title="Research Visualization ML Model",
description="Click on the square to generate predictions",
allow_flagging="never"
)
# Enable API endpoint
demo.launch() |