Update rag_llamaindex.py
Browse files- rag_llamaindex.py +51 -50
rag_llamaindex.py
CHANGED
@@ -22,69 +22,70 @@ MONGODB_INDEX_NAME = "default"
|
|
22 |
logging.basicConfig(stream = sys.stdout, level = logging.INFO)
|
23 |
logging.getLogger().addHandler(logging.StreamHandler(stream = sys.stdout))
|
24 |
|
25 |
-
|
26 |
-
|
|
|
27 |
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
|
33 |
-
|
34 |
-
|
35 |
|
36 |
-
|
37 |
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
|
43 |
-
|
44 |
-
|
45 |
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
|
58 |
-
|
59 |
|
60 |
-
def store_documents(config, docs):
|
61 |
-
|
62 |
-
|
63 |
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
|
69 |
-
def get_vector_store():
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
|
77 |
-
def rag_ingestion_llamaindex(config):
|
78 |
-
|
79 |
|
80 |
-
|
81 |
|
82 |
-
def rag_retrieval(config, prompt):
|
83 |
-
|
84 |
-
|
85 |
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
|
90 |
-
|
|
|
22 |
logging.basicConfig(stream = sys.stdout, level = logging.INFO)
|
23 |
logging.getLogger().addHandler(logging.StreamHandler(stream = sys.stdout))
|
24 |
|
25 |
+
Class LlamaIndexRAG:
|
26 |
+
def load_documents():
|
27 |
+
docs = []
|
28 |
|
29 |
+
# PDF
|
30 |
+
PDFReader = download_loader("PDFReader")
|
31 |
+
loader = PDFReader()
|
32 |
+
out_dir = Path("data")
|
33 |
|
34 |
+
if not out_dir.exists():
|
35 |
+
os.makedirs(out_dir)
|
36 |
|
37 |
+
out_path = out_dir / "gpt-4.pdf"
|
38 |
|
39 |
+
if not out_path.exists():
|
40 |
+
r = requests.get(PDF_URL)
|
41 |
+
with open(out_path, "wb") as f:
|
42 |
+
f.write(r.content)
|
43 |
|
44 |
+
docs.extend(loader.load_data(file = Path(out_path)))
|
45 |
+
#print("docs = " + str(len(docs)))
|
46 |
|
47 |
+
# Web
|
48 |
+
SimpleWebPageReader = download_loader("SimpleWebPageReader")
|
49 |
+
loader = SimpleWebPageReader()
|
50 |
+
docs.extend(loader.load_data(urls = [WEB_URL]))
|
51 |
+
#print("docs = " + str(len(docs)))
|
52 |
|
53 |
+
# YouTube
|
54 |
+
loader = YoutubeTranscriptReader()
|
55 |
+
docs.extend(loader.load_data(ytlinks = [YOUTUBE_URL_1,
|
56 |
+
YOUTUBE_URL_2]))
|
57 |
+
#print("docs = " + str(len(docs)))
|
58 |
|
59 |
+
return docs
|
60 |
|
61 |
+
def store_documents(config, docs):
|
62 |
+
storage_context = StorageContext.from_defaults(
|
63 |
+
vector_store = get_vector_store())
|
64 |
|
65 |
+
VectorStoreIndex.from_documents(
|
66 |
+
docs,
|
67 |
+
storage_context = storage_context
|
68 |
+
)
|
69 |
|
70 |
+
def get_vector_store():
|
71 |
+
return MongoDBAtlasVectorSearch(
|
72 |
+
MongoClient(MONGODB_ATLAS_CLUSTER_URI),
|
73 |
+
db_name = MONGODB_DB_NAME,
|
74 |
+
collection_name = MONGODB_COLLECTION_NAME,
|
75 |
+
index_name = MONGODB_INDEX_NAME
|
76 |
+
)
|
77 |
|
78 |
+
def rag_ingestion_llamaindex(config):
|
79 |
+
docs = load_documents()
|
80 |
|
81 |
+
store_documents(config, docs)
|
82 |
|
83 |
+
def rag_retrieval(config, prompt):
|
84 |
+
index = VectorStoreIndex.from_vector_store(
|
85 |
+
vector_store = get_vector_store())
|
86 |
|
87 |
+
query_engine = index.as_query_engine(
|
88 |
+
similarity_top_k = config["k"]
|
89 |
+
)
|
90 |
|
91 |
+
return query_engine.query(prompt)
|