news-class-1 / app.py
zionia's picture
update interface to be consistent and allow file uploads
1dd5bbf verified
raw
history blame
2.83 kB
import gradio as gr
from transformers import AutoModelForSequenceClassification, AutoTokenizer, pipeline
import pandas as pd
MODEL_URL = "https://huggingface.co/dsfsi/PuoBERTa-News"
WEBSITE_URL = "https://www.kodiks.com/ai_solutions.html"
tokenizer = AutoTokenizer.from_pretrained("dsfsi/PuoBERTa-News")
model = AutoModelForSequenceClassification.from_pretrained("dsfsi/PuoBERTa-News")
categories = {
"arts_culture_entertainment_and_media": "Botsweretshi, setso, boitapoloso le bobegakgang",
"crime_law_and_justice": "Bosenyi, molao le bosiamisi",
"disaster_accident_and_emergency_incident": "Masetlapelo, kotsi le tiragalo ya maemo a tshoganyetso",
"economy_business_and_finance": "Ikonomi, tsa kgwebo le tsa ditšhelete",
"education": "Thuto",
"environment": "Tikologo",
"health": "Boitekanelo",
"politics": "Dipolotiki",
"religion_and_belief": "Bodumedi le tumelo",
"society": "Setšhaba"
}
def prediction(news):
clasifer = pipeline("sentiment-analysis", tokenizer=tokenizer, model=model, return_all_scores=True)
preds = clasifer(news)
preds_dict = {categories.get(pred['label'], pred['label']): pred['score'] for pred in preds[0]}
return preds_dict
def file_prediction(file):
if file.name.endswith('.csv'):
df = pd.read_csv(file.name)
news_list = df.iloc[:, 0].tolist()
else:
news_list = [file.read().decode('utf-8')] # Load plain text
results = []
for news in news_list:
results.append(prediction(news))
return results
gradio_ui = gr.Interface(
fn=prediction,
title="Setswana News Classification",
description=f"Enter Setswana news article to see the category of the news.\n For this classification, the {MODEL_URL} model was used.",
inputs=gr.Textbox(lines=10, label="Paste some Setswana news here"),
outputs=gr.Label(num_top_classes=5, label="News categories probabilities"),
theme="default",
css="""
body {
background-color: white !important;
color: black !important;
}
.gradio-container {
background-color: white !important;
color: black !important;
}
.gr-button {
background-color: #f0f0f0 !important;
color: black !important;
}
"""
)
gradio_file_ui = gr.Interface(
fn=file_prediction,
title="Upload File for Setswana News Classification",
description=f"Upload a text or CSV file with Setswana news articles. The first column in the CSV should contain the news text.",
inputs=gr.File(label="Upload text or CSV file"),
outputs=gr.Dataframe(headers=["News Text", "Category Predictions"], label="Predictions from file"),
theme="default"
)
gradio_combined_ui = gr.TabbedInterface([gradio_ui, gradio_file_ui], ["Text Input", "File Upload"])
gradio_combined_ui.launch()