Roozeec's picture
update
6323541
raw
history blame
No virus
2.22 kB
import streamlit as st
import wna_googlenews as wna
import pandas as pd
from transformers import pipeline
st.set_page_config(layout="wide")
st.title("WNA Google News App")
st.subheader("Search for News and classify the headlines with sentiment analysis")
query = st.text_input("Enter Query")
models = [
"SamLowe/roberta-base-go_emotions",
# "distilbert/distilbert-base-uncased-finetuned-sst-2-english"
]
settings = {
"lang": "fr",
"region": "FR",
"period": "1d",
"model": models[0],
}
with st.sidebar:
st.title("Settings")
# add language and country parameters
st.header("Language and Country")
settings["lang"] = st.selectbox("Select Language", ["en", "fr"])
settings["region"] = st.selectbox("Select Country", ["US", "FR"])
# add period parameter
st.header("Period")
settings["period"] = st.selectbox("Select Period", ["1d", "7", "30d"])
# Add models parameters
st.header("Models")
settings["model"] = st.selectbox("Select Model", models)
if st.button("Search"):
classifier = pipeline(task="text-classification", model=settings["model"], top_k=None)
df = wna.get_news(settings, query)
# st.dataframe(df)
# get each title colums
sentences = df["title"]
# convert into array
sentences = sentences.tolist()
# st.write(sentences)
# create new dataframe
df = pd.DataFrame(columns=["sentence", "best","second"])
# loop on each sentence and call classifier
for sentence in sentences:
cur_sentence = sentence
model_outputs = classifier(sentence)
cur_result = model_outputs[0]
#st.write(cur_result)
# get label 1
label = cur_result[0]['label']
score = cur_result[0]['score']
percentage = round(score * 100, 2)
str1 = label + " " + str(percentage)
# get label 2
label = cur_result[1]['label']
score = cur_result[1]['score']
percentage = round(score * 100, 2)
str2 = label + " " + str(percentage)
# insert cur_sentence and cur_result into dataframe
df.loc[len(df.index)] = [cur_sentence, str1, str2]
# write info on the output
st.write("Number of sentences:", len(df))
st.write("Language:", settings["lang"], "Country:", settings["region"])
st.dataframe(df)