bstraehle commited on
Commit
d693fc5
1 Parent(s): 5858e19

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -12
app.py CHANGED
@@ -1,5 +1,5 @@
1
  import gradio as gr
2
- import json, openai, os, time, wandb
3
 
4
  from langchain.chains import LLMChain, RetrievalQA
5
  from langchain.chat_models import ChatOpenAI
@@ -57,6 +57,10 @@ YOUTUBE_URL_1 = "https://www.youtube.com/watch?v=--khbXchTeE"
57
  YOUTUBE_URL_2 = "https://www.youtube.com/watch?v=hdhZwyf24mE"
58
  YOUTUBE_URL_3 = "https://www.youtube.com/watch?v=vw-KWfKwvTQ"
59
 
 
 
 
 
60
  def document_loading_splitting():
61
  # Document loading
62
  docs = []
@@ -116,28 +120,28 @@ def rag_chain(llm, prompt, db):
116
  return completion, rag_chain
117
 
118
  def wandb_trace(rag_option, prompt, completion, chain, status_msg, start_time_ms, end_time_ms):
119
- if (rag_option == "Off" or str(status_msg) != ""):
120
  result = completion
121
  else:
122
  result = completion["result"]
123
- docs_meta = json.loads([doc.metadata for doc in completion["source_documents"]])
124
  wandb.init(project = "openai-llm-rag")
125
  trace = Trace(
126
  kind = "chain",
127
  name = type(chain).__name__ if (chain != None) else "",
128
- status_code = "SUCCESS" if (str(status_msg) == "") else "ERROR",
129
  status_message = str(status_msg),
130
  metadata = {
131
- "chunk_overlap": "" if (rag_option == "Off") else config["chunk_overlap"],
132
- "chunk_size": "" if (rag_option == "Off") else config["chunk_size"],
133
- "k": "" if (rag_option == "Off") else config["k"],
134
  "model": config["model"],
135
  "temperature": config["temperature"],
136
  },
137
  inputs = {"rag_option": rag_option if (str(status_msg) == "") else "",
138
  "prompt": str(prompt if (str(status_msg) == "") else ""),
139
- "prompt_template": str((llm_template if (rag_option == "Off") else rag_template) if (str(status_msg) == "") else ""),
140
- "docs_meta": "" if (rag_option == "Off" or str(status_msg) != "") else docs_meta},
141
  outputs = {"result": result},
142
  start_time_ms = start_time_ms,
143
  end_time_ms = end_time_ms
@@ -161,13 +165,13 @@ def invoke(openai_api_key, rag_option, prompt):
161
  llm = ChatOpenAI(model_name = config["model"],
162
  openai_api_key = openai_api_key,
163
  temperature = config["temperature"])
164
- if (rag_option == "Chroma"):
165
  #splits = document_loading_splitting()
166
  #document_storage_chroma(splits)
167
  db = document_retrieval_chroma(llm, prompt)
168
  completion, chain = rag_chain(llm, prompt, db)
169
  result = completion["result"]
170
- elif (rag_option == "MongoDB"):
171
  #splits = document_loading_splitting()
172
  #document_storage_mongodb(splits)
173
  db = document_retrieval_mongodb(llm, prompt)
@@ -187,7 +191,7 @@ def invoke(openai_api_key, rag_option, prompt):
187
  gr.close_all()
188
  demo = gr.Interface(fn=invoke,
189
  inputs = [gr.Textbox(label = "OpenAI API Key", value = "sk-", lines = 1),
190
- gr.Radio(["Off", "Chroma", "MongoDB"], label="Retrieval Augmented Generation", value = "Off"),
191
  gr.Textbox(label = "Prompt", value = "What is GPT-4?", lines = 1)],
192
  outputs = [gr.Textbox(label = "Completion", lines = 1)],
193
  title = "Generative AI - LLM & RAG",
 
1
  import gradio as gr
2
+ import openai, os, time, wandb
3
 
4
  from langchain.chains import LLMChain, RetrievalQA
5
  from langchain.chat_models import ChatOpenAI
 
57
  YOUTUBE_URL_2 = "https://www.youtube.com/watch?v=hdhZwyf24mE"
58
  YOUTUBE_URL_3 = "https://www.youtube.com/watch?v=vw-KWfKwvTQ"
59
 
60
+ RAG_OFF = "Off"
61
+ RAG_CHROMA = "Chroma"
62
+ RAG_MONGODB = "MongoDB"
63
+
64
  def document_loading_splitting():
65
  # Document loading
66
  docs = []
 
120
  return completion, rag_chain
121
 
122
  def wandb_trace(rag_option, prompt, completion, chain, status_msg, start_time_ms, end_time_ms):
123
+ if (rag_option == RAG_OFF or str(status_msg) != ""):
124
  result = completion
125
  else:
126
  result = completion["result"]
127
+ docs_meta = str([doc.metadata for doc in completion["source_documents"]])
128
  wandb.init(project = "openai-llm-rag")
129
  trace = Trace(
130
  kind = "chain",
131
  name = type(chain).__name__ if (chain != None) else "",
132
+ status_code = "success" if (str(status_msg) == "") else "error",
133
  status_message = str(status_msg),
134
  metadata = {
135
+ "chunk_overlap": "" if (rag_option == RAG_OFF) else config["chunk_overlap"],
136
+ "chunk_size": "" if (rag_option == RAG_OFF) else config["chunk_size"],
137
+ "k": "" if (rag_option == RAG_OFF) else config["k"],
138
  "model": config["model"],
139
  "temperature": config["temperature"],
140
  },
141
  inputs = {"rag_option": rag_option if (str(status_msg) == "") else "",
142
  "prompt": str(prompt if (str(status_msg) == "") else ""),
143
+ "prompt_template": str((llm_template if (rag_option == RAG_OFF) else rag_template) if (str(status_msg) == "") else ""),
144
+ "docs_meta": "" if (rag_option == RAG_OFF or str(status_msg) != "") else docs_meta},
145
  outputs = {"result": result},
146
  start_time_ms = start_time_ms,
147
  end_time_ms = end_time_ms
 
165
  llm = ChatOpenAI(model_name = config["model"],
166
  openai_api_key = openai_api_key,
167
  temperature = config["temperature"])
168
+ if (rag_option == RAG_CHROMA):
169
  #splits = document_loading_splitting()
170
  #document_storage_chroma(splits)
171
  db = document_retrieval_chroma(llm, prompt)
172
  completion, chain = rag_chain(llm, prompt, db)
173
  result = completion["result"]
174
+ elif (rag_option == RAG_MONGODB):
175
  #splits = document_loading_splitting()
176
  #document_storage_mongodb(splits)
177
  db = document_retrieval_mongodb(llm, prompt)
 
191
  gr.close_all()
192
  demo = gr.Interface(fn=invoke,
193
  inputs = [gr.Textbox(label = "OpenAI API Key", value = "sk-", lines = 1),
194
+ gr.Radio([RAG_OFF, RAG_CHROMA, RAG_MONGODB], label="Retrieval Augmented Generation", value = RAG_OFF),
195
  gr.Textbox(label = "Prompt", value = "What is GPT-4?", lines = 1)],
196
  outputs = [gr.Textbox(label = "Completion", lines = 1)],
197
  title = "Generative AI - LLM & RAG",