storytelling / src /create_statistics.py
jitesh's picture
Adds "Check Logs" mode to analyse the log data
bfa0c67
raw
history blame
4.32 kB
import random
import numpy as np
import plotly.express as px
import streamlit as st
import xlsxwriter
import pandas as pd
from .lib import initialise_storytelling, set_input
import io
def run_create_statistics(gen, container_guide, container_param, container_button):
first_sentence, first_emotion, length = initialise_storytelling(
gen, container_guide, container_param, container_button)
# story_till_now = first_sentence
num_generation = set_input(container_param,
label='Number of generation', min_value=1, max_value=100, value=5, step=1,
key_slider='num_generation_slider', key_input='num_generation_input',)
num_tests = set_input(container_param,
label='Number of tests', min_value=1, max_value=1000, value=3, step=1,
key_slider='num_tests_slider', key_input='num_tests_input',)
reaction_weight_mode = container_param.radio(
"Reaction Weight w:", ["Random", "Fixed"])
if reaction_weight_mode == "Fixed":
reaction_weight = set_input(container_param,
label='Reaction Weight w', min_value=0.0, max_value=1.0, value=0.5, step=0.01,
key_slider='w_slider', key_input='w_input',)
elif reaction_weight_mode == "Random":
reaction_weight = -1
if container_button.button('Analyse'):
gen.get_stats(story_till_now=first_sentence,
num_generation=num_generation, length=length, reaction_weight=reaction_weight, num_tests=num_tests)
# if len(gen.stories) > 0:
# for si, story in enumerate(gen.stories):
# st.markdown(f'### Story no. {si}:', unsafe_allow_html=False)
# st.markdown(story, unsafe_allow_html=False)
# data=gen.stats_df[gen.stats_df.sentence_no==3]
# fig = px.violin(data_frame=data, x="reaction_weight", y="num_reactions", hover_data=data.columns)
# st.plotly_chart(fig, use_container_width=True)
# fig2 = px.box(data_frame=data, x="reaction_weight", y="num_reactions", hover_data=data.columns)
# st.plotly_chart(fig2, use_container_width=True)
if len(gen.data) > 0:
for si, story in enumerate(gen.data):
st.markdown(f'### Story {si}:', unsafe_allow_html=False)
for i, sentence in enumerate(story):
col_turn, col_sentence, col_emo = st.columns([1, 8, 2])
col_turn.markdown(
sentence['turn'], unsafe_allow_html=False)
col_sentence.markdown(
sentence['sentence'], unsafe_allow_html=False)
col_emo.markdown(
f'{sentence["emotion"]} {np.round(sentence["confidence_score"], 3)}', unsafe_allow_html=False)
st.table(data=gen.stats_df, )
data = gen.stats_df[gen.stats_df.sentence_no == 3]
fig = px.violin(data_frame=data, x="reaction_weight",
y="num_reactions", hover_data=data.columns)
st.plotly_chart(fig, use_container_width=True)
fig2 = px.box(data_frame=data, x="reaction_weight",
y="num_reactions", hover_data=data.columns)
st.plotly_chart(fig2, use_container_width=True)
# csv = gen.stats_df.to_csv().encode('utf-8')
buffer = io.BytesIO()
with pd.ExcelWriter(buffer, engine='xlsxwriter') as writer:
# Write each dataframe to a different worksheet.
gen.stats_df.to_excel(writer, sheet_name='AllData')
# Close the Pandas Excel writer and output the Excel file to the buffer
writer.save()
st.download_button(
label="Download data",
data=buffer,
file_name='data.xlsx',
mime='application/vnd.ms-excel',
)
else:
container_guide.markdown(
'### You selected statistics. Now set your parameters and click the `Analyse` button.')