|
import gradio as gr |
|
from transformers import pipeline |
|
|
|
model_name = "cornelliusyudhawijaya/AG_News_Classification_DistillBert" |
|
classifier = pipeline("text-classification", model=model_name, tokenizer=model_name) |
|
|
|
label_names = {0: 'World', 1: 'Sports', 2: 'Business', 3: 'Sci/Tech'} |
|
|
|
def classify_text(text): |
|
result = classifier(text)[0] |
|
label_id = int(result['label'].split('_')[-1]) |
|
label_name = label_names[label_id] |
|
return f"Label: {label_name}, Score: {result['score']:.4f}" |
|
|
|
|
|
iface = gr.Interface( |
|
fn=classify_text, |
|
inputs=gr.Textbox(lines=2, placeholder="Enter text here..."), |
|
outputs="text", |
|
title="News Classification", |
|
description="Enter text to classify the news." |
|
) |
|
|
|
if __name__ == "__main__": |
|
iface.launch() |