bstraehle commited on
Commit
4fb4308
1 Parent(s): 86ffba3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -42
app.py CHANGED
@@ -1,5 +1,5 @@
1
  import gradio as gr
2
- import openai, os, time, wandb
3
 
4
  from dotenv import load_dotenv, find_dotenv
5
  from langchain.chains import LLMChain, RetrievalQA
@@ -14,8 +14,8 @@ from langchain.text_splitter import RecursiveCharacterTextSplitter
14
  from langchain.vectorstores import Chroma
15
  from langchain.vectorstores import MongoDBAtlasVectorSearch
16
  from pymongo import MongoClient
 
17
  from trace import wandb_trace
18
- #from wandb.sdk.data_types.trace_tree import Trace
19
 
20
  _ = load_dotenv(find_dotenv())
21
 
@@ -36,8 +36,6 @@ MONGODB_INDEX_NAME = "default"
36
  LLM_CHAIN_PROMPT = PromptTemplate(input_variables = ["question"], template = os.environ["LLM_TEMPLATE"])
37
  RAG_CHAIN_PROMPT = PromptTemplate(input_variables = ["context", "question"], template = os.environ["RAG_TEMPLATE"])
38
 
39
- #WANDB_API_KEY = os.environ["WANDB_API_KEY"]
40
-
41
  RAG_OFF = "Off"
42
  RAG_CHROMA = "Chroma"
43
  RAG_MONGODB = "MongoDB"
@@ -116,43 +114,6 @@ def rag_chain(llm, prompt, db):
116
  completion = rag_chain({"query": prompt})
117
  return completion, rag_chain
118
 
119
- #def wandb_trace(rag_option, prompt, completion, result, generation_info, llm_output, chain, err_msg, start_time_ms, end_time_ms):
120
- # wandb.init(project = "openai-llm-rag")
121
- #
122
- # trace = Trace(
123
- # kind = "chain",
124
- # name = "" if (chain == None) else type(chain).__name__,
125
- # status_code = "success" if (str(err_msg) == "") else "error",
126
- # status_message = str(err_msg),
127
- # metadata = {"chunk_overlap": "" if (rag_option == RAG_OFF) else config["chunk_overlap"],
128
- # "chunk_size": "" if (rag_option == RAG_OFF) else config["chunk_size"],
129
- # } if (str(err_msg) == "") else {},
130
- # inputs = {"rag_option": rag_option,
131
- # "prompt": prompt,
132
- # "chain_prompt": (str(chain.prompt) if (rag_option == RAG_OFF) else
133
- # str(chain.combine_documents_chain.llm_chain.prompt)),
134
- # "source_documents": "" if (rag_option == RAG_OFF) else str([doc.metadata["source"] for doc in completion["source_documents"]]),
135
- # } if (str(err_msg) == "") else {},
136
- # outputs = {"result": result,
137
- # "generation_info": str(generation_info),
138
- # "llm_output": str(llm_output),
139
- # "completion": str(completion),
140
- # } if (str(err_msg) == "") else {},
141
- # model_dict = {"client": (str(chain.llm.client) if (rag_option == RAG_OFF) else
142
- # str(chain.combine_documents_chain.llm_chain.llm.client)),
143
- # "model_name": (str(chain.llm.model_name) if (rag_option == RAG_OFF) else
144
- # str(chain.combine_documents_chain.llm_chain.llm.model_name)),
145
- # "temperature": (str(chain.llm.temperature) if (rag_option == RAG_OFF) else
146
- # str(chain.combine_documents_chain.llm_chain.llm.temperature)),
147
- # "retriever": ("" if (rag_option == RAG_OFF) else str(chain.retriever)),
148
- # } if (str(err_msg) == "") else {},
149
- # start_time_ms = start_time_ms,
150
- # end_time_ms = end_time_ms
151
- # )
152
- #
153
- # trace.log("evaluation")
154
- # wandb.finish()
155
-
156
  def invoke(openai_api_key, rag_option, prompt):
157
  if (openai_api_key == ""):
158
  raise gr.Error("OpenAI API Key is required.")
@@ -199,14 +160,25 @@ def invoke(openai_api_key, rag_option, prompt):
199
  llm_output = completion.llm_output
200
  except Exception as e:
201
  err_msg = e
 
202
  raise gr.Error(e)
203
  finally:
204
  end_time_ms = round(time.time() * 1000)
205
 
206
- wandb_trace(rag_option, prompt, completion, result, generation_info, llm_output, chain, err_msg, start_time_ms, end_time_ms)
 
 
 
 
 
 
 
 
 
207
  return result
208
 
209
  gr.close_all()
 
210
  demo = gr.Interface(fn=invoke,
211
  inputs = [gr.Textbox(label = "OpenAI API Key", type = "password", lines = 1),
212
  gr.Radio([RAG_OFF, RAG_CHROMA, RAG_MONGODB], label = "Retrieval Augmented Generation", value = RAG_OFF),
@@ -215,4 +187,5 @@ demo = gr.Interface(fn=invoke,
215
  outputs = [gr.Textbox(label = "Completion", lines = 1)],
216
  title = "Generative AI - LLM & RAG",
217
  description = os.environ["DESCRIPTION"])
 
218
  demo.launch()
 
1
  import gradio as gr
2
+ import openai, os, time
3
 
4
  from dotenv import load_dotenv, find_dotenv
5
  from langchain.chains import LLMChain, RetrievalQA
 
14
  from langchain.vectorstores import Chroma
15
  from langchain.vectorstores import MongoDBAtlasVectorSearch
16
  from pymongo import MongoClient
17
+
18
  from trace import wandb_trace
 
19
 
20
  _ = load_dotenv(find_dotenv())
21
 
 
36
  LLM_CHAIN_PROMPT = PromptTemplate(input_variables = ["question"], template = os.environ["LLM_TEMPLATE"])
37
  RAG_CHAIN_PROMPT = PromptTemplate(input_variables = ["context", "question"], template = os.environ["RAG_TEMPLATE"])
38
 
 
 
39
  RAG_OFF = "Off"
40
  RAG_CHROMA = "Chroma"
41
  RAG_MONGODB = "MongoDB"
 
114
  completion = rag_chain({"query": prompt})
115
  return completion, rag_chain
116
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
117
  def invoke(openai_api_key, rag_option, prompt):
118
  if (openai_api_key == ""):
119
  raise gr.Error("OpenAI API Key is required.")
 
160
  llm_output = completion.llm_output
161
  except Exception as e:
162
  err_msg = e
163
+
164
  raise gr.Error(e)
165
  finally:
166
  end_time_ms = round(time.time() * 1000)
167
 
168
+ wandb_trace(rag_option,
169
+ prompt,
170
+ completion,
171
+ result,
172
+ generation_info,
173
+ llm_output,
174
+ chain,
175
+ err_msg,
176
+ start_time_ms,
177
+ end_time_ms)
178
  return result
179
 
180
  gr.close_all()
181
+
182
  demo = gr.Interface(fn=invoke,
183
  inputs = [gr.Textbox(label = "OpenAI API Key", type = "password", lines = 1),
184
  gr.Radio([RAG_OFF, RAG_CHROMA, RAG_MONGODB], label = "Retrieval Augmented Generation", value = RAG_OFF),
 
187
  outputs = [gr.Textbox(label = "Completion", lines = 1)],
188
  title = "Generative AI - LLM & RAG",
189
  description = os.environ["DESCRIPTION"])
190
+
191
  demo.launch()