bstraehle commited on
Commit
9c86fb0
·
1 Parent(s): 090c3ab

Update rag.py

Browse files
Files changed (1) hide show
  1. rag.py +17 -8
rag.py CHANGED
@@ -71,6 +71,14 @@ def document_storage_mongodb(chunks):
71
  collection = collection,
72
  index_name = MONGODB_INDEX_NAME)
73
 
 
 
 
 
 
 
 
 
74
  def document_retrieval_chroma():
75
  return Chroma(embedding_function = OpenAIEmbeddings(disallowed_special = ()),
76
  persist_directory = CHROMA_DIR)
@@ -81,14 +89,6 @@ def document_retrieval_mongodb():
81
  OpenAIEmbeddings(disallowed_special = ()),
82
  index_name = MONGODB_INDEX_NAME)
83
 
84
- def rag_batch(config):
85
- docs = document_loading()
86
-
87
- chunks = document_splitting(config, docs)
88
-
89
- document_storage_chroma(chunks)
90
- document_storage_mongodb(chunks)
91
-
92
  def get_llm(config, openai_api_key):
93
  return ChatOpenAI(model_name = config["model_name"],
94
  openai_api_key = openai_api_key,
@@ -110,6 +110,14 @@ def rag_chain(config, openai_api_key, rag_option, prompt):
110
  db = document_retrieval_chroma()
111
  elif (rag_option == RAG_MONGODB):
112
  db = document_retrieval_mongodb()
 
 
 
 
 
 
 
 
113
 
114
  rag_chain = RetrievalQA.from_chain_type(llm,
115
  chain_type_kwargs = {"prompt": RAG_CHAIN_PROMPT},
@@ -118,6 +126,7 @@ def rag_chain(config, openai_api_key, rag_option, prompt):
118
  verbose = False)
119
 
120
  completion = rag_chain({"query": prompt}, include_run_info = True)
 
121
  print("###" + str(completion["__run"]))
122
 
123
  return completion, rag_chain
 
71
  collection = collection,
72
  index_name = MONGODB_INDEX_NAME)
73
 
74
+ def rag_batch(config):
75
+ docs = document_loading()
76
+
77
+ chunks = document_splitting(config, docs)
78
+
79
+ document_storage_chroma(chunks)
80
+ document_storage_mongodb(chunks)
81
+
82
  def document_retrieval_chroma():
83
  return Chroma(embedding_function = OpenAIEmbeddings(disallowed_special = ()),
84
  persist_directory = CHROMA_DIR)
 
89
  OpenAIEmbeddings(disallowed_special = ()),
90
  index_name = MONGODB_INDEX_NAME)
91
 
 
 
 
 
 
 
 
 
92
  def get_llm(config, openai_api_key):
93
  return ChatOpenAI(model_name = config["model_name"],
94
  openai_api_key = openai_api_key,
 
110
  db = document_retrieval_chroma()
111
  elif (rag_option == RAG_MONGODB):
112
  db = document_retrieval_mongodb()
113
+
114
+ ###
115
+ retriever = db.as_retriever(search_kwargs = {"k": config["k"]})
116
+ retrieved_docs = retriever.invoke(prompt)
117
+ print(retrieved_docs[0].page_content)
118
+ print(retrieved_docs[1].page_content)
119
+ print(retrieved_docs[2].page_content)
120
+ ###
121
 
122
  rag_chain = RetrievalQA.from_chain_type(llm,
123
  chain_type_kwargs = {"prompt": RAG_CHAIN_PROMPT},
 
126
  verbose = False)
127
 
128
  completion = rag_chain({"query": prompt}, include_run_info = True)
129
+ print("###" + str(completion))
130
  print("###" + str(completion["__run"]))
131
 
132
  return completion, rag_chain