askpkd / functions.py
Likith79an's picture
Update functions.py
1c085db verified
import os
import tempfile
import logging
import streamlit as st
from langchain.vectorstores import FAISS
from langchain.document_loaders import PyPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.embeddings import HuggingFaceEmbeddings
from huggingface_hub import InferenceClient
from sentence_transformers import CrossEncoder
from dotenv import load_dotenv
load_dotenv()
import datetime
# Initialize Logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Retrieve HF API Key and initialize HF client
# HF_API_KEY = os.getenv("HF_API_KEY")
# if not HF_API_KEY:
# raise ValueError("Hugging Face API key not found. Please set HF_API_KEY in the .env file.")
client = InferenceClient()
FALLBACK_MESSAGE = "Sorry, I didn’t understand your question. Do you want to connect with a live agent?"
FEEDBACK_ERROR = "Our servers are busy. Please try again later."
# Retrieval parameters
FAISS_TOP_K = 20 # Number of top chunks to retrieve from FAISS
RERANK_TOP_K = 6 # Number of top chunks to use for generating the answer
CHUNK_SIZE = 1500
CHUNK_OVERLAP = 100
# Threshold for re-ranker score
MIN_SCORE_THRESHOLD = -8.0
# PDF upload constraints
MAX_PDF_SIZE = 200 * 1024 * 1024
def build_vectorstore(pdf_file):
"""Loads and processes a single PDF file, then builds a new FAISS vectorstore."""
if pdf_file is None:
logger.warning("No PDF file provided to build_vectorstore.")
return 0, {}
# Clear existing vectorstore if any
st.session_state.vectorstore = None
# Check PDF size
if pdf_file.size > MAX_PDF_SIZE:
st.error(f"PDF exceeds the maximum allowed size of {MAX_PDF_SIZE / (1024 * 1024)} MB.")
logger.warning("Uploaded PDF exceeds the maximum allowed size.")
return 0, {}
try:
with tempfile.NamedTemporaryFile(delete=False, suffix=".pdf") as tmp_file:
tmp_file.write(pdf_file.getbuffer())
tmp_path = tmp_file.name
logger.info(f"Temporary PDF saved at {tmp_path}")
loader = PyPDFLoader(tmp_path)
docs = loader.load()
# Extract metadata
metadata = loader.metadata if hasattr(loader, 'metadata') else {}
os.remove(tmp_path)
logger.info("PDF loaded and temporary file removed.")
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=CHUNK_SIZE,
chunk_overlap=CHUNK_OVERLAP,
length_function=len
)
chunks = text_splitter.split_documents(docs)
num_chunks = len(chunks)
logger.info(f"PDF split into {num_chunks} chunks.")
# Build FAISS vectorstore
st.session_state.vectorstore = FAISS.from_documents(chunks, st.session_state.embeddings)
logger.info("FAISS vectorstore built successfully.")
return num_chunks, metadata
except Exception as e:
st.error("An error occurred while processing the PDF.")
logger.error(f"Error in build_vectorstore: {e}", exc_info=True)
return 0, {}
def retrieve(question, k=FAISS_TOP_K):
"""Retrieves FAISS_TOP_K number of documents based on vector similarity search from database."""
vs = st.session_state.vectorstore
retrieved_docs = vs.similarity_search(question, k=k)
if not retrieved_docs:
logger.info("No documents retrieved for the question.")
return []
logger.info(f"Retrieved {len(retrieved_docs)} documents from FAISS.")
return retrieved_docs
def sanitize_input(user_input):
"""Validates user input. """
if not user_input:
st.error("Input cannot be empty.")
return False
if len(user_input) > 1000:
st.error("Input exceeds the maximum allowed length of 1000 characters.")
return False
return True
def rerank_documents(question, documents):
"""Re-rank the retrieved documents using a cross-encoder reranker and returns RERANK_TOP_K number of documents according to relevance order"""
if not documents:
logger.info("No documents to rerank.")
return []
reranker = st.session_state.reranker
model_inputs = [(question, doc.page_content) for doc in documents]
try:
rerank_scores = reranker.predict(model_inputs, batch_size=16)
logger.info("Re-ranker scores generated successfully.")
except Exception as e:
logger.error(f"Error during reranking: {e}")
return []
# Pair docs with reranker scores and sort by reranker score descending
reranked_pairs = sorted(zip(documents, rerank_scores), key=lambda x: x[1], reverse=True)
logger.info("Documents reranked successfully.")
return reranked_pairs
def call_hf_api(prompt):
"""Calls the Hugging Face InferenceClient API for chat completion and returns the generated response from the LLM"""
messages = [
{"role": "system", "content": (
"You are a knowledgeable and helpful assistant. When answering, please:\n"
"- Use headings and subheadings for organization. Utilize bullet points or numbered lists for clarity.\n"
"- Provide factual information based solely on the provided context.\n"
"- Mention that schedules or dates may change if applicable.\n"
"- Clearly list accessibility or facilities details when relevant.\n"
"- Your output should only contain the relevant answer to the question.\n "
"- If the answer is not found in the context or the user input is not a valid question, respond only with: 'Sorry, I didn’t understand your question. Do you want to connect with a live agent?'.\n"
"- End responses by asking if the user needs further assistance.\n"
"- Maintain a professional and concise tone."
)},
{"role": "user", "content": prompt}
]
try:
completion = client.chat.completions.create(
model="mistralai/Mistral-7B-Instruct-v0.3",
messages=messages,
max_tokens=2000
)
response = completion.choices[0].message["content"].strip()
logger.info("LLM response generated successfully.")
return response
except Exception as e:
logger.error(f"Error during LLM call: {e}")
return FALLBACK_ERROR
def get_answer(question):
"""Main Logical block. Question -> Retrives chunks -> Re-ranks them -> Gives them to LLM to answer"""
if len(question.split()) < 3:
# Question has fewer than 3 words. Not a valid question.
return FALLBACK_MESSAGE
logger.info(f"Processing question: {question}")
retrieved_docs = retrieve(question, k=FAISS_TOP_K)
if not retrieved_docs:
logger.info("No relevant documents found. Triggering fallback.")
return FALLBACK_MESSAGE
reranked_pairs = rerank_documents(question, retrieved_docs)
if not reranked_pairs:
logger.info("Re-ranking failed or no documents after reranking. Triggering fallback.")
return FALLBACK_MESSAGE
# Check the top re-ranker score
top_doc, top_score = reranked_pairs[0]
logger.info(f"Top re-ranker score: {top_score}")
if top_score < MIN_SCORE_THRESHOLD:
logger.info("Top re-ranker score below threshold. Triggering fallback.")
return FALLBACK_MESSAGE
# Select top 5 chunks for the final context
top_chunks = reranked_pairs[:RERANK_TOP_K]
combined_context = "\n\n".join([doc.page_content for doc, _ in top_chunks])
# Construct the final prompt
final_prompt = (
f"Context:\n{combined_context}\n\n"
f"User Question:\n{question}\n\n"
"Instructions:\n"
"Answer the user's question based only on the provided context. Ensure your response is:\n"
"- Clear and concise.\n"
"- Well-structured with headings or subheadings.\n"
"- Formatted using bullet points, numbered lists, or short paragraphs.\n"
"- Inclusive of logical inferences if implied by the context.\n"
"- Free from unrelated information.\n"
"- Concluding with an offer for further assistance.\n\n"
"Answer:"
)
final_answer = call_hf_api(final_prompt).strip()
logger.info("Final answer generated.")
return final_answer if final_answer else FALLBACK_MESSAGE