bstraehle commited on
Commit
5f04412
1 Parent(s): 9883e25

Create rag_llamaindex.py

Browse files
Files changed (1) hide show
  1. rag_llamaindex.py +90 -0
rag_llamaindex.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging, os, sys
2
+
3
+ from llama_hub.youtube_transcript import YoutubeTranscriptReader
4
+ from llama_index import download_loader, PromptTemplate
5
+ from llama_index.indices.vector_store.base import VectorStoreIndex
6
+ from llama_index.storage.storage_context import StorageContext
7
+ from llama_index.vector_stores.mongodb import MongoDBAtlasVectorSearch
8
+
9
+ from pathlib import Path
10
+ from pymongo.mongo_client import MongoClient
11
+
12
+ PDF_URL = "https://arxiv.org/pdf/2303.08774.pdf"
13
+ WEB_URL = "https://openai.com/research/gpt-4"
14
+ YOUTUBE_URL_1 = "https://www.youtube.com/watch?v=--khbXchTeE"
15
+ YOUTUBE_URL_2 = "https://www.youtube.com/watch?v=hdhZwyf24mE"
16
+
17
+ MONGODB_ATLAS_CLUSTER_URI = os.environ["MONGODB_ATLAS_CLUSTER_URI"]
18
+ MONGODB_DB_NAME = "llamaindex_db"
19
+ MONGODB_COLLECTION_NAME = "gpt-4"
20
+ MONGODB_INDEX_NAME = "default"
21
+
22
+ logging.basicConfig(stream = sys.stdout, level = logging.INFO)
23
+ logging.getLogger().addHandler(logging.StreamHandler(stream = sys.stdout))
24
+
25
+ def load_documents():
26
+ docs = []
27
+
28
+ # PDF
29
+ PDFReader = download_loader("PDFReader")
30
+ loader = PDFReader()
31
+ out_dir = Path("data")
32
+
33
+ if not out_dir.exists():
34
+ os.makedirs(out_dir)
35
+
36
+ out_path = out_dir / "gpt-4.pdf"
37
+
38
+ if not out_path.exists():
39
+ r = requests.get(PDF_URL)
40
+ with open(out_path, "wb") as f:
41
+ f.write(r.content)
42
+
43
+ docs.extend(loader.load_data(file = Path(out_path)))
44
+ #print("docs = " + str(len(docs)))
45
+
46
+ # Web
47
+ SimpleWebPageReader = download_loader("SimpleWebPageReader")
48
+ loader = SimpleWebPageReader()
49
+ docs.extend(loader.load_data(urls = [WEB_URL]))
50
+ #print("docs = " + str(len(docs)))
51
+
52
+ # YouTube
53
+ loader = YoutubeTranscriptReader()
54
+ docs.extend(loader.load_data(ytlinks = [YOUTUBE_URL_1,
55
+ YOUTUBE_URL_2]))
56
+ #print("docs = " + str(len(docs)))
57
+
58
+ return docs
59
+
60
+ def store_documents(config, docs):
61
+ storage_context = StorageContext.from_defaults(
62
+ vector_store = get_vector_store())
63
+
64
+ VectorStoreIndex.from_documents(
65
+ docs,
66
+ storage_context = storage_context
67
+ )
68
+
69
+ def get_vector_store():
70
+ return MongoDBAtlasVectorSearch(
71
+ MongoClient(MONGODB_ATLAS_CLUSTER_URI),
72
+ db_name = MONGODB_DB_NAME,
73
+ collection_name = MONGODB_COLLECTION_NAME,
74
+ index_name = MONGODB_INDEX_NAME
75
+ )
76
+
77
+ def rag_ingestion(config):
78
+ docs = load_documents()
79
+
80
+ store_documents(config, docs)
81
+
82
+ def rag_retrieval(config, prompt):
83
+ index = VectorStoreIndex.from_vector_store(
84
+ vector_store = get_vector_store())
85
+
86
+ query_engine = index.as_query_engine(
87
+ similarity_top_k = config["k"]
88
+ )
89
+
90
+ return query_engine.query(prompt)