leavoigt's picture
Update appStore/target.py
84050ab
raw
history blame
No virus
3.16 kB
# set path
import glob, os, sys;
sys.path.append('../utils')
#import needed libraries
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import streamlit as st
from utils.target_classifier import load_targetClassifier, target_classification
import logging
logger = logging.getLogger(__name__)
from utils.config import get_classifier_params
from utils.preprocessing import paraLengthCheck
from io import BytesIO
import xlsxwriter
import plotly.express as px
from utils.target_classifier import label_dict
# Declare all the necessary variables
classifier_identifier = 'target'
params = get_classifier_params(classifier_identifier)
@st.cache_data
def to_excel(df,sectorlist):
len_df = len(df)
output = BytesIO()
writer = pd.ExcelWriter(output, engine='xlsxwriter')
df.to_excel(writer, index=False, sheet_name='Sheet1')
workbook = writer.book
worksheet = writer.sheets['Sheet1']
worksheet.data_validation('S2:S{}'.format(len_df),
{'validate': 'list',
'source': ['No', 'Yes', 'Discard']})
worksheet.data_validation('X2:X{}'.format(len_df),
{'validate': 'list',
'source': sectorlist + ['Blank']})
worksheet.data_validation('T2:T{}'.format(len_df),
{'validate': 'list',
'source': sectorlist + ['Blank']})
worksheet.data_validation('U2:U{}'.format(len_df),
{'validate': 'list',
'source': sectorlist + ['Blank']})
worksheet.data_validation('V2:V{}'.format(len_df),
{'validate': 'list',
'source': sectorlist + ['Blank']})
worksheet.data_validation('W2:U{}'.format(len_df),
{'validate': 'list',
'source': sectorlist + ['Blank']})
writer.save()
processed_data = output.getvalue()
return processed_data
def app():
### Main app code ###
with st.container():
if 'key1' in st.session_state:
# Load the existing dataset
df = st.session_state.key1
st.write("This is key 1 - utils")
st.write(df.head())
# Load the classifier model
classifier = load_targetClassifier(classifier_name=params['model_name'])
st.session_state['{}_classifier'.format(classifier_identifier)] = classifier
# test
if "target_classifier" not in st.session_state:
st.write("target classifier not saved :(")
df = target_classification(haystack_doc=df,
threshold= params['threshold'])
st.write("This is the second part")
st.write(df)
st.session_state.key1 = df
def target_display():
# Assign dataframe a name
df = st.session_state['key1']
st.write(df)