RAG_db / app.py
Luciferalive's picture
Update app.py
c2f4356 verified
raw
history blame
8.08 kB
import gradio as gr
from langchain.chains import LLMChain
from langchain.prompts import PromptTemplate
from langchain_community.llms import HuggingFaceEndpoint
import fitz # PyMuPDF
import pytesseract
from PIL import Image
import io
import re
import numpy as np
import boto3
from typing import List
from sentence_transformers import SentenceTransformer
from langchain_community.vectorstores import Chroma
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.embeddings import SentenceTransformerEmbeddings
import os
# AWS access credentials
access_key = os.getenv("access_key")
secret_key = os.getenv("secret_key")
# S3 bucket details
bucket_name = os.getenv("bucket_name")
prefix = os.getenv("prefix")
HUGGINGFACEHUB_API_TOKEN = os.getenv("HUGGINGFACEHUB_API_TOKEN")
def extract_text_from_pdf(pdf_content):
"""Extract text from PDF content using OCR."""
try:
doc = fitz.open(stream=pdf_content, filetype="pdf")
text = ""
for page in doc:
pix = page.get_pixmap()
img = Image.open(io.BytesIO(pix.tobytes()))
text += pytesseract.image_to_string(img)
return text
except Exception as e:
print("Failed to extract text from PDF:", e)
return ""
def preprocess_text(text):
"""Preprocess text by cleaning and normalizing."""
try:
text = text.replace('\n', ' ').replace('\r', ' ')
text = re.sub(r'[^\x00-\x7F]+', ' ', text)
text = text.lower()
text = re.sub(r'[^\w\s]', '', text)
text = re.sub(r'\s+', ' ', text).strip()
return text
except Exception as e:
print("Failed to preprocess text:", e)
return ""
def process_files(file_contents: List[bytes]):
"""Process and combine text from multiple files."""
all_text = ""
for file_content in file_contents:
extracted_text = extract_text_from_pdf(file_content)
preprocessed_text = preprocess_text(extracted_text)
all_text += preprocessed_text + " "
return all_text
def compute_cosine_similarity_scores(query, retrieved_docs):
"""Compute cosine similarity scores between a query and retrieved documents."""
model = SentenceTransformer("sentence-transformers/all-mpnet-base-v2")
query_embedding = model.encode(query, convert_to_tensor=True)
doc_embeddings = model.encode(retrieved_docs, convert_to_tensor=True)
cosine_scores = np.dot(doc_embeddings.cpu(), query_embedding.cpu().T)
readable_scores = [{"doc": doc, "score": float(score)} for doc, score in zip(retrieved_docs, cosine_scores.flatten())]
return readable_scores
def fetch_files_from_s3():
"""Fetch files from an S3 bucket."""
s3 = boto3.client('s3', aws_access_key_id=access_key, aws_secret_access_key=secret_key)
objects = s3.list_objects_v2(Bucket=bucket_name, Prefix=prefix)
file_contents = []
for obj in objects.get('Contents', []):
if not obj['Key'].endswith('/'): # Skip directories
response = s3.get_object(Bucket=bucket_name, Key=obj['Key'])
file_content = response['Body'].read()
file_contents.append(file_content)
return file_contents
def create_vector_store(all_text):
"""Create a vector store for similarity-based searching."""
embeddings = SentenceTransformerEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2")
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
texts = text_splitter.split_text(all_text)
if not texts:
print("No text chunks created.")
return None
vector_store = Chroma.from_texts(texts, embeddings, collection_metadata={"hnsw:space": "cosine"}, persist_directory="stores/insurance_cosine")
print("Vector DB Successfully Created!")
return vector_store
def load_vector_store():
"""Load the vector store from the persistent directory."""
embeddings = SentenceTransformerEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2")
try:
db = Chroma(persist_directory="stores/insurance_cosine", embedding_function=embeddings)
print("Vector DB Successfully Loaded!")
return db
except Exception as e:
print("Failed to load Vector DB:", e)
return None
def answer_query_with_similarity(query):
"""Answer a query by finding similar documents and generating responses using a language model."""
try:
# Load the vector store
vector_store = load_vector_store()
# If vector store doesn't exist, fetch files from S3, process them, and create the vector store
if not vector_store:
file_contents = fetch_files_from_s3()
if not file_contents:
print("No files fetched from S3.")
return None
all_text = process_files(file_contents)
if not all_text.strip():
print("No text extracted from documents.")
return None
vector_store = create_vector_store(all_text)
if not vector_store:
print("Failed to create Vector DB.")
return None
# Perform similarity search
docs = vector_store.similarity_search(query)
print(f"\n\nDocuments retrieved: {len(docs)}")
if not docs:
print("No documents match the query.")
return None
docs_content = [doc.page_content for doc in docs]
# Compute cosine similarity scores
cosine_similarity_scores = compute_cosine_similarity_scores(query, docs_content)
all_docs_content = " ".join(docs_content)
# Generate response using a language model
template = """
### [INST] Instruction:
You are an AI assistant named Goose. Your purpose is to provide accurate, relevant, and helpful information to users in a friendly, warm, and supportive manner, similar to ChatGPT. When responding to queries, please keep the following guidelines in mind:
- When someone says hi, or small talk, only respond in a sentence.
- Retrieve relevant information from your knowledge base to formulate accurate and informative responses.
- Always maintain a positive, friendly, and encouraging tone in your interactions with users.
- Strictly write crisp and clear answers, don't write unnecessary stuff.
- Only answer the asked question, don't hallucinate or print any pre-information.
- After providing the answer, always ask for any other help needed in the next paragraph.
- Writing in bullet format is our top preference.
Remember, your goal is to be a reliable, friendly, and supportive AI assistant that provides accurate information while creating a positive user experience, just like ChatGPT. Adapt your communication style to best suit each user's needs and preferences.
### Docs: {docs}
### Question: {question}
"""
prompt = PromptTemplate.from_template(template.format(docs=all_docs_content, question=query))
repo_id = "mistralai/Mixtral-8x7B-Instruct-v0.1"
llm = HuggingFaceEndpoint(
repo_id=repo_id,
temperature=0.1,
model_kwargs={'token': HUGGINGFACEHUB_API_TOKEN},
top_p=0.15,
max_new_tokens=256,
repetition_penalty=1.1
)
llm_chain = LLMChain(prompt=prompt, llm=llm)
answer = llm_chain.run(question=query).strip()
print(f"\n\nAnswer: {answer}")
return answer
except Exception as e:
print("An error occurred while getting the answer: ", str(e))
return None
def gradio_interface(query):
return answer_query_with_similarity(query)
interface = gr.Interface(
fn=gradio_interface,
inputs=gr.Textbox(lines=2, placeholder="Enter your query here..."),
outputs="text",
title="Document Query App"
)
if __name__ == "__main__":
interface.launch()