Update app.py
Browse files
app.py
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
import gradio as gr
|
2 |
-
import
|
3 |
|
4 |
from langchain.chains import LLMChain, RetrievalQA
|
5 |
from langchain.chat_models import ChatOpenAI
|
@@ -57,6 +57,10 @@ YOUTUBE_URL_1 = "https://www.youtube.com/watch?v=--khbXchTeE"
|
|
57 |
YOUTUBE_URL_2 = "https://www.youtube.com/watch?v=hdhZwyf24mE"
|
58 |
YOUTUBE_URL_3 = "https://www.youtube.com/watch?v=vw-KWfKwvTQ"
|
59 |
|
|
|
|
|
|
|
|
|
60 |
def document_loading_splitting():
|
61 |
# Document loading
|
62 |
docs = []
|
@@ -116,28 +120,28 @@ def rag_chain(llm, prompt, db):
|
|
116 |
return completion, rag_chain
|
117 |
|
118 |
def wandb_trace(rag_option, prompt, completion, chain, status_msg, start_time_ms, end_time_ms):
|
119 |
-
if (rag_option ==
|
120 |
result = completion
|
121 |
else:
|
122 |
result = completion["result"]
|
123 |
-
docs_meta =
|
124 |
wandb.init(project = "openai-llm-rag")
|
125 |
trace = Trace(
|
126 |
kind = "chain",
|
127 |
name = type(chain).__name__ if (chain != None) else "",
|
128 |
-
status_code = "
|
129 |
status_message = str(status_msg),
|
130 |
metadata = {
|
131 |
-
"chunk_overlap": "" if (rag_option ==
|
132 |
-
"chunk_size": "" if (rag_option ==
|
133 |
-
"k": "" if (rag_option ==
|
134 |
"model": config["model"],
|
135 |
"temperature": config["temperature"],
|
136 |
},
|
137 |
inputs = {"rag_option": rag_option if (str(status_msg) == "") else "",
|
138 |
"prompt": str(prompt if (str(status_msg) == "") else ""),
|
139 |
-
"prompt_template": str((llm_template if (rag_option ==
|
140 |
-
"docs_meta": "" if (rag_option ==
|
141 |
outputs = {"result": result},
|
142 |
start_time_ms = start_time_ms,
|
143 |
end_time_ms = end_time_ms
|
@@ -161,13 +165,13 @@ def invoke(openai_api_key, rag_option, prompt):
|
|
161 |
llm = ChatOpenAI(model_name = config["model"],
|
162 |
openai_api_key = openai_api_key,
|
163 |
temperature = config["temperature"])
|
164 |
-
if (rag_option ==
|
165 |
#splits = document_loading_splitting()
|
166 |
#document_storage_chroma(splits)
|
167 |
db = document_retrieval_chroma(llm, prompt)
|
168 |
completion, chain = rag_chain(llm, prompt, db)
|
169 |
result = completion["result"]
|
170 |
-
elif (rag_option ==
|
171 |
#splits = document_loading_splitting()
|
172 |
#document_storage_mongodb(splits)
|
173 |
db = document_retrieval_mongodb(llm, prompt)
|
@@ -187,7 +191,7 @@ def invoke(openai_api_key, rag_option, prompt):
|
|
187 |
gr.close_all()
|
188 |
demo = gr.Interface(fn=invoke,
|
189 |
inputs = [gr.Textbox(label = "OpenAI API Key", value = "sk-", lines = 1),
|
190 |
-
gr.Radio([
|
191 |
gr.Textbox(label = "Prompt", value = "What is GPT-4?", lines = 1)],
|
192 |
outputs = [gr.Textbox(label = "Completion", lines = 1)],
|
193 |
title = "Generative AI - LLM & RAG",
|
|
|
1 |
import gradio as gr
|
2 |
+
import openai, os, time, wandb
|
3 |
|
4 |
from langchain.chains import LLMChain, RetrievalQA
|
5 |
from langchain.chat_models import ChatOpenAI
|
|
|
57 |
YOUTUBE_URL_2 = "https://www.youtube.com/watch?v=hdhZwyf24mE"
|
58 |
YOUTUBE_URL_3 = "https://www.youtube.com/watch?v=vw-KWfKwvTQ"
|
59 |
|
60 |
+
RAG_OFF = "Off"
|
61 |
+
RAG_CHROMA = "Chroma"
|
62 |
+
RAG_MONGODB = "MongoDB"
|
63 |
+
|
64 |
def document_loading_splitting():
|
65 |
# Document loading
|
66 |
docs = []
|
|
|
120 |
return completion, rag_chain
|
121 |
|
122 |
def wandb_trace(rag_option, prompt, completion, chain, status_msg, start_time_ms, end_time_ms):
|
123 |
+
if (rag_option == RAG_OFF or str(status_msg) != ""):
|
124 |
result = completion
|
125 |
else:
|
126 |
result = completion["result"]
|
127 |
+
docs_meta = str([doc.metadata for doc in completion["source_documents"]])
|
128 |
wandb.init(project = "openai-llm-rag")
|
129 |
trace = Trace(
|
130 |
kind = "chain",
|
131 |
name = type(chain).__name__ if (chain != None) else "",
|
132 |
+
status_code = "success" if (str(status_msg) == "") else "error",
|
133 |
status_message = str(status_msg),
|
134 |
metadata = {
|
135 |
+
"chunk_overlap": "" if (rag_option == RAG_OFF) else config["chunk_overlap"],
|
136 |
+
"chunk_size": "" if (rag_option == RAG_OFF) else config["chunk_size"],
|
137 |
+
"k": "" if (rag_option == RAG_OFF) else config["k"],
|
138 |
"model": config["model"],
|
139 |
"temperature": config["temperature"],
|
140 |
},
|
141 |
inputs = {"rag_option": rag_option if (str(status_msg) == "") else "",
|
142 |
"prompt": str(prompt if (str(status_msg) == "") else ""),
|
143 |
+
"prompt_template": str((llm_template if (rag_option == RAG_OFF) else rag_template) if (str(status_msg) == "") else ""),
|
144 |
+
"docs_meta": "" if (rag_option == RAG_OFF or str(status_msg) != "") else docs_meta},
|
145 |
outputs = {"result": result},
|
146 |
start_time_ms = start_time_ms,
|
147 |
end_time_ms = end_time_ms
|
|
|
165 |
llm = ChatOpenAI(model_name = config["model"],
|
166 |
openai_api_key = openai_api_key,
|
167 |
temperature = config["temperature"])
|
168 |
+
if (rag_option == RAG_CHROMA):
|
169 |
#splits = document_loading_splitting()
|
170 |
#document_storage_chroma(splits)
|
171 |
db = document_retrieval_chroma(llm, prompt)
|
172 |
completion, chain = rag_chain(llm, prompt, db)
|
173 |
result = completion["result"]
|
174 |
+
elif (rag_option == RAG_MONGODB):
|
175 |
#splits = document_loading_splitting()
|
176 |
#document_storage_mongodb(splits)
|
177 |
db = document_retrieval_mongodb(llm, prompt)
|
|
|
191 |
gr.close_all()
|
192 |
demo = gr.Interface(fn=invoke,
|
193 |
inputs = [gr.Textbox(label = "OpenAI API Key", value = "sk-", lines = 1),
|
194 |
+
gr.Radio([RAG_OFF, RAG_CHROMA, RAG_MONGODB], label="Retrieval Augmented Generation", value = RAG_OFF),
|
195 |
gr.Textbox(label = "Prompt", value = "What is GPT-4?", lines = 1)],
|
196 |
outputs = [gr.Textbox(label = "Completion", lines = 1)],
|
197 |
title = "Generative AI - LLM & RAG",
|