bstraehle commited on
Commit
99bbf81
1 Parent(s): 8484d1a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +51 -42
app.py CHANGED
@@ -1,45 +1,32 @@
1
  import gradio as gr
2
  import openai, os, time, wandb
3
 
 
4
  from langchain.chains import LLMChain, RetrievalQA
5
  from langchain.chat_models import ChatOpenAI
6
  from langchain.document_loaders import PyPDFLoader, WebBaseLoader
7
  from langchain.document_loaders.blob_loaders.youtube_audio import YoutubeAudioLoader
8
  from langchain.document_loaders.generic import GenericLoader
9
  from langchain.document_loaders.parsers import OpenAIWhisperParser
10
-
11
  from langchain.embeddings.openai import OpenAIEmbeddings
12
  from langchain.prompts import PromptTemplate
13
  from langchain.text_splitter import RecursiveCharacterTextSplitter
14
  from langchain.vectorstores import Chroma
15
  from langchain.vectorstores import MongoDBAtlasVectorSearch
16
-
17
  from pymongo import MongoClient
18
-
19
  from wandb.sdk.data_types.trace_tree import Trace
20
 
21
- from dotenv import load_dotenv, find_dotenv
22
  _ = load_dotenv(find_dotenv())
23
 
24
  WANDB_API_KEY = os.environ["WANDB_API_KEY"]
25
 
26
- MONGODB_URI = os.environ["MONGODB_ATLAS_CLUSTER_URI"]
27
- client = MongoClient(MONGODB_URI)
 
28
  MONGODB_DB_NAME = "langchain_db"
29
  MONGODB_COLLECTION_NAME = "gpt-4"
30
- MONGODB_COLLECTION = client[MONGODB_DB_NAME][MONGODB_COLLECTION_NAME]
31
  MONGODB_INDEX_NAME = "default"
32
 
33
- description = os.environ["DESCRIPTION"]
34
-
35
- config = {
36
- "chunk_overlap": 150,
37
- "chunk_size": 1500,
38
- "k": 3,
39
- "model_name": "gpt-4",
40
- "temperature": 0,
41
- }
42
-
43
  template = """If you don't know the answer, just say that you don't know, don't try to make up an answer. Keep the answer as concise as possible. Always say "Thanks for using the 🧠 app - Bernd" at the end of the answer. """
44
 
45
  llm_template = "Answer the question at the end. " + template + "Question: {question} Helpful Answer: "
@@ -61,52 +48,68 @@ RAG_OFF = "Off"
61
  RAG_CHROMA = "Chroma"
62
  RAG_MONGODB = "MongoDB"
63
 
 
 
 
 
 
 
 
 
 
 
 
64
  def document_loading_splitting():
65
  # Document loading
66
  docs = []
 
67
  # Load PDF
68
  loader = PyPDFLoader(PDF_URL)
69
  docs.extend(loader.load())
 
70
  # Load Web
71
  loader = WebBaseLoader(WEB_URL)
72
  docs.extend(loader.load())
 
73
  # Load YouTube
74
  loader = GenericLoader(YoutubeAudioLoader([YOUTUBE_URL_1,
75
  YOUTUBE_URL_2,
76
  YOUTUBE_URL_3], YOUTUBE_DIR),
77
  OpenAIWhisperParser())
78
  docs.extend(loader.load())
 
79
  # Document splitting
80
  text_splitter = RecursiveCharacterTextSplitter(chunk_overlap = config["chunk_overlap"],
81
  chunk_size = config["chunk_size"])
82
- splits = text_splitter.split_documents(docs)
83
- return splits
 
84
 
85
- def document_storage_chroma(splits):
86
- Chroma.from_documents(documents = splits,
87
  embedding = OpenAIEmbeddings(disallowed_special = ()),
88
  persist_directory = CHROMA_DIR)
89
 
90
- def document_storage_mongodb(splits):
91
- MongoDBAtlasVectorSearch.from_documents(documents = splits,
92
  embedding = OpenAIEmbeddings(disallowed_special = ()),
93
- collection = MONGODB_COLLECTION,
94
  index_name = MONGODB_INDEX_NAME)
95
 
96
  def document_retrieval_chroma(llm, prompt):
97
- db = Chroma(embedding_function = OpenAIEmbeddings(),
98
- persist_directory = CHROMA_DIR)
99
- return db
100
 
101
  def document_retrieval_mongodb(llm, prompt):
102
- db = MongoDBAtlasVectorSearch.from_connection_string(MONGODB_URI,
103
- MONGODB_DB_NAME + "." + MONGODB_COLLECTION_NAME,
104
- OpenAIEmbeddings(disallowed_special = ()),
105
- index_name = MONGODB_INDEX_NAME)
106
- return db
107
 
108
  def llm_chain(llm, prompt):
109
- llm_chain = LLMChain(llm = llm, prompt = LLM_CHAIN_PROMPT, verbose = False)
 
 
110
  completion = llm_chain.generate([{"question": prompt}])
111
  return completion, llm_chain
112
 
@@ -127,18 +130,17 @@ def wandb_trace(rag_option, prompt, completion, result, generation_info, llm_out
127
  name = "" if (chain == None) else type(chain).__name__,
128
  status_code = "success" if (str(err_msg) == "") else "error",
129
  status_message = str(err_msg),
130
- metadata = {
131
- "chunk_overlap": config["chunk_overlap"] if (rag_option != RAG_OFF) else "",
132
- "chunk_size": config["chunk_size"] if (rag_option != RAG_OFF) else "",
133
- } if (str(err_msg) == "") else {},
134
  inputs = {"rag_option": rag_option,
135
  "prompt": prompt,
136
- } if (str(err_msg) == "") else {},
137
  outputs = {"result": result,
138
  "generation_info": str(generation_info),
139
  "llm_output": str(llm_output),
140
  "completion": str(completion),
141
- } if (str(err_msg) == "") else {},
142
  model_dict = {"client": (str(chain.llm.client) if (rag_option == RAG_OFF) else
143
  str(chain.combine_documents_chain.llm_chain.llm.client)),
144
  "model_name": (str(chain.llm.model_name) if (rag_option == RAG_OFF) else
@@ -148,7 +150,7 @@ def wandb_trace(rag_option, prompt, completion, result, generation_info, llm_out
148
  "prompt": (str(chain.prompt) if (rag_option == RAG_OFF) else
149
  str(chain.combine_documents_chain.llm_chain.prompt)),
150
  "retriever": ("" if (rag_option == RAG_OFF) else str(chain.retriever)),
151
- } if (str(err_msg) == "") else {},
152
  start_time_ms = start_time_ms,
153
  end_time_ms = end_time_ms
154
  )
@@ -173,6 +175,7 @@ def invoke(openai_api_key, rag_option, prompt):
173
 
174
  try:
175
  start_time_ms = round(time.time() * 1000)
 
176
  llm = ChatOpenAI(model_name = config["model_name"],
177
  openai_api_key = openai_api_key,
178
  temperature = config["temperature"])
@@ -180,26 +183,31 @@ def invoke(openai_api_key, rag_option, prompt):
180
  if (rag_option == RAG_CHROMA):
181
  #splits = document_loading_splitting()
182
  #document_storage_chroma(splits)
 
183
  db = document_retrieval_chroma(llm, prompt)
184
  completion, chain = rag_chain(llm, prompt, db)
185
  result = completion["result"]
186
  elif (rag_option == RAG_MONGODB):
187
  #splits = document_loading_splitting()
188
  #document_storage_mongodb(splits)
 
189
  db = document_retrieval_mongodb(llm, prompt)
190
  completion, chain = rag_chain(llm, prompt, db)
191
  result = completion["result"]
192
  else:
193
  completion, chain = llm_chain(llm, prompt)
 
194
  if (completion.generations[0] != None and completion.generations[0][0] != None):
195
  result = completion.generations[0][0].text
196
  generation_info = completion.generations[0][0].generation_info
 
197
  llm_output = completion.llm_output
198
  except Exception as e:
199
  err_msg = e
200
  raise gr.Error(e)
201
  finally:
202
  end_time_ms = round(time.time() * 1000)
 
203
  wandb_trace(rag_option, prompt, completion, result, generation_info, llm_output, chain, err_msg, start_time_ms, end_time_ms)
204
  return result
205
 
@@ -207,8 +215,9 @@ gr.close_all()
207
  demo = gr.Interface(fn=invoke,
208
  inputs = [gr.Textbox(label = "OpenAI API Key", type = "password", lines = 1),
209
  gr.Radio([RAG_OFF, RAG_CHROMA, RAG_MONGODB], label = "Retrieval Augmented Generation", value = RAG_OFF),
210
- gr.Textbox(label = "Prompt", value = "What are GPT-4's media capabilities in 5 emojis and 1 sentence?", lines = 1)],
 
211
  outputs = [gr.Textbox(label = "Completion", lines = 1)],
212
  title = "Generative AI - LLM & RAG",
213
- description = description)
214
  demo.launch()
 
1
  import gradio as gr
2
  import openai, os, time, wandb
3
 
4
+ from dotenv import load_dotenv, find_dotenv
5
  from langchain.chains import LLMChain, RetrievalQA
6
  from langchain.chat_models import ChatOpenAI
7
  from langchain.document_loaders import PyPDFLoader, WebBaseLoader
8
  from langchain.document_loaders.blob_loaders.youtube_audio import YoutubeAudioLoader
9
  from langchain.document_loaders.generic import GenericLoader
10
  from langchain.document_loaders.parsers import OpenAIWhisperParser
 
11
  from langchain.embeddings.openai import OpenAIEmbeddings
12
  from langchain.prompts import PromptTemplate
13
  from langchain.text_splitter import RecursiveCharacterTextSplitter
14
  from langchain.vectorstores import Chroma
15
  from langchain.vectorstores import MongoDBAtlasVectorSearch
 
16
  from pymongo import MongoClient
 
17
  from wandb.sdk.data_types.trace_tree import Trace
18
 
 
19
  _ = load_dotenv(find_dotenv())
20
 
21
  WANDB_API_KEY = os.environ["WANDB_API_KEY"]
22
 
23
+ DESCRIPTION = os.environ["DESCRIPTION"]
24
+
25
+ MONGODB_ATLAS_CLUSTER_URI = os.environ["MONGODB_ATLAS_CLUSTER_URI"]
26
  MONGODB_DB_NAME = "langchain_db"
27
  MONGODB_COLLECTION_NAME = "gpt-4"
 
28
  MONGODB_INDEX_NAME = "default"
29
 
 
 
 
 
 
 
 
 
 
 
30
  template = """If you don't know the answer, just say that you don't know, don't try to make up an answer. Keep the answer as concise as possible. Always say "Thanks for using the 🧠 app - Bernd" at the end of the answer. """
31
 
32
  llm_template = "Answer the question at the end. " + template + "Question: {question} Helpful Answer: "
 
48
  RAG_CHROMA = "Chroma"
49
  RAG_MONGODB = "MongoDB"
50
 
51
+ client = MongoClient(MONGODB_ATLAS_CLUSTER_URI)
52
+ collection = client[MONGODB_DB_NAME][MONGODB_COLLECTION_NAME]
53
+
54
+ config = {
55
+ "chunk_overlap": 150,
56
+ "chunk_size": 1500,
57
+ "k": 3,
58
+ "model_name": "gpt-4",
59
+ "temperature": 0,
60
+ }
61
+
62
  def document_loading_splitting():
63
  # Document loading
64
  docs = []
65
+
66
  # Load PDF
67
  loader = PyPDFLoader(PDF_URL)
68
  docs.extend(loader.load())
69
+
70
  # Load Web
71
  loader = WebBaseLoader(WEB_URL)
72
  docs.extend(loader.load())
73
+
74
  # Load YouTube
75
  loader = GenericLoader(YoutubeAudioLoader([YOUTUBE_URL_1,
76
  YOUTUBE_URL_2,
77
  YOUTUBE_URL_3], YOUTUBE_DIR),
78
  OpenAIWhisperParser())
79
  docs.extend(loader.load())
80
+
81
  # Document splitting
82
  text_splitter = RecursiveCharacterTextSplitter(chunk_overlap = config["chunk_overlap"],
83
  chunk_size = config["chunk_size"])
84
+ split_documents = text_splitter.split_documents(docs)
85
+
86
+ return split_documents
87
 
88
+ def document_storage_chroma(documents):
89
+ Chroma.from_documents(documents = documents,
90
  embedding = OpenAIEmbeddings(disallowed_special = ()),
91
  persist_directory = CHROMA_DIR)
92
 
93
+ def document_storage_mongodb(documents):
94
+ MongoDBAtlasVectorSearch.from_documents(documents = documents,
95
  embedding = OpenAIEmbeddings(disallowed_special = ()),
96
+ collection = collection,
97
  index_name = MONGODB_INDEX_NAME)
98
 
99
  def document_retrieval_chroma(llm, prompt):
100
+ return Chroma(embedding_function = OpenAIEmbeddings(),
101
+ persist_directory = CHROMA_DIR)
 
102
 
103
  def document_retrieval_mongodb(llm, prompt):
104
+ return MongoDBAtlasVectorSearch.from_connection_string(MONGODB_ATLAS_CLUSTER_URI,
105
+ MONGODB_DB_NAME + "." + MONGODB_COLLECTION_NAME,
106
+ OpenAIEmbeddings(disallowed_special = ()),
107
+ index_name = MONGODB_INDEX_NAME)
 
108
 
109
  def llm_chain(llm, prompt):
110
+ llm_chain = LLMChain(llm = llm,
111
+ prompt = LLM_CHAIN_PROMPT,
112
+ verbose = False)
113
  completion = llm_chain.generate([{"question": prompt}])
114
  return completion, llm_chain
115
 
 
130
  name = "" if (chain == None) else type(chain).__name__,
131
  status_code = "success" if (str(err_msg) == "") else "error",
132
  status_message = str(err_msg),
133
+ metadata = {"chunk_overlap": "" if (rag_option == RAG_OFF) else config["chunk_overlap"],
134
+ "chunk_size": "" if (rag_option == RAG_OFF) else config["chunk_size"],
135
+ } if (str(err_msg) == "") else {},
 
136
  inputs = {"rag_option": rag_option,
137
  "prompt": prompt,
138
+ } if (str(err_msg) == "") else {},
139
  outputs = {"result": result,
140
  "generation_info": str(generation_info),
141
  "llm_output": str(llm_output),
142
  "completion": str(completion),
143
+ } if (str(err_msg) == "") else {},
144
  model_dict = {"client": (str(chain.llm.client) if (rag_option == RAG_OFF) else
145
  str(chain.combine_documents_chain.llm_chain.llm.client)),
146
  "model_name": (str(chain.llm.model_name) if (rag_option == RAG_OFF) else
 
150
  "prompt": (str(chain.prompt) if (rag_option == RAG_OFF) else
151
  str(chain.combine_documents_chain.llm_chain.prompt)),
152
  "retriever": ("" if (rag_option == RAG_OFF) else str(chain.retriever)),
153
+ } if (str(err_msg) == "") else {},
154
  start_time_ms = start_time_ms,
155
  end_time_ms = end_time_ms
156
  )
 
175
 
176
  try:
177
  start_time_ms = round(time.time() * 1000)
178
+
179
  llm = ChatOpenAI(model_name = config["model_name"],
180
  openai_api_key = openai_api_key,
181
  temperature = config["temperature"])
 
183
  if (rag_option == RAG_CHROMA):
184
  #splits = document_loading_splitting()
185
  #document_storage_chroma(splits)
186
+
187
  db = document_retrieval_chroma(llm, prompt)
188
  completion, chain = rag_chain(llm, prompt, db)
189
  result = completion["result"]
190
  elif (rag_option == RAG_MONGODB):
191
  #splits = document_loading_splitting()
192
  #document_storage_mongodb(splits)
193
+
194
  db = document_retrieval_mongodb(llm, prompt)
195
  completion, chain = rag_chain(llm, prompt, db)
196
  result = completion["result"]
197
  else:
198
  completion, chain = llm_chain(llm, prompt)
199
+
200
  if (completion.generations[0] != None and completion.generations[0][0] != None):
201
  result = completion.generations[0][0].text
202
  generation_info = completion.generations[0][0].generation_info
203
+
204
  llm_output = completion.llm_output
205
  except Exception as e:
206
  err_msg = e
207
  raise gr.Error(e)
208
  finally:
209
  end_time_ms = round(time.time() * 1000)
210
+
211
  wandb_trace(rag_option, prompt, completion, result, generation_info, llm_output, chain, err_msg, start_time_ms, end_time_ms)
212
  return result
213
 
 
215
  demo = gr.Interface(fn=invoke,
216
  inputs = [gr.Textbox(label = "OpenAI API Key", type = "password", lines = 1),
217
  gr.Radio([RAG_OFF, RAG_CHROMA, RAG_MONGODB], label = "Retrieval Augmented Generation", value = RAG_OFF),
218
+ gr.Textbox(label = "Prompt", value = "What are GPT-4's media capabilities in 5 emojis and 1 sentence?", lines = 1),
219
+ ],
220
  outputs = [gr.Textbox(label = "Completion", lines = 1)],
221
  title = "Generative AI - LLM & RAG",
222
+ description = DESCRIPTION)
223
  demo.launch()