bstraehle commited on
Commit
542a800
1 Parent(s): 1dcde2f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -4
app.py CHANGED
@@ -104,7 +104,7 @@ def document_retrieval_mongodb(llm, prompt):
104
  return db
105
 
106
  def llm_chain(llm, prompt):
107
- llm_chain = LLMChain(llm = llm, prompt = LLM_CHAIN_PROMPT)
108
  completion = llm_chain.run({"question": prompt})
109
  return completion, llm_chain
110
 
@@ -112,11 +112,15 @@ def rag_chain(llm, prompt, db):
112
  rag_chain = RetrievalQA.from_chain_type(llm,
113
  chain_type_kwargs = {"prompt": RAG_CHAIN_PROMPT},
114
  retriever = db.as_retriever(search_kwargs = {"k": config["k"]}),
115
- return_source_documents = True)
 
116
  completion = rag_chain({"query": prompt})
117
  return completion, rag_chain
118
 
119
  def wandb_trace(rag_option, prompt, completion, chain, status_msg, start_time_ms, end_time_ms):
 
 
 
120
  wandb.init(project = "openai-llm-rag")
121
  if (rag_option == "Off" or str(status_msg) != ""):
122
  result = completion
@@ -145,8 +149,7 @@ def wandb_trace(rag_option, prompt, completion, chain, status_msg, start_time_ms
145
  "document_2": "" if (rag_option == "Off" or str(status_msg) != "") else str(document_2)},
146
  outputs = {"result": result},
147
  start_time_ms = start_time_ms,
148
- end_time_ms = end_time_ms,
149
- model_dict={"chain": chain}
150
  )
151
  trace.log("test")
152
  wandb.finish()
 
104
  return db
105
 
106
  def llm_chain(llm, prompt):
107
+ llm_chain = LLMChain(llm = llm, prompt = LLM_CHAIN_PROMPT, verbose = True)
108
  completion = llm_chain.run({"question": prompt})
109
  return completion, llm_chain
110
 
 
112
  rag_chain = RetrievalQA.from_chain_type(llm,
113
  chain_type_kwargs = {"prompt": RAG_CHAIN_PROMPT},
114
  retriever = db.as_retriever(search_kwargs = {"k": config["k"]}),
115
+ return_source_documents = True,
116
+ verbose = True)
117
  completion = rag_chain({"query": prompt})
118
  return completion, rag_chain
119
 
120
  def wandb_trace(rag_option, prompt, completion, chain, status_msg, start_time_ms, end_time_ms):
121
+ #print(chain.inputKey)
122
+ #print(chain.outputKey)
123
+ #print(chain.retriever)
124
  wandb.init(project = "openai-llm-rag")
125
  if (rag_option == "Off" or str(status_msg) != ""):
126
  result = completion
 
149
  "document_2": "" if (rag_option == "Off" or str(status_msg) != "") else str(document_2)},
150
  outputs = {"result": result},
151
  start_time_ms = start_time_ms,
152
+ end_time_ms = end_time_ms
 
153
  )
154
  trace.log("test")
155
  wandb.finish()