File size: 4,627 Bytes
46fc427
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import numpy as np
import os
from langchain_core.prompts import PromptTemplate
from langchain_huggingface.embeddings import HuggingFaceEndpointEmbeddings
from langchain_community.document_loaders import TextLoader
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import FAISS
from operator import itemgetter
from langchain.schema.output_parser import StrOutputParser
from langchain.schema.runnable import RunnablePassthrough
from langchain_huggingface import HuggingFaceEndpoint
from uuid import uuid4
from qdrant_client import QdrantClient
from qdrant_client.http.models import Distance, VectorParams
from langchain_qdrant import QdrantVectorStore

from numpy.linalg import norm

def get_rag_prompt():
    rp = """\
        <|start_header_id|>system<|end_header_id|>
        You are a helpful assistant. You answer user questions based on provided context. If you can't answer the question with the provided context, say you don't know.<|eot_id|>

        <|start_header_id|>user<|end_header_id|>
        User Query:
        {query}

        Context:
        {context}<|eot_id|>

        <|start_header_id|>assistant<|end_header_id|>
        """

    rag_prompt = PromptTemplate.from_template(rp)
    return rag_prompt

def process_documents(use_qdrant=False):
    HF_LLM_ENDPOINT= os.environ["HF_LLM_ENDPOINT"]
    HF_EMBED_ENDPOINT = os.environ["HF_EMBED_ENDPOINT"]
    HF_TOKEN = os.environ["HF_TOKEN"]

    rag_prompt = get_rag_prompt()
    document_loader = TextLoader("./data/paul_graham_essays.txt")
    documents = document_loader.load()
    
    text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=30)
    split_documents = text_splitter.split_documents(documents)

    hf_llm = HuggingFaceEndpoint(
        endpoint_url=HF_LLM_ENDPOINT,
        max_new_tokens=512,
        top_k=10,
        top_p=0.95,
        typical_p=0.95,
        temperature=0.01,
        repetition_penalty=1.03,
        huggingfacehub_api_token=HF_TOKEN
    )

    hf_embeddings = HuggingFaceEndpointEmbeddings(
        model=HF_EMBED_ENDPOINT,
        task="feature-extraction",
        huggingfacehub_api_token=os.environ["HF_TOKEN"],
    )
    if use_qdrant:
        collection_name = f"pdf_to_parse_{uuid4()}"
        client = QdrantClient(":memory:")
        client.create_collection(
            collection_name=collection_name,
            vectors_config=VectorParams(size=768, distance=Distance.COSINE),
        )

        vectorstore = QdrantVectorStore(
            client=client,
            collection_name=collection_name,
            embedding=hf_embeddings)
        
        print(f"Number of batches: {len(split_documents)/32}")
        
        for i in range(0, len(split_documents), 32):
            print(f"processing batch {i/32}")
            if i == 0:
                vectorstore.add_documents(split_documents[i:i+32])
                continue
            vectorstore.add_documents(split_documents[i:i+32])
            
        # vectorstore.add_documents(split_documents)
        print("Loaded Vectorstore using Qdrant")
        hf_retriever = vectorstore.as_retriever(search_type="mmr", search_kwargs={"k": 3})
    else:
        vectorstore_path = "./data/vectorstore"
        if os.path.exists(vectorstore_path) and os.listdir(vectorstore_path):
            print(f"Reading Faiss vector store from disk - {vectorstore_path}")
            vectorstore = FAISS.load_local(
                vectorstore_path, 
                hf_embeddings, 
                allow_dangerous_deserialization=True # this is necessary to load the vectorstore from disk as it's stored as a `.pkl` file.
            )
            hf_retriever = vectorstore.as_retriever()
            print("Loaded Vectorstore using Faiss")
        else:
            print("Indexing Files")
            os.makedirs(vectorstore_path, exist_ok=True)
            print(f"Number of batches: {len(split_documents)/32}")
            for i in range(0, len(split_documents), 32):
                print(f"processing batch {i/32}")
                if i == 0:
                    vectorstore = FAISS.from_documents(split_documents[i:i+32], hf_embeddings)
                    continue
                vectorstore.add_documents(split_documents[i:i+32])
            vectorstore.save_local(vectorstore_path)
            print(f"Faiss vector store saved to disk - {vectorstore_path}")

        hf_retriever = vectorstore.as_retriever()
    
    lcel_rag_chain = {"context": itemgetter("query") | hf_retriever, "query": itemgetter("query")}| rag_prompt | hf_llm
    return lcel_rag_chain