bstraehle commited on
Commit
40e55f0
·
1 Parent(s): 4e80daf

Update rag.py

Browse files
Files changed (1) hide show
  1. rag.py +8 -5
rag.py CHANGED
@@ -14,6 +14,9 @@ from langchain.vectorstores import MongoDBAtlasVectorSearch
14
 
15
  from pymongo import MongoClient
16
 
 
 
 
17
  PDF_URL = "https://arxiv.org/pdf/2303.08774.pdf"
18
  WEB_URL = "https://openai.com/research/gpt-4"
19
  YOUTUBE_URL_1 = "https://www.youtube.com/watch?v=--khbXchTeE"
@@ -71,11 +74,11 @@ def document_storage_mongodb(documents):
71
  collection = collection,
72
  index_name = MONGODB_INDEX_NAME)
73
 
74
- def document_retrieval_chroma(llm, prompt):
75
  return Chroma(embedding_function = OpenAIEmbeddings(),
76
  persist_directory = CHROMA_DIR)
77
 
78
- def document_retrieval_mongodb(llm, prompt):
79
  return MongoDBAtlasVectorSearch.from_connection_string(MONGODB_ATLAS_CLUSTER_URI,
80
  MONGODB_DB_NAME + "." + MONGODB_COLLECTION_NAME,
81
  OpenAIEmbeddings(disallowed_special = ()),
@@ -99,9 +102,9 @@ def rag_chain(config, openai_api_key, rag_option, prompt):
99
  llm = get_llm(config, openai_api_key)
100
 
101
  if (rag_option == RAG_CHROMA):
102
- db = document_retrieval_chroma(llm, prompt)
103
- else:
104
- db = document_retrieval_mongodb(llm, prompt)
105
 
106
  rag_chain = RetrievalQA.from_chain_type(llm,
107
  chain_type_kwargs = {"prompt": RAG_CHAIN_PROMPT},
 
14
 
15
  from pymongo import MongoClient
16
 
17
+ RAG_CHROMA = "Chroma"
18
+ RAG_MONGODB = "MongoDB"
19
+
20
  PDF_URL = "https://arxiv.org/pdf/2303.08774.pdf"
21
  WEB_URL = "https://openai.com/research/gpt-4"
22
  YOUTUBE_URL_1 = "https://www.youtube.com/watch?v=--khbXchTeE"
 
74
  collection = collection,
75
  index_name = MONGODB_INDEX_NAME)
76
 
77
+ def document_retrieval_chroma():
78
  return Chroma(embedding_function = OpenAIEmbeddings(),
79
  persist_directory = CHROMA_DIR)
80
 
81
+ def document_retrieval_mongodb():
82
  return MongoDBAtlasVectorSearch.from_connection_string(MONGODB_ATLAS_CLUSTER_URI,
83
  MONGODB_DB_NAME + "." + MONGODB_COLLECTION_NAME,
84
  OpenAIEmbeddings(disallowed_special = ()),
 
102
  llm = get_llm(config, openai_api_key)
103
 
104
  if (rag_option == RAG_CHROMA):
105
+ db = document_retrieval_chroma()
106
+ elif (rag_option == RAG_MONGODB):
107
+ db = document_retrieval_mongodb()
108
 
109
  rag_chain = RetrievalQA.from_chain_type(llm,
110
  chain_type_kwargs = {"prompt": RAG_CHAIN_PROMPT},