|
import gradio as gr |
|
|
|
|
|
|
|
import warnings |
|
warnings.filterwarnings("ignore") |
|
|
|
import os, requests, shutil |
|
from collections import defaultdict |
|
from itertools import chain |
|
|
|
from langchain.document_loaders import TextLoader |
|
from langchain.text_splitter import CharacterTextSplitter, RecursiveCharacterTextSplitter |
|
from langchain.embeddings import HuggingFaceEmbeddings |
|
from langchain.vectorstores import Chroma |
|
from langchain.llms import HuggingFaceEndpoint |
|
from langchain.storage import InMemoryStore |
|
from langchain.chains import LLMChain |
|
from langchain.prompts import PromptTemplate |
|
from langchain.retrievers import ParentDocumentRetriever, BM25Retriever |
|
from langchain.retrievers.document_compressors import LLMChainExtractor, LLMChainFilter, EmbeddingsFilter |
|
from langchain_community.document_loaders import PyMuPDFLoader |
|
from langchain.prompts import PromptTemplate |
|
|
|
|
|
HF_READ_API_KEY = os.environ["HF_READ_API_KEY"] |
|
|
|
def get_text(docs): |
|
return ['Result ' + str(i+1) + '\n' + d.page_content + '\n' for i, d in enumerate(docs)] |
|
|
|
def load_pdf(path): |
|
loader = PyMuPDFLoader(path) |
|
docs = loader.load() |
|
|
|
return docs, 'PDF loaded successfully' |
|
|
|
|
|
def multi_query_retrieval(query, llm, retriever): |
|
DEFAULT_QUERY_PROMPT = PromptTemplate( |
|
input_variables=["question"], |
|
template="""You are an AI assistant. Generate 3 different versions of the given question to retrieve relevant docs. |
|
Provide these alternative questions separated by newlines. |
|
Original question: {question}""", |
|
) |
|
mq_llm_chain = LLMChain(llm=llm, prompt=DEFAULT_QUERY_PROMPT) |
|
|
|
generated_queries = mq_llm_chain.invoke(query)['text'].split("\n") |
|
all_queries = [query] + generated_queries |
|
|
|
all_retrieved_docs = [] |
|
for q in all_queries: |
|
retrieved_docs = retriever.get_relevant_documents(q) |
|
all_retrieved_docs.extend(retrieved_docs) |
|
|
|
unique_retrieved_docs = [doc for i, doc in enumerate(all_retrieved_docs) if doc not in all_retrieved_docs[:i]] |
|
|
|
return get_text(unique_retrieved_docs) |
|
|
|
def compressed_retrieval(query, llm, retriever, extractor_type='chain', embedding_model=None): |
|
retrieved_docs = retriever.get_relevant_documents(query) |
|
if extractor_type == 'chain': |
|
extractor = LLMChainExtractor.from_llm(llm) |
|
elif extractor_type == 'filter': |
|
extractor = LLMChainFilter.from_llm(llm) |
|
elif extractor_type == 'embeddings': |
|
if embedding_model is None: |
|
raise ValueError("Embeddings model must be provided for embeddings extractor.") |
|
extractor = EmbeddingsFilter(embeddings=embedding_model, similarity_threshold=0.5) |
|
else: |
|
raise ValueError("Invalid extractor_type. Options are 'chain', 'filter', or 'embeddings'.") |
|
compressed_docs = extractor.compress_documents(retrieved_docs, query) |
|
return get_text(compressed_docs) |
|
|
|
def unique_by_key(iterable, key_func): |
|
seen = set() |
|
for element in iterable: |
|
key = key_func(element) |
|
if key not in seen: |
|
seen.add(key) |
|
yield element |
|
|
|
def ensemble_retrieval(query, retrievers_list, c=60): |
|
retrieved_docs_by_retriever = [retriever.get_relevant_documents(query) for retriever in retrievers_list] |
|
weights = [1 / len(retrievers_list)] * len(retrievers_list) |
|
rrf_score = defaultdict(float) |
|
for doc_list, weight in zip(retrieved_docs_by_retriever, weights): |
|
for rank, doc in enumerate(doc_list, start=1): |
|
rrf_score[doc.page_content] += weight / (rank + c) |
|
|
|
all_docs = chain.from_iterable(retrieved_docs_by_retriever) |
|
sorted_docs = sorted( |
|
unique_by_key(all_docs, lambda doc: doc.page_content), |
|
key=lambda doc: rrf_score[doc.page_content], |
|
reverse=True |
|
) |
|
return get_text(sorted_docs) |
|
|
|
def long_context_reorder_retrieval(query, retriever): |
|
retrieved_docs = retriever.get_relevant_documents(query) |
|
retrieved_docs.reverse() |
|
reordered_results = [] |
|
for i, doc in enumerate(retrieved_docs): |
|
if i % 2 == 1: |
|
reordered_results.append(doc) |
|
else: |
|
reordered_results.insert(0, doc) |
|
return get_text(reordered_results) |
|
|
|
def process_query(docs, query, embedding_model, inference_model, retrieval_method, chunk_size, chunk_overlap, max_new_tokens, temperature, top_p): |
|
|
|
|
|
chunking_parameters = {'chunk_size': chunk_size, 'chunk_overlap': chunk_overlap} |
|
inference_model_params = {'max_new_tokens': max_new_tokens, 'temperature': temperature, 'top_p': top_p} |
|
|
|
text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunking_parameters['chunk_size'], chunk_overlap=chunking_parameters['chunk_overlap']) |
|
|
|
texts = text_splitter.split_documents(docs) |
|
|
|
hf = HuggingFaceEmbeddings(model_name=embedding_model) |
|
vector_db_from_docs = Chroma.from_documents(texts, hf) |
|
simple_retriever = vector_db_from_docs.as_retriever(search_kwargs={"k": 5}) |
|
|
|
llm_model = HuggingFaceEndpoint(repo_id=inference_model, |
|
max_new_tokens=inference_model_params['max_new_tokens'], |
|
temperature=inference_model_params['temperature'], |
|
top_p=inference_model_params['top_p'], |
|
huggingfacehub_api_token=HF_READ_API_KEY) |
|
|
|
if retrieval_method == "Simple": |
|
retrieved_docs = simple_retriever.get_relevant_documents(query) |
|
result = get_text(retrieved_docs) |
|
elif retrieval_method == "Parent & Child": |
|
parent_text_splitter = child_text_splitter = text_splitter |
|
vector_db = Chroma(collection_name="parent_child", embedding_function=hf) |
|
store = InMemoryStore() |
|
pr_retriever = ParentDocumentRetriever( |
|
vectorstore=vector_db, |
|
docstore=store, |
|
child_splitter=child_text_splitter, |
|
parent_splitter=parent_text_splitter, |
|
) |
|
pr_retriever.add_documents(docs) |
|
retrieved_docs = pr_retriever.get_relevant_documents(query) |
|
result = get_text(retrieved_docs) |
|
elif retrieval_method == "Multi Query": |
|
result = multi_query_retrieval(query, llm_model, simple_retriever) |
|
elif retrieval_method == "Contextual Compression (chain extraction)": |
|
result = compressed_retrieval(query, llm_model, simple_retriever, extractor_type='chain') |
|
elif retrieval_method == "Contextual Compression (query filter)": |
|
result = compressed_retrieval(query, llm_model, simple_retriever, extractor_type='filter') |
|
elif retrieval_method == "Contextual Compression (embeddings filter)": |
|
result = compressed_retrieval(query, llm_model, simple_retriever, extractor_type='embeddings', embedding_model=hf) |
|
elif retrieval_method == "Ensemble": |
|
bm25_retriever = BM25Retriever.from_documents(docs) |
|
all_retrievers = [simple_retriever, bm25_retriever] |
|
result = ensemble_retrieval(query, all_retrievers) |
|
elif retrieval_method == "Long Context Reorder": |
|
result = long_context_reorder_retrieval(query, simple_retriever) |
|
else: |
|
raise ValueError(f"Unknown retrieval method: {retrieval_method}") |
|
|
|
|
|
prompt_template = PromptTemplate.from_template( |
|
"Answer the query {query} with the following context:\n {context}. If you cannot use the context to answer the query, say 'I cannot answer the query with the provided context.'" |
|
) |
|
|
|
answer = llm_model.invoke(prompt_template.format(query=query, context=result)) |
|
|
|
return "\n".join(result), answer.strip() |
|
|
|
embedding_model_list = ['sentence-transformers/all-MiniLM-L6-v2', 'BAAI/bge-small-en-v1.5', 'BAAI/bge-large-en-v1.5'] |
|
inference_model_list = ['google/gemma-2b-it', 'google/gemma-7b-it', 'microsoft/phi-2', 'mistralai/Mixtral-8x7B-Instruct-v0.1', 'mistralai/Mistral-7B-Instruct-v0.2'] |
|
retrieval_method_list = ["Simple", "Parent & Child", "Multi Query", |
|
"Contextual Compression (chain extraction)", "Contextual Compression (query filter)", |
|
"Contextual Compression (embeddings filter)", "Ensemble", "Long Context Reorder"] |
|
|
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown("## Compare Retrieval Methods for PDFs") |
|
with gr.Row(): |
|
with gr.Column(): |
|
pdf_url = gr.Textbox(label="Enter URL to PDF", value="https://www.berkshirehathaway.com/letters/2023ltr.pdf") |
|
load_button = gr.Button("Load and process PDF") |
|
status = gr.Textbox(label="Status") |
|
docs = gr.State() |
|
load_button.click(load_pdf, inputs=[pdf_url], outputs=[docs, status]) |
|
|
|
query = gr.Textbox(label="Enter your query", value="What does Warren Buffet think about Coca Cola?") |
|
with gr.Row(): |
|
embedding_model = gr.Dropdown(embedding_model_list, label="Select Embedding Model", value=embedding_model_list[0]) |
|
inference_model = gr.Dropdown(inference_model_list, label="Select Inference Model", value=inference_model_list[3]) |
|
retrieval_method = gr.Dropdown(retrieval_method_list, label="Select Retrieval Method", value=retrieval_method_list[0]) |
|
|
|
with gr.Row(): |
|
chunk_size = gr.Number(label="Chunk Size", value=1000) |
|
chunk_overlap = gr.Number(label="Chunk Overlap", value=200) |
|
|
|
with gr.Row(): |
|
max_new_tokens = gr.Number(label="Max New Tokens", value=100) |
|
temperature = gr.Slider(minimum=0.0, maximum=1.0, step=0.1, label="Temperature", value=0.7) |
|
top_p = gr.Slider(minimum=0.0, maximum=1.0, step=0.1, label="Top P", value=0.9) |
|
|
|
search_button = gr.Button("Retrieval") |
|
with gr.Column(): |
|
answer = gr.Textbox(label="Answer") |
|
retrieval_output = gr.Textbox(label="Retrieval Results") |
|
|
|
search_button.click(process_query, inputs=[docs, query, embedding_model, inference_model, retrieval_method, chunk_size, chunk_overlap, max_new_tokens, temperature, top_p], outputs=[retrieval_output, answer]) |
|
|
|
if __name__ == "__main__": |
|
demo.launch() |