Spaces:
Runtime error
Runtime error
import streamlit as st | |
from transformers import pipeline | |
import torch | |
import matplotlib.pyplot as plt | |
import numpy as np | |
from transformers import BertForSequenceClassification, BertTokenizer | |
model = BertForSequenceClassification.from_pretrained("RuudVelo/dutch_news_clf_bert_finetuned") | |
tokenizer = BertTokenizer.from_pretrained("RuudVelo/dutch_news_clf_bert_finetuned") | |
# Title | |
st.title("Dutch news article classification") | |
st.write("This app classifies a Dutch news article into one of 9 pre-defined* article categories") | |
st.image('dataset-cover_articles.jpeg', width=150) | |
text = st.text_area('Please type/copy/paste text of the Dutch article and click Submit') | |
if st.button('Submit'): | |
with st.spinner('Generating a response...'): | |
encoding = tokenizer(text, return_tensors="pt") | |
outputs = model(**encoding) | |
predictions = outputs.logits.argmax(-1) | |
number = predictions[0].cpu().detach().numpy() | |
probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1) | |
fig = plt.figure(figsize=(10,4)) | |
ax = fig.add_axes([0,0,1,1]) | |
labels_plot = ['Binnenland', 'Buitenland' ,'Cultuur & Media' ,'Economie' ,'Koningshuis', | |
'Opmerkelijk' ,'Politiek', 'Regionaal nieuws', 'Tech'] | |
probs_plot = probabilities[0].cpu().detach().numpy()*100 | |
ax.barh(labels_plot,probs_plot) | |
ax.set_title("Predicted article category probability", fontsize=20) | |
ax.set_xlabel("Probability (%)", fontsize=16) | |
ax.set_ylabel("Predicted category", fontsize=16) | |
# change the fontsize | |
#ax.set_xticklabels(fontsize=14) | |
ax.set_yticklabels(labels_plot, fontsize=14) | |
st.pyplot(fig) | |
st.write('The predicted category is: **{}** with a probability of: **{:.1f}%**'.format(labels_plot[number],(probs_plot[predictions])*1)) | |
st.write("The pre-defined categories are Binnenland, Buitenland, Cultuur & Media, Economie , Koningshuis, Opmerkelijk, Politiek, Regionaal nieuws en Tech") | |
st.write("The model for this app has been trained using data from Dutch news articles published by NOS. More information regarding the dataset can be found at https://www.kaggle.com/maxscheijen/dutch-news-articles") | |
#st.write('\n') | |
st.write('Model performance details can be found at https://huggingface.co/RuudVelo/dutch_news_clf_bert_finetuned') | |