Spaces:
Sleeping
Sleeping
import os | |
import tempfile | |
import logging | |
import streamlit as st | |
from langchain.vectorstores import FAISS | |
from langchain.document_loaders import PyPDFLoader | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from langchain_community.embeddings import HuggingFaceEmbeddings | |
from huggingface_hub import InferenceClient | |
from sentence_transformers import CrossEncoder | |
from dotenv import load_dotenv | |
load_dotenv() | |
import datetime | |
# Initialize Logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
# Retrieve HF API Key and initialize HF client | |
# HF_API_KEY = os.getenv("HF_API_KEY") | |
# if not HF_API_KEY: | |
# raise ValueError("Hugging Face API key not found. Please set HF_API_KEY in the .env file.") | |
client = InferenceClient() | |
FALLBACK_MESSAGE = "Sorry, I didn’t understand your question. Do you want to connect with a live agent?" | |
FEEDBACK_ERROR = "Our servers are busy. Please try again later." | |
# Retrieval parameters | |
FAISS_TOP_K = 20 # Number of top chunks to retrieve from FAISS | |
RERANK_TOP_K = 6 # Number of top chunks to use for generating the answer | |
CHUNK_SIZE = 1500 | |
CHUNK_OVERLAP = 100 | |
# Threshold for re-ranker score | |
MIN_SCORE_THRESHOLD = -8.0 | |
# PDF upload constraints | |
MAX_PDF_SIZE = 200 * 1024 * 1024 | |
def build_vectorstore(pdf_file): | |
"""Loads and processes a single PDF file, then builds a new FAISS vectorstore.""" | |
if pdf_file is None: | |
logger.warning("No PDF file provided to build_vectorstore.") | |
return 0, {} | |
# Clear existing vectorstore if any | |
st.session_state.vectorstore = None | |
# Check PDF size | |
if pdf_file.size > MAX_PDF_SIZE: | |
st.error(f"PDF exceeds the maximum allowed size of {MAX_PDF_SIZE / (1024 * 1024)} MB.") | |
logger.warning("Uploaded PDF exceeds the maximum allowed size.") | |
return 0, {} | |
try: | |
with tempfile.NamedTemporaryFile(delete=False, suffix=".pdf") as tmp_file: | |
tmp_file.write(pdf_file.getbuffer()) | |
tmp_path = tmp_file.name | |
logger.info(f"Temporary PDF saved at {tmp_path}") | |
loader = PyPDFLoader(tmp_path) | |
docs = loader.load() | |
# Extract metadata | |
metadata = loader.metadata if hasattr(loader, 'metadata') else {} | |
os.remove(tmp_path) | |
logger.info("PDF loaded and temporary file removed.") | |
text_splitter = RecursiveCharacterTextSplitter( | |
chunk_size=CHUNK_SIZE, | |
chunk_overlap=CHUNK_OVERLAP, | |
length_function=len | |
) | |
chunks = text_splitter.split_documents(docs) | |
num_chunks = len(chunks) | |
logger.info(f"PDF split into {num_chunks} chunks.") | |
# Build FAISS vectorstore | |
st.session_state.vectorstore = FAISS.from_documents(chunks, st.session_state.embeddings) | |
logger.info("FAISS vectorstore built successfully.") | |
return num_chunks, metadata | |
except Exception as e: | |
st.error("An error occurred while processing the PDF.") | |
logger.error(f"Error in build_vectorstore: {e}", exc_info=True) | |
return 0, {} | |
def retrieve(question, k=FAISS_TOP_K): | |
"""Retrieves FAISS_TOP_K number of documents based on vector similarity search from database.""" | |
vs = st.session_state.vectorstore | |
retrieved_docs = vs.similarity_search(question, k=k) | |
if not retrieved_docs: | |
logger.info("No documents retrieved for the question.") | |
return [] | |
logger.info(f"Retrieved {len(retrieved_docs)} documents from FAISS.") | |
return retrieved_docs | |
def sanitize_input(user_input): | |
"""Validates user input. """ | |
if not user_input: | |
st.error("Input cannot be empty.") | |
return False | |
if len(user_input) > 1000: | |
st.error("Input exceeds the maximum allowed length of 1000 characters.") | |
return False | |
return True | |
def rerank_documents(question, documents): | |
"""Re-rank the retrieved documents using a cross-encoder reranker and returns RERANK_TOP_K number of documents according to relevance order""" | |
if not documents: | |
logger.info("No documents to rerank.") | |
return [] | |
reranker = st.session_state.reranker | |
model_inputs = [(question, doc.page_content) for doc in documents] | |
try: | |
rerank_scores = reranker.predict(model_inputs, batch_size=16) | |
logger.info("Re-ranker scores generated successfully.") | |
except Exception as e: | |
logger.error(f"Error during reranking: {e}") | |
return [] | |
# Pair docs with reranker scores and sort by reranker score descending | |
reranked_pairs = sorted(zip(documents, rerank_scores), key=lambda x: x[1], reverse=True) | |
logger.info("Documents reranked successfully.") | |
return reranked_pairs | |
def call_hf_api(prompt): | |
"""Calls the Hugging Face InferenceClient API for chat completion and returns the generated response from the LLM""" | |
messages = [ | |
{"role": "system", "content": ( | |
"You are a knowledgeable and helpful assistant. When answering, please:\n" | |
"- Use headings and subheadings for organization. Utilize bullet points or numbered lists for clarity.\n" | |
"- Provide factual information based solely on the provided context.\n" | |
"- Mention that schedules or dates may change if applicable.\n" | |
"- Clearly list accessibility or facilities details when relevant.\n" | |
"- Your output should only contain the relevant answer to the question.\n " | |
"- If the answer is not found in the context or the user input is not a valid question, respond only with: 'Sorry, I didn’t understand your question. Do you want to connect with a live agent?'.\n" | |
"- End responses by asking if the user needs further assistance.\n" | |
"- Maintain a professional and concise tone." | |
)}, | |
{"role": "user", "content": prompt} | |
] | |
try: | |
completion = client.chat.completions.create( | |
model="mistralai/Mistral-7B-Instruct-v0.3", | |
messages=messages, | |
max_tokens=2000 | |
) | |
response = completion.choices[0].message["content"].strip() | |
logger.info("LLM response generated successfully.") | |
return response | |
except Exception as e: | |
logger.error(f"Error during LLM call: {e}") | |
return FALLBACK_ERROR | |
def get_answer(question): | |
"""Main Logical block. Question -> Retrives chunks -> Re-ranks them -> Gives them to LLM to answer""" | |
if len(question.split()) < 3: | |
# Question has fewer than 3 words. Not a valid question. | |
return FALLBACK_MESSAGE | |
logger.info(f"Processing question: {question}") | |
retrieved_docs = retrieve(question, k=FAISS_TOP_K) | |
if not retrieved_docs: | |
logger.info("No relevant documents found. Triggering fallback.") | |
return FALLBACK_MESSAGE | |
reranked_pairs = rerank_documents(question, retrieved_docs) | |
if not reranked_pairs: | |
logger.info("Re-ranking failed or no documents after reranking. Triggering fallback.") | |
return FALLBACK_MESSAGE | |
# Check the top re-ranker score | |
top_doc, top_score = reranked_pairs[0] | |
logger.info(f"Top re-ranker score: {top_score}") | |
if top_score < MIN_SCORE_THRESHOLD: | |
logger.info("Top re-ranker score below threshold. Triggering fallback.") | |
return FALLBACK_MESSAGE | |
# Select top 5 chunks for the final context | |
top_chunks = reranked_pairs[:RERANK_TOP_K] | |
combined_context = "\n\n".join([doc.page_content for doc, _ in top_chunks]) | |
# Construct the final prompt | |
final_prompt = ( | |
f"Context:\n{combined_context}\n\n" | |
f"User Question:\n{question}\n\n" | |
"Instructions:\n" | |
"Answer the user's question based only on the provided context. Ensure your response is:\n" | |
"- Clear and concise.\n" | |
"- Well-structured with headings or subheadings.\n" | |
"- Formatted using bullet points, numbered lists, or short paragraphs.\n" | |
"- Inclusive of logical inferences if implied by the context.\n" | |
"- Free from unrelated information.\n" | |
"- Concluding with an offer for further assistance.\n\n" | |
"Answer:" | |
) | |
final_answer = call_hf_api(final_prompt).strip() | |
logger.info("Final answer generated.") | |
return final_answer if final_answer else FALLBACK_MESSAGE | |