Spaces:
Running
Running
import gradio as gr | |
import torch | |
import numpy as np | |
import pickle | |
import pandas as pd | |
from tqdm import tqdm | |
import altair as alt | |
import matplotlib.pyplot as plt | |
from datetime import date, timedelta | |
from transformers import AutoTokenizer, AutoConfig, AutoModel, AutoModelForSequenceClassification | |
""" | |
description_sentence = "<h3>Demo EmotioNL</h3>\nThis demo allows you to analyse the emotion in a sentence." | |
description_dataset = "<h3>Demo EmotioNL</h3>\nThis demo allows you to analyse the emotions in a dataset.\nThe data should be in tsv-format with two named columns: the first column (id) should contain the sentence IDs, and the second column (text) should contain the actual texts. Optionally, there is a third column named 'date', which specifies the date associated with the text (e.g., tweet date). This column is necessary when the options 'emotion distribution over time' and 'peaks' are selected." | |
inference_modelpath = "model/checkpoint-128" | |
def inference_sentence(text): | |
tokenizer = AutoTokenizer.from_pretrained(inference_modelpath) | |
model = AutoModelForSequenceClassification.from_pretrained(inference_modelpath) | |
for text in tqdm([text]): | |
inputs = tokenizer(text, return_tensors="pt") | |
with torch.no_grad(): # run model | |
logits = model(**inputs).logits | |
predicted_class_id = logits.argmax().item() | |
output = model.config.id2label[predicted_class_id] | |
return output | |
def frequencies(preds): | |
preds_dict = {"neutral": 0, "anger": 0, "fear": 0, "joy": 0, "love": 0, "sadness": 0} | |
for pred in preds: | |
preds_dict[pred] = preds_dict[pred] + 1 | |
bars = list(preds_dict.keys()) | |
height = list(preds_dict.values()) | |
x_pos = np.arange(len(bars)) | |
plt.bar(x_pos, height, color=['lightgrey', 'firebrick', 'rebeccapurple', 'orange', 'palevioletred', 'cornflowerblue']) | |
plt.xticks(x_pos, bars) | |
return plt | |
def inference_dataset(file_object, option_list): | |
tokenizer = AutoTokenizer.from_pretrained(inference_modelpath) | |
model = AutoModelForSequenceClassification.from_pretrained(inference_modelpath) | |
data_path = open(file_object.name, 'r') | |
df = pd.read_csv(data_path, delimiter='\t', header=0, names=['id', 'text']) | |
ids = df["id"].tolist() | |
texts = df["text"].tolist() | |
preds = [] | |
for text in tqdm(texts): # progressbar | |
inputs = tokenizer(text, return_tensors="pt") | |
with torch.no_grad(): # run model | |
logits = model(**inputs).logits | |
predicted_class_id = logits.argmax().item() | |
prediction = model.config.id2label[predicted_class_id] | |
preds.append(prediction) | |
predictions_content = list(zip(ids, texts, preds)) | |
# write predictions to file | |
output = "output.txt" | |
f = open(output, 'w') | |
f.write("id\ttext\tprediction\n") | |
for line in predictions_content: | |
f.write(str(line[0]) + '\t' + str(line[1]) + '\t' + str(line[2]) + '\n') | |
output1 = output | |
output2 = output3 = output4 = output5 = "This option was not selected." | |
if "emotion frequencies" in option_list: | |
output2 = frequencies(preds) | |
else: | |
output2 = None | |
if "emotion distribution over time" in option_list: | |
output3 = "This option was selected." | |
if "peaks" in option_list: | |
output4 = "This option was selected." | |
if "topics" in option_list: | |
output5 = "This option was selected." | |
return [output1, output2, output3, output4, output5] | |
iface_sentence = gr.Interface( | |
fn=inference_sentence, | |
description = description_sentence, | |
inputs = gr.Textbox( | |
label="Enter a sentence", | |
lines=1), | |
outputs="text") | |
inputs = [gr.File( | |
label="Upload a dataset"), | |
gr.CheckboxGroup( | |
["emotion frequencies", "emotion distribution over time", "peaks", "topics"], | |
label = "Select options")] | |
outputs = [gr.File(), | |
gr.Plot(label="Emotion frequencies"), | |
gr.Textbox(label="Emotion distribution over time"), | |
gr.Textbox(label="Peaks"), | |
gr.Textbox(label="Topics")] | |
iface_dataset = gr.Interface( | |
fn = inference_dataset, | |
description = description_dataset, | |
inputs=inputs, | |
outputs = outputs) | |
iface = gr.TabbedInterface([iface_sentence, iface_dataset], ["Sentence", "Dataset"]) | |
iface.queue().launch() | |
""" | |
inference_modelpath = "model/checkpoint-128" | |
def inference_sentence(text): | |
tokenizer = AutoTokenizer.from_pretrained(inference_modelpath) | |
model = AutoModelForSequenceClassification.from_pretrained(inference_modelpath) | |
for text in tqdm([text]): | |
inputs = tokenizer(text, return_tensors="pt") | |
with torch.no_grad(): # run model | |
logits = model(**inputs).logits | |
predicted_class_id = logits.argmax().item() | |
output = model.config.id2label[predicted_class_id] | |
return "Predicted emotion:\n" + output | |
""" | |
def inference_sentence(text): | |
output = "This sentence will be processed:\n" + text | |
return output | |
""" | |
def unavailable(input_file, input_checks): | |
output = "As we are currently updating this demo, submitting your own data is unavailable for the moment. However, you can try out the showcase mode π" | |
return gr.update(value=output, label="Oops!", visible=True) | |
def showcase(input_file): | |
output = "showcase/example_predictions.txt" | |
return gr.update(visible=False), gr.update(value=output, visible=True), gr.update(visible=True), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False) # next_button_freq becomes available | |
def file(input_file, input_checks): | |
#output = "output.txt" | |
#f = open(output, 'w') | |
#f.write("The predictions come here.") | |
#f.close() | |
output = "showcase/example_predictions.txt" | |
if "emotion frequencies" in input_checks: | |
return gr.update(value=output, visible=True), gr.update(visible=True), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False) # next_button_freq becomes available | |
elif "emotion distribution over time" in input_checks: | |
return gr.update(value=output, visible=True), gr.update(visible=False), gr.update(visible=True), gr.update(visible=False), gr.update(visible=False) # next_button_dist becomes available | |
elif "peaks" in input_checks: | |
return gr.update(value=output, visible=True), gr.update(visible=False), gr.update(visible=False), gr.update(visible=True), gr.update(visible=False) # next_button_peaks becomes available | |
elif "topics" in input_checks: | |
return gr.update(value=output, visible=True), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=True) # next_button_topics becomes available | |
else: | |
return gr.update(value=output, visible=True), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False) # no next_button becomes available | |
def freq(output_file, input_checks): | |
#simple = pd.DataFrame({ | |
#'Emotion category': ['neutral', 'anger', 'fear', 'joy', 'love', 'sadness'], | |
#'Frequency': [10, 8, 2, 15, 3, 4]}) | |
f = open("showcase/example_predictions.txt", 'r') | |
data = f.read().split("\n") | |
f.close() | |
data = [line.split("\t") for line in data[1:-1]] | |
freq_dict = {} | |
for line in data: | |
if line[1] not in freq_dict.keys(): | |
freq_dict[line[1]] = 1 | |
else: | |
freq_dict[line[1]] += 1 | |
simple = pd.DataFrame({ | |
'Emotion category': ['neutral', 'anger', 'fear', 'joy', 'love', 'sadness'], | |
'Frequency': [freq_dict['neutral'], freq_dict['anger'], freq_dict['fear'], freq_dict['joy'], freq_dict['love'], freq_dict['sadness']]}) | |
domain = ['neutral', 'anger', 'fear', 'joy', 'love', 'sadness'] | |
range_ = ['#999999', '#b22222', '#663399', '#ffcc00', '#db7093', '#6495ed'] | |
n = max(simple['Frequency']) | |
plot = alt.Chart(simple).mark_bar().encode( | |
x=alt.X("Emotion category", sort=['neutral', 'anger', 'fear', 'joy', 'love', 'sadness']), | |
y=alt.Y("Frequency", axis=alt.Axis(grid=False), scale=alt.Scale(domain=[0, (n + 9) // 10 * 10])), | |
color=alt.Color("Emotion category", scale=alt.Scale(domain=domain, range=range_), legend=None), | |
tooltip=['Emotion category', 'Frequency']).properties( | |
width=600).configure_axis( | |
grid=False).interactive() | |
if "emotion distribution over time" in input_checks or (output_file.name).startswith('/tmp/example_predictions'): | |
return gr.update(value=plot, visible=True), gr.update(visible=True), gr.update(visible=False), gr.update(visible=False) # next_button_dist becomes available | |
elif "peaks" in input_checks: | |
return gr.update(value=plot, visible=True), gr.update(visible=False), gr.update(visible=True), gr.update(visible=False) # next_button_peaks becomes available | |
elif "topics" in input_checks: | |
return gr.update(value=plot, visible=True), gr.update(visible=False), gr.update(visible=False), gr.update(visible=True) # next_button_topics becomes available | |
else: | |
return gr.update(value=plot, visible=True), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False) # no next_button becomes available | |
def dist(output_file, input_checks): | |
#data = pd.DataFrame({ | |
#'Date': ['1/1', '1/1', '1/1', '1/1', '1/1', '1/1', '2/1', '2/1', '2/1', '2/1', '2/1', '2/1', '3/1', '3/1', '3/1', '3/1', '3/1', '3/1'], | |
#'Frequency': [3, 5, 1, 8, 2, 3, 4, 7, 1, 12, 4, 2, 3, 6, 3, 10, 3, 4], | |
#'Emotion category': ['neutral', 'anger', 'fear', 'joy', 'love', 'sadness', 'neutral', 'anger', 'fear', 'joy', 'love', 'sadness', 'neutral', 'anger', 'fear', 'joy', 'love', 'sadness']}) | |
f = open("showcase/data.txt", 'r') | |
data = f.read().split("\n") | |
f.close() | |
data = [line.split("\t") for line in data[1:-1]] | |
freq_dict = {} | |
for line in data: | |
dat = str(date(2000+int(line[0].split("/")[2]), int(line[0].split("/")[1]), int(line[0].split("/")[0]))) | |
if dat not in freq_dict.keys(): | |
freq_dict[dat] = {} | |
if line[1] not in freq_dict[dat].keys(): | |
freq_dict[dat][line[1]] = 1 | |
else: | |
freq_dict[dat][line[1]] += 1 | |
else: | |
if line[1] not in freq_dict[dat].keys(): | |
freq_dict[dat][line[1]] = 1 | |
else: | |
freq_dict[dat][line[1]] += 1 | |
start_date = date(2000+int(data[0][0].split("/")[2]), int(data[0][0].split("/")[1]), int(data[0][0].split("/")[0])) | |
end_date = date(2000+int(data[-1][0].split("/")[2]), int(data[-1][0].split("/")[1]), int(data[-1][0].split("/")[0])) | |
delta = end_date - start_date # returns timedelta | |
date_range = [str(start_date + timedelta(days=i)) for i in range(delta.days + 1)] | |
dates = [dat for dat in date_range for i in range(6)] | |
frequency = [freq_dict[dat][emotion] if (dat in freq_dict.keys() and emotion in freq_dict[dat].keys()) else 0 for dat in date_range for emotion in ['neutral', 'anger', 'fear', 'joy', 'love', 'sadness']] | |
categories = [emotion for dat in date_range for emotion in ['neutral', 'anger', 'fear', 'joy', 'love', 'sadness']] | |
data = pd.DataFrame({ | |
'Date': dates, | |
'Frequency': frequency, | |
'Emotion category': categories}) | |
domain = ['neutral', 'anger', 'fear', 'joy', 'love', 'sadness'] | |
range_ = ['#999999', '#b22222', '#663399', '#ffcc00', '#db7093', '#6495ed'] | |
n = max(data['Frequency']) | |
highlight = alt.selection( | |
type='single', on='mouseover', fields=["Emotion category"], nearest=True) | |
base = alt.Chart(data).encode( | |
x ="Date:T", | |
y=alt.Y("Frequency", scale=alt.Scale(domain=[0, (n + 9) // 10 * 10])), | |
color=alt.Color("Emotion category", scale=alt.Scale(domain=domain, range=range_), legend=alt.Legend(orient='bottom', direction='horizontal'))) | |
points = base.mark_circle().encode( | |
opacity=alt.value(0), | |
tooltip=[ | |
alt.Tooltip('Emotion category', title='Emotion category'), | |
alt.Tooltip('Date:T', title='Date'), | |
alt.Tooltip('Frequency', title='Frequency') | |
]).add_selection(highlight) | |
lines = base.mark_line().encode( | |
size=alt.condition(~highlight, alt.value(1), alt.value(3))) | |
plot = (points + lines).properties(width=600, height=350).interactive() | |
if "peaks" in input_checks or (output_file.name).startswith('/tmp/example_predictions'): | |
return gr.Plot.update(value=plot, visible=True), gr.update(visible=True), gr.update(visible=False) # next_button_peaks becomes available | |
elif "topics" in input_checks: | |
return gr.Plot.update(value=plot, visible=True), gr.update(visible=False), gr.update(visible=True) # next_button_topics becomes available | |
else: | |
return gr.Plot.update(value=plot, visible=True), gr.update(visible=False), gr.update(visible=False) # no next_button becomes available | |
def peaks(output_file, input_checks): | |
plot = pickle.load(open('showcase/peaks_covid.p', 'rb')) | |
if "topics" in input_checks or (output_file.name).startswith('/tmp/example_predictions'): | |
return gr.Plot.update(value=plot, visible=True), gr.update(visible=True) # next_button_topics becomes available | |
else: | |
return gr.Plot.update(value=plot, visible=True), gr.update(visible=False) # no next_button becomes available | |
def topics(output_file, input_checks): | |
plot = pickle.load(open('showcase/vis_classes_covid.p', 'rb')) | |
plot.update_layout(width=600, height=400) | |
return gr.Plot.update(value=plot, visible=True) # no next_button becomes available | |
with gr.Blocks() as demo: | |
with gr.Tab("Sentence"): | |
gr.Markdown(""" | |
# Demo EmotioNL | |
This demo allows you to analyse the emotion in a Dutch sentence. | |
""") | |
with gr.Row(): | |
with gr.Column(): | |
input = gr.Textbox( | |
label="Enter a sentence", | |
value="Jaaah! Volgende vakantie Barcelona en na het zomerseizoen naar de Algarve", | |
lines=1) | |
send_btn = gr.Button("Send") | |
output = gr.Textbox() | |
send_btn.click(fn=inference_sentence, inputs=input, outputs=output) | |
with gr.Tab("Dataset"): | |
gr.Markdown(""" | |
# Demo EmotioNL | |
This demo allows you to analyse the emotions in a dataset with Dutch sentences. | |
_! As we are currently updating this demo, submitting your own data is unavailable for the moment ! However, you can try out the showcase mode π_ | |
The data should be in tsv-format with two named columns: the first column (id) should contain the sentence IDs, and the second column (text) should contain the actual texts. Optionally, there is a third column named 'date', which specifies the date associated with the text (e.g., tweet date). This column is necessary when the options 'emotion distribution over time' and 'peaks' are selected. | |
You can also try out the demo in showcase mode, which uses example data, namely a dataset with tweets about the COVID-19 pandemic. | |
""") | |
with gr.Row(): | |
with gr.Column(): | |
input_file = gr.File( | |
label="Upload a dataset") | |
input_checks = gr.CheckboxGroup( | |
["emotion frequencies", "emotion distribution over time", "peaks", "topics"], | |
label = "Select options") | |
send_btn = gr.Button("Submit data") | |
demo_btn = gr.Button("... or showcase with example data") | |
with gr.Column(): | |
message = gr.Textbox(label="Message", visible=False) | |
output_file = gr.File(label="Predictions", visible=False) | |
next_button_freq = gr.Button("Show emotion frequencies", visible=False) | |
output_plot = gr.Plot(show_label=False, visible=False).style(container=True) | |
next_button_dist = gr.Button("Show emotion distribution over time", visible=False) | |
output_dist = gr.Plot(show_label=False, visible=False) | |
next_button_peaks = gr.Button("Show peaks", visible=False) | |
output_peaks = gr.HTML(visible=False) | |
next_button_topics = gr.Button("Show topics", visible=False) | |
output_topics = gr.Plot(show_label=False, visible=False) | |
#send_btn.click(fn=file, inputs=[input_file,input_checks], outputs=[output_file,next_button_freq,next_button_dist,next_button_peaks,next_button_topics]) | |
next_button_freq.click(fn=freq, inputs=[output_file,input_checks], outputs=[output_plot,next_button_dist,next_button_peaks,next_button_topics]) | |
next_button_dist.click(fn=dist, inputs=[output_file,input_checks], outputs=[output_dist,next_button_peaks,next_button_topics]) | |
next_button_peaks.click(fn=peaks, inputs=[output_file,input_checks], outputs=[output_peaks,next_button_topics]) | |
next_button_topics.click(fn=topics, inputs=[output_file,input_checks], outputs=output_topics) | |
send_btn.click(fn=unavailable, inputs=[input_file,input_checks], outputs=message) | |
demo_btn.click(fn=showcase, inputs=[input_file], outputs=[message,output_file,next_button_freq,next_button_dist,next_button_peaks,next_button_topics]) | |
demo.launch() |