AIE4-16 / utilities /utilities.py
rchrdgwr's picture
updates to chainlit app - 2 vector stores
46fc427
import numpy as np
import os
from langchain_core.prompts import PromptTemplate
from langchain_huggingface.embeddings import HuggingFaceEndpointEmbeddings
from langchain_community.document_loaders import TextLoader
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import FAISS
from operator import itemgetter
from langchain.schema.output_parser import StrOutputParser
from langchain.schema.runnable import RunnablePassthrough
from langchain_huggingface import HuggingFaceEndpoint
from uuid import uuid4
from qdrant_client import QdrantClient
from qdrant_client.http.models import Distance, VectorParams
from langchain_qdrant import QdrantVectorStore
from numpy.linalg import norm
def get_rag_prompt():
rp = """\
<|start_header_id|>system<|end_header_id|>
You are a helpful assistant. You answer user questions based on provided context. If you can't answer the question with the provided context, say you don't know.<|eot_id|>
<|start_header_id|>user<|end_header_id|>
User Query:
{query}
Context:
{context}<|eot_id|>
<|start_header_id|>assistant<|end_header_id|>
"""
rag_prompt = PromptTemplate.from_template(rp)
return rag_prompt
def process_documents(use_qdrant=False):
HF_LLM_ENDPOINT= os.environ["HF_LLM_ENDPOINT"]
HF_EMBED_ENDPOINT = os.environ["HF_EMBED_ENDPOINT"]
HF_TOKEN = os.environ["HF_TOKEN"]
rag_prompt = get_rag_prompt()
document_loader = TextLoader("./data/paul_graham_essays.txt")
documents = document_loader.load()
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=30)
split_documents = text_splitter.split_documents(documents)
hf_llm = HuggingFaceEndpoint(
endpoint_url=HF_LLM_ENDPOINT,
max_new_tokens=512,
top_k=10,
top_p=0.95,
typical_p=0.95,
temperature=0.01,
repetition_penalty=1.03,
huggingfacehub_api_token=HF_TOKEN
)
hf_embeddings = HuggingFaceEndpointEmbeddings(
model=HF_EMBED_ENDPOINT,
task="feature-extraction",
huggingfacehub_api_token=os.environ["HF_TOKEN"],
)
if use_qdrant:
collection_name = f"pdf_to_parse_{uuid4()}"
client = QdrantClient(":memory:")
client.create_collection(
collection_name=collection_name,
vectors_config=VectorParams(size=768, distance=Distance.COSINE),
)
vectorstore = QdrantVectorStore(
client=client,
collection_name=collection_name,
embedding=hf_embeddings)
print(f"Number of batches: {len(split_documents)/32}")
for i in range(0, len(split_documents), 32):
print(f"processing batch {i/32}")
if i == 0:
vectorstore.add_documents(split_documents[i:i+32])
continue
vectorstore.add_documents(split_documents[i:i+32])
# vectorstore.add_documents(split_documents)
print("Loaded Vectorstore using Qdrant")
hf_retriever = vectorstore.as_retriever(search_type="mmr", search_kwargs={"k": 3})
else:
vectorstore_path = "./data/vectorstore"
if os.path.exists(vectorstore_path) and os.listdir(vectorstore_path):
print(f"Reading Faiss vector store from disk - {vectorstore_path}")
vectorstore = FAISS.load_local(
vectorstore_path,
hf_embeddings,
allow_dangerous_deserialization=True # this is necessary to load the vectorstore from disk as it's stored as a `.pkl` file.
)
hf_retriever = vectorstore.as_retriever()
print("Loaded Vectorstore using Faiss")
else:
print("Indexing Files")
os.makedirs(vectorstore_path, exist_ok=True)
print(f"Number of batches: {len(split_documents)/32}")
for i in range(0, len(split_documents), 32):
print(f"processing batch {i/32}")
if i == 0:
vectorstore = FAISS.from_documents(split_documents[i:i+32], hf_embeddings)
continue
vectorstore.add_documents(split_documents[i:i+32])
vectorstore.save_local(vectorstore_path)
print(f"Faiss vector store saved to disk - {vectorstore_path}")
hf_retriever = vectorstore.as_retriever()
lcel_rag_chain = {"context": itemgetter("query") | hf_retriever, "query": itemgetter("query")}| rag_prompt | hf_llm
return lcel_rag_chain