bstraehle commited on
Commit
503e34f
1 Parent(s): bf1b617

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +20 -14
app.py CHANGED
@@ -69,26 +69,31 @@ def document_storage_chroma(splits):
69
  persist_directory = CHROMA_DIR)
70
 
71
  def document_storage_mongodb(splits):
 
72
  vector_db = Chroma.from_documents(documents = splits,
73
  embedding = OpenAIEmbeddings(disallowed_special = ()),
74
  persist_directory = CHROMA_DIR)
75
 
76
  def document_retrieval_chroma(llm, prompt):
77
- vector_db = Chroma(embedding_function = OpenAIEmbeddings(),
78
  persist_directory = CHROMA_DIR)
79
- rag_chain = RetrievalQA.from_chain_type(llm,
80
- chain_type_kwargs = {"prompt": RAG_CHAIN_PROMPT},
81
- retriever = vector_db.as_retriever(search_kwargs = {"k": 3}),
82
- return_source_documents = True)
83
- result = rag_chain({"query": prompt})
84
- return result["result"]
85
 
86
  def document_retrieval_mongodb(llm, prompt):
87
- vector_db = Chroma(embedding_function = OpenAIEmbeddings(),
88
- persist_directory = CHROMA_DIR)
 
 
 
 
 
 
 
 
 
89
  rag_chain = RetrievalQA.from_chain_type(llm,
90
  chain_type_kwargs = {"prompt": RAG_CHAIN_PROMPT},
91
- retriever = vector_db.as_retriever(search_kwargs = {"k": 3}),
92
  return_source_documents = True)
93
  result = rag_chain({"query": prompt})
94
  return result["result"]
@@ -107,13 +112,14 @@ def invoke(openai_api_key, rag_option, prompt):
107
  #splits = document_loading_splitting()
108
  if (rag_option == "Chroma"):
109
  #document_storage_chroma(splits)
110
- result = document_retrieval_chroma(llm, prompt)
 
111
  elif (rag_option == "MongoDB"):
112
  #document_storage_mongodb(splits)
113
- result = document_retrieval_mongodb(llm, prompt)
 
114
  else:
115
- chain = LLMChain(llm = llm, prompt = LLM_CHAIN_PROMPT)
116
- result = chain.run({"question": prompt})
117
  except Exception as e:
118
  raise gr.Error(e)
119
  return result
 
69
  persist_directory = CHROMA_DIR)
70
 
71
  def document_storage_mongodb(splits):
72
+ #TODO
73
  vector_db = Chroma.from_documents(documents = splits,
74
  embedding = OpenAIEmbeddings(disallowed_special = ()),
75
  persist_directory = CHROMA_DIR)
76
 
77
  def document_retrieval_chroma(llm, prompt):
78
+ db = Chroma(embedding_function = OpenAIEmbeddings(),
79
  persist_directory = CHROMA_DIR)
80
+ return db
 
 
 
 
 
81
 
82
  def document_retrieval_mongodb(llm, prompt):
83
+ #TODO
84
+ db = Chroma(embedding_function = OpenAIEmbeddings(),
85
+ persist_directory = CHROMA_DIR)
86
+ return db
87
+
88
+ def llm_chain(llm, prompt):
89
+ llm_chain = LLMChain(llm = llm, prompt = LLM_CHAIN_PROMPT)
90
+ result = llm_chain.run({"question": prompt})
91
+ return result
92
+
93
+ def rag_chain(llm, prompt, db):
94
  rag_chain = RetrievalQA.from_chain_type(llm,
95
  chain_type_kwargs = {"prompt": RAG_CHAIN_PROMPT},
96
+ retriever = db.as_retriever(search_kwargs = {"k": 3}),
97
  return_source_documents = True)
98
  result = rag_chain({"query": prompt})
99
  return result["result"]
 
112
  #splits = document_loading_splitting()
113
  if (rag_option == "Chroma"):
114
  #document_storage_chroma(splits)
115
+ db = document_retrieval_chroma(llm, prompt)
116
+ result = rag_chain(llm, prompt, db)
117
  elif (rag_option == "MongoDB"):
118
  #document_storage_mongodb(splits)
119
+ db = document_retrieval_mongodb(llm, prompt)
120
+ result = rag_chain(llm, prompt, db)
121
  else:
122
+ result = llm_chain(llm, prompt)
 
123
  except Exception as e:
124
  raise gr.Error(e)
125
  return result