Update app.py
Browse files
app.py
CHANGED
@@ -1,45 +1,32 @@
|
|
1 |
import gradio as gr
|
2 |
import openai, os, time, wandb
|
3 |
|
|
|
4 |
from langchain.chains import LLMChain, RetrievalQA
|
5 |
from langchain.chat_models import ChatOpenAI
|
6 |
from langchain.document_loaders import PyPDFLoader, WebBaseLoader
|
7 |
from langchain.document_loaders.blob_loaders.youtube_audio import YoutubeAudioLoader
|
8 |
from langchain.document_loaders.generic import GenericLoader
|
9 |
from langchain.document_loaders.parsers import OpenAIWhisperParser
|
10 |
-
|
11 |
from langchain.embeddings.openai import OpenAIEmbeddings
|
12 |
from langchain.prompts import PromptTemplate
|
13 |
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
14 |
from langchain.vectorstores import Chroma
|
15 |
from langchain.vectorstores import MongoDBAtlasVectorSearch
|
16 |
-
|
17 |
from pymongo import MongoClient
|
18 |
-
|
19 |
from wandb.sdk.data_types.trace_tree import Trace
|
20 |
|
21 |
-
from dotenv import load_dotenv, find_dotenv
|
22 |
_ = load_dotenv(find_dotenv())
|
23 |
|
24 |
WANDB_API_KEY = os.environ["WANDB_API_KEY"]
|
25 |
|
26 |
-
|
27 |
-
|
|
|
28 |
MONGODB_DB_NAME = "langchain_db"
|
29 |
MONGODB_COLLECTION_NAME = "gpt-4"
|
30 |
-
MONGODB_COLLECTION = client[MONGODB_DB_NAME][MONGODB_COLLECTION_NAME]
|
31 |
MONGODB_INDEX_NAME = "default"
|
32 |
|
33 |
-
description = os.environ["DESCRIPTION"]
|
34 |
-
|
35 |
-
config = {
|
36 |
-
"chunk_overlap": 150,
|
37 |
-
"chunk_size": 1500,
|
38 |
-
"k": 3,
|
39 |
-
"model_name": "gpt-4",
|
40 |
-
"temperature": 0,
|
41 |
-
}
|
42 |
-
|
43 |
template = """If you don't know the answer, just say that you don't know, don't try to make up an answer. Keep the answer as concise as possible. Always say "Thanks for using the 🧠 app - Bernd" at the end of the answer. """
|
44 |
|
45 |
llm_template = "Answer the question at the end. " + template + "Question: {question} Helpful Answer: "
|
@@ -61,52 +48,68 @@ RAG_OFF = "Off"
|
|
61 |
RAG_CHROMA = "Chroma"
|
62 |
RAG_MONGODB = "MongoDB"
|
63 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
64 |
def document_loading_splitting():
|
65 |
# Document loading
|
66 |
docs = []
|
|
|
67 |
# Load PDF
|
68 |
loader = PyPDFLoader(PDF_URL)
|
69 |
docs.extend(loader.load())
|
|
|
70 |
# Load Web
|
71 |
loader = WebBaseLoader(WEB_URL)
|
72 |
docs.extend(loader.load())
|
|
|
73 |
# Load YouTube
|
74 |
loader = GenericLoader(YoutubeAudioLoader([YOUTUBE_URL_1,
|
75 |
YOUTUBE_URL_2,
|
76 |
YOUTUBE_URL_3], YOUTUBE_DIR),
|
77 |
OpenAIWhisperParser())
|
78 |
docs.extend(loader.load())
|
|
|
79 |
# Document splitting
|
80 |
text_splitter = RecursiveCharacterTextSplitter(chunk_overlap = config["chunk_overlap"],
|
81 |
chunk_size = config["chunk_size"])
|
82 |
-
|
83 |
-
|
|
|
84 |
|
85 |
-
def document_storage_chroma(
|
86 |
-
Chroma.from_documents(documents =
|
87 |
embedding = OpenAIEmbeddings(disallowed_special = ()),
|
88 |
persist_directory = CHROMA_DIR)
|
89 |
|
90 |
-
def document_storage_mongodb(
|
91 |
-
MongoDBAtlasVectorSearch.from_documents(documents =
|
92 |
embedding = OpenAIEmbeddings(disallowed_special = ()),
|
93 |
-
collection =
|
94 |
index_name = MONGODB_INDEX_NAME)
|
95 |
|
96 |
def document_retrieval_chroma(llm, prompt):
|
97 |
-
|
98 |
-
|
99 |
-
return db
|
100 |
|
101 |
def document_retrieval_mongodb(llm, prompt):
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
return db
|
107 |
|
108 |
def llm_chain(llm, prompt):
|
109 |
-
llm_chain = LLMChain(llm = llm,
|
|
|
|
|
110 |
completion = llm_chain.generate([{"question": prompt}])
|
111 |
return completion, llm_chain
|
112 |
|
@@ -127,18 +130,17 @@ def wandb_trace(rag_option, prompt, completion, result, generation_info, llm_out
|
|
127 |
name = "" if (chain == None) else type(chain).__name__,
|
128 |
status_code = "success" if (str(err_msg) == "") else "error",
|
129 |
status_message = str(err_msg),
|
130 |
-
metadata = {
|
131 |
-
|
132 |
-
|
133 |
-
} if (str(err_msg) == "") else {},
|
134 |
inputs = {"rag_option": rag_option,
|
135 |
"prompt": prompt,
|
136 |
-
|
137 |
outputs = {"result": result,
|
138 |
"generation_info": str(generation_info),
|
139 |
"llm_output": str(llm_output),
|
140 |
"completion": str(completion),
|
141 |
-
|
142 |
model_dict = {"client": (str(chain.llm.client) if (rag_option == RAG_OFF) else
|
143 |
str(chain.combine_documents_chain.llm_chain.llm.client)),
|
144 |
"model_name": (str(chain.llm.model_name) if (rag_option == RAG_OFF) else
|
@@ -148,7 +150,7 @@ def wandb_trace(rag_option, prompt, completion, result, generation_info, llm_out
|
|
148 |
"prompt": (str(chain.prompt) if (rag_option == RAG_OFF) else
|
149 |
str(chain.combine_documents_chain.llm_chain.prompt)),
|
150 |
"retriever": ("" if (rag_option == RAG_OFF) else str(chain.retriever)),
|
151 |
-
|
152 |
start_time_ms = start_time_ms,
|
153 |
end_time_ms = end_time_ms
|
154 |
)
|
@@ -173,6 +175,7 @@ def invoke(openai_api_key, rag_option, prompt):
|
|
173 |
|
174 |
try:
|
175 |
start_time_ms = round(time.time() * 1000)
|
|
|
176 |
llm = ChatOpenAI(model_name = config["model_name"],
|
177 |
openai_api_key = openai_api_key,
|
178 |
temperature = config["temperature"])
|
@@ -180,26 +183,31 @@ def invoke(openai_api_key, rag_option, prompt):
|
|
180 |
if (rag_option == RAG_CHROMA):
|
181 |
#splits = document_loading_splitting()
|
182 |
#document_storage_chroma(splits)
|
|
|
183 |
db = document_retrieval_chroma(llm, prompt)
|
184 |
completion, chain = rag_chain(llm, prompt, db)
|
185 |
result = completion["result"]
|
186 |
elif (rag_option == RAG_MONGODB):
|
187 |
#splits = document_loading_splitting()
|
188 |
#document_storage_mongodb(splits)
|
|
|
189 |
db = document_retrieval_mongodb(llm, prompt)
|
190 |
completion, chain = rag_chain(llm, prompt, db)
|
191 |
result = completion["result"]
|
192 |
else:
|
193 |
completion, chain = llm_chain(llm, prompt)
|
|
|
194 |
if (completion.generations[0] != None and completion.generations[0][0] != None):
|
195 |
result = completion.generations[0][0].text
|
196 |
generation_info = completion.generations[0][0].generation_info
|
|
|
197 |
llm_output = completion.llm_output
|
198 |
except Exception as e:
|
199 |
err_msg = e
|
200 |
raise gr.Error(e)
|
201 |
finally:
|
202 |
end_time_ms = round(time.time() * 1000)
|
|
|
203 |
wandb_trace(rag_option, prompt, completion, result, generation_info, llm_output, chain, err_msg, start_time_ms, end_time_ms)
|
204 |
return result
|
205 |
|
@@ -207,8 +215,9 @@ gr.close_all()
|
|
207 |
demo = gr.Interface(fn=invoke,
|
208 |
inputs = [gr.Textbox(label = "OpenAI API Key", type = "password", lines = 1),
|
209 |
gr.Radio([RAG_OFF, RAG_CHROMA, RAG_MONGODB], label = "Retrieval Augmented Generation", value = RAG_OFF),
|
210 |
-
gr.Textbox(label = "Prompt", value = "What are GPT-4's media capabilities in 5 emojis and 1 sentence?", lines = 1)
|
|
|
211 |
outputs = [gr.Textbox(label = "Completion", lines = 1)],
|
212 |
title = "Generative AI - LLM & RAG",
|
213 |
-
description =
|
214 |
demo.launch()
|
|
|
1 |
import gradio as gr
|
2 |
import openai, os, time, wandb
|
3 |
|
4 |
+
from dotenv import load_dotenv, find_dotenv
|
5 |
from langchain.chains import LLMChain, RetrievalQA
|
6 |
from langchain.chat_models import ChatOpenAI
|
7 |
from langchain.document_loaders import PyPDFLoader, WebBaseLoader
|
8 |
from langchain.document_loaders.blob_loaders.youtube_audio import YoutubeAudioLoader
|
9 |
from langchain.document_loaders.generic import GenericLoader
|
10 |
from langchain.document_loaders.parsers import OpenAIWhisperParser
|
|
|
11 |
from langchain.embeddings.openai import OpenAIEmbeddings
|
12 |
from langchain.prompts import PromptTemplate
|
13 |
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
14 |
from langchain.vectorstores import Chroma
|
15 |
from langchain.vectorstores import MongoDBAtlasVectorSearch
|
|
|
16 |
from pymongo import MongoClient
|
|
|
17 |
from wandb.sdk.data_types.trace_tree import Trace
|
18 |
|
|
|
19 |
_ = load_dotenv(find_dotenv())
|
20 |
|
21 |
WANDB_API_KEY = os.environ["WANDB_API_KEY"]
|
22 |
|
23 |
+
DESCRIPTION = os.environ["DESCRIPTION"]
|
24 |
+
|
25 |
+
MONGODB_ATLAS_CLUSTER_URI = os.environ["MONGODB_ATLAS_CLUSTER_URI"]
|
26 |
MONGODB_DB_NAME = "langchain_db"
|
27 |
MONGODB_COLLECTION_NAME = "gpt-4"
|
|
|
28 |
MONGODB_INDEX_NAME = "default"
|
29 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
30 |
template = """If you don't know the answer, just say that you don't know, don't try to make up an answer. Keep the answer as concise as possible. Always say "Thanks for using the 🧠 app - Bernd" at the end of the answer. """
|
31 |
|
32 |
llm_template = "Answer the question at the end. " + template + "Question: {question} Helpful Answer: "
|
|
|
48 |
RAG_CHROMA = "Chroma"
|
49 |
RAG_MONGODB = "MongoDB"
|
50 |
|
51 |
+
client = MongoClient(MONGODB_ATLAS_CLUSTER_URI)
|
52 |
+
collection = client[MONGODB_DB_NAME][MONGODB_COLLECTION_NAME]
|
53 |
+
|
54 |
+
config = {
|
55 |
+
"chunk_overlap": 150,
|
56 |
+
"chunk_size": 1500,
|
57 |
+
"k": 3,
|
58 |
+
"model_name": "gpt-4",
|
59 |
+
"temperature": 0,
|
60 |
+
}
|
61 |
+
|
62 |
def document_loading_splitting():
|
63 |
# Document loading
|
64 |
docs = []
|
65 |
+
|
66 |
# Load PDF
|
67 |
loader = PyPDFLoader(PDF_URL)
|
68 |
docs.extend(loader.load())
|
69 |
+
|
70 |
# Load Web
|
71 |
loader = WebBaseLoader(WEB_URL)
|
72 |
docs.extend(loader.load())
|
73 |
+
|
74 |
# Load YouTube
|
75 |
loader = GenericLoader(YoutubeAudioLoader([YOUTUBE_URL_1,
|
76 |
YOUTUBE_URL_2,
|
77 |
YOUTUBE_URL_3], YOUTUBE_DIR),
|
78 |
OpenAIWhisperParser())
|
79 |
docs.extend(loader.load())
|
80 |
+
|
81 |
# Document splitting
|
82 |
text_splitter = RecursiveCharacterTextSplitter(chunk_overlap = config["chunk_overlap"],
|
83 |
chunk_size = config["chunk_size"])
|
84 |
+
split_documents = text_splitter.split_documents(docs)
|
85 |
+
|
86 |
+
return split_documents
|
87 |
|
88 |
+
def document_storage_chroma(documents):
|
89 |
+
Chroma.from_documents(documents = documents,
|
90 |
embedding = OpenAIEmbeddings(disallowed_special = ()),
|
91 |
persist_directory = CHROMA_DIR)
|
92 |
|
93 |
+
def document_storage_mongodb(documents):
|
94 |
+
MongoDBAtlasVectorSearch.from_documents(documents = documents,
|
95 |
embedding = OpenAIEmbeddings(disallowed_special = ()),
|
96 |
+
collection = collection,
|
97 |
index_name = MONGODB_INDEX_NAME)
|
98 |
|
99 |
def document_retrieval_chroma(llm, prompt):
|
100 |
+
return Chroma(embedding_function = OpenAIEmbeddings(),
|
101 |
+
persist_directory = CHROMA_DIR)
|
|
|
102 |
|
103 |
def document_retrieval_mongodb(llm, prompt):
|
104 |
+
return MongoDBAtlasVectorSearch.from_connection_string(MONGODB_ATLAS_CLUSTER_URI,
|
105 |
+
MONGODB_DB_NAME + "." + MONGODB_COLLECTION_NAME,
|
106 |
+
OpenAIEmbeddings(disallowed_special = ()),
|
107 |
+
index_name = MONGODB_INDEX_NAME)
|
|
|
108 |
|
109 |
def llm_chain(llm, prompt):
|
110 |
+
llm_chain = LLMChain(llm = llm,
|
111 |
+
prompt = LLM_CHAIN_PROMPT,
|
112 |
+
verbose = False)
|
113 |
completion = llm_chain.generate([{"question": prompt}])
|
114 |
return completion, llm_chain
|
115 |
|
|
|
130 |
name = "" if (chain == None) else type(chain).__name__,
|
131 |
status_code = "success" if (str(err_msg) == "") else "error",
|
132 |
status_message = str(err_msg),
|
133 |
+
metadata = {"chunk_overlap": "" if (rag_option == RAG_OFF) else config["chunk_overlap"],
|
134 |
+
"chunk_size": "" if (rag_option == RAG_OFF) else config["chunk_size"],
|
135 |
+
} if (str(err_msg) == "") else {},
|
|
|
136 |
inputs = {"rag_option": rag_option,
|
137 |
"prompt": prompt,
|
138 |
+
} if (str(err_msg) == "") else {},
|
139 |
outputs = {"result": result,
|
140 |
"generation_info": str(generation_info),
|
141 |
"llm_output": str(llm_output),
|
142 |
"completion": str(completion),
|
143 |
+
} if (str(err_msg) == "") else {},
|
144 |
model_dict = {"client": (str(chain.llm.client) if (rag_option == RAG_OFF) else
|
145 |
str(chain.combine_documents_chain.llm_chain.llm.client)),
|
146 |
"model_name": (str(chain.llm.model_name) if (rag_option == RAG_OFF) else
|
|
|
150 |
"prompt": (str(chain.prompt) if (rag_option == RAG_OFF) else
|
151 |
str(chain.combine_documents_chain.llm_chain.prompt)),
|
152 |
"retriever": ("" if (rag_option == RAG_OFF) else str(chain.retriever)),
|
153 |
+
} if (str(err_msg) == "") else {},
|
154 |
start_time_ms = start_time_ms,
|
155 |
end_time_ms = end_time_ms
|
156 |
)
|
|
|
175 |
|
176 |
try:
|
177 |
start_time_ms = round(time.time() * 1000)
|
178 |
+
|
179 |
llm = ChatOpenAI(model_name = config["model_name"],
|
180 |
openai_api_key = openai_api_key,
|
181 |
temperature = config["temperature"])
|
|
|
183 |
if (rag_option == RAG_CHROMA):
|
184 |
#splits = document_loading_splitting()
|
185 |
#document_storage_chroma(splits)
|
186 |
+
|
187 |
db = document_retrieval_chroma(llm, prompt)
|
188 |
completion, chain = rag_chain(llm, prompt, db)
|
189 |
result = completion["result"]
|
190 |
elif (rag_option == RAG_MONGODB):
|
191 |
#splits = document_loading_splitting()
|
192 |
#document_storage_mongodb(splits)
|
193 |
+
|
194 |
db = document_retrieval_mongodb(llm, prompt)
|
195 |
completion, chain = rag_chain(llm, prompt, db)
|
196 |
result = completion["result"]
|
197 |
else:
|
198 |
completion, chain = llm_chain(llm, prompt)
|
199 |
+
|
200 |
if (completion.generations[0] != None and completion.generations[0][0] != None):
|
201 |
result = completion.generations[0][0].text
|
202 |
generation_info = completion.generations[0][0].generation_info
|
203 |
+
|
204 |
llm_output = completion.llm_output
|
205 |
except Exception as e:
|
206 |
err_msg = e
|
207 |
raise gr.Error(e)
|
208 |
finally:
|
209 |
end_time_ms = round(time.time() * 1000)
|
210 |
+
|
211 |
wandb_trace(rag_option, prompt, completion, result, generation_info, llm_output, chain, err_msg, start_time_ms, end_time_ms)
|
212 |
return result
|
213 |
|
|
|
215 |
demo = gr.Interface(fn=invoke,
|
216 |
inputs = [gr.Textbox(label = "OpenAI API Key", type = "password", lines = 1),
|
217 |
gr.Radio([RAG_OFF, RAG_CHROMA, RAG_MONGODB], label = "Retrieval Augmented Generation", value = RAG_OFF),
|
218 |
+
gr.Textbox(label = "Prompt", value = "What are GPT-4's media capabilities in 5 emojis and 1 sentence?", lines = 1),
|
219 |
+
],
|
220 |
outputs = [gr.Textbox(label = "Completion", lines = 1)],
|
221 |
title = "Generative AI - LLM & RAG",
|
222 |
+
description = DESCRIPTION)
|
223 |
demo.launch()
|