Spaces:
Runtime error
Runtime error
import streamlit as st | |
from transformers import pipeline | |
import torch | |
import matplotlib.pyplot as plt | |
#pipe = pipeline(model="RuudVelo/dutch_news_classifier_bert_finetuned") | |
#text = st.text_area('Please type/copy/paste the Dutch article') | |
#labels = ['Binnenland' 'Buitenland' 'Cultuur & Media' 'Economie' 'Koningshuis' | |
# 'Opmerkelijk' 'Politiek' 'Regionaal nieuws' 'Tech'] | |
#if text: | |
# out = pipe(text) | |
# st.json(out) | |
# load tokenizer and model, create trainer | |
#model_name = "RuudVelo/dutch_news_classifier_bert_finetuned" | |
#tokenizer = AutoTokenizer.from_pretrained(model_name) | |
#model = AutoModelForSequenceClassification.from_pretrained(model_name) | |
#trainer = Trainer(model=model) | |
#print(filename, type(filename)) | |
#print(filename.name) | |
from transformers import BertForSequenceClassification, BertTokenizer | |
model = BertForSequenceClassification.from_pretrained("RuudVelo/dutch_news_clf_bert_finetuned") | |
#from transformers import BertTokenizer | |
tokenizer = BertTokenizer.from_pretrained("RuudVelo/dutch_news_clf_bert_finetuned") | |
# Title | |
st.title("Dutch news article classification") | |
text = st.text_area('Please type/copy/paste text of the Dutch article') | |
#if text: | |
# encoding = tokenizer(text, return_tensors="pt") | |
# outputs = model(**encoding) | |
# predictions = outputs.logits.argmax(-1) | |
# probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1) | |
## fig = plt.figure() | |
# ax = fig.add_axes([0,0,1,1]) | |
# labels_plot = ['Binnenland', 'Buitenland' ,'Cultuur & Media' ,'Economie' ,'Koningshuis', | |
# 'Opmerkelijk' ,'Politiek', 'Regionaal nieuws', 'Tech'] | |
# probs_plot = probabilities[0].cpu().detach().numpy() | |
# ax.barh(labels_plot,probs_plot ) | |
# st.pyplot(fig) | |
#input = st.text_input('Context') | |
if st.button('Submit'): | |
with st.spinner('Generating a response...'): | |
encoding = tokenizer(text, return_tensors="pt") | |
outputs = model(**encoding) | |
predictions = outputs.logits.argmax(-1) | |
probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1) | |
fig = plt.figure() | |
ax = fig.add_axes([0,0,1,1]) | |
labels_plot = ['Binnenland', 'Buitenland' ,'Cultuur & Media' ,'Economie' ,'Koningshuis', | |
'Opmerkelijk' ,'Politiek', 'Regionaal nieuws', 'Tech'] | |
probs_plot = probabilities[0].cpu().detach().numpy() | |
ax.barh(labels_plot,probs_plot) | |
ax.set_title("Predicted article category probability") | |
ax.set_xlabel("Probability") | |
ax.set_ylabel("Predicted category") | |
st.pyplot(fig) | |
# output = genQuestion(option, input) | |
# print(output) | |
# st.write(output) | |
#encoding = tokenizer(text, return_tensors="pt") | |
#import numpy as np | |
st.write("The model for this app has been trained using data from Dutch news articles published by NOS. For more information regarding the dataset can be found at https://www.kaggle.com/maxscheijen/dutch-news-articles") | |
st.write('\n') | |
st.write('The model performance details can be found at https://huggingface.co/RuudVelo/dutch_news_classifier_bert_finetuned') | |