import gradio as gr import torch import numpy as np import pickle import pandas as pd from tqdm import tqdm import altair as alt import matplotlib.pyplot as plt from datetime import date, timedelta from transformers import AutoTokenizer, AutoConfig, AutoModel, AutoModelForSequenceClassification def inference_sentence(text): tokenizer = AutoTokenizer.from_pretrained(inference_modelpath) model = AutoModelForSequenceClassification.from_pretrained(inference_modelpath) for text in tqdm([text]): inputs = tokenizer(text, return_tensors="pt") with torch.no_grad(): # run model logits = model(**inputs).logits predicted_class_id = logits.argmax().item() output = model.config.id2label[predicted_class_id] return "Predicted emotion:\n" + output def freq(file_output): f = open(file_output, 'r') data = f.read().split("\n") f.close() data = [line.split(",") for line in data[1:-1]] freq_dict = {} for line in data: if line[1] not in freq_dict.keys(): freq_dict[line[1]] = 1 else: freq_dict[line[1]] += 1 simple = pd.DataFrame({ 'Emotion category': ['neutral', 'anger', 'fear', 'joy', 'love', 'sadness'], 'Frequency': [freq_dict['0'], freq_dict['1'], freq_dict['2'], freq_dict['3'], freq_dict['4'], freq_dict['5']]}) domain = ['neutral', 'anger', 'fear', 'joy', 'love', 'sadness'] range_ = ['#999999', '#b22222', '#663399', '#ffcc00', '#db7093', '#6495ed'] n = max(simple['Frequency']) plot = alt.Chart(simple).mark_bar().encode( x=alt.X("Emotion category", sort=['neutral', 'anger', 'fear', 'joy', 'love', 'sadness']), y=alt.Y("Frequency", axis=alt.Axis(grid=False), scale=alt.Scale(domain=[0, (n + 9) // 10 * 10])), color=alt.Color("Emotion category", scale=alt.Scale(domain=domain, range=range_), legend=None), tooltip=['Emotion category', 'Frequency']).properties( width=600).configure_axis( grid=False).interactive() return plot def dist(file_output): f = open(file_output, 'r') data = f.read().split("\n") f.close() data = [line.split(",") for line in data[1:-1]] freq_dict = {} mapping_dict = {'0': 'neutral', '1': 'anger', '2': 'fear', '3': 'joy', '4': 'love', '5': 'sadness'} for line in data: dat = str(date(int(line[0][:4]), int(line[0][4:6]), int(line[0][6:8]))) if dat not in freq_dict.keys(): freq_dict[dat] = {} if mapping_dict[line[1]] not in freq_dict[dat].keys(): freq_dict[dat][mapping_dict[line[1]]] = 1 else: freq_dict[dat][mapping_dict[line[1]]] += 1 else: if mapping_dict[line[1]] not in freq_dict[dat].keys(): freq_dict[dat][mapping_dict[line[1]]] = 1 else: freq_dict[dat][mapping_dict[line[1]]] += 1 start_date = date(int(data[0][0][:4]), int(data[0][0][4:6]), int(data[0][0][6:8])) end_date = date(int(data[-1][0][:4]), int(data[-1][0][4:6]), int(data[-1][0][6:8])) delta = end_date - start_date # returns timedelta date_range = [str(start_date + timedelta(days=i)) for i in range(delta.days + 1)] dates = [dat for dat in date_range for i in range(6)] frequency = [freq_dict[dat][emotion] if (dat in freq_dict.keys() and emotion in freq_dict[dat].keys()) else 0 for dat in date_range for emotion in ['neutral', 'anger', 'fear', 'joy', 'love', 'sadness']] categories = [emotion for dat in date_range for emotion in ['neutral', 'anger', 'fear', 'joy', 'love', 'sadness']] data = pd.DataFrame({ 'Date': dates, 'Frequency': frequency, 'Emotion category': categories}) domain = ['neutral', 'anger', 'fear', 'joy', 'love', 'sadness'] range_ = ['#999999', '#b22222', '#663399', '#ffcc00', '#db7093', '#6495ed'] n = max(data['Frequency']) highlight = alt.selection( type='single', on='mouseover', fields=["Emotion category"], nearest=True) base = alt.Chart(data).encode( x ="Date:T", y=alt.Y("Frequency", scale=alt.Scale(domain=[0, (n + 9) // 10 * 10])), color=alt.Color("Emotion category", scale=alt.Scale(domain=domain, range=range_), legend=alt.Legend(orient='bottom', direction='horizontal'))) points = base.mark_circle().encode( opacity=alt.value(0), tooltip=[ alt.Tooltip('Emotion category', title='Emotion category'), alt.Tooltip('Date:T', title='Date'), alt.Tooltip('Frequency', title='Frequency') ]).add_selection(highlight) lines = base.mark_line().encode( size=alt.condition(~highlight, alt.value(1), alt.value(3))) plot = (points + lines).properties(width=600, height=350).interactive() return plot def showcase(dataset): # predictions file if dataset == "The Voice of Holland": file_output = "output/predictions_tvoh.txt" elif dataset == "Floodings": file_output = "output/predictions_floodings.txt" elif dataset == "COVID-19": file_output = "output/predictions_covid.txt" elif dataset == "Childcare Benefits": file_output = "output/predictions_toeslagen.txt" # freq bar plot freq_output = freq(file_output) # dist plot dist_output = dist(file_output) # peaks if dataset == "The Voice of Holland": peaks_output = pickle.load(open('output/peaks_tvoh.p', 'rb')) elif dataset == "Floodings": peaks_output = pickle.load(open('output/peaks_floodings.p', 'rb')) elif dataset == "COVID-19": peaks_output = pickle.load(open('output/peaks_covid.p', 'rb')) elif dataset == "Childcare Benefits": peaks_output = pickle.load(open('output/peaks_toeslagen.p', 'rb')) # topics if dataset == "The Voice of Holland": topics_output = pickle.load(open('output/topics_tvoh.p', 'rb')) elif dataset == "Floodings": topics_output = pickle.load(open('output/topics_floodings.p', 'rb')) elif dataset == "COVID-19": topics_output = pickle.load(open('output/topics_covid.p', 'rb')) elif dataset == "Childcare Benefits": topics_output = pickle.load(open('output/topics_toeslagen.p', 'rb')) return gr.update(visible=True), gr.update(value=file_output, visible=True), gr.update(value=freq_output,visible=True), gr.update(value=dist_output,visible=True), gr.update(value=peaks_output,visible=True), gr.update(value=topics_output,visible=True) inference_modelpath = "model/checkpoint-128" with gr.Blocks() as demo: with gr.Column(scale=1, min_width=50): gr.Markdown(""" """) with gr.Column(scale=5): gr.Markdown("""