bstraehle commited on
Commit
7ddcfd9
·
1 Parent(s): eb978fe

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +35 -164
app.py CHANGED
@@ -1,157 +1,27 @@
1
  import gradio as gr
2
- import openai, os, time, wandb
3
 
4
  from dotenv import load_dotenv, find_dotenv
5
- from langchain.chains import LLMChain, RetrievalQA
6
- from langchain.chat_models import ChatOpenAI
7
- from langchain.document_loaders import PyPDFLoader, WebBaseLoader
8
- from langchain.document_loaders.blob_loaders.youtube_audio import YoutubeAudioLoader
9
- from langchain.document_loaders.generic import GenericLoader
10
- from langchain.document_loaders.parsers import OpenAIWhisperParser
11
- from langchain.embeddings.openai import OpenAIEmbeddings
12
- from langchain.prompts import PromptTemplate
13
- from langchain.text_splitter import RecursiveCharacterTextSplitter
14
- from langchain.vectorstores import Chroma
15
- from langchain.vectorstores import MongoDBAtlasVectorSearch
16
- from pymongo import MongoClient
17
- from wandb.sdk.data_types.trace_tree import Trace
18
 
19
- _ = load_dotenv(find_dotenv())
20
-
21
- PDF_URL = "https://arxiv.org/pdf/2303.08774.pdf"
22
- WEB_URL = "https://openai.com/research/gpt-4"
23
- YOUTUBE_URL_1 = "https://www.youtube.com/watch?v=--khbXchTeE"
24
- YOUTUBE_URL_2 = "https://www.youtube.com/watch?v=hdhZwyf24mE"
25
- YOUTUBE_URL_3 = "https://www.youtube.com/watch?v=vw-KWfKwvTQ"
26
-
27
- YOUTUBE_DIR = "/data/youtube"
28
- CHROMA_DIR = "/data/chroma"
29
 
30
- MONGODB_ATLAS_CLUSTER_URI = os.environ["MONGODB_ATLAS_CLUSTER_URI"]
31
- MONGODB_DB_NAME = "langchain_db"
32
- MONGODB_COLLECTION_NAME = "gpt-4"
33
- MONGODB_INDEX_NAME = "default"
34
 
35
- LLM_CHAIN_PROMPT = PromptTemplate(input_variables = ["question"], template = os.environ["LLM_TEMPLATE"])
36
- RAG_CHAIN_PROMPT = PromptTemplate(input_variables = ["context", "question"], template = os.environ["RAG_TEMPLATE"])
37
 
38
- WANDB_API_KEY = os.environ["WANDB_API_KEY"]
 
 
 
 
 
 
39
 
40
  RAG_OFF = "Off"
41
  RAG_CHROMA = "Chroma"
42
  RAG_MONGODB = "MongoDB"
43
 
44
- client = MongoClient(MONGODB_ATLAS_CLUSTER_URI)
45
- collection = client[MONGODB_DB_NAME][MONGODB_COLLECTION_NAME]
46
-
47
- config = {
48
- "chunk_overlap": 150,
49
- "chunk_size": 1500,
50
- "k": 3,
51
- "model_name": "gpt-4-0613",
52
- "temperature": 0,
53
- }
54
-
55
- def document_loading_splitting():
56
- # Document loading
57
- docs = []
58
-
59
- # Load PDF
60
- loader = PyPDFLoader(PDF_URL)
61
- docs.extend(loader.load())
62
-
63
- # Load Web
64
- loader = WebBaseLoader(WEB_URL)
65
- docs.extend(loader.load())
66
-
67
- # Load YouTube
68
- loader = GenericLoader(YoutubeAudioLoader([YOUTUBE_URL_1,
69
- YOUTUBE_URL_2,
70
- YOUTUBE_URL_3], YOUTUBE_DIR),
71
- OpenAIWhisperParser())
72
- docs.extend(loader.load())
73
-
74
- # Document splitting
75
- text_splitter = RecursiveCharacterTextSplitter(chunk_overlap = config["chunk_overlap"],
76
- chunk_size = config["chunk_size"])
77
- split_documents = text_splitter.split_documents(docs)
78
-
79
- return split_documents
80
-
81
- def document_storage_chroma(documents):
82
- Chroma.from_documents(documents = documents,
83
- embedding = OpenAIEmbeddings(disallowed_special = ()),
84
- persist_directory = CHROMA_DIR)
85
-
86
- def document_storage_mongodb(documents):
87
- MongoDBAtlasVectorSearch.from_documents(documents = documents,
88
- embedding = OpenAIEmbeddings(disallowed_special = ()),
89
- collection = collection,
90
- index_name = MONGODB_INDEX_NAME)
91
-
92
- def document_retrieval_chroma(llm, prompt):
93
- return Chroma(embedding_function = OpenAIEmbeddings(),
94
- persist_directory = CHROMA_DIR)
95
-
96
- def document_retrieval_mongodb(llm, prompt):
97
- return MongoDBAtlasVectorSearch.from_connection_string(MONGODB_ATLAS_CLUSTER_URI,
98
- MONGODB_DB_NAME + "." + MONGODB_COLLECTION_NAME,
99
- OpenAIEmbeddings(disallowed_special = ()),
100
- index_name = MONGODB_INDEX_NAME)
101
-
102
- def llm_chain(llm, prompt):
103
- llm_chain = LLMChain(llm = llm,
104
- prompt = LLM_CHAIN_PROMPT,
105
- verbose = False)
106
- completion = llm_chain.generate([{"question": prompt}])
107
- return completion, llm_chain
108
-
109
- def rag_chain(llm, prompt, db):
110
- rag_chain = RetrievalQA.from_chain_type(llm,
111
- chain_type_kwargs = {"prompt": RAG_CHAIN_PROMPT},
112
- retriever = db.as_retriever(search_kwargs = {"k": config["k"]}),
113
- return_source_documents = True,
114
- verbose = False)
115
- completion = rag_chain({"query": prompt})
116
- return completion, rag_chain
117
-
118
- def wandb_trace(rag_option, prompt, completion, result, generation_info, llm_output, chain, err_msg, start_time_ms, end_time_ms):
119
- wandb.init(project = "openai-llm-rag")
120
-
121
- trace = Trace(
122
- kind = "chain",
123
- name = "" if (chain == None) else type(chain).__name__,
124
- status_code = "success" if (str(err_msg) == "") else "error",
125
- status_message = str(err_msg),
126
- metadata = {"chunk_overlap": "" if (rag_option == RAG_OFF) else config["chunk_overlap"],
127
- "chunk_size": "" if (rag_option == RAG_OFF) else config["chunk_size"],
128
- } if (str(err_msg) == "") else {},
129
- inputs = {"rag_option": rag_option,
130
- "prompt": prompt,
131
- "chain_prompt": (str(chain.prompt) if (rag_option == RAG_OFF) else
132
- str(chain.combine_documents_chain.llm_chain.prompt)),
133
- "source_documents": "" if (rag_option == RAG_OFF) else str([doc.metadata["source"] for doc in completion["source_documents"]]),
134
- } if (str(err_msg) == "") else {},
135
- outputs = {"result": result,
136
- "generation_info": str(generation_info),
137
- "llm_output": str(llm_output),
138
- "completion": str(completion),
139
- } if (str(err_msg) == "") else {},
140
- model_dict = {"client": (str(chain.llm.client) if (rag_option == RAG_OFF) else
141
- str(chain.combine_documents_chain.llm_chain.llm.client)),
142
- "model_name": (str(chain.llm.model_name) if (rag_option == RAG_OFF) else
143
- str(chain.combine_documents_chain.llm_chain.llm.model_name)),
144
- "temperature": (str(chain.llm.temperature) if (rag_option == RAG_OFF) else
145
- str(chain.combine_documents_chain.llm_chain.llm.temperature)),
146
- "retriever": ("" if (rag_option == RAG_OFF) else str(chain.retriever)),
147
- } if (str(err_msg) == "") else {},
148
- start_time_ms = start_time_ms,
149
- end_time_ms = end_time_ms
150
- )
151
-
152
- trace.log("evaluation")
153
- wandb.finish()
154
-
155
  def invoke(openai_api_key, rag_option, prompt):
156
  if (openai_api_key == ""):
157
  raise gr.Error("OpenAI API Key is required.")
@@ -159,6 +29,9 @@ def invoke(openai_api_key, rag_option, prompt):
159
  raise gr.Error("Retrieval Augmented Generation is required.")
160
  if (prompt == ""):
161
  raise gr.Error("Prompt is required.")
 
 
 
162
 
163
  chain = None
164
  completion = ""
@@ -170,42 +43,39 @@ def invoke(openai_api_key, rag_option, prompt):
170
  try:
171
  start_time_ms = round(time.time() * 1000)
172
 
173
- llm = ChatOpenAI(model_name = config["model_name"],
174
- openai_api_key = openai_api_key,
175
- temperature = config["temperature"])
176
-
177
- if (rag_option == RAG_CHROMA):
178
- #splits = document_loading_splitting()
179
- #document_storage_chroma(splits)
180
-
181
- db = document_retrieval_chroma(llm, prompt)
182
- completion, chain = rag_chain(llm, prompt, db)
183
- result = completion["result"]
184
- elif (rag_option == RAG_MONGODB):
185
- #splits = document_loading_splitting()
186
- #document_storage_mongodb(splits)
187
-
188
- db = document_retrieval_mongodb(llm, prompt)
189
- completion, chain = rag_chain(llm, prompt, db)
190
- result = completion["result"]
191
- else:
192
- completion, chain = llm_chain(llm, prompt)
193
 
194
  if (completion.generations[0] != None and completion.generations[0][0] != None):
195
  result = completion.generations[0][0].text
196
  generation_info = completion.generations[0][0].generation_info
197
 
198
  llm_output = completion.llm_output
 
 
 
199
  except Exception as e:
200
  err_msg = e
 
201
  raise gr.Error(e)
202
  finally:
203
  end_time_ms = round(time.time() * 1000)
204
 
205
- wandb_trace(rag_option, prompt, completion, result, generation_info, llm_output, chain, err_msg, start_time_ms, end_time_ms)
 
 
 
 
 
 
 
 
 
 
206
  return result
207
 
208
  gr.close_all()
 
209
  demo = gr.Interface(fn=invoke,
210
  inputs = [gr.Textbox(label = "OpenAI API Key", type = "password", lines = 1),
211
  gr.Radio([RAG_OFF, RAG_CHROMA, RAG_MONGODB], label = "Retrieval Augmented Generation", value = RAG_OFF),
@@ -214,4 +84,5 @@ demo = gr.Interface(fn=invoke,
214
  outputs = [gr.Textbox(label = "Completion", lines = 1)],
215
  title = "Generative AI - LLM & RAG",
216
  description = os.environ["DESCRIPTION"])
217
- demo.launch()
 
 
1
  import gradio as gr
2
+ import os, time
3
 
4
  from dotenv import load_dotenv, find_dotenv
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
+ from rag import llm_chain, rag_chain, rag_batch
7
+ from trace import wandb_trace
 
 
 
 
 
 
 
 
8
 
9
+ _ = load_dotenv(find_dotenv())
 
 
 
10
 
11
+ RAG_BATCH = False # document loading, splitting, storage
 
12
 
13
+ config = {
14
+ "chunk_overlap": 150, # document splitting
15
+ "chunk_size": 1500, # document splitting
16
+ "k": 3, # document retrieval
17
+ "model_name": "gpt-4-0314", # llm
18
+ "temperature": 0, # llm
19
+ }
20
 
21
  RAG_OFF = "Off"
22
  RAG_CHROMA = "Chroma"
23
  RAG_MONGODB = "MongoDB"
24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  def invoke(openai_api_key, rag_option, prompt):
26
  if (openai_api_key == ""):
27
  raise gr.Error("OpenAI API Key is required.")
 
29
  raise gr.Error("Retrieval Augmented Generation is required.")
30
  if (prompt == ""):
31
  raise gr.Error("Prompt is required.")
32
+
33
+ if (RAG_BATCH):
34
+ rag_batch(config)
35
 
36
  chain = None
37
  completion = ""
 
43
  try:
44
  start_time_ms = round(time.time() * 1000)
45
 
46
+ if (rag_option == RAG_OFF):
47
+ completion, chain = llm_chain(config, openai_api_key, prompt)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
 
49
  if (completion.generations[0] != None and completion.generations[0][0] != None):
50
  result = completion.generations[0][0].text
51
  generation_info = completion.generations[0][0].generation_info
52
 
53
  llm_output = completion.llm_output
54
+ else:
55
+ completion, chain = rag_chain(config, openai_api_key, rag_option, prompt)
56
+ result = completion["result"]
57
  except Exception as e:
58
  err_msg = e
59
+
60
  raise gr.Error(e)
61
  finally:
62
  end_time_ms = round(time.time() * 1000)
63
 
64
+ wandb_trace(config,
65
+ rag_option == RAG_OFF,
66
+ prompt,
67
+ completion,
68
+ result,
69
+ generation_info,
70
+ llm_output,
71
+ chain,
72
+ err_msg,
73
+ start_time_ms,
74
+ end_time_ms)
75
  return result
76
 
77
  gr.close_all()
78
+
79
  demo = gr.Interface(fn=invoke,
80
  inputs = [gr.Textbox(label = "OpenAI API Key", type = "password", lines = 1),
81
  gr.Radio([RAG_OFF, RAG_CHROMA, RAG_MONGODB], label = "Retrieval Augmented Generation", value = RAG_OFF),
 
84
  outputs = [gr.Textbox(label = "Completion", lines = 1)],
85
  title = "Generative AI - LLM & RAG",
86
  description = os.environ["DESCRIPTION"])
87
+
88
+ demo.launch()