RuudVelo's picture
Update app.py
1af96b9
raw
history blame
3.07 kB
import streamlit as st
from transformers import pipeline
import torch
import matplotlib.pyplot as plt
#pipe = pipeline(model="RuudVelo/dutch_news_classifier_bert_finetuned")
#text = st.text_area('Please type/copy/paste the Dutch article')
#labels = ['Binnenland' 'Buitenland' 'Cultuur & Media' 'Economie' 'Koningshuis'
# 'Opmerkelijk' 'Politiek' 'Regionaal nieuws' 'Tech']
#if text:
# out = pipe(text)
# st.json(out)
# load tokenizer and model, create trainer
#model_name = "RuudVelo/dutch_news_classifier_bert_finetuned"
#tokenizer = AutoTokenizer.from_pretrained(model_name)
#model = AutoModelForSequenceClassification.from_pretrained(model_name)
#trainer = Trainer(model=model)
#print(filename, type(filename))
#print(filename.name)
from transformers import BertForSequenceClassification, BertTokenizer
model = BertForSequenceClassification.from_pretrained("RuudVelo/dutch_news_clf_bert_finetuned")
#from transformers import BertTokenizer
tokenizer = BertTokenizer.from_pretrained("RuudVelo/dutch_news_clf_bert_finetuned")
# Title
st.title("Dutch news article classification")
text = st.text_area('Please type/copy/paste text of the Dutch article')
#if text:
# encoding = tokenizer(text, return_tensors="pt")
# outputs = model(**encoding)
# predictions = outputs.logits.argmax(-1)
# probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1)
## fig = plt.figure()
# ax = fig.add_axes([0,0,1,1])
# labels_plot = ['Binnenland', 'Buitenland' ,'Cultuur & Media' ,'Economie' ,'Koningshuis',
# 'Opmerkelijk' ,'Politiek', 'Regionaal nieuws', 'Tech']
# probs_plot = probabilities[0].cpu().detach().numpy()
# ax.barh(labels_plot,probs_plot )
# st.pyplot(fig)
#input = st.text_input('Context')
if st.button('Submit'):
with st.spinner('Generating a response...'):
encoding = tokenizer(text, return_tensors="pt")
outputs = model(**encoding)
predictions = outputs.logits.argmax(-1)
probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1)
fig = plt.figure()
ax = fig.add_axes([0,0,1,1])
labels_plot = ['Binnenland', 'Buitenland' ,'Cultuur & Media' ,'Economie' ,'Koningshuis',
'Opmerkelijk' ,'Politiek', 'Regionaal nieuws', 'Tech']
probs_plot = probabilities[0].cpu().detach().numpy()
ax.barh(labels_plot,probs_plot)
ax.set_title("Predicted article category probability")
ax.set_xlabel("Probability")
ax.set_ylabel("Predicted category")
st.pyplot(fig)
# output = genQuestion(option, input)
# print(output)
# st.write(output)
#encoding = tokenizer(text, return_tensors="pt")
#import numpy as np
st.write("The model for this app has been trained using data from Dutch news articles published by NOS. For more information regarding the dataset can be found at https://www.kaggle.com/maxscheijen/dutch-news-articles")
st.write('\n')
st.write('The model performance details can be found at https://huggingface.co/RuudVelo/dutch_news_classifier_bert_finetuned')