Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -1,5 +1,6 @@
|
|
1 |
import streamlit as st
|
2 |
-
from streamlit_chat import message
|
|
|
3 |
|
4 |
import os
|
5 |
from langchain.llms import HuggingFaceHub # for calling HuggingFace Inference API (free for our use case)
|
@@ -26,15 +27,17 @@ warnings.filterwarnings("ignore", category=DeprecationWarning)
|
|
26 |
|
27 |
# os.environ['HUGGINGFACEHUB_API_TOKEN'] = 'your_api_key' # for using HuggingFace Inference API
|
28 |
|
29 |
-
|
30 |
from langchain.callbacks.base import BaseCallbackHandler
|
|
|
|
|
|
|
31 |
class MyCallbackHandler(BaseCallbackHandler):
|
32 |
def __init__(self):
|
33 |
self.tokens = []
|
34 |
|
35 |
-
def on_llm_new_token(self, token, **kwargs) -> None: # HuggingFaceHub() cannot stream
|
36 |
-
|
37 |
-
|
38 |
|
39 |
def on_agent_action(self, action, **kwargs):
|
40 |
"""Run on agent action."""
|
@@ -74,7 +77,7 @@ class MyCallbackHandler(BaseCallbackHandler):
|
|
74 |
def on_tool_end(self, output, **kwargs):
|
75 |
"""Run when tool ends running."""
|
76 |
#print("\n\nTool End: ", output)
|
77 |
-
tool_output = f"Tool Output: {output} \n \nI am processing the output from the tool..."
|
78 |
st.session_state.messages.append(
|
79 |
{"role": "assistant", "content": tool_output}
|
80 |
)
|
@@ -114,7 +117,6 @@ if 'countries_to_scrape' not in st.session_state:
|
|
114 |
# in main app, add configuration for user to upload PDF to override country's existing policies in vectorstore
|
115 |
|
116 |
|
117 |
-
|
118 |
# Retriever config
|
119 |
if 'chroma_n_similar_documents' not in st.session_state:
|
120 |
st.session_state['chroma_n_similar_documents'] = 5 # number of chunks returned by chroma vector store retriever (semantic)
|
@@ -155,7 +157,6 @@ countries = [
|
|
155 |
"Germany",
|
156 |
]
|
157 |
|
158 |
-
|
159 |
@st.cache_data # only going to get once
|
160 |
def get_llm(temp = st.session_state['temperature'], tokens = st.session_state['max_new_tokens']):
|
161 |
# This is an inference endpoint API from huggingface, the model is not run locally, it is run on huggingface
|
@@ -264,6 +265,7 @@ def retrieve_answer_for_country(query_and_country: str) -> str: # TODO, change d
|
|
264 |
return_source_documents=True # returned in result['source_documents']
|
265 |
)
|
266 |
result = qa(query)
|
|
|
267 |
st.session_state['source_documents'].append(result['source_documents']) # let user know what source docs are used
|
268 |
return result['result']
|
269 |
|
@@ -305,14 +307,17 @@ def generic_chat_llm(query: str) -> str:
|
|
305 |
@tool
|
306 |
def compare(query:str) -> str:
|
307 |
"""Use this tool to give you hints and instructions on how you can compare between policies of countries.
|
308 |
-
Use this tool
|
309 |
When putting the query into this tool, look at the entire query that the user has asked at the start,
|
310 |
do not leave any details in the query out.
|
311 |
"""
|
312 |
-
return f"""
|
313 |
-
|
314 |
-
|
315 |
-
use
|
|
|
|
|
|
|
316 |
|
317 |
retrieve_answer_for_country.callbacks = [my_callback_handler]
|
318 |
compare.callbacks = [my_callback_handler]
|
@@ -333,77 +338,94 @@ agent = initialize_agent(
|
|
333 |
# max_iterations=10
|
334 |
)
|
335 |
|
|
|
|
|
|
|
336 |
|
337 |
-
|
338 |
-
|
339 |
-
|
340 |
|
341 |
col1, col2 = st.columns(2)
|
342 |
# with col1:
|
343 |
|
344 |
-
|
345 |
-
|
346 |
-
|
347 |
-
|
348 |
-
|
349 |
-
|
350 |
-
|
351 |
-
|
352 |
-
|
353 |
-
|
354 |
-
|
355 |
-
|
356 |
-
|
357 |
-
|
358 |
-
|
359 |
-
|
360 |
-
|
361 |
-
|
362 |
-
|
363 |
-
#
|
364 |
-
|
365 |
-
#
|
366 |
-
|
367 |
-
|
368 |
-
#
|
369 |
-
st.
|
370 |
-
|
371 |
-
|
372 |
-
|
373 |
-
|
374 |
-
|
375 |
-
|
|
|
|
|
|
|
376 |
|
377 |
-
|
378 |
-
|
|
|
379 |
|
380 |
-
|
381 |
-
|
382 |
-
{"role": "assistant", "content": action_plan_message}
|
383 |
-
)
|
384 |
-
# Add the response to the chat window
|
385 |
-
with st.chat_message("assistant"):
|
386 |
-
st.markdown(action_plan_message)
|
387 |
|
388 |
-
|
389 |
-
|
390 |
-
|
391 |
-
|
392 |
-
|
393 |
-
|
|
|
394 |
|
395 |
-
|
396 |
-
|
397 |
-
|
398 |
-
|
|
|
|
|
399 |
|
400 |
-
|
401 |
-
|
402 |
-
|
|
|
403 |
|
|
|
|
|
|
|
404 |
|
405 |
-
# with col2:
|
406 |
-
# st.write("hi")
|
407 |
|
|
|
|
|
408 |
|
409 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import streamlit as st
|
2 |
+
#from streamlit_chat import message
|
3 |
+
from streamlit_option_menu import option_menu
|
4 |
|
5 |
import os
|
6 |
from langchain.llms import HuggingFaceHub # for calling HuggingFace Inference API (free for our use case)
|
|
|
27 |
|
28 |
# os.environ['HUGGINGFACEHUB_API_TOKEN'] = 'your_api_key' # for using HuggingFace Inference API
|
29 |
|
|
|
30 |
from langchain.callbacks.base import BaseCallbackHandler
|
31 |
+
|
32 |
+
# callback is needed to print intermediate steps of agent reasoning in the chatbot
|
33 |
+
# i.e. when action is taken, when tool is called, when tool call is complete etc.
|
34 |
class MyCallbackHandler(BaseCallbackHandler):
|
35 |
def __init__(self):
|
36 |
self.tokens = []
|
37 |
|
38 |
+
# def on_llm_new_token(self, token, **kwargs) -> None: # HuggingFaceHub() cannot stream unfortunately!
|
39 |
+
# self.tokens.append(token)
|
40 |
+
# print(token)
|
41 |
|
42 |
def on_agent_action(self, action, **kwargs):
|
43 |
"""Run on agent action."""
|
|
|
77 |
def on_tool_end(self, output, **kwargs):
|
78 |
"""Run when tool ends running."""
|
79 |
#print("\n\nTool End: ", output)
|
80 |
+
tool_output = f"Tool Output for Me: {output} \n \nI am processing the output from the tool..."
|
81 |
st.session_state.messages.append(
|
82 |
{"role": "assistant", "content": tool_output}
|
83 |
)
|
|
|
117 |
# in main app, add configuration for user to upload PDF to override country's existing policies in vectorstore
|
118 |
|
119 |
|
|
|
120 |
# Retriever config
|
121 |
if 'chroma_n_similar_documents' not in st.session_state:
|
122 |
st.session_state['chroma_n_similar_documents'] = 5 # number of chunks returned by chroma vector store retriever (semantic)
|
|
|
157 |
"Germany",
|
158 |
]
|
159 |
|
|
|
160 |
@st.cache_data # only going to get once
|
161 |
def get_llm(temp = st.session_state['temperature'], tokens = st.session_state['max_new_tokens']):
|
162 |
# This is an inference endpoint API from huggingface, the model is not run locally, it is run on huggingface
|
|
|
265 |
return_source_documents=True # returned in result['source_documents']
|
266 |
)
|
267 |
result = qa(query)
|
268 |
+
st.session_state['source_documents'].append(f"Documents retrieved for agent query '{query}' for country '{country}'.")
|
269 |
st.session_state['source_documents'].append(result['source_documents']) # let user know what source docs are used
|
270 |
return result['result']
|
271 |
|
|
|
307 |
@tool
|
308 |
def compare(query:str) -> str:
|
309 |
"""Use this tool to give you hints and instructions on how you can compare between policies of countries.
|
310 |
+
Use this tool as a final step, only after you have used other tools to obtain all the information you need.
|
311 |
When putting the query into this tool, look at the entire query that the user has asked at the start,
|
312 |
do not leave any details in the query out.
|
313 |
"""
|
314 |
+
return f"""Once again, check through all your previous observations to answer the user query.
|
315 |
+
Make sure every part of the query is addressed by the context, or that you have at least tried to do so.
|
316 |
+
Make sure you have not forgotten to address anything in the query.
|
317 |
+
If you still need more details, you can use another tool to find out more if you have not tried using the same tool with the necessary input earlier.
|
318 |
+
If you have enough information, use your reasoning to answer them to the best of your ability.
|
319 |
+
Give as much elaboration in your answer as possible but they MUST be from the earlier context.
|
320 |
+
Do not give details that cannot be found in the earlier context."""
|
321 |
|
322 |
retrieve_answer_for_country.callbacks = [my_callback_handler]
|
323 |
compare.callbacks = [my_callback_handler]
|
|
|
338 |
# max_iterations=10
|
339 |
)
|
340 |
|
341 |
+
# original menu options
|
342 |
+
if "menu" not in st.session_state:
|
343 |
+
st.session_state["menu"] = ["Chatbot", 'Source Documents \n (for Last Query, Click Only After Full Execution)', 'Settings']
|
344 |
|
345 |
+
with st.sidebar:
|
346 |
+
selected = option_menu("Main Menu", st.session_state["menu"],
|
347 |
+
icons=['house', 'gear', 'gear'], menu_icon="cast", default_index=0)
|
348 |
|
349 |
col1, col2 = st.columns(2)
|
350 |
# with col1:
|
351 |
|
352 |
+
if selected == "Chatbot":
|
353 |
+
st.header("Chat")
|
354 |
+
|
355 |
+
# Store the conversation in the session state.
|
356 |
+
# Used to render the chat conversation.
|
357 |
+
# Initialize it with the first message for users to be greeted with
|
358 |
+
if "messages" not in st.session_state:
|
359 |
+
st.session_state.messages = [
|
360 |
+
{"role": "assistant", "content": "How may I help you today? example qn, "}
|
361 |
+
]
|
362 |
+
|
363 |
+
if "current_response" not in st.session_state:
|
364 |
+
st.session_state.current_response = ""
|
365 |
+
|
366 |
+
# Loop through each message in the session state and render it as a chat message.
|
367 |
+
for message in st.session_state.messages:
|
368 |
+
with st.chat_message(message["role"]):
|
369 |
+
st.markdown(message["content"])
|
370 |
+
|
371 |
+
# We initialize the quantized LLM from a local path.
|
372 |
+
# Currently most parameters are fixed but we can make them
|
373 |
+
# configurable.
|
374 |
+
#llm_chain = create_chain(retriever)
|
375 |
+
|
376 |
+
# We take questions/instructions from the chat input to pass to the LLM
|
377 |
+
if user_query := st.chat_input("Your message here", key="user_input"):
|
378 |
+
# remove source documents option from menu while query is running
|
379 |
+
|
380 |
+
st.session_state['source_documents'] = [f"User query: '{user_query}'"] # reset source documents list
|
381 |
+
|
382 |
+
formatted_user_query = f":blue[{user_query}]"
|
383 |
+
# Add our input to the session state
|
384 |
+
st.session_state.messages.append(
|
385 |
+
{"role": "user", "content": formatted_user_query}
|
386 |
+
)
|
387 |
|
388 |
+
# Add our input to the chat window
|
389 |
+
with st.chat_message("user"):
|
390 |
+
st.markdown(formatted_user_query)
|
391 |
|
392 |
+
# Let user know agent is planning the actions
|
393 |
+
action_plan_message = "Please wait while I plan out a best set of actions to obtain the necessary information to answer your query."
|
|
|
|
|
|
|
|
|
|
|
394 |
|
395 |
+
# Add the response to the session state
|
396 |
+
st.session_state.messages.append(
|
397 |
+
{"role": "assistant", "content": action_plan_message}
|
398 |
+
)
|
399 |
+
# Add the response to the chat window
|
400 |
+
with st.chat_message("assistant"):
|
401 |
+
st.markdown(action_plan_message)
|
402 |
|
403 |
+
# Pass our input to the llm chain and capture the final responses.
|
404 |
+
# It is worth noting that the Stream Handler is already receiving the
|
405 |
+
# streaming response as the llm is generating. We get our response
|
406 |
+
# here once the llm has finished generating the complete response.
|
407 |
+
results = agent(user_query)
|
408 |
+
response = f":blue[The answer to your query is:] {results['output']}"
|
409 |
|
410 |
+
# Add the response to the session state
|
411 |
+
st.session_state.messages.append(
|
412 |
+
{"role": "assistant", "content": response}
|
413 |
+
)
|
414 |
|
415 |
+
# Add the response to the chat window
|
416 |
+
with st.chat_message("assistant"):
|
417 |
+
st.markdown(response)
|
418 |
|
|
|
|
|
419 |
|
420 |
+
# with col2:
|
421 |
+
# st.write("hi")
|
422 |
|
423 |
|
424 |
+
if selected == "Source Documents \n (for Last Query, Click Only After Full Execution)":
|
425 |
+
st.header("Source Documents for Last Query")
|
426 |
+
try:
|
427 |
+
st.subheader(st.session_state['source_documents'][0])
|
428 |
+
for doc in st.session_state['source_documents'][1:]:
|
429 |
+
st.write(doc)
|
430 |
+
except:
|
431 |
+
st.write("No source documents retrieved yet. Please run a user query before coming back to this page.")
|