File size: 4,501 Bytes
0d56e9d
5f04412
 
50dc063
3707a95
3e7b971
5f04412
a0b5dc6
32dedd9
5f04412
 
 
 
c947c47
1342013
5f04412
ebaeae5
 
dce66ba
e946a29
0ddb69a
5f04412
0ddb69a
 
 
 
5f04412
0ddb69a
 
5f04412
0ddb69a
5f04412
0ddb69a
3940450
0ddb69a
 
5f04412
0ddb69a
 
5f04412
0ddb69a
 
 
3940450
0ddb69a
5f04412
0ddb69a
 
3940450
 
0ddb69a
5f04412
0ddb69a
5f04412
3707a95
 
 
 
 
 
 
 
05bcaa4
 
c2646e7
 
 
5a04c98
c2646e7
 
 
c273c9f
 
 
 
0ddb69a
5f04412
e946a29
0ddb69a
3940450
008f2f7
3940450
 
0ddb69a
c273c9f
c4c7ba9
c273c9f
3707a95
f5c95d9
 
1b00820
c273c9f
 
 
c4c7ba9
c273c9f
 
 
 
 
 
 
 
 
 
 
 
 
 
05bcaa4
e946a29
ff3ca13
5f04412
ff3ca13
c273c9f
e946a29
0ddb69a
db1e09d
 
5f04412
3707a95
 
0ddb69a
3707a95
6b51abc
7eb4985
0ddb69a
dce66ba
7eb4985
3462636
05bcaa4
a426ca0
05bcaa4
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
import os, requests, tiktoken

from llama_hub.youtube_transcript import YoutubeTranscriptReader
from llama_index import download_loader, PromptTemplate, ServiceContext
from llama_index.callbacks import CallbackManager, TokenCountingHandler
from llama_index.embeddings import OpenAIEmbedding
from llama_index.indices.vector_store.base import VectorStoreIndex
from llama_index.llms import OpenAI
from llama_index.prompts import PromptTemplate
from llama_index.storage.storage_context import StorageContext
from llama_index.vector_stores.mongodb import MongoDBAtlasVectorSearch

from pathlib import Path
from pymongo import MongoClient
from rag_base import BaseRAG

class LlamaIndexRAG(BaseRAG):
    MONGODB_DB_NAME = "llamaindex_db"

    def load_documents(self):
        docs = []
    
        # PDF
        PDFReader = download_loader("PDFReader")
        loader = PDFReader()
        out_dir = Path("data")
    
        if not out_dir.exists():
            os.makedirs(out_dir)
    
        out_path = out_dir / "gpt-4.pdf"
    
        if not out_path.exists():
            r = requests.get(self.PDF_URL)
            with open(out_path, "wb") as f:
                f.write(r.content)

        docs.extend(loader.load_data(file = Path(out_path)))
        #print("docs = " + str(len(docs)))
    
        # Web
        SimpleWebPageReader = download_loader("SimpleWebPageReader")
        loader = SimpleWebPageReader()
        docs.extend(loader.load_data(urls = [self.WEB_URL]))
        #print("docs = " + str(len(docs)))

        # YouTube
        loader = YoutubeTranscriptReader()
        docs.extend(loader.load_data(ytlinks = [self.YOUTUBE_URL_1,
                                                self.YOUTUBE_URL_2]))
        #print("docs = " + str(len(docs)))
    
        return docs

    def get_callback_manager(self, config):
        token_counter = TokenCountingHandler(
            tokenizer = tiktoken.encoding_for_model(config["model_name"]).encode
        )

        token_counter.reset_counts()

        return CallbackManager([token_counter])

    def get_callback(self, token_counter):
        return ("Tokens Used: " +
                str(token_counter.total_llm_token_count) + "\n" +
                "Prompt Tokens: " +
                str(token_counter.prompt_llm_token_count) + "\n" +
                "Completion Tokens: " +
                str(token_counter.completion_llm_token_count))

    def get_llm(self, config):
        return OpenAI(
            model = config["model_name"], 
            temperature = config["temperature"]
        )

    def get_vector_store(self):
        return MongoDBAtlasVectorSearch(
            MongoClient(self.MONGODB_ATLAS_CLUSTER_URI),
            db_name = self.MONGODB_DB_NAME,
            collection_name = self.MONGODB_COLLECTION_NAME,
            index_name = self.MONGODB_INDEX_NAME
        )
        
    def get_service_context(self, config):
        return ServiceContext.from_defaults(
            callback_manager = self.get_callback_manager(config),
            chunk_overlap = config["chunk_overlap"],
            chunk_size = config["chunk_size"],
            embed_model = OpenAIEmbedding(), # embed
            llm = self.get_llm(config)
        )

    def get_storage_context(self):
        return StorageContext.from_defaults(
            vector_store = self.get_vector_store()
        )
        
    def store_documents(self, config, docs):
        storage_context = StorageContext.from_defaults(
            vector_store = self.get_vector_store()
        )
    
        VectorStoreIndex.from_documents(
            docs,
            service_context = self.get_service_context(config),
            storage_context = self.get_storage_context()
        )
  
    def ingestion(self, config):
        docs = self.load_documents()
    
        self.store_documents(config, docs)
       
    def retrieval(self, config, prompt):
        index = VectorStoreIndex.from_vector_store(
            vector_store = self.get_vector_store()
        )

        service_context = self.get_service_context(config)
        
        query_engine = index.as_query_engine(
            service_context = service_context,
            similarity_top_k = config["k"],
            text_qa_template = PromptTemplate(os.environ["LLAMAINDEX_TEMPLATE"])
        )

        completion = query_engine.query(prompt)
        callback = self.get_callback(
            service_context.callback_manager.handlers[0])
        
        return completion, callback