Create rag_llamaindex.py
Browse files- rag_llamaindex.py +90 -0
rag_llamaindex.py
ADDED
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging, os, sys
|
2 |
+
|
3 |
+
from llama_hub.youtube_transcript import YoutubeTranscriptReader
|
4 |
+
from llama_index import download_loader, PromptTemplate
|
5 |
+
from llama_index.indices.vector_store.base import VectorStoreIndex
|
6 |
+
from llama_index.storage.storage_context import StorageContext
|
7 |
+
from llama_index.vector_stores.mongodb import MongoDBAtlasVectorSearch
|
8 |
+
|
9 |
+
from pathlib import Path
|
10 |
+
from pymongo.mongo_client import MongoClient
|
11 |
+
|
12 |
+
PDF_URL = "https://arxiv.org/pdf/2303.08774.pdf"
|
13 |
+
WEB_URL = "https://openai.com/research/gpt-4"
|
14 |
+
YOUTUBE_URL_1 = "https://www.youtube.com/watch?v=--khbXchTeE"
|
15 |
+
YOUTUBE_URL_2 = "https://www.youtube.com/watch?v=hdhZwyf24mE"
|
16 |
+
|
17 |
+
MONGODB_ATLAS_CLUSTER_URI = os.environ["MONGODB_ATLAS_CLUSTER_URI"]
|
18 |
+
MONGODB_DB_NAME = "llamaindex_db"
|
19 |
+
MONGODB_COLLECTION_NAME = "gpt-4"
|
20 |
+
MONGODB_INDEX_NAME = "default"
|
21 |
+
|
22 |
+
logging.basicConfig(stream = sys.stdout, level = logging.INFO)
|
23 |
+
logging.getLogger().addHandler(logging.StreamHandler(stream = sys.stdout))
|
24 |
+
|
25 |
+
def load_documents():
|
26 |
+
docs = []
|
27 |
+
|
28 |
+
# PDF
|
29 |
+
PDFReader = download_loader("PDFReader")
|
30 |
+
loader = PDFReader()
|
31 |
+
out_dir = Path("data")
|
32 |
+
|
33 |
+
if not out_dir.exists():
|
34 |
+
os.makedirs(out_dir)
|
35 |
+
|
36 |
+
out_path = out_dir / "gpt-4.pdf"
|
37 |
+
|
38 |
+
if not out_path.exists():
|
39 |
+
r = requests.get(PDF_URL)
|
40 |
+
with open(out_path, "wb") as f:
|
41 |
+
f.write(r.content)
|
42 |
+
|
43 |
+
docs.extend(loader.load_data(file = Path(out_path)))
|
44 |
+
#print("docs = " + str(len(docs)))
|
45 |
+
|
46 |
+
# Web
|
47 |
+
SimpleWebPageReader = download_loader("SimpleWebPageReader")
|
48 |
+
loader = SimpleWebPageReader()
|
49 |
+
docs.extend(loader.load_data(urls = [WEB_URL]))
|
50 |
+
#print("docs = " + str(len(docs)))
|
51 |
+
|
52 |
+
# YouTube
|
53 |
+
loader = YoutubeTranscriptReader()
|
54 |
+
docs.extend(loader.load_data(ytlinks = [YOUTUBE_URL_1,
|
55 |
+
YOUTUBE_URL_2]))
|
56 |
+
#print("docs = " + str(len(docs)))
|
57 |
+
|
58 |
+
return docs
|
59 |
+
|
60 |
+
def store_documents(config, docs):
|
61 |
+
storage_context = StorageContext.from_defaults(
|
62 |
+
vector_store = get_vector_store())
|
63 |
+
|
64 |
+
VectorStoreIndex.from_documents(
|
65 |
+
docs,
|
66 |
+
storage_context = storage_context
|
67 |
+
)
|
68 |
+
|
69 |
+
def get_vector_store():
|
70 |
+
return MongoDBAtlasVectorSearch(
|
71 |
+
MongoClient(MONGODB_ATLAS_CLUSTER_URI),
|
72 |
+
db_name = MONGODB_DB_NAME,
|
73 |
+
collection_name = MONGODB_COLLECTION_NAME,
|
74 |
+
index_name = MONGODB_INDEX_NAME
|
75 |
+
)
|
76 |
+
|
77 |
+
def rag_ingestion(config):
|
78 |
+
docs = load_documents()
|
79 |
+
|
80 |
+
store_documents(config, docs)
|
81 |
+
|
82 |
+
def rag_retrieval(config, prompt):
|
83 |
+
index = VectorStoreIndex.from_vector_store(
|
84 |
+
vector_store = get_vector_store())
|
85 |
+
|
86 |
+
query_engine = index.as_query_engine(
|
87 |
+
similarity_top_k = config["k"]
|
88 |
+
)
|
89 |
+
|
90 |
+
return query_engine.query(prompt)
|