|
import numpy as np |
|
import os |
|
from langchain_core.prompts import PromptTemplate |
|
from langchain_huggingface.embeddings import HuggingFaceEndpointEmbeddings |
|
from langchain_community.document_loaders import TextLoader |
|
from langchain_text_splitters import RecursiveCharacterTextSplitter |
|
from langchain_community.vectorstores import FAISS |
|
from operator import itemgetter |
|
from langchain.schema.output_parser import StrOutputParser |
|
from langchain.schema.runnable import RunnablePassthrough |
|
from langchain_huggingface import HuggingFaceEndpoint |
|
from uuid import uuid4 |
|
from qdrant_client import QdrantClient |
|
from qdrant_client.http.models import Distance, VectorParams |
|
from langchain_qdrant import QdrantVectorStore |
|
|
|
from numpy.linalg import norm |
|
|
|
def get_rag_prompt(): |
|
rp = """\ |
|
<|start_header_id|>system<|end_header_id|> |
|
You are a helpful assistant. You answer user questions based on provided context. If you can't answer the question with the provided context, say you don't know.<|eot_id|> |
|
|
|
<|start_header_id|>user<|end_header_id|> |
|
User Query: |
|
{query} |
|
|
|
Context: |
|
{context}<|eot_id|> |
|
|
|
<|start_header_id|>assistant<|end_header_id|> |
|
""" |
|
|
|
rag_prompt = PromptTemplate.from_template(rp) |
|
return rag_prompt |
|
|
|
def process_documents(use_qdrant=False): |
|
HF_LLM_ENDPOINT= os.environ["HF_LLM_ENDPOINT"] |
|
HF_EMBED_ENDPOINT = os.environ["HF_EMBED_ENDPOINT"] |
|
HF_TOKEN = os.environ["HF_TOKEN"] |
|
|
|
rag_prompt = get_rag_prompt() |
|
document_loader = TextLoader("./data/paul_graham_essays.txt") |
|
documents = document_loader.load() |
|
|
|
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=30) |
|
split_documents = text_splitter.split_documents(documents) |
|
|
|
hf_llm = HuggingFaceEndpoint( |
|
endpoint_url=HF_LLM_ENDPOINT, |
|
max_new_tokens=512, |
|
top_k=10, |
|
top_p=0.95, |
|
typical_p=0.95, |
|
temperature=0.01, |
|
repetition_penalty=1.03, |
|
huggingfacehub_api_token=HF_TOKEN |
|
) |
|
|
|
hf_embeddings = HuggingFaceEndpointEmbeddings( |
|
model=HF_EMBED_ENDPOINT, |
|
task="feature-extraction", |
|
huggingfacehub_api_token=os.environ["HF_TOKEN"], |
|
) |
|
if use_qdrant: |
|
collection_name = f"pdf_to_parse_{uuid4()}" |
|
client = QdrantClient(":memory:") |
|
client.create_collection( |
|
collection_name=collection_name, |
|
vectors_config=VectorParams(size=768, distance=Distance.COSINE), |
|
) |
|
|
|
vectorstore = QdrantVectorStore( |
|
client=client, |
|
collection_name=collection_name, |
|
embedding=hf_embeddings) |
|
|
|
print(f"Number of batches: {len(split_documents)/32}") |
|
|
|
for i in range(0, len(split_documents), 32): |
|
print(f"processing batch {i/32}") |
|
if i == 0: |
|
vectorstore.add_documents(split_documents[i:i+32]) |
|
continue |
|
vectorstore.add_documents(split_documents[i:i+32]) |
|
|
|
|
|
print("Loaded Vectorstore using Qdrant") |
|
hf_retriever = vectorstore.as_retriever(search_type="mmr", search_kwargs={"k": 3}) |
|
else: |
|
vectorstore_path = "./data/vectorstore" |
|
if os.path.exists(vectorstore_path) and os.listdir(vectorstore_path): |
|
print(f"Reading Faiss vector store from disk - {vectorstore_path}") |
|
vectorstore = FAISS.load_local( |
|
vectorstore_path, |
|
hf_embeddings, |
|
allow_dangerous_deserialization=True |
|
) |
|
hf_retriever = vectorstore.as_retriever() |
|
print("Loaded Vectorstore using Faiss") |
|
else: |
|
print("Indexing Files") |
|
os.makedirs(vectorstore_path, exist_ok=True) |
|
print(f"Number of batches: {len(split_documents)/32}") |
|
for i in range(0, len(split_documents), 32): |
|
print(f"processing batch {i/32}") |
|
if i == 0: |
|
vectorstore = FAISS.from_documents(split_documents[i:i+32], hf_embeddings) |
|
continue |
|
vectorstore.add_documents(split_documents[i:i+32]) |
|
vectorstore.save_local(vectorstore_path) |
|
print(f"Faiss vector store saved to disk - {vectorstore_path}") |
|
|
|
hf_retriever = vectorstore.as_retriever() |
|
|
|
lcel_rag_chain = {"context": itemgetter("query") | hf_retriever, "query": itemgetter("query")}| rag_prompt | hf_llm |
|
return lcel_rag_chain |
|
|
|
|