bstraehle commited on
Commit
1dcde2f
1 Parent(s): dc7f3cd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -9
app.py CHANGED
@@ -106,7 +106,7 @@ def document_retrieval_mongodb(llm, prompt):
106
  def llm_chain(llm, prompt):
107
  llm_chain = LLMChain(llm = llm, prompt = LLM_CHAIN_PROMPT)
108
  completion = llm_chain.run({"question": prompt})
109
- return completion
110
 
111
  def rag_chain(llm, prompt, db):
112
  rag_chain = RetrievalQA.from_chain_type(llm,
@@ -114,9 +114,9 @@ def rag_chain(llm, prompt, db):
114
  retriever = db.as_retriever(search_kwargs = {"k": config["k"]}),
115
  return_source_documents = True)
116
  completion = rag_chain({"query": prompt})
117
- return completion
118
 
119
- def wandb_trace(rag_option, prompt, completion, status_msg, start_time_ms, end_time_ms):
120
  wandb.init(project = "openai-llm-rag")
121
  if (rag_option == "Off" or str(status_msg) != ""):
122
  result = completion
@@ -126,7 +126,7 @@ def wandb_trace(rag_option, prompt, completion, status_msg, start_time_ms, end_t
126
  document_1 = completion["source_documents"][1]
127
  document_2 = completion["source_documents"][2]
128
  trace = Trace(
129
- kind = "llm",
130
  name = "LLMChain" if (rag_option == "Off") else "RetrievalQA",
131
  status_code = "SUCCESS" if (str(status_msg) == "") else "ERROR",
132
  status_message = str(status_msg),
@@ -145,7 +145,8 @@ def wandb_trace(rag_option, prompt, completion, status_msg, start_time_ms, end_t
145
  "document_2": "" if (rag_option == "Off" or str(status_msg) != "") else str(document_2)},
146
  outputs = {"result": result},
147
  start_time_ms = start_time_ms,
148
- end_time_ms = end_time_ms
 
149
  )
150
  trace.log("test")
151
  wandb.finish()
@@ -159,6 +160,7 @@ def invoke(openai_api_key, rag_option, prompt):
159
  raise gr.Error("Prompt is required.")
160
  completion = ""
161
  result = ""
 
162
  status_msg = ""
163
  try:
164
  start_time_ms = round(time.time() * 1000)
@@ -169,23 +171,23 @@ def invoke(openai_api_key, rag_option, prompt):
169
  #splits = document_loading_splitting()
170
  #document_storage_chroma(splits)
171
  db = document_retrieval_chroma(llm, prompt)
172
- completion = rag_chain(llm, prompt, db)
173
  result = completion["result"]
174
  elif (rag_option == "MongoDB"):
175
  #splits = document_loading_splitting()
176
  #document_storage_mongodb(splits)
177
  db = document_retrieval_mongodb(llm, prompt)
178
- completion = rag_chain(llm, prompt, db)
179
  result = completion["result"]
180
  else:
181
- result = llm_chain(llm, prompt)
182
  completion = result
183
  except Exception as e:
184
  status_msg = e
185
  raise gr.Error(e)
186
  finally:
187
  end_time_ms = round(time.time() * 1000)
188
- wandb_trace(rag_option, prompt, completion, status_msg, start_time_ms, end_time_ms)
189
  return result
190
 
191
  gr.close_all()
 
106
  def llm_chain(llm, prompt):
107
  llm_chain = LLMChain(llm = llm, prompt = LLM_CHAIN_PROMPT)
108
  completion = llm_chain.run({"question": prompt})
109
+ return completion, llm_chain
110
 
111
  def rag_chain(llm, prompt, db):
112
  rag_chain = RetrievalQA.from_chain_type(llm,
 
114
  retriever = db.as_retriever(search_kwargs = {"k": config["k"]}),
115
  return_source_documents = True)
116
  completion = rag_chain({"query": prompt})
117
+ return completion, rag_chain
118
 
119
+ def wandb_trace(rag_option, prompt, completion, chain, status_msg, start_time_ms, end_time_ms):
120
  wandb.init(project = "openai-llm-rag")
121
  if (rag_option == "Off" or str(status_msg) != ""):
122
  result = completion
 
126
  document_1 = completion["source_documents"][1]
127
  document_2 = completion["source_documents"][2]
128
  trace = Trace(
129
+ kind = "chain",
130
  name = "LLMChain" if (rag_option == "Off") else "RetrievalQA",
131
  status_code = "SUCCESS" if (str(status_msg) == "") else "ERROR",
132
  status_message = str(status_msg),
 
145
  "document_2": "" if (rag_option == "Off" or str(status_msg) != "") else str(document_2)},
146
  outputs = {"result": result},
147
  start_time_ms = start_time_ms,
148
+ end_time_ms = end_time_ms,
149
+ model_dict={"chain": chain}
150
  )
151
  trace.log("test")
152
  wandb.finish()
 
160
  raise gr.Error("Prompt is required.")
161
  completion = ""
162
  result = ""
163
+ chain = ""
164
  status_msg = ""
165
  try:
166
  start_time_ms = round(time.time() * 1000)
 
171
  #splits = document_loading_splitting()
172
  #document_storage_chroma(splits)
173
  db = document_retrieval_chroma(llm, prompt)
174
+ completion, chain = rag_chain(llm, prompt, db)
175
  result = completion["result"]
176
  elif (rag_option == "MongoDB"):
177
  #splits = document_loading_splitting()
178
  #document_storage_mongodb(splits)
179
  db = document_retrieval_mongodb(llm, prompt)
180
+ completion, chain = rag_chain(llm, prompt, db)
181
  result = completion["result"]
182
  else:
183
+ result, chain = llm_chain(llm, prompt)
184
  completion = result
185
  except Exception as e:
186
  status_msg = e
187
  raise gr.Error(e)
188
  finally:
189
  end_time_ms = round(time.time() * 1000)
190
+ wandb_trace(rag_option, prompt, completion, chain, status_msg, start_time_ms, end_time_ms)
191
  return result
192
 
193
  gr.close_all()