AI-Docwhiz / app.py
sourabhzanwar's picture
changed cookie_expiry to 1 day
132310c
raw
history blame
11.2 kB
from utils.check_pydantic_version import use_pydantic_v1
use_pydantic_v1() #This function has to be run before importing haystack. as haystack requires pydantic v1 to run
from operator import index
import streamlit as st
import logging
import os
from annotated_text import annotation
from json import JSONDecodeError
from markdown import markdown
from utils.config import parser
from utils.haystack import start_document_store, query, initialize_pipeline, start_preprocessor_node, start_retriever, start_reader
from utils.ui import reset_results, set_initial_state
import pandas as pd
import haystack
from datetime import datetime
import streamlit_authenticator as stauth
import pickle
names = ['admin']
usernames = ['admin']
with open('hashed_password.pkl','rb') as f:
hashed_passwords = pickle.load(f)
# Whether the file upload should be enabled or not
DISABLE_FILE_UPLOAD = bool(os.getenv("DISABLE_FILE_UPLOAD"))
# Define a function to handle file uploads
def upload_files():
uploaded_files = upload_container.file_uploader(
"upload", type=["pdf", "txt", "docx"], accept_multiple_files=True, label_visibility="hidden"
)
return uploaded_files
# Define a function to process a single file
def process_file(data_file, preprocesor, document_store):
# read file and add content
file_contents = data_file.read().decode("utf-8")
docs = [{
'content': str(file_contents),
'meta': {'name': str(data_file.name)}
}]
try:
names = [item.meta.get('name') for item in document_store.get_all_documents()]
#if args.store == 'inmemory':
# doc = converter.convert(file_path=files, meta=None)
if data_file.name in names:
print(f"{data_file.name} already processed")
else:
print(f'preprocessing uploaded doc {data_file.name}.......')
#print(data_file.read().decode("utf-8"))
preprocessed_docs = preprocesor.process(docs)
print('writing to document store.......')
document_store.write_documents(preprocessed_docs)
print('updating emebdding.......')
document_store.update_embeddings(retriever)
except Exception as e:
print(e)
# Define a function to upload the documents to haystack document store
def upload_document():
print(f'Uploading document store at {datetime.now()}')
upload_status = 0
if data_files is not None:
for data_file in data_files:
# Upload file
if data_file:
try:
#raw_json = upload_doc(data_file)
# Call the process_file function for each uploaded file
if args.store == 'inmemory':
processed_data = process_file(data_file, preprocesor, document_store)
upload_container.write(str(data_file.name) + "    βœ… ")
except Exception as e:
upload_container.write(str(data_file.name) + "    ❌ ")
upload_container.write("_This file could not be parsed, see the logs for more information._")
# Define a function to reset the documents in haystack document store
def reset_documents():
print('\nReseting documents list at ' + str(datetime.now()) + '\n')
document_store.delete_documents()
try:
args = parser.parse_args()
preprocesor = start_preprocessor_node()
document_store = start_document_store(type=args.store)
retriever = start_retriever(document_store)
reader = start_reader()
st.set_page_config(
page_title="MLReplySearch",
layout="centered",
page_icon=":shark:",
menu_items={
'Get Help': 'https://www.extremelycoolapp.com/help',
'Report a bug': "https://www.extremelycoolapp.com/bug",
'About': "# This is a header. This is an *extremely* cool app!"
}
)
st.sidebar.image("ml_logo.png", use_column_width=True)
authenticator = stauth.Authenticate(names, usernames, hashed_passwords, "document_search", "random_text", cookie_expiry_days=1)
name, authentication_status, username = authenticator.login("Login", "main")
if authentication_status == False:
st.error("Username/Password is incorrect")
if authentication_status == None:
st.warning("Please enter youe username and password")
if authentication_status:
# Sidebar for Task Selection
st.sidebar.header('Options:')
# OpenAI Key Input
openai_key = st.sidebar.text_input("Enter OpenAI Key:", type="password")
if openai_key:
task_options = ['Extractive', 'Generative']
else:
task_options = ['Extractive']
task_selection = st.sidebar.radio('Select the task:', task_options)
# Check the task and initialize pipeline accordingly
if task_selection == 'Extractive':
pipeline_extractive = initialize_pipeline("extractive", document_store, retriever, reader)
elif task_selection == 'Generative' and openai_key: # Check for openai_key to ensure user has entered it
pipeline_rag = initialize_pipeline("rag", document_store, retriever, reader, openai_key=openai_key)
set_initial_state()
st.write('# ' + args.name)
# File upload block
if not DISABLE_FILE_UPLOAD:
upload_container = st.sidebar.container()
upload_container.write("## File Upload:")
data_files = upload_files()
# Button to update files in the documentStore
upload_container.button('Upload Files', on_click=upload_document, args=())
# Button to reset the documents in DocumentStore
st.sidebar.button("Reset documents", on_click=reset_documents, args=())
if "question" not in st.session_state:
st.session_state.question = ""
# Search bar
question = st.text_input("Question", value=st.session_state.question, max_chars=100, on_change=reset_results, label_visibility="hidden")
run_pressed = st.button("Run")
run_query = (
run_pressed or question != st.session_state.question #or task_selection != st.session_state.task
)
# Get results for query
if run_query and question:
if task_selection == 'Extractive':
reset_results()
st.session_state.question = question
with st.spinner("πŸ”Ž    Running your pipeline"):
try:
st.session_state.results_extractive = query(pipeline_extractive, question)
st.session_state.task = task_selection
except JSONDecodeError as je:
st.error(
"πŸ‘“    An error occurred reading the results. Is the document store working?"
)
except Exception as e:
logging.exception(e)
st.error("🐞    An error occurred during the request.")
elif task_selection == 'Generative':
reset_results()
st.session_state.question = question
with st.spinner("πŸ”Ž    Running your pipeline"):
try:
st.session_state.results_generative = query(pipeline_rag, question)
st.session_state.task = task_selection
except JSONDecodeError as je:
st.error(
"πŸ‘“    An error occurred reading the results. Is the document store working?"
)
except Exception as e:
if "API key is invalid" in str(e):
logging.exception(e)
st.error("🐞    incorrect API key provided. You can find your API key at https://platform.openai.com/account/api-keys.")
else:
logging.exception(e)
st.error("🐞    An error occurred during the request.")
# Display results
if (st.session_state.results_extractive or st.session_state.results_generative) and run_query:
# Handle Extractive Answers
if task_selection == 'Extractive':
results = st.session_state.results_extractive
st.subheader("Extracted Answers:")
if 'answers' in results:
answers = results['answers']
treshold = 0.2
higher_then_treshold = any(ans.score > treshold for ans in answers)
if not higher_then_treshold:
st.markdown(f"<span style='color:red'>Please note none of the answers achieved a score higher then {int(treshold) * 100}%. Which probably means that the desired answer is not in the searched documents.</span>", unsafe_allow_html=True)
for count, answer in enumerate(answers):
if answer.answer:
text, context = answer.answer, answer.context
start_idx = context.find(text)
end_idx = start_idx + len(text)
score = round(answer.score, 3)
st.markdown(f"**Answer {count + 1}:**")
st.markdown(
context[:start_idx] + str(annotation(body=text, label=f'SCORE {score}', background='#964448', color='#ffffff')) + context[end_idx:],
unsafe_allow_html=True,
)
else:
st.info(
"πŸ€” &nbsp;&nbsp; Haystack is unsure whether any of the documents contain an answer to your question. Try to reformulate it!"
)
# Handle Generative Answers
elif task_selection == 'Generative':
results = st.session_state.results_generative
st.subheader("Generated Answer:")
if 'results' in results:
st.markdown("**Answer:**")
st.write(results['results'][0])
# Handle Retrieved Documents
if 'documents' in results:
retrieved_documents = results['documents']
st.subheader("Retriever Results:")
data = []
for i, document in enumerate(retrieved_documents):
# Truncate the content
truncated_content = (document.content[:150] + '...') if len(document.content) > 150 else document.content
data.append([i + 1, document.meta['name'], truncated_content])
# Convert data to DataFrame and display using Streamlit
df = pd.DataFrame(data, columns=['Ranked Context', 'Document Name', 'Content'])
st.table(df)
except SystemExit as e:
os._exit(e.code)