Spaces:
Running
Running
import gradio as gr | |
import torch | |
import numpy as np | |
import pickle | |
import pandas as pd | |
from tqdm import tqdm | |
import altair as alt | |
import matplotlib.pyplot as plt | |
from datetime import date, timedelta | |
from transformers import AutoTokenizer, AutoConfig, AutoModel, AutoModelForSequenceClassification | |
def inference_sentence(text): | |
tokenizer = AutoTokenizer.from_pretrained(inference_modelpath) | |
model = AutoModelForSequenceClassification.from_pretrained(inference_modelpath) | |
for text in tqdm([text]): | |
inputs = tokenizer(text, return_tensors="pt") | |
with torch.no_grad(): # run model | |
logits = model(**inputs).logits | |
predicted_class_id = logits.argmax().item() | |
output = model.config.id2label[predicted_class_id] | |
return "Predicted emotion:\n" + output | |
def freq(file_output): | |
f = open(file_output, 'r') | |
data = f.read().split("\n") | |
f.close() | |
data = [line.split(",") for line in data[1:-1]] | |
freq_dict = {} | |
for line in data: | |
if line[1] not in freq_dict.keys(): | |
freq_dict[line[1]] = 1 | |
else: | |
freq_dict[line[1]] += 1 | |
simple = pd.DataFrame({ | |
'Emotion category': ['neutral', 'anger', 'fear', 'joy', 'love', 'sadness'], | |
'Frequency': [freq_dict['0'], freq_dict['1'], freq_dict['2'], freq_dict['3'], freq_dict['4'], freq_dict['5']]}) | |
domain = ['neutral', 'anger', 'fear', 'joy', 'love', 'sadness'] | |
range_ = ['#999999', '#b22222', '#663399', '#ffcc00', '#db7093', '#6495ed'] | |
n = max(simple['Frequency']) | |
plot = alt.Chart(simple).mark_bar().encode( | |
x=alt.X("Emotion category", sort=['neutral', 'anger', 'fear', 'joy', 'love', 'sadness']), | |
y=alt.Y("Frequency", axis=alt.Axis(grid=False), scale=alt.Scale(domain=[0, (n + 9) // 10 * 10])), | |
color=alt.Color("Emotion category", scale=alt.Scale(domain=domain, range=range_), legend=None), | |
tooltip=['Emotion category', 'Frequency']).properties( | |
width=600).configure_axis( | |
grid=False).interactive() | |
return plot | |
def dist(file_output): | |
f = open(file_output, 'r') | |
data = f.read().split("\n") | |
f.close() | |
data = [line.split(",") for line in data[1:-1]] | |
freq_dict = {} | |
mapping_dict = {'0': 'neutral', '1': 'anger', '2': 'fear', '3': 'joy', '4': 'love', '5': 'sadness'} | |
for line in data: | |
dat = str(date(int(line[0][:4]), int(line[0][4:6]), int(line[0][6:8]))) | |
if dat not in freq_dict.keys(): | |
freq_dict[dat] = {} | |
if mapping_dict[line[1]] not in freq_dict[dat].keys(): | |
freq_dict[dat][mapping_dict[line[1]]] = 1 | |
else: | |
freq_dict[dat][mapping_dict[line[1]]] += 1 | |
else: | |
if mapping_dict[line[1]] not in freq_dict[dat].keys(): | |
freq_dict[dat][mapping_dict[line[1]]] = 1 | |
else: | |
freq_dict[dat][mapping_dict[line[1]]] += 1 | |
start_date = date(int(data[0][0][:4]), int(data[0][0][4:6]), int(data[0][0][6:8])) | |
end_date = date(int(data[-1][0][:4]), int(data[-1][0][4:6]), int(data[-1][0][6:8])) | |
delta = end_date - start_date # returns timedelta | |
date_range = [str(start_date + timedelta(days=i)) for i in range(delta.days + 1)] | |
dates = [dat for dat in date_range for i in range(6)] | |
frequency = [freq_dict[dat][emotion] if (dat in freq_dict.keys() and emotion in freq_dict[dat].keys()) else 0 for dat in date_range for emotion in ['neutral', 'anger', 'fear', 'joy', 'love', 'sadness']] | |
categories = [emotion for dat in date_range for emotion in ['neutral', 'anger', 'fear', 'joy', 'love', 'sadness']] | |
data = pd.DataFrame({ | |
'Date': dates, | |
'Frequency': frequency, | |
'Emotion category': categories}) | |
domain = ['neutral', 'anger', 'fear', 'joy', 'love', 'sadness'] | |
range_ = ['#999999', '#b22222', '#663399', '#ffcc00', '#db7093', '#6495ed'] | |
n = max(data['Frequency']) | |
highlight = alt.selection( | |
type='single', on='mouseover', fields=["Emotion category"], nearest=True) | |
base = alt.Chart(data).encode( | |
x ="Date:T", | |
y=alt.Y("Frequency", scale=alt.Scale(domain=[0, (n + 9) // 10 * 10])), | |
color=alt.Color("Emotion category", scale=alt.Scale(domain=domain, range=range_), legend=alt.Legend(orient='bottom', direction='horizontal'))) | |
points = base.mark_circle().encode( | |
opacity=alt.value(0), | |
tooltip=[ | |
alt.Tooltip('Emotion category', title='Emotion category'), | |
alt.Tooltip('Date:T', title='Date'), | |
alt.Tooltip('Frequency', title='Frequency') | |
]).add_selection(highlight) | |
lines = base.mark_line().encode( | |
size=alt.condition(~highlight, alt.value(1), alt.value(3))) | |
plot = (points + lines).properties(width=600, height=350).interactive() | |
return plot | |
def showcase(dataset): | |
# predictions file | |
if dataset == "The Voice of Holland": | |
file_output = "output/predictions_tvoh.txt" | |
elif dataset == "Floodings": | |
file_output = "output/predictions_floodings.txt" | |
elif dataset == "COVID-19": | |
file_output = "output/predictions_covid.txt" | |
elif dataset == "Childcare Benefits": | |
file_output = "output/predictions_toeslagen.txt" | |
# freq bar plot | |
freq_output = freq(file_output) | |
# dist plot | |
dist_output = dist(file_output) | |
# peaks | |
if dataset == "The Voice of Holland": | |
peaks_output = pickle.load(open('output/peaks_tvoh.p', 'rb')) | |
elif dataset == "Floodings": | |
peaks_output = pickle.load(open('output/peaks_floodings.p', 'rb')) | |
elif dataset == "COVID-19": | |
peaks_output = pickle.load(open('output/peaks_covid.p', 'rb')) | |
elif dataset == "Childcare Benefits": | |
peaks_output = pickle.load(open('output/peaks_toeslagen.p', 'rb')) | |
# topics | |
if dataset == "The Voice of Holland": | |
topics_output = pickle.load(open('output/topics_tvoh.p', 'rb')) | |
elif dataset == "Floodings": | |
topics_output = pickle.load(open('output/topics_floodings.p', 'rb')) | |
elif dataset == "COVID-19": | |
topics_output = pickle.load(open('output/topics_covid.p', 'rb')) | |
elif dataset == "Childcare Benefits": | |
topics_output = pickle.load(open('output/topics_toeslagen.p', 'rb')) | |
return gr.update(visible=True), gr.update(value=file_output, visible=True), gr.update(value=freq_output,visible=True), gr.update(value=dist_output,visible=True), gr.update(value=peaks_output,visible=True), gr.update(value=topics_output,visible=True) | |
inference_modelpath = "model/checkpoint-128" | |
with gr.Blocks() as demo: | |
with gr.Column(scale=1, min_width=50): | |
gr.Markdown(""" | |
""") | |
with gr.Column(scale=5): | |
gr.Markdown(""" | |
<div style="text-align: center"><h1>EmotioNL: A framework for Dutch emotion detection</h1></div> | |
<div style="display: block;margin-left: auto;margin-right: auto;width: 60%;"><img alt="EmotioNL logo" src="https://users.ugent.be/~lundbruy/EmotioNL.png" width="100%"></div> | |
<div style="display: block;margin-left: auto;margin-right: auto;width: 75%;">This demo was made to demonstrate the EmotioNL model, a transformer-based classification model that analyses emotions in Dutch texts. The model uses <a href="https://github.com/iPieter/RobBERT">RobBERT</a>, which was further fine-tuned on the <a href="https://lt3.ugent.be/resources/emotionl/">EmotioNL dataset</a>. The resulting model is a classifier that, given a sentence, predicts one of the following emotion categories: <i>anger</i>, <i>fear</i>, <i>joy</i>, <i>love</i>, <i>sadness</i> or <i>neutral</i>. The demo can be used either in <b>sentence mode</b>, which allows you to enter a sentence for which an emotion will be predicted; or in <b>showcase mode</b>, which allows you to see the full functionality with example data.</div> | |
""") | |
with gr.Tab("Sentence"): | |
gr.Markdown(""" | |
""") | |
with gr.Row(): | |
with gr.Column(): | |
input = gr.Textbox( | |
label="Enter a sentence", | |
value="Jaaah! Volgende vakantie Barcelona en na het zomerseizoen naar de Algarve", | |
lines=1) | |
send_btn = gr.Button("Send") | |
output = gr.Textbox() | |
send_btn.click(fn=inference_sentence, inputs=input, outputs=output) | |
with gr.Tab("Showcase"): | |
with gr.Row(): | |
with gr.Column(): | |
gr.Markdown(""" | |
**<font size="4">Run the demo on the data of a specific crisis case</font>** | |
Select the desired dataset and click the button to run the demo. | |
""") | |
with gr.Column(): | |
gr.Markdown(""" | |
""") | |
with gr.Column(): | |
gr.Markdown(""" | |
**<font size="4">Output</font>** | |
After having clicked on the run button, scroll down to see the output (running may take a while): | |
""") | |
with gr.Row(): | |
with gr.Column(): | |
# demo1_btn = gr.Button("The Voice of Holland", variant="primary") | |
# demo2_btn = gr.Button("Floodings", variant="primary") | |
# demo3_btn = gr.Button("COVID-19", variant="primary") | |
# demo4_btn = gr.Button("Childcare Benefits", variant="primary") | |
dataset = gr.Dropdown(["The Voice of Holland", "Floodings", "COVID-19", "Childcare Benefits"], show_label=False) | |
run_btn = gr.Button("Run", variant="primary") | |
with gr.Column(): | |
gr.Markdown(""" | |
**The Voice of Holland:** 18,502 tweets about a scandal about sexual misconduct in the Dutch reality TV singing competition 'The Voice of Holland'. | |
**Floodings:** 9,923 tweets about the floodings that affected Belgium and the Netherlands in the Summer of 2021. | |
**COVID-19:** 609,206 tweets about the COVID-19 pandemic, posted in the first eight months of the crisis. | |
**Chilcare Benefits:** 66,961 tweets about the political scandal concerning false allegations of fraud regarding childcare allowance in the Netherlands. | |
""") | |
with gr.Column(): | |
gr.Markdown(""" | |
**Predictions:** file with the predicted emotion label for each instance in the dataset. | |
**Emotion frequencies:** bar plot with the prediction frequencies of each emotion category (anger, fear, joy, love, sadness or neutral). | |
**Emotion distribution over time:** line plot that visualises the frequency of predicted emotions over time for each emotion category. | |
**Peaks:** step graph that only shows the significant fluctuations (upwards and downwards) in emotion frequencies over time. | |
**Topics:** a bar plot that shows the emotion distribution for different topics in the dataset. Topics are extracted using [BERTopic](https://maartengr.github.io/BERTopic/index.html). | |
""") | |
with gr.Row(): | |
gr.Markdown(""" | |
___ | |
""") | |
with gr.Row(): | |
with gr.Column(): | |
output_markdown = gr.Markdown(""" | |
**<font size="4">Output</font>** | |
""", visible=False) | |
message = gr.Textbox(label="Message", visible=False) | |
output_file = gr.File(label="Predictions", visible=False) | |
output_plot = gr.Plot(show_label=False, visible=False).style(container=True) | |
output_dist = gr.Plot(show_label=False, visible=False) | |
output_peaks = gr.Plot(show_label=False, visible=False) | |
output_topics = gr.Plot(show_label=False, visible=False) | |
run_btn.click(fn=showcase, inputs=[dataset], outputs=[output_markdown, output_file, output_plot, output_dist, output_peaks, output_topics]) | |
with gr.Row(): | |
with gr.Column(): | |
gr.Markdown(""" | |
<font size="2">Both this demo and the dataset have been created by [LT3](https://lt3.ugent.be/), the Language and Translation Technology Team of Ghent University. The EmotioNL project has been carried out with support from the Research Foundation – Flanders (FWO). For any questions, please contact luna.debruyne@ugent.be.</font> | |
<div style="display: grid;grid-template-columns:150px auto;"> <img style="margin-right: 1em" alt="LT3 logo" src="https://lt3.ugent.be/static/images/logo_v2_single.png" width="136" height="58"> <img style="margin-right: 1em" alt="FWO logo" src="https://www.fwo.be/images/logo_desktop.png" height="58"></div> | |
""") | |
with gr.Column(scale=1, min_width=50): | |
gr.Markdown(""" | |
""") | |
demo.launch() |