bstraehle commited on
Commit
0ddb69a
·
1 Parent(s): 4873e9b

Update rag_llamaindex.py

Browse files
Files changed (1) hide show
  1. rag_llamaindex.py +51 -50
rag_llamaindex.py CHANGED
@@ -22,69 +22,70 @@ MONGODB_INDEX_NAME = "default"
22
  logging.basicConfig(stream = sys.stdout, level = logging.INFO)
23
  logging.getLogger().addHandler(logging.StreamHandler(stream = sys.stdout))
24
 
25
- def load_documents():
26
- docs = []
 
27
 
28
- # PDF
29
- PDFReader = download_loader("PDFReader")
30
- loader = PDFReader()
31
- out_dir = Path("data")
32
 
33
- if not out_dir.exists():
34
- os.makedirs(out_dir)
35
 
36
- out_path = out_dir / "gpt-4.pdf"
37
 
38
- if not out_path.exists():
39
- r = requests.get(PDF_URL)
40
- with open(out_path, "wb") as f:
41
- f.write(r.content)
42
 
43
- docs.extend(loader.load_data(file = Path(out_path)))
44
- #print("docs = " + str(len(docs)))
45
 
46
- # Web
47
- SimpleWebPageReader = download_loader("SimpleWebPageReader")
48
- loader = SimpleWebPageReader()
49
- docs.extend(loader.load_data(urls = [WEB_URL]))
50
- #print("docs = " + str(len(docs)))
51
 
52
- # YouTube
53
- loader = YoutubeTranscriptReader()
54
- docs.extend(loader.load_data(ytlinks = [YOUTUBE_URL_1,
55
- YOUTUBE_URL_2]))
56
- #print("docs = " + str(len(docs)))
57
 
58
- return docs
59
 
60
- def store_documents(config, docs):
61
- storage_context = StorageContext.from_defaults(
62
- vector_store = get_vector_store())
63
 
64
- VectorStoreIndex.from_documents(
65
- docs,
66
- storage_context = storage_context
67
- )
68
 
69
- def get_vector_store():
70
- return MongoDBAtlasVectorSearch(
71
- MongoClient(MONGODB_ATLAS_CLUSTER_URI),
72
- db_name = MONGODB_DB_NAME,
73
- collection_name = MONGODB_COLLECTION_NAME,
74
- index_name = MONGODB_INDEX_NAME
75
- )
76
 
77
- def rag_ingestion_llamaindex(config):
78
- docs = load_documents()
79
 
80
- store_documents(config, docs)
81
 
82
- def rag_retrieval(config, prompt):
83
- index = VectorStoreIndex.from_vector_store(
84
- vector_store = get_vector_store())
85
 
86
- query_engine = index.as_query_engine(
87
- similarity_top_k = config["k"]
88
- )
89
 
90
- return query_engine.query(prompt)
 
22
  logging.basicConfig(stream = sys.stdout, level = logging.INFO)
23
  logging.getLogger().addHandler(logging.StreamHandler(stream = sys.stdout))
24
 
25
+ Class LlamaIndexRAG:
26
+ def load_documents():
27
+ docs = []
28
 
29
+ # PDF
30
+ PDFReader = download_loader("PDFReader")
31
+ loader = PDFReader()
32
+ out_dir = Path("data")
33
 
34
+ if not out_dir.exists():
35
+ os.makedirs(out_dir)
36
 
37
+ out_path = out_dir / "gpt-4.pdf"
38
 
39
+ if not out_path.exists():
40
+ r = requests.get(PDF_URL)
41
+ with open(out_path, "wb") as f:
42
+ f.write(r.content)
43
 
44
+ docs.extend(loader.load_data(file = Path(out_path)))
45
+ #print("docs = " + str(len(docs)))
46
 
47
+ # Web
48
+ SimpleWebPageReader = download_loader("SimpleWebPageReader")
49
+ loader = SimpleWebPageReader()
50
+ docs.extend(loader.load_data(urls = [WEB_URL]))
51
+ #print("docs = " + str(len(docs)))
52
 
53
+ # YouTube
54
+ loader = YoutubeTranscriptReader()
55
+ docs.extend(loader.load_data(ytlinks = [YOUTUBE_URL_1,
56
+ YOUTUBE_URL_2]))
57
+ #print("docs = " + str(len(docs)))
58
 
59
+ return docs
60
 
61
+ def store_documents(config, docs):
62
+ storage_context = StorageContext.from_defaults(
63
+ vector_store = get_vector_store())
64
 
65
+ VectorStoreIndex.from_documents(
66
+ docs,
67
+ storage_context = storage_context
68
+ )
69
 
70
+ def get_vector_store():
71
+ return MongoDBAtlasVectorSearch(
72
+ MongoClient(MONGODB_ATLAS_CLUSTER_URI),
73
+ db_name = MONGODB_DB_NAME,
74
+ collection_name = MONGODB_COLLECTION_NAME,
75
+ index_name = MONGODB_INDEX_NAME
76
+ )
77
 
78
+ def rag_ingestion_llamaindex(config):
79
+ docs = load_documents()
80
 
81
+ store_documents(config, docs)
82
 
83
+ def rag_retrieval(config, prompt):
84
+ index = VectorStoreIndex.from_vector_store(
85
+ vector_store = get_vector_store())
86
 
87
+ query_engine = index.as_query_engine(
88
+ similarity_top_k = config["k"]
89
+ )
90
 
91
+ return query_engine.query(prompt)