bstraehle commited on
Commit
eb978fe
1 Parent(s): 96012de

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +164 -35
app.py CHANGED
@@ -1,27 +1,157 @@
1
  import gradio as gr
2
- import os, time
3
 
4
  from dotenv import load_dotenv, find_dotenv
5
-
6
- from rag import llm_chain, rag_chain, rag_batch
7
- from trace import wandb_trace
 
 
 
 
 
 
 
 
 
 
8
 
9
  _ = load_dotenv(find_dotenv())
10
 
11
- RAG_BATCH = False # document loading, splitting, storage
 
 
 
 
12
 
13
- config = {
14
- "chunk_overlap": 150, # document splitting
15
- "chunk_size": 1500, # document splitting
16
- "k": 3, # document retrieval
17
- "model_name": "gpt-4-0314", # llm
18
- "temperature": 0, # llm
19
- }
 
 
 
 
 
20
 
21
  RAG_OFF = "Off"
22
  RAG_CHROMA = "Chroma"
23
  RAG_MONGODB = "MongoDB"
24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  def invoke(openai_api_key, rag_option, prompt):
26
  if (openai_api_key == ""):
27
  raise gr.Error("OpenAI API Key is required.")
@@ -29,9 +159,6 @@ def invoke(openai_api_key, rag_option, prompt):
29
  raise gr.Error("Retrieval Augmented Generation is required.")
30
  if (prompt == ""):
31
  raise gr.Error("Prompt is required.")
32
-
33
- if (RAG_BATCH):
34
- rag_batch(config)
35
 
36
  chain = None
37
  completion = ""
@@ -43,39 +170,42 @@ def invoke(openai_api_key, rag_option, prompt):
43
  try:
44
  start_time_ms = round(time.time() * 1000)
45
 
46
- if (rag_option == RAG_OFF):
47
- completion, chain = llm_chain(config, openai_api_key, prompt)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
 
49
  if (completion.generations[0] != None and completion.generations[0][0] != None):
50
  result = completion.generations[0][0].text
51
  generation_info = completion.generations[0][0].generation_info
52
 
53
  llm_output = completion.llm_output
54
- else:
55
- completion, chain = rag_chain(config, openai_api_key, rag_option, prompt)
56
- result = completion["result"]
57
  except Exception as e:
58
  err_msg = e
59
-
60
  raise gr.Error(e)
61
  finally:
62
  end_time_ms = round(time.time() * 1000)
63
 
64
- wandb_trace(config,
65
- rag_option == RAG_OFF,
66
- prompt,
67
- completion,
68
- result,
69
- generation_info,
70
- llm_output,
71
- chain,
72
- err_msg,
73
- start_time_ms,
74
- end_time_ms)
75
  return result
76
 
77
  gr.close_all()
78
-
79
  demo = gr.Interface(fn=invoke,
80
  inputs = [gr.Textbox(label = "OpenAI API Key", type = "password", lines = 1),
81
  gr.Radio([RAG_OFF, RAG_CHROMA, RAG_MONGODB], label = "Retrieval Augmented Generation", value = RAG_OFF),
@@ -84,5 +214,4 @@ demo = gr.Interface(fn=invoke,
84
  outputs = [gr.Textbox(label = "Completion", lines = 1)],
85
  title = "Generative AI - LLM & RAG",
86
  description = os.environ["DESCRIPTION"])
87
-
88
- demo.launch()
 
1
  import gradio as gr
2
+ import openai, os, time, wandb
3
 
4
  from dotenv import load_dotenv, find_dotenv
5
+ from langchain.chains import LLMChain, RetrievalQA
6
+ from langchain.chat_models import ChatOpenAI
7
+ from langchain.document_loaders import PyPDFLoader, WebBaseLoader
8
+ from langchain.document_loaders.blob_loaders.youtube_audio import YoutubeAudioLoader
9
+ from langchain.document_loaders.generic import GenericLoader
10
+ from langchain.document_loaders.parsers import OpenAIWhisperParser
11
+ from langchain.embeddings.openai import OpenAIEmbeddings
12
+ from langchain.prompts import PromptTemplate
13
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
14
+ from langchain.vectorstores import Chroma
15
+ from langchain.vectorstores import MongoDBAtlasVectorSearch
16
+ from pymongo import MongoClient
17
+ from wandb.sdk.data_types.trace_tree import Trace
18
 
19
  _ = load_dotenv(find_dotenv())
20
 
21
+ PDF_URL = "https://arxiv.org/pdf/2303.08774.pdf"
22
+ WEB_URL = "https://openai.com/research/gpt-4"
23
+ YOUTUBE_URL_1 = "https://www.youtube.com/watch?v=--khbXchTeE"
24
+ YOUTUBE_URL_2 = "https://www.youtube.com/watch?v=hdhZwyf24mE"
25
+ YOUTUBE_URL_3 = "https://www.youtube.com/watch?v=vw-KWfKwvTQ"
26
 
27
+ YOUTUBE_DIR = "/data/youtube"
28
+ CHROMA_DIR = "/data/chroma"
29
+
30
+ MONGODB_ATLAS_CLUSTER_URI = os.environ["MONGODB_ATLAS_CLUSTER_URI"]
31
+ MONGODB_DB_NAME = "langchain_db"
32
+ MONGODB_COLLECTION_NAME = "gpt-4"
33
+ MONGODB_INDEX_NAME = "default"
34
+
35
+ LLM_CHAIN_PROMPT = PromptTemplate(input_variables = ["question"], template = os.environ["LLM_TEMPLATE"])
36
+ RAG_CHAIN_PROMPT = PromptTemplate(input_variables = ["context", "question"], template = os.environ["RAG_TEMPLATE"])
37
+
38
+ WANDB_API_KEY = os.environ["WANDB_API_KEY"]
39
 
40
  RAG_OFF = "Off"
41
  RAG_CHROMA = "Chroma"
42
  RAG_MONGODB = "MongoDB"
43
 
44
+ client = MongoClient(MONGODB_ATLAS_CLUSTER_URI)
45
+ collection = client[MONGODB_DB_NAME][MONGODB_COLLECTION_NAME]
46
+
47
+ config = {
48
+ "chunk_overlap": 150,
49
+ "chunk_size": 1500,
50
+ "k": 3,
51
+ "model_name": "gpt-4-0613",
52
+ "temperature": 0,
53
+ }
54
+
55
+ def document_loading_splitting():
56
+ # Document loading
57
+ docs = []
58
+
59
+ # Load PDF
60
+ loader = PyPDFLoader(PDF_URL)
61
+ docs.extend(loader.load())
62
+
63
+ # Load Web
64
+ loader = WebBaseLoader(WEB_URL)
65
+ docs.extend(loader.load())
66
+
67
+ # Load YouTube
68
+ loader = GenericLoader(YoutubeAudioLoader([YOUTUBE_URL_1,
69
+ YOUTUBE_URL_2,
70
+ YOUTUBE_URL_3], YOUTUBE_DIR),
71
+ OpenAIWhisperParser())
72
+ docs.extend(loader.load())
73
+
74
+ # Document splitting
75
+ text_splitter = RecursiveCharacterTextSplitter(chunk_overlap = config["chunk_overlap"],
76
+ chunk_size = config["chunk_size"])
77
+ split_documents = text_splitter.split_documents(docs)
78
+
79
+ return split_documents
80
+
81
+ def document_storage_chroma(documents):
82
+ Chroma.from_documents(documents = documents,
83
+ embedding = OpenAIEmbeddings(disallowed_special = ()),
84
+ persist_directory = CHROMA_DIR)
85
+
86
+ def document_storage_mongodb(documents):
87
+ MongoDBAtlasVectorSearch.from_documents(documents = documents,
88
+ embedding = OpenAIEmbeddings(disallowed_special = ()),
89
+ collection = collection,
90
+ index_name = MONGODB_INDEX_NAME)
91
+
92
+ def document_retrieval_chroma(llm, prompt):
93
+ return Chroma(embedding_function = OpenAIEmbeddings(),
94
+ persist_directory = CHROMA_DIR)
95
+
96
+ def document_retrieval_mongodb(llm, prompt):
97
+ return MongoDBAtlasVectorSearch.from_connection_string(MONGODB_ATLAS_CLUSTER_URI,
98
+ MONGODB_DB_NAME + "." + MONGODB_COLLECTION_NAME,
99
+ OpenAIEmbeddings(disallowed_special = ()),
100
+ index_name = MONGODB_INDEX_NAME)
101
+
102
+ def llm_chain(llm, prompt):
103
+ llm_chain = LLMChain(llm = llm,
104
+ prompt = LLM_CHAIN_PROMPT,
105
+ verbose = False)
106
+ completion = llm_chain.generate([{"question": prompt}])
107
+ return completion, llm_chain
108
+
109
+ def rag_chain(llm, prompt, db):
110
+ rag_chain = RetrievalQA.from_chain_type(llm,
111
+ chain_type_kwargs = {"prompt": RAG_CHAIN_PROMPT},
112
+ retriever = db.as_retriever(search_kwargs = {"k": config["k"]}),
113
+ return_source_documents = True,
114
+ verbose = False)
115
+ completion = rag_chain({"query": prompt})
116
+ return completion, rag_chain
117
+
118
+ def wandb_trace(rag_option, prompt, completion, result, generation_info, llm_output, chain, err_msg, start_time_ms, end_time_ms):
119
+ wandb.init(project = "openai-llm-rag")
120
+
121
+ trace = Trace(
122
+ kind = "chain",
123
+ name = "" if (chain == None) else type(chain).__name__,
124
+ status_code = "success" if (str(err_msg) == "") else "error",
125
+ status_message = str(err_msg),
126
+ metadata = {"chunk_overlap": "" if (rag_option == RAG_OFF) else config["chunk_overlap"],
127
+ "chunk_size": "" if (rag_option == RAG_OFF) else config["chunk_size"],
128
+ } if (str(err_msg) == "") else {},
129
+ inputs = {"rag_option": rag_option,
130
+ "prompt": prompt,
131
+ "chain_prompt": (str(chain.prompt) if (rag_option == RAG_OFF) else
132
+ str(chain.combine_documents_chain.llm_chain.prompt)),
133
+ "source_documents": "" if (rag_option == RAG_OFF) else str([doc.metadata["source"] for doc in completion["source_documents"]]),
134
+ } if (str(err_msg) == "") else {},
135
+ outputs = {"result": result,
136
+ "generation_info": str(generation_info),
137
+ "llm_output": str(llm_output),
138
+ "completion": str(completion),
139
+ } if (str(err_msg) == "") else {},
140
+ model_dict = {"client": (str(chain.llm.client) if (rag_option == RAG_OFF) else
141
+ str(chain.combine_documents_chain.llm_chain.llm.client)),
142
+ "model_name": (str(chain.llm.model_name) if (rag_option == RAG_OFF) else
143
+ str(chain.combine_documents_chain.llm_chain.llm.model_name)),
144
+ "temperature": (str(chain.llm.temperature) if (rag_option == RAG_OFF) else
145
+ str(chain.combine_documents_chain.llm_chain.llm.temperature)),
146
+ "retriever": ("" if (rag_option == RAG_OFF) else str(chain.retriever)),
147
+ } if (str(err_msg) == "") else {},
148
+ start_time_ms = start_time_ms,
149
+ end_time_ms = end_time_ms
150
+ )
151
+
152
+ trace.log("evaluation")
153
+ wandb.finish()
154
+
155
  def invoke(openai_api_key, rag_option, prompt):
156
  if (openai_api_key == ""):
157
  raise gr.Error("OpenAI API Key is required.")
 
159
  raise gr.Error("Retrieval Augmented Generation is required.")
160
  if (prompt == ""):
161
  raise gr.Error("Prompt is required.")
 
 
 
162
 
163
  chain = None
164
  completion = ""
 
170
  try:
171
  start_time_ms = round(time.time() * 1000)
172
 
173
+ llm = ChatOpenAI(model_name = config["model_name"],
174
+ openai_api_key = openai_api_key,
175
+ temperature = config["temperature"])
176
+
177
+ if (rag_option == RAG_CHROMA):
178
+ #splits = document_loading_splitting()
179
+ #document_storage_chroma(splits)
180
+
181
+ db = document_retrieval_chroma(llm, prompt)
182
+ completion, chain = rag_chain(llm, prompt, db)
183
+ result = completion["result"]
184
+ elif (rag_option == RAG_MONGODB):
185
+ #splits = document_loading_splitting()
186
+ #document_storage_mongodb(splits)
187
+
188
+ db = document_retrieval_mongodb(llm, prompt)
189
+ completion, chain = rag_chain(llm, prompt, db)
190
+ result = completion["result"]
191
+ else:
192
+ completion, chain = llm_chain(llm, prompt)
193
 
194
  if (completion.generations[0] != None and completion.generations[0][0] != None):
195
  result = completion.generations[0][0].text
196
  generation_info = completion.generations[0][0].generation_info
197
 
198
  llm_output = completion.llm_output
 
 
 
199
  except Exception as e:
200
  err_msg = e
 
201
  raise gr.Error(e)
202
  finally:
203
  end_time_ms = round(time.time() * 1000)
204
 
205
+ wandb_trace(rag_option, prompt, completion, result, generation_info, llm_output, chain, err_msg, start_time_ms, end_time_ms)
 
 
 
 
 
 
 
 
 
 
206
  return result
207
 
208
  gr.close_all()
 
209
  demo = gr.Interface(fn=invoke,
210
  inputs = [gr.Textbox(label = "OpenAI API Key", type = "password", lines = 1),
211
  gr.Radio([RAG_OFF, RAG_CHROMA, RAG_MONGODB], label = "Retrieval Augmented Generation", value = RAG_OFF),
 
214
  outputs = [gr.Textbox(label = "Completion", lines = 1)],
215
  title = "Generative AI - LLM & RAG",
216
  description = os.environ["DESCRIPTION"])
217
+ demo.launch()