Spaces:
Runtime error
Runtime error
import os | |
import tempfile | |
import uuid | |
import streamlit as st | |
from dotenv import load_dotenv | |
from qdrant_client import models | |
from utils import setup_openai_embeddings, setup_qdrant_client, delete_collection, is_document_embedded | |
from embed import embed_documents_into_qdrant | |
from preprocess import split_documents, update_metadata, load_documents_OCR | |
from retrieve import retrieve_documents_from_collection | |
from summarize import summarize_documents | |
# Load environment variables | |
load_dotenv() | |
def main(): | |
st.sidebar.title("PDF Management") | |
uploaded_files = st.sidebar.file_uploader("Upload PDF files", type=["pdf"], accept_multiple_files=True) | |
if 'uploaded_collection_name' not in st.session_state: | |
st.session_state['uploaded_collection_name'] = None | |
if uploaded_files: | |
if st.sidebar.button("Add Docs to Data Bank"): | |
files_info = save_uploaded_files(uploaded_files) | |
embed_documents_to_data_bank(files_info) | |
if st.sidebar.button("Add Docs to Current Chat"): | |
files_info = save_uploaded_files(uploaded_files) | |
add_docs_to_current_chat(files_info) | |
pages = { | |
"Lex Document Summarization": page_summarization, | |
"Chat with RSCA": page_qna, | |
"Chat with Uploaded Docs": page_chat_with_uploaded_docs, | |
"Chat with VOO": page_chat_with_voo | |
} | |
st.sidebar.title("Page Navigation") | |
page = st.sidebar.radio("Select a page", tuple(pages.keys())) | |
# Initialize session state for summarization results if not already set | |
if 'summaries' not in st.session_state: | |
st.session_state['summaries'] = {} | |
# Call the page function based on the user selection | |
if page: | |
pages[page](uploaded_files) | |
def save_uploaded_files(uploaded_files): | |
"""Save uploaded files to a temporary directory and return their file paths along with original filenames.""" | |
files_info = [] | |
for uploaded_file in uploaded_files: | |
with tempfile.NamedTemporaryFile(delete=False, suffix='.pdf') as tmpfile: | |
tmpfile.write(uploaded_file.getvalue()) | |
files_info.append((tmpfile.name, uploaded_file.name)) | |
return files_info | |
def page_summarization(uploaded_files): | |
"""Page for document summarization.""" | |
st.title("Lex Document Summarization") | |
if uploaded_files: | |
files_info = save_uploaded_files(uploaded_files) | |
for temp_path, original_name in files_info: | |
summary_button = st.button(f"Summarize {original_name}", key=original_name) | |
if summary_button or (original_name in st.session_state['summaries']): | |
with st.container(): | |
st.write(f"Summary for {original_name}:") | |
if summary_button: # Only summarize if the button is pressed | |
try: | |
documents = load_documents_OCR(temp_path, os.getenv('UNSTRUCTURED_API')) | |
summary = summarize_documents(documents, os.getenv('OPENAI_API_KEY')) | |
st.session_state['summaries'][original_name] = summary # Store summary in session state | |
except Exception as e: | |
st.error(f"Failed to summarize {original_name}: {str(e)}") | |
if original_name in st.session_state['summaries']: | |
st.text_area("", value=st.session_state['summaries'][original_name], height=1000, key=f"summary_{original_name}") | |
else: | |
st.error(f"No summary found for {original_name}. Please click the summarize button.") | |
def page_qna(uploaded_files): | |
"""Page for Q&A functionality.""" | |
st.title("Chat with RSCA") | |
user_query = st.text_area("Enter your question here:", height=300) | |
if st.button('Get Answer'): | |
if user_query: | |
answer = handle_query(user_query) | |
st.write(answer) | |
else: | |
st.error("Please enter a question to get an answer.") | |
def page_chat_with_uploaded_docs(uploaded_files): | |
"""Page for chatting with uploaded documents.""" | |
st.title("Chat with Uploaded Documents") | |
user_query = st.text_area("Enter your question here:", height=300) | |
if st.button('Get Answer'): | |
if user_query: | |
answer = handle_uploaded_docs_query(user_query, st.session_state['uploaded_collection_name']) | |
st.write(answer) | |
else: | |
st.error("Please enter a question to get an answer.") | |
if st.session_state['uploaded_collection_name']: | |
if st.button('Delete Embedded Collection'): | |
collection_name = st.session_state['uploaded_collection_name'] | |
delete_collection(collection_name, os.getenv('QDRANT_URL'), os.getenv('QDRANT_API_KEY')) | |
st.session_state['uploaded_collection_name'] = None | |
st.success(f"Deleted collection {collection_name}") | |
def page_chat_with_voo(uploaded_files): | |
"""Page for chatting with VOO documents.""" | |
st.title("Chat with VOO") | |
user_query = st.text_area("Enter your question here:", height=300) | |
if st.button('Get Answer'): | |
if user_query: | |
answer = handle_voo_query(user_query) | |
st.write(answer) | |
else: | |
st.error("Please enter a question to get an answer.") | |
def embed_documents_to_data_bank(files_info): | |
"""Function to embed documents into the data bank.""" | |
for temp_path, original_name in files_info: | |
if not is_document_embedded(original_name): | |
try: | |
documents = load_documents_OCR(temp_path, os.getenv('UNSTRUCTURED_API')) | |
documents = update_metadata(documents, original_name) | |
documents = split_documents(documents) | |
if documents: | |
embed_documents_into_qdrant(documents, os.getenv('OPENAI_API_KEY'), os.getenv('QDRANT_URL'), os.getenv('QDRANT_API_KEY'), 'Lex-v1') | |
st.success(f"Embedded {original_name} into Data Bank") | |
else: | |
st.error(f"No documents found or extracted from {original_name}") | |
except Exception as e: | |
st.error(f"Failed to embed {original_name}: {str(e)}") | |
else: | |
st.info(f"{original_name} is already embedded.") | |
def add_docs_to_current_chat(files_info): | |
"""Function to add documents to the current chat session.""" | |
if not st.session_state['uploaded_collection_name']: | |
st.session_state['uploaded_collection_name'] = f"session-{uuid.uuid4()}" | |
client = setup_qdrant_client(os.getenv('QDRANT_URL'), os.getenv('QDRANT_API_KEY')) | |
client.create_collection( | |
collection_name=st.session_state['uploaded_collection_name'], | |
vectors_config=models.VectorParams(size=1536, distance=models.Distance.COSINE) | |
) | |
else: | |
client = setup_qdrant_client(os.getenv('QDRANT_URL'), os.getenv('QDRANT_API_KEY')) | |
embeddings_model = setup_openai_embeddings(os.getenv('OPENAI_API_KEY')) | |
for temp_path, original_name in files_info: | |
if not is_document_embedded(original_name): | |
try: | |
documents = load_documents_OCR(temp_path, os.getenv('UNSTRUCTURED_API')) | |
documents = update_metadata(documents, original_name) | |
documents = split_documents(documents) | |
if documents: | |
embed_documents_into_qdrant(documents, os.getenv('OPENAI_API_KEY'), os.getenv('QDRANT_URL'), os.getenv('QDRANT_API_KEY'), collection_name=st.session_state['uploaded_collection_name']) | |
st.success(f"Embedded {original_name}") | |
else: | |
st.error(f"No documents found or extracted from {original_name}") | |
except Exception as e: | |
st.error(f"Failed to embed {original_name}: {str(e)}") | |
else: | |
st.info(f"{original_name} is already embedded.") | |
def handle_query(query): | |
"""Retrieve answers based on the query.""" | |
try: | |
answer = retrieve_documents_from_collection(query, os.getenv('OPENAI_API_KEY'), os.getenv('QDRANT_URL'), os.getenv('QDRANT_API_KEY'),'Lex-v1') | |
return answer or "No relevant answer found." | |
except Exception as e: | |
return f"Error processing the query: {str(e)}" | |
def handle_uploaded_docs_query(query, collection_name): | |
"""Retrieve answers from the uploaded documents collection.""" | |
try: | |
answer = retrieve_documents_from_collection(query, os.getenv('OPENAI_API_KEY'), os.getenv('QDRANT_URL'), os.getenv('QDRANT_API_KEY'), collection_name) | |
return answer or "No relevant answer found." | |
except Exception as e: | |
return f"Error processing the query: {str(e)}" | |
def handle_voo_query(query): | |
"""Retrieve answers from the VOO collection.""" | |
try: | |
answer = retrieve_documents_from_collection(query, os.getenv('OPENAI_API_KEY'), os.getenv('QDRANT_URL'), os.getenv('QDRANT_API_KEY'), 'Lex-v2') | |
return answer or "No relevant answer found." | |
except Exception as e: | |
return f"Error processing the query: {str(e)}" | |
def delete_collection(collection_name, qdrant_url, qdrant_api_key): | |
"""Delete a Qdrant collection.""" | |
client = setup_qdrant_client(qdrant_url, qdrant_api_key) | |
try: | |
client.delete_collection(collection_name=collection_name) | |
except Exception as e: | |
print("Failed to delete collection:", e) | |
if __name__ == "__main__": | |
main() |