Theo Alves Da Costa commited on
Commit
787d3cb
1 Parent(s): 977441e

Added streamign

Browse files
app.py CHANGED
@@ -68,89 +68,77 @@ from langchain.callbacks.base import BaseCallbackHandler
68
  from queue import Queue, Empty
69
  from threading import Thread
70
  from collections.abc import Generator
 
 
 
 
 
71
 
72
- # Create a Queue
73
- Q = Queue()
74
 
75
- class QueueCallback(BaseCallbackHandler):
76
- """Callback handler for streaming LLM responses to a queue."""
77
 
78
- def __init__(self, q):
 
 
 
 
79
  self.q = q
80
 
81
- def on_llm_new_token(self, token: str, **kwargs: any) -> None:
 
 
 
 
 
 
 
 
 
 
 
82
  self.q.put(token)
83
 
84
- def on_llm_end(self, *args, **kwargs: any) -> None:
85
- return self.q.empty()
86
-
 
 
 
 
 
 
 
 
 
87
 
88
  # Create embeddings function and LLM
89
  embeddings_function = HuggingFaceEmbeddings(model_name = "sentence-transformers/multi-qa-mpnet-base-dot-v1")
90
- llm = get_llm(max_tokens = 1024,temperature = 0.0,verbose = True,streaming = True,
91
- callbacks=[QueueCallback(Q)],
 
92
  )
93
 
94
  # Create vectorstore and retriever
95
  vectorstore = get_pinecone_vectorstore(embeddings_function)
96
  retriever = ClimateQARetriever(vectorstore=vectorstore,sources = ["IPCC"],k_summary = 3,k_total = 10)
97
- chain = load_climateqa_chain(retriever,llm)
98
 
99
 
100
  #---------------------------------------------------------------------------
101
  # ClimateQ&A Streaming
102
  # From https://github.com/gradio-app/gradio/issues/5345
 
103
  #---------------------------------------------------------------------------
104
 
 
105
 
106
-
107
- # Create a function that will return our generator
108
- def stream(chain, input_text) -> Generator:
109
- with Q.mutex:
110
- Q.queue.clear()
111
- job_done = object()
112
-
113
- # Create a function to call - this will run in a thread
114
- def task():
115
- answer = chain({"query":input_text,"audience":"expert climate scientist"})
116
- Q.put(job_done)
117
-
118
- # Create a thread and start the function
119
- t = Thread(target=task)
120
- t.start()
121
-
122
- content = ""
123
-
124
- # Get each new token from the queue and yield for our generator
125
- while True:
126
- try:
127
- next_token = Q.get(True, timeout=1)
128
- if next_token is job_done:
129
- break
130
- content += next_token
131
- yield next_token, content
132
- except Empty:
133
- continue
134
-
135
-
136
- def stream_sentences(chain, input_text) -> Generator:
137
- """wrapper to stream function"""
138
- sentence = ""
139
- for next_token, content in stream(chain, input_text):
140
- sentence += next_token
141
- if "\n\n" in next_token:
142
- yield sentence
143
- sentence = ""
144
- if sentence:
145
- yield sentence
146
-
147
-
148
-
149
 
150
  def answer_user(message,history):
151
  return message, history + [[message, None]]
152
 
153
-
154
  def answer_bot(message,history,audience):
155
 
156
  if audience == "Children":
@@ -170,25 +158,39 @@ def answer_bot(message,history,audience):
170
  # for next_token, content in stream(message):
171
  # yield(content)
172
 
173
- output = chain({"query":message,"audience":audience_prompt})
174
- question = output["question"]
175
- sources = output["source_documents"]
176
-
177
- if len(sources) > 0:
178
- sources_text = []
179
- for i, d in enumerate(sources, 1):
180
- sources_text.append(make_html_source(d,i))
181
- sources_text = "\n\n".join([f"Query used for retrieval:\n{question}"] + sources_text)
182
 
183
- history[-1][1] = output["answer"]
184
- return "",history,sources_text
 
185
 
186
- else:
187
- sources_text = "⚠️ No relevant passages found in the climate science reports (IPCC and IPBES)"
188
- complete_response = "**⚠️ No relevant passages found in the climate science reports (IPCC and IPBES), you may want to ask a more specific question (specifying your question on climate issues).**"
189
- history[-1][1] = complete_response
190
- return "",history, sources_text
191
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
192
 
193
  #---------------------------------------------------------------------------
194
  # ClimateQ&A core functions
@@ -348,7 +350,19 @@ def log_on_azure(file, logs, share_client):
348
  # --------------------------------------------------------------------
349
 
350
 
 
 
 
 
 
 
 
351
 
 
 
 
 
 
352
 
353
 
354
  with gr.Blocks(title="🌍 Climate Q&A", css="style.css", theme=theme) as demo:
@@ -363,7 +377,9 @@ with gr.Blocks(title="🌍 Climate Q&A", css="style.css", theme=theme) as demo:
363
  with gr.Row(elem_id="chatbot-row"):
364
  with gr.Column(scale=2):
365
  # state = gr.State([system_template])
366
- bot = gr.Chatbot(show_copy_button=True,show_label = False,elem_id="chatbot",layout = "panel",avatar_images = ("assets/logo4.png",None))
 
 
367
 
368
  with gr.Row(elem_id = "input-message"):
369
  textbox=gr.Textbox(placeholder="Ask me anything here!",show_label=False,scale=7)
@@ -441,7 +457,6 @@ with gr.Blocks(title="🌍 Climate Q&A", css="style.css", theme=theme) as demo:
441
  examples_hidden.change(answer_user, [examples_hidden, bot], [textbox, bot], queue=False).then(
442
  answer_bot, [textbox,bot,dropdown_audience], [textbox,bot,sources_textbox]
443
  )
444
-
445
  submit_button.click(answer_user, [textbox, bot], [textbox, bot], queue=False).then(
446
  answer_bot, [textbox,bot,dropdown_audience], [textbox,bot,sources_textbox]
447
  )
@@ -619,6 +634,8 @@ Or around 2 to 4 times more than a typical Google search.
619
  - ClimateQ&A on Hugging Face is finally working again with all the new features !
620
  - Switched all python code to langchain codebase for cleaner code, easier maintenance and future features
621
  - Updated GPT model to August version
 
 
622
  - Use of HuggingFace embed on https://climateqa.com to avoid demultiplying deployments
623
 
624
  ##### v1.0.0 - *2023-05-11*
 
68
  from queue import Queue, Empty
69
  from threading import Thread
70
  from collections.abc import Generator
71
+ from langchain.schema import LLMResult
72
+ from typing import Any, Union,Dict,List
73
+ from queue import SimpleQueue
74
+ # # Create a Queue
75
+ # Q = Queue()
76
 
 
 
77
 
 
 
78
 
79
+ Q = SimpleQueue()
80
+ job_done = object() # signals the processing is done
81
+
82
+ class StreamingGradioCallbackHandler(BaseCallbackHandler):
83
+ def __init__(self, q: SimpleQueue):
84
  self.q = q
85
 
86
+ def on_llm_start(
87
+ self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
88
+ ) -> None:
89
+ """Run when LLM starts running. Clean the queue."""
90
+ while not self.q.empty():
91
+ try:
92
+ self.q.get(block=False)
93
+ except Empty:
94
+ continue
95
+
96
+ def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
97
+ """Run on new LLM token. Only available when streaming is enabled."""
98
  self.q.put(token)
99
 
100
+ def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
101
+ """Run when LLM ends running."""
102
+ self.q.put(job_done)
103
+
104
+ def on_llm_error(
105
+ self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
106
+ ) -> None:
107
+ """Run when LLM errors."""
108
+ self.q.put(job_done)
109
+
110
+
111
+
112
 
113
  # Create embeddings function and LLM
114
  embeddings_function = HuggingFaceEmbeddings(model_name = "sentence-transformers/multi-qa-mpnet-base-dot-v1")
115
+ llm_reformulation = get_llm(max_tokens = 512,temperature = 0.0,verbose = True,streaming = False)
116
+ llm_streaming = get_llm(max_tokens = 1024,temperature = 0.0,verbose = True,streaming = True,
117
+ callbacks=[StreamingGradioCallbackHandler(Q),StreamingStdOutCallbackHandler()],
118
  )
119
 
120
  # Create vectorstore and retriever
121
  vectorstore = get_pinecone_vectorstore(embeddings_function)
122
  retriever = ClimateQARetriever(vectorstore=vectorstore,sources = ["IPCC"],k_summary = 3,k_total = 10)
123
+ chain = load_climateqa_chain(retriever,llm_reformulation,llm_streaming)
124
 
125
 
126
  #---------------------------------------------------------------------------
127
  # ClimateQ&A Streaming
128
  # From https://github.com/gradio-app/gradio/issues/5345
129
+ # And https://stackoverflow.com/questions/76057076/how-to-stream-agents-response-in-langchain
130
  #---------------------------------------------------------------------------
131
 
132
+ from threading import Thread
133
 
134
+ def threaded_chain(query,audience):
135
+ response = chain({"query":query,"audience":audience})
136
+ Q.put(response)
137
+ Q.put(job_done)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
138
 
139
  def answer_user(message,history):
140
  return message, history + [[message, None]]
141
 
 
142
  def answer_bot(message,history,audience):
143
 
144
  if audience == "Children":
 
158
  # for next_token, content in stream(message):
159
  # yield(content)
160
 
161
+ thread = Thread(target=threaded_chain, kwargs={"query":message,"audience":audience_prompt})
162
+ thread.start()
 
 
 
 
 
 
 
163
 
164
+ history[-1][1] = ""
165
+ while True:
166
+ next_item = Q.get(block=True) # Blocks until an input is available
167
 
168
+ if next_item is job_done:
169
+ continue
 
 
 
170
 
171
+ elif isinstance(next_item, dict): # assuming LLMResult is a dictionary
172
+ response = next_item
173
+ if "source_documents" in response and len(response["source_documents"]) > 0:
174
+ sources_text = []
175
+ for i, d in enumerate(response["source_documents"], 1):
176
+ sources_text.append(make_html_source(d, i))
177
+ sources_text = "\n\n".join([f"Query used for retrieval:\n{response['question']}"] + sources_text)
178
+ # history[-1][1] += next_item["answer"]
179
+ # history[-1][1] += "\n\n" + sources_text
180
+ yield "", history, sources_text
181
+
182
+ else:
183
+ sources_text = "⚠️ No relevant passages found in the scientific reports (IPCC and IPBES)"
184
+ complete_response = "**⚠️ No relevant passages found in the climate science reports (IPCC and IPBES), you may want to ask a more specific question (specifying your question on climate and biodiversity issues).**"
185
+ history[-1][1] += "\n\n" + complete_response
186
+ yield "", history, sources_text
187
+ break
188
+
189
+ elif isinstance(next_item, str):
190
+ history[-1][1] += next_item
191
+ yield "", history, ""
192
+
193
+ thread.join()
194
 
195
  #---------------------------------------------------------------------------
196
  # ClimateQ&A core functions
 
350
  # --------------------------------------------------------------------
351
 
352
 
353
+ init_prompt = """
354
+ Hello ! I am ClimateQ&A, a conversational assistant designed to help you understand climate change and biodiversity loss. I will answer your questions by **sifting through the IPCC and IPBES scientific reports**.
355
+
356
+ 💡 How to use
357
+ - **Language**: You can ask me your questions in any language.
358
+ - **Audience**: You can specify your audience (children, general public, experts) to get a more adapted answer.
359
+ - **Sources**: You can choose to search in the IPCC or IPBES reports, or both.
360
 
361
+ 📚 Limitations
362
+ *Please note that the AI is not perfect and may sometimes give irrelevant answers. If you are not satisfied with the answer, please ask a more specific question or report your feedback to help us improve the system.*
363
+
364
+ ❓ What do you want to learn ?
365
+ """
366
 
367
 
368
  with gr.Blocks(title="🌍 Climate Q&A", css="style.css", theme=theme) as demo:
 
377
  with gr.Row(elem_id="chatbot-row"):
378
  with gr.Column(scale=2):
379
  # state = gr.State([system_template])
380
+ bot = gr.Chatbot(
381
+ value=[[None,init_prompt]],
382
+ show_copy_button=True,show_label = False,elem_id="chatbot",layout = "panel",avatar_images = ("assets/logo4.png",None))
383
 
384
  with gr.Row(elem_id = "input-message"):
385
  textbox=gr.Textbox(placeholder="Ask me anything here!",show_label=False,scale=7)
 
457
  examples_hidden.change(answer_user, [examples_hidden, bot], [textbox, bot], queue=False).then(
458
  answer_bot, [textbox,bot,dropdown_audience], [textbox,bot,sources_textbox]
459
  )
 
460
  submit_button.click(answer_user, [textbox, bot], [textbox, bot], queue=False).then(
461
  answer_bot, [textbox,bot,dropdown_audience], [textbox,bot,sources_textbox]
462
  )
 
634
  - ClimateQ&A on Hugging Face is finally working again with all the new features !
635
  - Switched all python code to langchain codebase for cleaner code, easier maintenance and future features
636
  - Updated GPT model to August version
637
+ - Added streaming response to improve UX
638
+ - Created a custom Retriever chain to avoid calling the LLM if there is no documents retrieved
639
  - Use of HuggingFace embed on https://climateqa.com to avoid demultiplying deployments
640
 
641
  ##### v1.0.0 - *2023-05-11*
climateqa/chains.py CHANGED
@@ -8,7 +8,7 @@ from langchain.chains import TransformChain, SequentialChain
8
  from langchain.chains.qa_with_sources import load_qa_with_sources_chain
9
 
10
  from climateqa.prompts import answer_prompt, reformulation_prompt,audience_prompts
11
-
12
 
13
  def load_reformulation_chain(llm):
14
 
@@ -38,6 +38,7 @@ def load_reformulation_chain(llm):
38
 
39
 
40
 
 
41
  def load_answer_chain(retriever,llm):
42
  prompt = PromptTemplate(template=answer_prompt, input_variables=["summaries", "question","audience","language"])
43
  qa_chain = load_qa_with_sources_chain(llm, chain_type="stuff",prompt = prompt)
@@ -45,24 +46,27 @@ def load_answer_chain(retriever,llm):
45
  # This could be improved by providing a document prompt to avoid modifying page_content in the docs
46
  # See here https://github.com/langchain-ai/langchain/issues/3523
47
 
48
- answer_chain = RetrievalQAWithSourcesChain(
49
  combine_documents_chain = qa_chain,
50
  retriever=retriever,
51
  return_source_documents = True,
 
 
52
  )
53
  return answer_chain
54
 
55
 
56
- def load_climateqa_chain(retriever,llm):
57
 
58
- reformulation_chain = load_reformulation_chain(llm)
59
- answer_chain = load_answer_chain(retriever,llm)
60
 
61
  climateqa_chain = SequentialChain(
62
  chains = [reformulation_chain,answer_chain],
63
  input_variables=["query","audience"],
64
  output_variables=["answer","question","language","source_documents"],
65
  return_all = True,
 
66
  )
67
  return climateqa_chain
68
 
 
8
  from langchain.chains.qa_with_sources import load_qa_with_sources_chain
9
 
10
  from climateqa.prompts import answer_prompt, reformulation_prompt,audience_prompts
11
+ from climateqa.custom_retrieval_chain import CustomRetrievalQAWithSourcesChain
12
 
13
  def load_reformulation_chain(llm):
14
 
 
38
 
39
 
40
 
41
+
42
  def load_answer_chain(retriever,llm):
43
  prompt = PromptTemplate(template=answer_prompt, input_variables=["summaries", "question","audience","language"])
44
  qa_chain = load_qa_with_sources_chain(llm, chain_type="stuff",prompt = prompt)
 
46
  # This could be improved by providing a document prompt to avoid modifying page_content in the docs
47
  # See here https://github.com/langchain-ai/langchain/issues/3523
48
 
49
+ answer_chain = CustomRetrievalQAWithSourcesChain(
50
  combine_documents_chain = qa_chain,
51
  retriever=retriever,
52
  return_source_documents = True,
53
+ verbose = True,
54
+ fallback_answer="**⚠️ No relevant passages found in the climate science reports (IPCC and IPBES), you may want to ask a more specific question (specifying your question on climate issues).**",
55
  )
56
  return answer_chain
57
 
58
 
59
+ def load_climateqa_chain(retriever,llm_reformulation,llm_answer):
60
 
61
+ reformulation_chain = load_reformulation_chain(llm_reformulation)
62
+ answer_chain = load_answer_chain(retriever,llm_answer)
63
 
64
  climateqa_chain = SequentialChain(
65
  chains = [reformulation_chain,answer_chain],
66
  input_variables=["query","audience"],
67
  output_variables=["answer","question","language","source_documents"],
68
  return_all = True,
69
+ verbose = True,
70
  )
71
  return climateqa_chain
72
 
climateqa/custom_retrieval_chain.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+ import inspect
3
+ from typing import Any, Dict, List, Optional
4
+
5
+ from pydantic import Extra
6
+
7
+ from langchain.schema.language_model import BaseLanguageModel
8
+ from langchain.callbacks.manager import (
9
+ AsyncCallbackManagerForChainRun,
10
+ CallbackManagerForChainRun,
11
+ )
12
+ from langchain.chains.base import Chain
13
+ from langchain.prompts.base import BasePromptTemplate
14
+
15
+ from typing import Any, Dict, List
16
+
17
+ from langchain.callbacks.manager import (
18
+ AsyncCallbackManagerForChainRun,
19
+ CallbackManagerForChainRun,
20
+ )
21
+ from langchain.chains.combine_documents.stuff import StuffDocumentsChain
22
+ from langchain.chains.qa_with_sources.base import BaseQAWithSourcesChain
23
+ from langchain.docstore.document import Document
24
+ from langchain.pydantic_v1 import Field
25
+ from langchain.schema import BaseRetriever
26
+
27
+ from langchain.chains import RetrievalQAWithSourcesChain
28
+
29
+
30
+ from langchain.chains.router.llm_router import LLMRouterChain
31
+
32
+ class CustomRetrievalQAWithSourcesChain(RetrievalQAWithSourcesChain):
33
+
34
+ fallback_answer:str = "No sources available to answer this question."
35
+
36
+ def _call(self,inputs,run_manager=None):
37
+ _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
38
+ accepts_run_manager = (
39
+ "run_manager" in inspect.signature(self._get_docs).parameters
40
+ )
41
+ if accepts_run_manager:
42
+ docs = self._get_docs(inputs, run_manager=_run_manager)
43
+ else:
44
+ docs = self._get_docs(inputs) # type: ignore[call-arg]
45
+
46
+
47
+ if len(docs) == 0:
48
+ answer = self.fallback_answer
49
+ sources = []
50
+ else:
51
+
52
+ answer = self.combine_documents_chain.run(
53
+ input_documents=docs, callbacks=_run_manager.get_child(), **inputs
54
+ )
55
+ answer, sources = self._split_sources(answer)
56
+
57
+ result: Dict[str, Any] = {
58
+ self.answer_key: answer,
59
+ self.sources_answer_key: sources,
60
+ }
61
+ if self.return_source_documents:
62
+ result["source_documents"] = docs
63
+ return result