Spaces:
Sleeping
Sleeping
from transformers import pipeline | |
import gradio as gr | |
import spaces | |
import plotly.express as px | |
model_id = "BramVanroy/sonar-coarse-classification" | |
pipeline = pipeline("text-classification", model=model_id) | |
def pipe_predict(text): | |
return pipeline(text, top_k=None) | |
def predict(text): | |
"""Perform text classification and generate results.""" | |
text = text.strip() | |
if not text: | |
raise gr.Error("Text field cannot be empty!") | |
outputs = pipe_predict(text) | |
# Sort outputs by label name | |
outputs = sorted(outputs, key=lambda item: item["label"]) | |
df = { | |
"label": [item["label"] for item in outputs], | |
"score": [item["score"] for item in outputs], | |
} | |
fig = px.bar(df, x="score", y="label", orientation='h') | |
fig.update_layout(bargap=0.9) | |
return fig | |
def create_app(): | |
"""Create the Gradio app.""" | |
with gr.Blocks() as demo: | |
gr.Markdown("# SONAR Coarse Classification") | |
gr.Markdown("This is a test app for the SONAR Coarse Classification model. Not recommended for real-world use.") | |
with gr.Row(): | |
input_text = gr.Textbox( | |
label="Input Text", | |
lines=5, | |
placeholder="Paste your text here...", | |
) | |
prediction_output = gr.Plot(label="Prediction Probabilities") | |
submit_button = gr.Button("Classify Text") | |
submit_button.click( | |
predict, | |
inputs=[input_text], | |
outputs=[prediction_output], | |
) | |
return demo | |
if __name__ == "__main__": | |
app = create_app() | |
app.launch() | |