BramVanroy's picture
Update app.py
0d8aa23 verified
from transformers import pipeline
import gradio as gr
import spaces
import plotly.express as px
model_id = "BramVanroy/sonar-coarse-classification"
pipeline = pipeline("text-classification", model=model_id)
@spaces.GPU
def pipe_predict(text):
return pipeline(text, top_k=None)
def predict(text):
"""Perform text classification and generate results."""
text = text.strip()
if not text:
raise gr.Error("Text field cannot be empty!")
outputs = pipe_predict(text)
# Sort outputs by label name
outputs = sorted(outputs, key=lambda item: item["label"])
df = {
"label": [item["label"] for item in outputs],
"score": [item["score"] for item in outputs],
}
fig = px.bar(df, x="score", y="label", orientation='h')
fig.update_layout(bargap=0.9)
return fig
def create_app():
"""Create the Gradio app."""
with gr.Blocks() as demo:
gr.Markdown("# SONAR Coarse Classification")
gr.Markdown("This is a test app for the SONAR Coarse Classification model. Not recommended for real-world use.")
with gr.Row():
input_text = gr.Textbox(
label="Input Text",
lines=5,
placeholder="Paste your text here...",
)
prediction_output = gr.Plot(label="Prediction Probabilities")
submit_button = gr.Button("Classify Text")
submit_button.click(
predict,
inputs=[input_text],
outputs=[prediction_output],
)
return demo
if __name__ == "__main__":
app = create_app()
app.launch()