Update rag.py
Browse files
rag.py
CHANGED
@@ -71,6 +71,14 @@ def document_storage_mongodb(chunks):
|
|
71 |
collection = collection,
|
72 |
index_name = MONGODB_INDEX_NAME)
|
73 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
74 |
def document_retrieval_chroma():
|
75 |
return Chroma(embedding_function = OpenAIEmbeddings(disallowed_special = ()),
|
76 |
persist_directory = CHROMA_DIR)
|
@@ -81,14 +89,6 @@ def document_retrieval_mongodb():
|
|
81 |
OpenAIEmbeddings(disallowed_special = ()),
|
82 |
index_name = MONGODB_INDEX_NAME)
|
83 |
|
84 |
-
def rag_batch(config):
|
85 |
-
docs = document_loading()
|
86 |
-
|
87 |
-
chunks = document_splitting(config, docs)
|
88 |
-
|
89 |
-
document_storage_chroma(chunks)
|
90 |
-
document_storage_mongodb(chunks)
|
91 |
-
|
92 |
def get_llm(config, openai_api_key):
|
93 |
return ChatOpenAI(model_name = config["model_name"],
|
94 |
openai_api_key = openai_api_key,
|
@@ -110,6 +110,14 @@ def rag_chain(config, openai_api_key, rag_option, prompt):
|
|
110 |
db = document_retrieval_chroma()
|
111 |
elif (rag_option == RAG_MONGODB):
|
112 |
db = document_retrieval_mongodb()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
113 |
|
114 |
rag_chain = RetrievalQA.from_chain_type(llm,
|
115 |
chain_type_kwargs = {"prompt": RAG_CHAIN_PROMPT},
|
@@ -118,6 +126,7 @@ def rag_chain(config, openai_api_key, rag_option, prompt):
|
|
118 |
verbose = False)
|
119 |
|
120 |
completion = rag_chain({"query": prompt}, include_run_info = True)
|
|
|
121 |
print("###" + str(completion["__run"]))
|
122 |
|
123 |
return completion, rag_chain
|
|
|
71 |
collection = collection,
|
72 |
index_name = MONGODB_INDEX_NAME)
|
73 |
|
74 |
+
def rag_batch(config):
|
75 |
+
docs = document_loading()
|
76 |
+
|
77 |
+
chunks = document_splitting(config, docs)
|
78 |
+
|
79 |
+
document_storage_chroma(chunks)
|
80 |
+
document_storage_mongodb(chunks)
|
81 |
+
|
82 |
def document_retrieval_chroma():
|
83 |
return Chroma(embedding_function = OpenAIEmbeddings(disallowed_special = ()),
|
84 |
persist_directory = CHROMA_DIR)
|
|
|
89 |
OpenAIEmbeddings(disallowed_special = ()),
|
90 |
index_name = MONGODB_INDEX_NAME)
|
91 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
92 |
def get_llm(config, openai_api_key):
|
93 |
return ChatOpenAI(model_name = config["model_name"],
|
94 |
openai_api_key = openai_api_key,
|
|
|
110 |
db = document_retrieval_chroma()
|
111 |
elif (rag_option == RAG_MONGODB):
|
112 |
db = document_retrieval_mongodb()
|
113 |
+
|
114 |
+
###
|
115 |
+
retriever = db.as_retriever(search_kwargs = {"k": config["k"]})
|
116 |
+
retrieved_docs = retriever.invoke(prompt)
|
117 |
+
print(retrieved_docs[0].page_content)
|
118 |
+
print(retrieved_docs[1].page_content)
|
119 |
+
print(retrieved_docs[2].page_content)
|
120 |
+
###
|
121 |
|
122 |
rag_chain = RetrievalQA.from_chain_type(llm,
|
123 |
chain_type_kwargs = {"prompt": RAG_CHAIN_PROMPT},
|
|
|
126 |
verbose = False)
|
127 |
|
128 |
completion = rag_chain({"query": prompt}, include_run_info = True)
|
129 |
+
print("###" + str(completion))
|
130 |
print("###" + str(completion["__run"]))
|
131 |
|
132 |
return completion, rag_chain
|