Update rag.py
Browse files
rag.py
CHANGED
@@ -14,6 +14,9 @@ from langchain.vectorstores import MongoDBAtlasVectorSearch
|
|
14 |
|
15 |
from pymongo import MongoClient
|
16 |
|
|
|
|
|
|
|
17 |
PDF_URL = "https://arxiv.org/pdf/2303.08774.pdf"
|
18 |
WEB_URL = "https://openai.com/research/gpt-4"
|
19 |
YOUTUBE_URL_1 = "https://www.youtube.com/watch?v=--khbXchTeE"
|
@@ -71,11 +74,11 @@ def document_storage_mongodb(documents):
|
|
71 |
collection = collection,
|
72 |
index_name = MONGODB_INDEX_NAME)
|
73 |
|
74 |
-
def document_retrieval_chroma(
|
75 |
return Chroma(embedding_function = OpenAIEmbeddings(),
|
76 |
persist_directory = CHROMA_DIR)
|
77 |
|
78 |
-
def document_retrieval_mongodb(
|
79 |
return MongoDBAtlasVectorSearch.from_connection_string(MONGODB_ATLAS_CLUSTER_URI,
|
80 |
MONGODB_DB_NAME + "." + MONGODB_COLLECTION_NAME,
|
81 |
OpenAIEmbeddings(disallowed_special = ()),
|
@@ -99,9 +102,9 @@ def rag_chain(config, openai_api_key, rag_option, prompt):
|
|
99 |
llm = get_llm(config, openai_api_key)
|
100 |
|
101 |
if (rag_option == RAG_CHROMA):
|
102 |
-
db = document_retrieval_chroma(
|
103 |
-
|
104 |
-
db = document_retrieval_mongodb(
|
105 |
|
106 |
rag_chain = RetrievalQA.from_chain_type(llm,
|
107 |
chain_type_kwargs = {"prompt": RAG_CHAIN_PROMPT},
|
|
|
14 |
|
15 |
from pymongo import MongoClient
|
16 |
|
17 |
+
RAG_CHROMA = "Chroma"
|
18 |
+
RAG_MONGODB = "MongoDB"
|
19 |
+
|
20 |
PDF_URL = "https://arxiv.org/pdf/2303.08774.pdf"
|
21 |
WEB_URL = "https://openai.com/research/gpt-4"
|
22 |
YOUTUBE_URL_1 = "https://www.youtube.com/watch?v=--khbXchTeE"
|
|
|
74 |
collection = collection,
|
75 |
index_name = MONGODB_INDEX_NAME)
|
76 |
|
77 |
+
def document_retrieval_chroma():
|
78 |
return Chroma(embedding_function = OpenAIEmbeddings(),
|
79 |
persist_directory = CHROMA_DIR)
|
80 |
|
81 |
+
def document_retrieval_mongodb():
|
82 |
return MongoDBAtlasVectorSearch.from_connection_string(MONGODB_ATLAS_CLUSTER_URI,
|
83 |
MONGODB_DB_NAME + "." + MONGODB_COLLECTION_NAME,
|
84 |
OpenAIEmbeddings(disallowed_special = ()),
|
|
|
102 |
llm = get_llm(config, openai_api_key)
|
103 |
|
104 |
if (rag_option == RAG_CHROMA):
|
105 |
+
db = document_retrieval_chroma()
|
106 |
+
elif (rag_option == RAG_MONGODB):
|
107 |
+
db = document_retrieval_mongodb()
|
108 |
|
109 |
rag_chain = RetrievalQA.from_chain_type(llm,
|
110 |
chain_type_kwargs = {"prompt": RAG_CHAIN_PROMPT},
|