Spaces:
Runtime error
Runtime error
File size: 2,374 Bytes
e3d46c8 8158997 70318b5 0788ae6 56b99e8 e3d46c8 56b99e8 0788ae6 a85495d 5320ec6 ae0b0ef 5320ec6 6cf9356 2823250 093cd61 68d6aa3 b28d4fd bceabb4 035678c bceabb4 68d6aa3 0e4d9b1 bceabb4 68d6aa3 c85cea8 0788ae6 ed6dd13 0e4d9b1 5772d0d a5f0a48 db9840e a5f0a48 bceabb4 23f016c 3044367 5e56153 45322b9 5e56153 f5a8947 5e56153 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 |
import streamlit as st
from transformers import pipeline
import torch
import matplotlib.pyplot as plt
import numpy as np
from transformers import BertForSequenceClassification, BertTokenizer
model = BertForSequenceClassification.from_pretrained("RuudVelo/dutch_news_clf_bert_finetuned")
tokenizer = BertTokenizer.from_pretrained("RuudVelo/dutch_news_clf_bert_finetuned")
# Title
st.title("Dutch news article classification")
st.write("This app classifies a Dutch news article into one of 9 pre-defined* article categories")
st.image('dataset-cover_articles.jpeg', width=150)
text = st.text_area('Please type/copy/paste text of the Dutch article and click Submit')
if st.button('Submit'):
with st.spinner('Generating a response...'):
encoding = tokenizer(text, return_tensors="pt")
outputs = model(**encoding)
predictions = outputs.logits.argmax(-1)
number = predictions[0].cpu().detach().numpy()
probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1)
fig = plt.figure(figsize=(10,4))
ax = fig.add_axes([0,0,1,1])
labels_plot = ['Binnenland', 'Buitenland' ,'Cultuur & Media' ,'Economie' ,'Koningshuis',
'Opmerkelijk' ,'Politiek', 'Regionaal nieuws', 'Tech']
probs_plot = probabilities[0].cpu().detach().numpy()*100
ax.barh(labels_plot,probs_plot)
ax.set_title("Predicted article category probability", fontsize=20)
ax.set_xlabel("Probability (%)", fontsize=16)
ax.set_ylabel("Predicted category", fontsize=16)
# change the fontsize
#ax.set_xticklabels(fontsize=14)
ax.set_yticklabels(labels_plot, fontsize=14)
st.pyplot(fig)
st.write('The predicted category is: **{}** with a probability of: **{:.1f}%**'.format(labels_plot[number],(probs_plot[predictions])*1))
st.write("The pre-defined categories are Binnenland, Buitenland, Cultuur & Media, Economie , Koningshuis, Opmerkelijk, Politiek, Regionaal nieuws en Tech")
st.write("The model for this app has been trained using data from Dutch news articles published by NOS. More information regarding the dataset can be found at https://www.kaggle.com/maxscheijen/dutch-news-articles")
#st.write('\n')
st.write('Model performance details can be found at https://huggingface.co/RuudVelo/dutch_news_clf_bert_finetuned')
|