bstraehle commited on
Commit
1e517cc
1 Parent(s): 14172a1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -14
app.py CHANGED
@@ -119,8 +119,21 @@ def rag_chain(llm, prompt, db):
119
  completion = rag_chain({"query": prompt})
120
  return completion, rag_chain
121
 
122
- def wandb_trace(rag_option, prompt, completion, result, generation_info, llm_output, chain, err_msg, start_time_ms, end_time_ms):
 
 
 
 
 
 
 
 
 
 
 
 
123
  wandb.init(project = "openai-llm-rag")
 
124
  trace = Trace(
125
  kind = "chain",
126
  name = "" if (chain == None) else type(chain).__name__,
@@ -130,9 +143,9 @@ def wandb_trace(rag_option, prompt, completion, result, generation_info, llm_out
130
  "chunk_overlap": config["chunk_overlap"] if (str(err_msg) == "" and rag_option != RAG_OFF) else "",
131
  "chunk_size": config["chunk_size"] if (str(err_msg) == "" and rag_option != RAG_OFF) else "",
132
  },
133
- inputs = {"rag_option": rag_option if (str(err_msg) == "") else "",
134
- "prompt": prompt if (str(err_msg) == "") else "",
135
- },
136
  outputs = {"result": result if (str(err_msg) == "") else "",
137
  "generation_info": str(generation_info) if (str(err_msg) == "") else "",
138
  "llm_output": str(llm_output) if (str(err_msg) == "") else "",
@@ -151,6 +164,7 @@ def wandb_trace(rag_option, prompt, completion, result, generation_info, llm_out
151
  start_time_ms = start_time_ms,
152
  end_time_ms = end_time_ms
153
  )
 
154
  trace.log("test")
155
  wandb.finish()
156
 
@@ -161,41 +175,35 @@ def invoke(openai_api_key, rag_option, prompt):
161
  raise gr.Error("Retrieval Augmented Generation is required.")
162
  if (prompt == ""):
163
  raise gr.Error("Prompt is required.")
 
164
  chain = None
165
  completion = ""
166
- result = ""
167
- generation_info = ""
168
- llm_output = ""
169
  err_msg = ""
 
170
  try:
171
  start_time_ms = round(time.time() * 1000)
172
  llm = ChatOpenAI(model_name = config["model_name"],
173
  openai_api_key = openai_api_key,
174
  temperature = config["temperature"])
 
175
  if (rag_option == RAG_CHROMA):
176
  #splits = document_loading_splitting()
177
  #document_storage_chroma(splits)
178
  db = document_retrieval_chroma(llm, prompt)
179
  completion, chain = rag_chain(llm, prompt, db)
180
- result = completion["result"]
181
  elif (rag_option == RAG_MONGODB):
182
  #splits = document_loading_splitting()
183
  #document_storage_mongodb(splits)
184
  db = document_retrieval_mongodb(llm, prompt)
185
  completion, chain = rag_chain(llm, prompt, db)
186
- result = completion["result"]
187
  else:
188
  completion, chain = llm_chain(llm, prompt)
189
- if (completion.generations[0] != None and completion.generations[0][0] != None):
190
- result = completion.generations[0][0].text
191
- generation_info = completion.generations[0][0].generation_info
192
- llm_output = completion.llm_output
193
  except Exception as e:
194
  err_msg = e
195
  raise gr.Error(e)
196
  finally:
197
  end_time_ms = round(time.time() * 1000)
198
- wandb_trace(rag_option, prompt, completion, result, generation_info, llm_output, chain, err_msg, start_time_ms, end_time_ms)
199
  return result
200
 
201
  gr.close_all()
 
119
  completion = rag_chain({"query": prompt})
120
  return completion, rag_chain
121
 
122
+ def wandb_trace(rag_option, prompt, completion, chain, err_msg, start_time_ms, end_time_ms):
123
+ result = ""
124
+ generation_info = ""
125
+ llm_output = ""
126
+
127
+ if (rag_option == RAG_OFF):
128
+ if (completion.generations[0] != None and completion.generations[0][0] != None):
129
+ result = completion.generations[0][0].text
130
+ generation_info = completion.generations[0][0].generation_info
131
+ llm_output = completion.llm_output
132
+ else:
133
+ result = completion["result"]
134
+
135
  wandb.init(project = "openai-llm-rag")
136
+
137
  trace = Trace(
138
  kind = "chain",
139
  name = "" if (chain == None) else type(chain).__name__,
 
143
  "chunk_overlap": config["chunk_overlap"] if (str(err_msg) == "" and rag_option != RAG_OFF) else "",
144
  "chunk_size": config["chunk_size"] if (str(err_msg) == "" and rag_option != RAG_OFF) else "",
145
  },
146
+ inputs = {"rag_option": rag_option,
147
+ "prompt": prompt,
148
+ } if (str(err_msg) == "") else "",
149
  outputs = {"result": result if (str(err_msg) == "") else "",
150
  "generation_info": str(generation_info) if (str(err_msg) == "") else "",
151
  "llm_output": str(llm_output) if (str(err_msg) == "") else "",
 
164
  start_time_ms = start_time_ms,
165
  end_time_ms = end_time_ms
166
  )
167
+
168
  trace.log("test")
169
  wandb.finish()
170
 
 
175
  raise gr.Error("Retrieval Augmented Generation is required.")
176
  if (prompt == ""):
177
  raise gr.Error("Prompt is required.")
178
+
179
  chain = None
180
  completion = ""
 
 
 
181
  err_msg = ""
182
+
183
  try:
184
  start_time_ms = round(time.time() * 1000)
185
  llm = ChatOpenAI(model_name = config["model_name"],
186
  openai_api_key = openai_api_key,
187
  temperature = config["temperature"])
188
+
189
  if (rag_option == RAG_CHROMA):
190
  #splits = document_loading_splitting()
191
  #document_storage_chroma(splits)
192
  db = document_retrieval_chroma(llm, prompt)
193
  completion, chain = rag_chain(llm, prompt, db)
 
194
  elif (rag_option == RAG_MONGODB):
195
  #splits = document_loading_splitting()
196
  #document_storage_mongodb(splits)
197
  db = document_retrieval_mongodb(llm, prompt)
198
  completion, chain = rag_chain(llm, prompt, db)
 
199
  else:
200
  completion, chain = llm_chain(llm, prompt)
 
 
 
 
201
  except Exception as e:
202
  err_msg = e
203
  raise gr.Error(e)
204
  finally:
205
  end_time_ms = round(time.time() * 1000)
206
+ wandb_trace(rag_option, prompt, completion, chain, err_msg, start_time_ms, end_time_ms)
207
  return result
208
 
209
  gr.close_all()