EmotioNL / app.py
lunadebruyne's picture
Update app.py
a735135
raw
history blame
16.4 kB
import gradio as gr
import torch
import numpy as np
import pandas as pd
from tqdm import tqdm
import altair as alt
import matplotlib.pyplot as plt
import datetime
from transformers import AutoTokenizer, AutoConfig, AutoModel, AutoModelForSequenceClassification
"""
description_sentence = "<h3>Demo EmotioNL</h3>\nThis demo allows you to analyse the emotion in a sentence."
description_dataset = "<h3>Demo EmotioNL</h3>\nThis demo allows you to analyse the emotions in a dataset.\nThe data should be in tsv-format with two named columns: the first column (id) should contain the sentence IDs, and the second column (text) should contain the actual texts. Optionally, there is a third column named 'date', which specifies the date associated with the text (e.g., tweet date). This column is necessary when the options 'emotion distribution over time' and 'peaks' are selected."
inference_modelpath = "model/checkpoint-128"
def inference_sentence(text):
tokenizer = AutoTokenizer.from_pretrained(inference_modelpath)
model = AutoModelForSequenceClassification.from_pretrained(inference_modelpath)
for text in tqdm([text]):
inputs = tokenizer(text, return_tensors="pt")
with torch.no_grad(): # run model
logits = model(**inputs).logits
predicted_class_id = logits.argmax().item()
output = model.config.id2label[predicted_class_id]
return output
def frequencies(preds):
preds_dict = {"neutral": 0, "anger": 0, "fear": 0, "joy": 0, "love": 0, "sadness": 0}
for pred in preds:
preds_dict[pred] = preds_dict[pred] + 1
bars = list(preds_dict.keys())
height = list(preds_dict.values())
x_pos = np.arange(len(bars))
plt.bar(x_pos, height, color=['lightgrey', 'firebrick', 'rebeccapurple', 'orange', 'palevioletred', 'cornflowerblue'])
plt.xticks(x_pos, bars)
return plt
def inference_dataset(file_object, option_list):
tokenizer = AutoTokenizer.from_pretrained(inference_modelpath)
model = AutoModelForSequenceClassification.from_pretrained(inference_modelpath)
data_path = open(file_object.name, 'r')
df = pd.read_csv(data_path, delimiter='\t', header=0, names=['id', 'text'])
ids = df["id"].tolist()
texts = df["text"].tolist()
preds = []
for text in tqdm(texts): # progressbar
inputs = tokenizer(text, return_tensors="pt")
with torch.no_grad(): # run model
logits = model(**inputs).logits
predicted_class_id = logits.argmax().item()
prediction = model.config.id2label[predicted_class_id]
preds.append(prediction)
predictions_content = list(zip(ids, texts, preds))
# write predictions to file
output = "output.txt"
f = open(output, 'w')
f.write("id\ttext\tprediction\n")
for line in predictions_content:
f.write(str(line[0]) + '\t' + str(line[1]) + '\t' + str(line[2]) + '\n')
output1 = output
output2 = output3 = output4 = output5 = "This option was not selected."
if "emotion frequencies" in option_list:
output2 = frequencies(preds)
else:
output2 = None
if "emotion distribution over time" in option_list:
output3 = "This option was selected."
if "peaks" in option_list:
output4 = "This option was selected."
if "topics" in option_list:
output5 = "This option was selected."
return [output1, output2, output3, output4, output5]
iface_sentence = gr.Interface(
fn=inference_sentence,
description = description_sentence,
inputs = gr.Textbox(
label="Enter a sentence",
lines=1),
outputs="text")
inputs = [gr.File(
label="Upload a dataset"),
gr.CheckboxGroup(
["emotion frequencies", "emotion distribution over time", "peaks", "topics"],
label = "Select options")]
outputs = [gr.File(),
gr.Plot(label="Emotion frequencies"),
gr.Textbox(label="Emotion distribution over time"),
gr.Textbox(label="Peaks"),
gr.Textbox(label="Topics")]
iface_dataset = gr.Interface(
fn = inference_dataset,
description = description_dataset,
inputs=inputs,
outputs = outputs)
iface = gr.TabbedInterface([iface_sentence, iface_dataset], ["Sentence", "Dataset"])
iface.queue().launch()
"""
def inference_sentence(text):
output = "This sentence will be processed:\n" + text
return output
def unavailable(input_file, input_checks):
output = "Submitting your own data is currently unavailable, but you can try out the showcase mode 😊"
return gr.update(value=output, label="Oops!", visible=True)
def file(input_file, input_checks):
output = "output.txt"
f = open(output, 'w')
f.write("The predictions come here.")
f.close()
if "emotion frequencies" in input_checks:
return gr.update(value=output, visible=True), gr.update(visible=True), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False) # next_button_freq becomes available
elif "emotion distribution over time" in input_checks:
return gr.update(value=output, visible=True), gr.update(visible=False), gr.update(visible=True), gr.update(visible=False), gr.update(visible=False) # next_button_dist becomes available
elif "peaks" in input_checks:
return gr.update(value=output, visible=True), gr.update(visible=False), gr.update(visible=False), gr.update(visible=True), gr.update(visible=False) # next_button_peaks becomes available
elif "topics" in input_checks:
return gr.update(value=output, visible=True), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=True) # next_button_topics becomes available
else:
return gr.update(value=output, visible=True), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False) # no next_button becomes available
def freq(output_file, input_checks):
simple = pd.DataFrame({
'Emotion category': ['neutral', 'anger', 'fear', 'joy', 'love', 'sadness'],
'Frequency': [10, 8, 2, 15, 3, 4]})
domain = ['neutral', 'anger', 'fear', 'joy', 'love', 'sadness']
range_ = ['#999999', '#b22222', '#663399', '#ffcc00', '#db7093', '#6495ed']
n = max(simple['Frequency'])
plot = alt.Chart(simple).mark_bar().encode(
x=alt.X("Emotion category", sort=['neutral', 'anger', 'fear', 'joy', 'love', 'sadness']),
y=alt.Y("Frequency", axis=alt.Axis(grid=False), scale=alt.Scale(domain=[0, (n + 9) // 10 * 10])),
color=alt.Color("Emotion category", scale=alt.Scale(domain=domain, range=range_), legend=None),
tooltip=['Emotion category', 'Frequency']).properties(
width=600).configure_axis(
grid=False).interactive()
if "emotion distribution over time" in input_checks:
return gr.update(value=plot, visible=True), gr.update(visible=True), gr.update(visible=False), gr.update(visible=False) # next_button_dist becomes available
elif "peaks" in input_checks:
return gr.update(value=plot, visible=True), gr.update(visible=False), gr.update(visible=True), gr.update(visible=False) # next_button_peaks becomes available
elif "topics" in input_checks:
return gr.update(value=plot, visible=True), gr.update(visible=False), gr.update(visible=False), gr.update(visible=True) # next_button_topics becomes available
else:
return gr.update(value=plot, visible=True), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False) # no next_button becomes available
def dist(output_file, input_checks):
data = pd.DataFrame({
'Date': ['1/1', '1/1', '1/1', '1/1', '1/1', '1/1', '2/1', '2/1', '2/1', '2/1', '2/1', '2/1', '3/1', '3/1', '3/1', '3/1', '3/1', '3/1'],
'Frequency': [3, 5, 1, 8, 2, 3, 4, 7, 1, 12, 4, 2, 3, 6, 3, 10, 3, 4],
'Emotion category': ['neutral', 'anger', 'fear', 'joy', 'love', 'sadness', 'neutral', 'anger', 'fear', 'joy', 'love', 'sadness', 'neutral', 'anger', 'fear', 'joy', 'love', 'sadness']})
domain = ['neutral', 'anger', 'fear', 'joy', 'love', 'sadness']
range_ = ['#999999', '#b22222', '#663399', '#ffcc00', '#db7093', '#6495ed']
n = max(data['Frequency'])
highlight = alt.selection(
type='single', on='mouseover', fields=["Emotion category"], nearest=True)
base = alt.Chart(data).encode(
x ="Date",
y=alt.Y("Frequency", scale=alt.Scale(domain=[0, (n + 9) // 10 * 10])),
color=alt.Color("Emotion category", scale=alt.Scale(domain=domain, range=range_), legend=alt.Legend(orient='bottom', direction='horizontal')))
points = base.mark_circle().encode(
opacity=alt.value(0),
tooltip=[
alt.Tooltip('Emotion category', title='Emotion category'),
alt.Tooltip('Date', title='Date'),
alt.Tooltip('Frequency', title='Frequency')
]).add_selection(highlight)
lines = base.mark_line().encode(
size=alt.condition(~highlight, alt.value(1), alt.value(3)))
plot = (points + lines).properties(width=600, height=350).interactive()
if "peaks" in input_checks:
return gr.Plot.update(value=plot, visible=True), gr.update(visible=True), gr.update(visible=False) # next_button_peaks becomes available
elif "topics" in input_checks:
return gr.Plot.update(value=plot, visible=True), gr.update(visible=False), gr.update(visible=True) # next_button_topics becomes available
else:
return gr.Plot.update(value=plot, visible=True), gr.update(visible=False), gr.update(visible=False) # no next_button becomes available
def peaks(output_file, input_checks):
peaks_neutral = {"13/3/2020": "up", "20/3/2020": "up", "21/3/2020": "down"}
peaks_anger = {"13/3/2020": "up", "20/3/2020": "up", "21/3/2020": "down"}
peaks_fear = {"13/3/2020": "up", "20/3/2020": "up", "21/3/2020": "down"}
peaks_joy = {"13/3/2020": "up", "20/3/2020": "up", "21/3/2020": "down"}
peaks_love = {"13/3/2020": "up", "20/3/2020": "up", "21/3/2020": "down"}
peaks_sadness = {"13/3/2020": "up", "20/3/2020": "up", "21/3/2020": "down"}
text_neutral = ", ".join([str(key) + " (↑)" if value == "up" else str(key) + " (↓)" for key, value in peaks_neutral.items()])
text_anger = ", ".join([str(key) + " (↑)" if value == "up" else str(key) + " (↓)" for key, value in peaks_anger.items()])
text_fear = ", ".join([str(key) + " (↑)" if value == "up" else str(key) + " (↓)" for key, value in peaks_fear.items()])
text_joy = ", ".join([str(key) + " (↑)" if value == "up" else str(key) + " (↓)" for key, value in peaks_joy.items()])
text_love = ", ".join([str(key) + " (↑)" if value == "up" else str(key) + " (↓)" for key, value in peaks_love.items()])
text_sadness = ", ".join([str(key) + " (↑)" if value == "up" else str(key) + " (↓)" for key, value in peaks_sadness.items()])
html = (
'<html>'
'<head>'
'<meta name="viewport" content="width=device-width, initial-scale=1">'
'<style>'
'.dot_neutral {'
'height: 11px;'
'width: 11px;'
'background-color: #999999;'
'border-radius: 50%;'
'display: inline-block;'
'}'
'.dot_anger {'
'height: 11px;'
'width: 11px;'
'background-color: #b22222;'
'border-radius: 50%;'
'display: inline-block;'
'}'
'.dot_fear {'
'height: 11px;'
'width: 11px;'
'background-color: #663399;'
'border-radius: 50%;'
'display: inline-block;'
'}'
'.dot_joy {'
'height: 11px;'
'width: 11px;'
'background-color: #ffcc00;'
'border-radius: 50%;'
'display: inline-block;'
'}'
'.dot_love {'
'height: 11px;'
'width: 11px;'
'background-color: #db7093;'
'border-radius: 50%;'
'display: inline-block;'
'}'
'.dot_sadness {'
'height: 11px;'
'width: 11px;'
'background-color: #6495ed;'
'border-radius: 50%;'
'display: inline-block;'
'}'
'.tab {'
'padding-left: 1em;'
'}'
'</style>'
'</head>'
'<body>'
'<div>'
'<p>These significant fluctuations were found:</p>'
'<p><span class="dot_neutral"></span> neutral:</p>'
'<p class="tab">' + text_neutral + '<p>'
'<p><span class="dot_anger"></span> anger:</p>'
'<p class="tab">' + text_anger + '<p>'
'<p><span class="dot_fear"></span> fear:</p>'
'<p class="tab">' + text_fear + '<p>'
'<p><span class="dot_joy"></span> joy:</p>'
'<p class="tab">' + text_joy + '<p>'
'<p><span class="dot_love"></span> love:</p>'
'<p class="tab">' + text_love + '<p>'
'<p><span class="dot_sadness"></span> sadness:</p>'
'<p class="tab">' + text_sadness + '<p>'
'</div>'
'</body>'
'</html>'
)
if "topics" in input_checks:
return gr.update(value=html, visible=True), gr.update(visible=True) # next_button_topics becomes available
else:
return gr.update(value=html, visible=True), gr.update(visible=False) # no next_button becomes available
def topics(output_file, input_checks):
output = "Some topics are found."
return gr.update(value=output, visible=True) # no next_button becomes available
with gr.Blocks() as demo:
with gr.Tab("Sentence"):
gr.Markdown("""
# Demo EmotioNL
This demo allows you to analyse the emotion in a sentence.
""")
with gr.Row():
with gr.Column():
input = gr.Textbox(
label="Enter a sentence",
lines=1)
send_btn = gr.Button("Send")
output = gr.Textbox()
send_btn.click(fn=inference_sentence, inputs=input, outputs=output)
with gr.Tab("Dataset"):
gr.Markdown("""
# Demo EmotioNL
This demo allows you to analyse the emotions in a dataset. The data should be in tsv-format with two named columns: the first column (id) should contain the sentence IDs, and the second column (text) should contain the actual texts. Optionally, there is a third column named 'date', which specifies the date associated with the text (e.g., tweet date). This column is necessary when the options 'emotion distribution over time' and 'peaks' are selected.
""")
with gr.Row():
with gr.Column():
input_file = gr.File(
label="Upload a dataset")
input_checks = gr.CheckboxGroup(
["emotion frequencies", "emotion distribution over time", "peaks", "topics"],
label = "Select options")
send_btn = gr.Button("Submit data")
demo_btn = gr.Button("... or showcase with example data")
with gr.Column():
message = gr.Textbox(label="Message", visible=False)
output_file = gr.File(label="Predictions", visible=False)
next_button_freq = gr.Button("Show emotion frequencies", visible=False)
output_plot = gr.Plot(show_label=False, visible=False).style(container=True)
next_button_dist = gr.Button("Show emotion distribution over time", visible=False)
output_dist = gr.Plot(show_label=False, visible=False)
next_button_peaks = gr.Button("Show peaks", visible=False)
output_peaks = gr.HTML(visible=False)
next_button_topics = gr.Button("Show topics", visible=False)
output_topics = gr.Textbox(show_label=False, visible=False)
send_btn.click(fn=file, inputs=[input_file,input_checks], outputs=[output_file,next_button_freq,next_button_dist,next_button_peaks,next_button_topics])
next_button_freq.click(fn=freq, inputs=[output_file,input_checks], outputs=[output_plot,next_button_dist,next_button_peaks,next_button_topics])
next_button_dist.click(fn=dist, inputs=[output_file,input_checks], outputs=[output_dist,next_button_peaks,next_button_topics])
next_button_peaks.click(fn=peaks, inputs=[output_file,input_checks], outputs=[output_peaks,next_button_topics])
next_button_topics.click(fn=topics, inputs=[output_file,input_checks], outputs=output_topics)
#send_btn.click(fn=unavailable, inputs=[input_file,input_checks], outputs=message)
demo.launch()