Théo ALVES DA COSTA commited on
Commit
d4c1a74
1 Parent(s): cf1bde3

Updated to async + new storage for logs

Browse files
Files changed (3) hide show
  1. app.py +109 -65
  2. climateqa/engine/llm.py +1 -0
  3. climateqa/sample_questions.py +0 -1
app.py CHANGED
@@ -60,7 +60,7 @@ credential = {
60
  }
61
 
62
  account_url = os.environ["BLOB_ACCOUNT_URL"]
63
- file_share_name = "climategpt"
64
  service = ShareServiceClient(account_url=account_url, credential=credential)
65
  share_client = service.get_share_client(file_share_name)
66
 
@@ -104,7 +104,25 @@ def serialize_docs(docs):
104
  return new_docs
105
 
106
 
107
- def chat(query,history,audience,sources,reports):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
108
  """taking a query and a message history, use a pipeline (reformulation, retriever, answering) to yield a tuple of:
109
  (messages in gradio format, messages in langchain format, source documents)"""
110
 
@@ -124,7 +142,8 @@ def chat(query,history,audience,sources,reports):
124
  if len(reports) == 0:
125
  reports = []
126
 
127
- retriever = ClimateQARetriever(vectorstore=vectorstore,sources = sources,reports = reports,k_summary = 3,k_total = 10,threshold=0.7)
 
128
  rag_chain = make_rag_chain(retriever,llm)
129
 
130
  source_string = ""
@@ -144,45 +163,45 @@ def chat(query,history,audience,sources,reports):
144
  # memory.chat_memory.add_message(message)
145
 
146
  inputs = {"query": query,"audience": audience_prompt}
147
- # result = rag_chain.astream_log(inputs)
148
- result = rag_chain.stream(inputs)
149
 
150
  reformulated_question_path_id = "/logs/flatten_dict/final_output"
151
  retriever_path_id = "/logs/Retriever/final_output"
152
  streaming_output_path_id = "/logs/AzureChatOpenAI:2/streamed_output_str/-"
153
  final_output_path_id = "/streamed_output/-"
154
 
155
- docs_html = "No sources found for this question"
156
  output_query = ""
157
  output_language = ""
158
  gallery = []
159
 
160
- for output in result:
161
 
162
- if "language" in output:
163
- output_language = output["language"]
164
- if "question" in output:
165
- output_query = output["question"]
166
- if "docs" in output:
167
 
168
- try:
169
- docs = output['docs'] # List[Document]
170
- docs_html = []
171
- for i, d in enumerate(docs, 1):
172
- docs_html.append(make_html_source(d, i))
173
- docs_html = "".join(docs_html)
174
- except TypeError:
175
- print("No documents found")
176
- continue
177
 
178
- if "answer" in output:
179
- new_token = output["answer"] # str
180
- time.sleep(0.03)
181
- answer_yet = history[-1][1] + new_token
182
- answer_yet = parse_output_llm_with_sources(answer_yet)
183
- history[-1] = (query,answer_yet)
184
 
185
- yield history,docs_html,output_query,output_language,gallery
186
 
187
 
188
 
@@ -195,54 +214,54 @@ def chat(query,history,audience,sources,reports):
195
  # raise gr.Error(f"ClimateQ&A Error: {e}\nThe error has been noted, try another question and if the error remains, you can contact us :)")
196
 
197
 
198
- # async for op in fallback_iterator(result):
199
 
200
- # op = op.ops[0]
201
- # print("yo",op)
202
 
203
- # if op['path'] == reformulated_question_path_id: # reforulated question
204
- # output_language = op['value']["language"] # str
205
- # output_query = op["value"]["question"]
206
 
207
- # elif op['path'] == retriever_path_id: # documents
208
- # try:
209
- # docs = op['value']['documents'] # List[Document]
210
- # docs_html = []
211
- # for i, d in enumerate(docs, 1):
212
- # docs_html.append(make_html_source(d, i))
213
- # docs_html = "".join(docs_html)
214
- # except TypeError:
215
- # print("No documents found")
216
- # print("op: ",op)
217
- # continue
218
 
219
- # elif op['path'] == streaming_output_path_id: # final answer
220
- # new_token = op['value'] # str
221
- # time.sleep(0.03)
222
- # answer_yet = history[-1][1] + new_token
223
- # answer_yet = parse_output_llm_with_sources(answer_yet)
224
- # history[-1] = (query,answer_yet)
225
 
226
- # # elif op['path'] == final_output_path_id:
227
- # # final_output = op['value']
228
 
229
- # # if "answer" in final_output:
230
 
231
- # # final_output = final_output["answer"]
232
- # # print(final_output)
233
- # # answer = history[-1][1] + final_output
234
- # # answer = parse_output_llm_with_sources(answer)
235
- # # history[-1] = (query,answer)
236
 
237
- # else:
238
- # continue
239
 
240
- # history = [tuple(x) for x in history]
241
- # yield history,docs_html,output_query,output_language,gallery
242
 
243
 
244
  # Log answer on Azure Blob Storage
245
- if os.getenv("GRADIO_ENV") != "local":
246
  timestamp = str(datetime.now().timestamp())
247
  file = timestamp + ".json"
248
  prompt = history[-1][0]
@@ -269,6 +288,31 @@ def chat(query,history,audience,sources,reports):
269
  # memory.save_context(inputs, {"answer": gradio_format[-1][1]})
270
  # yield gradio_format, memory.load_memory_variables({})["history"], source_string
271
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
272
 
273
 
274
  def make_html_source(source,i):
@@ -701,4 +745,4 @@ Or around 2 to 4 times more than a typical Google search.
701
 
702
  demo.queue()
703
 
704
- demo.launch(max_threads = 8)
 
60
  }
61
 
62
  account_url = os.environ["BLOB_ACCOUNT_URL"]
63
+ file_share_name = "climateqa"
64
  service = ShareServiceClient(account_url=account_url, credential=credential)
65
  share_client = service.get_share_client(file_share_name)
66
 
 
104
  return new_docs
105
 
106
 
107
+ # import asyncio
108
+ # from typing import Any, Dict, List
109
+ # from langchain.callbacks.base import AsyncCallbackHandler, BaseCallbackHandler
110
+
111
+ # class MyCustomAsyncHandler(AsyncCallbackHandler):
112
+ # """Async callback handler that can be used to handle callbacks from langchain."""
113
+
114
+ # async def on_chain_start(
115
+ # self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
116
+ # ) -> Any:
117
+ # """Run when chain starts running."""
118
+ # print("zzzz....")
119
+ # await asyncio.sleep(3)
120
+ # print(f"on_chain_start {serialized['name']}")
121
+ # # raise gr.Error("ClimateQ&A Error: Timeout, try another question and if the error remains, you can contact us :)")
122
+
123
+
124
+
125
+ async def chat(query,history,audience,sources,reports):
126
  """taking a query and a message history, use a pipeline (reformulation, retriever, answering) to yield a tuple of:
127
  (messages in gradio format, messages in langchain format, source documents)"""
128
 
 
142
  if len(reports) == 0:
143
  reports = []
144
 
145
+
146
+ retriever = ClimateQARetriever(vectorstore=vectorstore,sources = sources,reports = reports,k_summary = 3,k_total = 10,threshold=0.5)
147
  rag_chain = make_rag_chain(retriever,llm)
148
 
149
  source_string = ""
 
163
  # memory.chat_memory.add_message(message)
164
 
165
  inputs = {"query": query,"audience": audience_prompt}
166
+ result = rag_chain.astream_log(inputs) #{"callbacks":[MyCustomAsyncHandler()]})
167
+ # result = rag_chain.stream(inputs)
168
 
169
  reformulated_question_path_id = "/logs/flatten_dict/final_output"
170
  retriever_path_id = "/logs/Retriever/final_output"
171
  streaming_output_path_id = "/logs/AzureChatOpenAI:2/streamed_output_str/-"
172
  final_output_path_id = "/streamed_output/-"
173
 
174
+ docs_html = ""
175
  output_query = ""
176
  output_language = ""
177
  gallery = []
178
 
179
+ # for output in result:
180
 
181
+ # if "language" in output:
182
+ # output_language = output["language"]
183
+ # if "question" in output:
184
+ # output_query = output["question"]
185
+ # if "docs" in output:
186
 
187
+ # try:
188
+ # docs = output['docs'] # List[Document]
189
+ # docs_html = []
190
+ # for i, d in enumerate(docs, 1):
191
+ # docs_html.append(make_html_source(d, i))
192
+ # docs_html = "".join(docs_html)
193
+ # except TypeError:
194
+ # print("No documents found")
195
+ # continue
196
 
197
+ # if "answer" in output:
198
+ # new_token = output["answer"] # str
199
+ # time.sleep(0.03)
200
+ # answer_yet = history[-1][1] + new_token
201
+ # answer_yet = parse_output_llm_with_sources(answer_yet)
202
+ # history[-1] = (query,answer_yet)
203
 
204
+ # yield history,docs_html,output_query,output_language,gallery
205
 
206
 
207
 
 
214
  # raise gr.Error(f"ClimateQ&A Error: {e}\nThe error has been noted, try another question and if the error remains, you can contact us :)")
215
 
216
 
217
+ async for op in result:
218
 
219
+ op = op.ops[0]
220
+ # print("ITERATION",op)
221
 
222
+ if op['path'] == reformulated_question_path_id: # reforulated question
223
+ output_language = op['value']["language"] # str
224
+ output_query = op["value"]["question"]
225
 
226
+ elif op['path'] == retriever_path_id: # documents
227
+ try:
228
+ docs = op['value']['documents'] # List[Document]
229
+ docs_html = []
230
+ for i, d in enumerate(docs, 1):
231
+ docs_html.append(make_html_source(d, i))
232
+ docs_html = "".join(docs_html)
233
+ except TypeError:
234
+ print("No documents found")
235
+ print("op: ",op)
236
+ continue
237
 
238
+ elif op['path'] == streaming_output_path_id: # final answer
239
+ new_token = op['value'] # str
240
+ time.sleep(0.02)
241
+ answer_yet = history[-1][1] + new_token
242
+ answer_yet = parse_output_llm_with_sources(answer_yet)
243
+ history[-1] = (query,answer_yet)
244
 
245
+ # elif op['path'] == final_output_path_id:
246
+ # final_output = op['value']
247
 
248
+ # if "answer" in final_output:
249
 
250
+ # final_output = final_output["answer"]
251
+ # print(final_output)
252
+ # answer = history[-1][1] + final_output
253
+ # answer = parse_output_llm_with_sources(answer)
254
+ # history[-1] = (query,answer)
255
 
256
+ else:
257
+ continue
258
 
259
+ history = [tuple(x) for x in history]
260
+ yield history,docs_html,output_query,output_language,gallery
261
 
262
 
263
  # Log answer on Azure Blob Storage
264
+ if os.getenv("GRADIO_ENV") == "local":
265
  timestamp = str(datetime.now().timestamp())
266
  file = timestamp + ".json"
267
  prompt = history[-1][0]
 
288
  # memory.save_context(inputs, {"answer": gradio_format[-1][1]})
289
  # yield gradio_format, memory.load_memory_variables({})["history"], source_string
290
 
291
+ # async def chat_with_timeout(query, history, audience, sources, reports, timeout_seconds=2):
292
+ # async def timeout_gen(async_gen, timeout):
293
+ # try:
294
+ # while True:
295
+ # try:
296
+ # yield await asyncio.wait_for(async_gen.__anext__(), timeout)
297
+ # except StopAsyncIteration:
298
+ # break
299
+ # except asyncio.TimeoutError:
300
+ # raise gr.Error("Operation timed out. Please try again.")
301
+
302
+ # return timeout_gen(chat(query, history, audience, sources, reports), timeout_seconds)
303
+
304
+
305
+
306
+ # # A wrapper function that includes a timeout
307
+ # async def chat_with_timeout(query, history, audience, sources, reports, timeout_seconds=2):
308
+ # try:
309
+ # # Use asyncio.wait_for to apply a timeout to the chat function
310
+ # return await asyncio.wait_for(chat(query, history, audience, sources, reports), timeout_seconds)
311
+ # except asyncio.TimeoutError:
312
+ # # Handle the timeout error as desired
313
+ # raise gr.Error("Operation timed out. Please try again.")
314
+
315
+
316
 
317
 
318
  def make_html_source(source,i):
 
745
 
746
  demo.queue()
747
 
748
+ demo.launch()
climateqa/engine/llm.py CHANGED
@@ -18,6 +18,7 @@ def get_llm(max_tokens = 1024,temperature = 0.0,verbose = True,streaming = False
18
  openai_api_type = "azure",
19
  max_tokens = max_tokens,
20
  temperature = temperature,
 
21
  verbose = verbose,
22
  streaming = streaming,
23
  **kwargs,
 
18
  openai_api_type = "azure",
19
  max_tokens = max_tokens,
20
  temperature = temperature,
21
+ request_timeout = 60,
22
  verbose = verbose,
23
  streaming = streaming,
24
  **kwargs,
climateqa/sample_questions.py CHANGED
@@ -1,7 +1,6 @@
1
 
2
  QUESTIONS = {
3
  "Popular Questions": [
4
- "Is climate change real?",
5
  "What evidence do we have of climate change?",
6
  "Are human activities causing global warming?",
7
  "What are the impacts of climate change?",
 
1
 
2
  QUESTIONS = {
3
  "Popular Questions": [
 
4
  "What evidence do we have of climate change?",
5
  "Are human activities causing global warming?",
6
  "What are the impacts of climate change?",