lpn / app.py
clement-bonnet's picture
Create app.py
d6614dc verified
raw
history blame
781 Bytes
import gradio as gr
import jax
import numpy as np
# Import your JAX model here
def predict(x: float, y: float):
"""
Replace this with your actual JAX model inference
"""
# This is just a placeholder - replace with your model
return {
"prediction": float(np.sin(x * 3.14) * np.cos(y * 3.14)),
"x": x,
"y": y
}
# Create Gradio interface with API endpoint
demo = gr.Interface(
fn=predict,
inputs=[
gr.Number(label="X coordinate", minimum=0, maximum=1),
gr.Number(label="Y coordinate", minimum=0, maximum=1)
],
outputs=gr.JSON(),
title="Research Visualization ML Model",
description="Click on the square to generate predictions",
allow_flagging="never"
)
# Enable API endpoint
demo.launch()