File size: 4,540 Bytes
0d56e9d 5f04412 50dc063 3707a95 3e7b971 5f04412 a0b5dc6 32dedd9 5f04412 c947c47 1342013 5f04412 ebaeae5 dce66ba e946a29 0ddb69a 5f04412 0ddb69a 5f04412 0ddb69a 5f04412 0ddb69a 5f04412 0ddb69a 3940450 0ddb69a 5f04412 0ddb69a 5f04412 0ddb69a 3940450 0ddb69a 5f04412 0ddb69a 3940450 0ddb69a 5f04412 0ddb69a 5f04412 3707a95 05bcaa4 c2646e7 5a04c98 c2646e7 c273c9f 0ddb69a 5f04412 e946a29 0ddb69a 3940450 008f2f7 3940450 0ddb69a c273c9f c4c7ba9 c273c9f 3707a95 f5c95d9 1b00820 c273c9f c4c7ba9 c273c9f 05bcaa4 e946a29 ff3ca13 5f04412 ff3ca13 c273c9f e946a29 0ddb69a db1e09d 5f04412 3707a95 0ddb69a 7eb4985 3707a95 7eb4985 0ddb69a dce66ba 7eb4985 3462636 05bcaa4 a426ca0 05bcaa4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 |
import os, requests, tiktoken
from llama_hub.youtube_transcript import YoutubeTranscriptReader
from llama_index import download_loader, PromptTemplate, ServiceContext
from llama_index.callbacks import CallbackManager, TokenCountingHandler
from llama_index.embeddings import OpenAIEmbedding
from llama_index.indices.vector_store.base import VectorStoreIndex
from llama_index.llms import OpenAI
from llama_index.prompts import PromptTemplate
from llama_index.storage.storage_context import StorageContext
from llama_index.vector_stores.mongodb import MongoDBAtlasVectorSearch
from pathlib import Path
from pymongo import MongoClient
from rag_base import BaseRAG
class LlamaIndexRAG(BaseRAG):
MONGODB_DB_NAME = "llamaindex_db"
def load_documents(self):
docs = []
# PDF
PDFReader = download_loader("PDFReader")
loader = PDFReader()
out_dir = Path("data")
if not out_dir.exists():
os.makedirs(out_dir)
out_path = out_dir / "gpt-4.pdf"
if not out_path.exists():
r = requests.get(self.PDF_URL)
with open(out_path, "wb") as f:
f.write(r.content)
docs.extend(loader.load_data(file = Path(out_path)))
#print("docs = " + str(len(docs)))
# Web
SimpleWebPageReader = download_loader("SimpleWebPageReader")
loader = SimpleWebPageReader()
docs.extend(loader.load_data(urls = [self.WEB_URL]))
#print("docs = " + str(len(docs)))
# YouTube
loader = YoutubeTranscriptReader()
docs.extend(loader.load_data(ytlinks = [self.YOUTUBE_URL_1,
self.YOUTUBE_URL_2]))
#print("docs = " + str(len(docs)))
return docs
def get_callback_manager(self, config):
token_counter = TokenCountingHandler(
tokenizer = tiktoken.encoding_for_model(config["model_name"]).encode
)
token_counter.reset_counts()
return CallbackManager([token_counter])
def get_callback(self, token_counter):
return ("Tokens Used: " +
str(token_counter.total_llm_token_count) + "\n" +
"Prompt Tokens: " +
str(token_counter.prompt_llm_token_count) + "\n" +
"Completion Tokens: " +
str(token_counter.completion_llm_token_count))
def get_llm(self, config):
return OpenAI(
model = config["model_name"],
temperature = config["temperature"]
)
def get_vector_store(self):
return MongoDBAtlasVectorSearch(
MongoClient(self.MONGODB_ATLAS_CLUSTER_URI),
db_name = self.MONGODB_DB_NAME,
collection_name = self.MONGODB_COLLECTION_NAME,
index_name = self.MONGODB_INDEX_NAME
)
def get_service_context(self, config):
return ServiceContext.from_defaults(
callback_manager = self.get_callback_manager(config),
chunk_overlap = config["chunk_overlap"],
chunk_size = config["chunk_size"],
embed_model = OpenAIEmbedding(), # embed
llm = self.get_llm(config)
)
def get_storage_context(self):
return StorageContext.from_defaults(
vector_store = self.get_vector_store()
)
def store_documents(self, config, docs):
storage_context = StorageContext.from_defaults(
vector_store = self.get_vector_store()
)
VectorStoreIndex.from_documents(
docs,
service_context = self.get_service_context(config),
storage_context = self.get_storage_context()
)
def ingestion(self, config):
docs = self.load_documents()
self.store_documents(config, docs)
def retrieval(self, config, prompt):
index = VectorStoreIndex.from_vector_store(
vector_store = self.get_vector_store()
)
service_context = self.get_service_context(config)
query_engine = index.as_query_engine(
response_mode = "compact",
service_context = service_context,
similarity_top_k = config["k"],
text_qa_template = PromptTemplate(os.environ["LLAMAINDEX_TEMPLATE"])
)
completion = query_engine.query(prompt)
callback = self.get_callback(
service_context.callback_manager.handlers[0])
return completion, callback |