saadob12's picture
Update app.py
a154d32
raw
history blame contribute delete
No virus
4.71 kB
import streamlit as st
import torch
import pandas as pd
from io import StringIO
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
class preProcess:
def __init__(self, filename, titlename):
self.filename = filename
self.title = titlename + '\n'
def read_data(self):
df = pd.read_csv(self.filename)
return df
def check_columns(self, df):
if (len(df.columns) > 4):
st.error('File has more than 3 coloumns.')
return False
if (len(df.columns) == 0):
st.error('File has no column.')
return False
else:
return True
def format_data(self, df):
headers = [[] for i in range(0, len(df.columns))]
for i in range(len(df.columns)):
headers[i] = list(df[df.columns[i]])
zipped = list(zip(*headers))
res = [' '.join(map(str,tups)) for tups in zipped]
if len(df.columns) < 3:
input_format = ' x-y values ' + ' - '.join(list(df.columns)) + ' values ' + ' , '.join(res)
else:
input_format = ' labels ' + ' - '.join(list(df.columns)) + ' values ' + ' , '.join(res)
return input_format
def combine_title_data(self,df):
data = self.format_data(df)
title_data = ' '.join([self.title,data])
return title_data
class Model:
def __init__(self,text,mode):
self.padding = 'max_length'
self.truncation = True
self.prefix = 'C2T: '
self.device = device = "cuda:0" if torch.cuda.is_available() else "cpu"
self.text = text
if mode.lower() == 'simple':
self.tokenizer = AutoTokenizer.from_pretrained('saadob12/t5_C2T_big')
self.model = AutoModelForSeq2SeqLM.from_pretrained('saadob12/t5_C2T_big').to(self.device)
elif mode.lower() == 'analytical':
self.tokenizer = AutoTokenizer.from_pretrained('saadob12/t5_autochart_2')
self.model = AutoModelForSeq2SeqLM.from_pretrained('saadob12/t5_autochart_2').to(self.device)
def generate(self):
tokens = self.tokenizer.encode(self.prefix + self.text, truncation=self.truncation, padding=self.padding, return_tensors='pt').to(self.device)
generated = self.model.generate(tokens, num_beams=4, max_length=256)
tgt_text = self.tokenizer.decode(generated[0], skip_special_tokens=True, clean_up_tokenization_spaces=True)
summary = str(tgt_text).strip('[]""')
if 'barchart' in summary:
summary.replace('barchart','statistic')
elif 'bar graph' in summary:
summary.replace('bar graph','statistic')
elif 'bar plot' in summary:
summary.replace('bar plot','statistic')
elif 'scatter plot' in summary:
summary.replace('scatter plot','statistic')
elif 'scatter graph' in summary:
summary.replace('scatter graph','statistic')
elif 'scatterchart' in summary:
summary.replace('scatter chart','statistic')
elif 'line plot' in summary:
summary.replace('line plot','statistic')
elif 'line graph' in summary:
summary.replace('line graph','statistic')
elif 'linechart' in summary:
summary.replace('linechart','statistic')
if 'graph' in summary:
summary.replace('graph','statistic')
return summary
st.title('Chart and Data Summarization')
st.write('This application generates a summary of a datafile (.csv) (or the underlying data of a chart). Right now, it only generates summaries of files with maximum of four columns. If the file contains more than four columns, the app will throw an error.')
mode = st.selectbox('What kind of summary do you want?',
('Simple', 'Analytical'))
st.write('You selected: ' + mode + ' summary.')
title = st.text_input('Add appropriate Title of the .csv file', 'State minimum wage rates in the United States as of January 1 , 2020')
st.write('Title of the file is: ' + title)
uploaded_file = st.file_uploader("Upload only .csv file")
if uploaded_file is not None and mode is not None and title is not None:
st.write('Preprocessing file...')
p = preProcess(uploaded_file, title)
contents = p.read_data()
check = p.check_columns(contents)
if check:
st.write('Your file contents:\n')
st.write(contents)
title_data = p.combine_title_data(contents)
st.write('Linearized input format of the data file:\n ')
st.markdown('**'+ title_data + '**')
st.write('Loading model...')
model = Model(title_data, mode)
st.write('Model loading done!\nGenerating Summary...')
summary = model.generate()
st.write('Generated Summary:\n')
st.markdown('**'+ summary + '**')