bstraehle commited on
Commit
26b6a5b
1 Parent(s): 4ca4f77

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -10
app.py CHANGED
@@ -115,13 +115,13 @@ def rag_chain(llm, prompt, db):
115
  completion = rag_chain({"query": prompt})
116
  return completion
117
 
118
- def wandb_trace(rag_option, prompt, prompt_template, result, completion, chain_name):
119
  wandb.init(project = "openai-llm-rag")
120
  trace = Trace(
121
  name = chain_name,
122
  kind = "chain",
123
- #status_code = status,
124
- #status_message = status_message,
125
  metadata={
126
  "chunk_overlap": config["chunk_overlap"],
127
  "chunk_size": config["chunk_size"],
@@ -129,8 +129,8 @@ def wandb_trace(rag_option, prompt, prompt_template, result, completion, chain_n
129
  "model": config["model"],
130
  "temperature": config["temperature"],
131
  },
132
- #start_time_ms = start_time_ms,
133
- #end_time_ms = end_time_ms,
134
  inputs = {"rag_option": rag_option, "prompt": prompt, "prompt_template": prompt_template},
135
  outputs = {"result": str(result), "completion": str(completion)},
136
  )
@@ -144,6 +144,8 @@ def invoke(openai_api_key, rag_option, prompt):
144
  raise gr.Error("Retrieval Augmented Generation is required.")
145
  if (prompt == ""):
146
  raise gr.Error("Prompt is required.")
 
 
147
  try:
148
  llm = ChatOpenAI(model_name = config["model"],
149
  openai_api_key = openai_api_key,
@@ -155,7 +157,7 @@ def invoke(openai_api_key, rag_option, prompt):
155
  completion = rag_chain(llm, prompt, db)
156
  result = completion["result"]
157
  prompt_template = rag_template
158
- chain_name = type(RetrievalQA)
159
  elif (rag_option == "MongoDB"):
160
  #splits = document_loading_splitting()
161
  #document_storage_mongodb(splits)
@@ -163,17 +165,17 @@ def invoke(openai_api_key, rag_option, prompt):
163
  completion = rag_chain(llm, prompt, db)
164
  result = completion["result"]
165
  prompt_template = rag_template
166
- chain_name = type(RetrievalQA)
167
  else:
168
  result = llm_chain(llm, prompt)
169
  completion = result
170
  prompt_template = llm_template
171
- chain_name = type(LLMChain)
172
  except Exception as e:
173
- completion = e
174
  raise gr.Error(e)
175
  finally:
176
- wandb_trace(rag_option, prompt, prompt_template, result, completion, chain_name)
177
  return result
178
 
179
  description = """<strong>Overview:</strong> Context-aware multimodal reasoning application using a <strong>large language model (LLM)</strong> with
 
115
  completion = rag_chain({"query": prompt})
116
  return completion
117
 
118
+ def wandb_trace(rag_option, prompt, prompt_template, result, completion, chain_name, status_msg):
119
  wandb.init(project = "openai-llm-rag")
120
  trace = Trace(
121
  name = chain_name,
122
  kind = "chain",
123
+ status_code = "TODO",
124
+ status_message = status_msg,
125
  metadata={
126
  "chunk_overlap": config["chunk_overlap"],
127
  "chunk_size": config["chunk_size"],
 
129
  "model": config["model"],
130
  "temperature": config["temperature"],
131
  },
132
+ start_time_ms = 123,
133
+ end_time_ms = 456,
134
  inputs = {"rag_option": rag_option, "prompt": prompt, "prompt_template": prompt_template},
135
  outputs = {"result": str(result), "completion": str(completion)},
136
  )
 
144
  raise gr.Error("Retrieval Augmented Generation is required.")
145
  if (prompt == ""):
146
  raise gr.Error("Prompt is required.")
147
+ completion = ""
148
+ status_msg = ""
149
  try:
150
  llm = ChatOpenAI(model_name = config["model"],
151
  openai_api_key = openai_api_key,
 
157
  completion = rag_chain(llm, prompt, db)
158
  result = completion["result"]
159
  prompt_template = rag_template
160
+ chain_name = RetrievalQA.__class__.__name__
161
  elif (rag_option == "MongoDB"):
162
  #splits = document_loading_splitting()
163
  #document_storage_mongodb(splits)
 
165
  completion = rag_chain(llm, prompt, db)
166
  result = completion["result"]
167
  prompt_template = rag_template
168
+ chain_name = RetrievalQA.__class__.__name__
169
  else:
170
  result = llm_chain(llm, prompt)
171
  completion = result
172
  prompt_template = llm_template
173
+ chain_name = LLMChain.__class__.__name__
174
  except Exception as e:
175
+ status_msg = e
176
  raise gr.Error(e)
177
  finally:
178
+ wandb_trace(rag_option, prompt, prompt_template, result, completion, chain_name, status_msg)
179
  return result
180
 
181
  description = """<strong>Overview:</strong> Context-aware multimodal reasoning application using a <strong>large language model (LLM)</strong> with