bohmian commited on
Commit
06c654b
·
verified ·
1 Parent(s): eb5e46b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +94 -72
app.py CHANGED
@@ -1,5 +1,6 @@
1
  import streamlit as st
2
- from streamlit_chat import message
 
3
 
4
  import os
5
  from langchain.llms import HuggingFaceHub # for calling HuggingFace Inference API (free for our use case)
@@ -26,15 +27,17 @@ warnings.filterwarnings("ignore", category=DeprecationWarning)
26
 
27
  # os.environ['HUGGINGFACEHUB_API_TOKEN'] = 'your_api_key' # for using HuggingFace Inference API
28
 
29
-
30
  from langchain.callbacks.base import BaseCallbackHandler
 
 
 
31
  class MyCallbackHandler(BaseCallbackHandler):
32
  def __init__(self):
33
  self.tokens = []
34
 
35
- def on_llm_new_token(self, token, **kwargs) -> None: # HuggingFaceHub() cannot stream
36
- self.tokens.append(token)
37
- print(token)
38
 
39
  def on_agent_action(self, action, **kwargs):
40
  """Run on agent action."""
@@ -74,7 +77,7 @@ class MyCallbackHandler(BaseCallbackHandler):
74
  def on_tool_end(self, output, **kwargs):
75
  """Run when tool ends running."""
76
  #print("\n\nTool End: ", output)
77
- tool_output = f"Tool Output: {output} \n \nI am processing the output from the tool..."
78
  st.session_state.messages.append(
79
  {"role": "assistant", "content": tool_output}
80
  )
@@ -114,7 +117,6 @@ if 'countries_to_scrape' not in st.session_state:
114
  # in main app, add configuration for user to upload PDF to override country's existing policies in vectorstore
115
 
116
 
117
-
118
  # Retriever config
119
  if 'chroma_n_similar_documents' not in st.session_state:
120
  st.session_state['chroma_n_similar_documents'] = 5 # number of chunks returned by chroma vector store retriever (semantic)
@@ -155,7 +157,6 @@ countries = [
155
  "Germany",
156
  ]
157
 
158
-
159
  @st.cache_data # only going to get once
160
  def get_llm(temp = st.session_state['temperature'], tokens = st.session_state['max_new_tokens']):
161
  # This is an inference endpoint API from huggingface, the model is not run locally, it is run on huggingface
@@ -264,6 +265,7 @@ def retrieve_answer_for_country(query_and_country: str) -> str: # TODO, change d
264
  return_source_documents=True # returned in result['source_documents']
265
  )
266
  result = qa(query)
 
267
  st.session_state['source_documents'].append(result['source_documents']) # let user know what source docs are used
268
  return result['result']
269
 
@@ -305,14 +307,17 @@ def generic_chat_llm(query: str) -> str:
305
  @tool
306
  def compare(query:str) -> str:
307
  """Use this tool to give you hints and instructions on how you can compare between policies of countries.
308
- Use this tool only at one of your final steps, do not use it at the start.
309
  When putting the query into this tool, look at the entire query that the user has asked at the start,
310
  do not leave any details in the query out.
311
  """
312
- return f"""Look at all your previous observations to answer the user query.
313
- Use as much relevant information as possible but only from your previous thoughts and observations.
314
- If you need more details, you can use a tool to find out more. If you have enough information,
315
- use your reasoning to answer them to the best of your ability. Give as much detail as you want in your answer."""
 
 
 
316
 
317
  retrieve_answer_for_country.callbacks = [my_callback_handler]
318
  compare.callbacks = [my_callback_handler]
@@ -333,77 +338,94 @@ agent = initialize_agent(
333
  # max_iterations=10
334
  )
335
 
 
 
 
336
 
337
-
338
- # Create a header element
339
- st.header("Chat")
340
 
341
  col1, col2 = st.columns(2)
342
  # with col1:
343
 
344
- # Store the conversation in the session state.
345
- # Used to render the chat conversation.
346
- # Initialize it with the first message for users to be greeted with
347
- if "messages" not in st.session_state:
348
- st.session_state.messages = [
349
- {"role": "assistant", "content": "How may I help you today?"}
350
- ]
351
-
352
- if "current_response" not in st.session_state:
353
- st.session_state.current_response = ""
354
-
355
- # Loop through each message in the session state and render it as a chat message.
356
- for message in st.session_state.messages:
357
- with st.chat_message(message["role"]):
358
- st.markdown(message["content"])
359
-
360
- # We initialize the quantized LLM from a local path.
361
- # Currently most parameters are fixed but we can make them
362
- # configurable.
363
- #llm_chain = create_chain(retriever)
364
-
365
- # We take questions/instructions from the chat input to pass to the LLM
366
- if user_query := st.chat_input("Your message here", key="user_input"):
367
-
368
- # Add our input to the session state
369
- st.session_state.messages.append(
370
- {"role": "user", "content": user_query}
371
- )
372
-
373
- # Add our input to the chat window
374
- with st.chat_message("user"):
375
- st.markdown(user_query)
 
 
 
376
 
377
- # Let user know agent is planning the actions
378
- action_plan_message = "Please wait while I plan out a best set of actions to obtain the information and answer your query."
 
379
 
380
- # Add the response to the session state
381
- st.session_state.messages.append(
382
- {"role": "assistant", "content": action_plan_message}
383
- )
384
- # Add the response to the chat window
385
- with st.chat_message("assistant"):
386
- st.markdown(action_plan_message)
387
 
388
- # Pass our input to the llm chain and capture the final responses.
389
- # It is worth noting that the Stream Handler is already receiving the
390
- # streaming response as the llm is generating. We get our response
391
- # here once the llm has finished generating the complete response.
392
- results = agent(user_query)
393
- response = f"The answer to your query is: {results['output']}"
 
394
 
395
- # Add the response to the session state
396
- st.session_state.messages.append(
397
- {"role": "assistant", "content": response}
398
- )
 
 
399
 
400
- # Add the response to the chat window
401
- with st.chat_message("assistant"):
402
- st.markdown(response)
 
403
 
 
 
 
404
 
405
- # with col2:
406
- # st.write("hi")
407
 
 
 
408
 
409
 
 
 
 
 
 
 
 
 
 
1
  import streamlit as st
2
+ #from streamlit_chat import message
3
+ from streamlit_option_menu import option_menu
4
 
5
  import os
6
  from langchain.llms import HuggingFaceHub # for calling HuggingFace Inference API (free for our use case)
 
27
 
28
  # os.environ['HUGGINGFACEHUB_API_TOKEN'] = 'your_api_key' # for using HuggingFace Inference API
29
 
 
30
  from langchain.callbacks.base import BaseCallbackHandler
31
+
32
+ # callback is needed to print intermediate steps of agent reasoning in the chatbot
33
+ # i.e. when action is taken, when tool is called, when tool call is complete etc.
34
  class MyCallbackHandler(BaseCallbackHandler):
35
  def __init__(self):
36
  self.tokens = []
37
 
38
+ # def on_llm_new_token(self, token, **kwargs) -> None: # HuggingFaceHub() cannot stream unfortunately!
39
+ # self.tokens.append(token)
40
+ # print(token)
41
 
42
  def on_agent_action(self, action, **kwargs):
43
  """Run on agent action."""
 
77
  def on_tool_end(self, output, **kwargs):
78
  """Run when tool ends running."""
79
  #print("\n\nTool End: ", output)
80
+ tool_output = f"Tool Output for Me: {output} \n \nI am processing the output from the tool..."
81
  st.session_state.messages.append(
82
  {"role": "assistant", "content": tool_output}
83
  )
 
117
  # in main app, add configuration for user to upload PDF to override country's existing policies in vectorstore
118
 
119
 
 
120
  # Retriever config
121
  if 'chroma_n_similar_documents' not in st.session_state:
122
  st.session_state['chroma_n_similar_documents'] = 5 # number of chunks returned by chroma vector store retriever (semantic)
 
157
  "Germany",
158
  ]
159
 
 
160
  @st.cache_data # only going to get once
161
  def get_llm(temp = st.session_state['temperature'], tokens = st.session_state['max_new_tokens']):
162
  # This is an inference endpoint API from huggingface, the model is not run locally, it is run on huggingface
 
265
  return_source_documents=True # returned in result['source_documents']
266
  )
267
  result = qa(query)
268
+ st.session_state['source_documents'].append(f"Documents retrieved for agent query '{query}' for country '{country}'.")
269
  st.session_state['source_documents'].append(result['source_documents']) # let user know what source docs are used
270
  return result['result']
271
 
 
307
  @tool
308
  def compare(query:str) -> str:
309
  """Use this tool to give you hints and instructions on how you can compare between policies of countries.
310
+ Use this tool as a final step, only after you have used other tools to obtain all the information you need.
311
  When putting the query into this tool, look at the entire query that the user has asked at the start,
312
  do not leave any details in the query out.
313
  """
314
+ return f"""Once again, check through all your previous observations to answer the user query.
315
+ Make sure every part of the query is addressed by the context, or that you have at least tried to do so.
316
+ Make sure you have not forgotten to address anything in the query.
317
+ If you still need more details, you can use another tool to find out more if you have not tried using the same tool with the necessary input earlier.
318
+ If you have enough information, use your reasoning to answer them to the best of your ability.
319
+ Give as much elaboration in your answer as possible but they MUST be from the earlier context.
320
+ Do not give details that cannot be found in the earlier context."""
321
 
322
  retrieve_answer_for_country.callbacks = [my_callback_handler]
323
  compare.callbacks = [my_callback_handler]
 
338
  # max_iterations=10
339
  )
340
 
341
+ # original menu options
342
+ if "menu" not in st.session_state:
343
+ st.session_state["menu"] = ["Chatbot", 'Source Documents \n (for Last Query, Click Only After Full Execution)', 'Settings']
344
 
345
+ with st.sidebar:
346
+ selected = option_menu("Main Menu", st.session_state["menu"],
347
+ icons=['house', 'gear', 'gear'], menu_icon="cast", default_index=0)
348
 
349
  col1, col2 = st.columns(2)
350
  # with col1:
351
 
352
+ if selected == "Chatbot":
353
+ st.header("Chat")
354
+
355
+ # Store the conversation in the session state.
356
+ # Used to render the chat conversation.
357
+ # Initialize it with the first message for users to be greeted with
358
+ if "messages" not in st.session_state:
359
+ st.session_state.messages = [
360
+ {"role": "assistant", "content": "How may I help you today? example qn, "}
361
+ ]
362
+
363
+ if "current_response" not in st.session_state:
364
+ st.session_state.current_response = ""
365
+
366
+ # Loop through each message in the session state and render it as a chat message.
367
+ for message in st.session_state.messages:
368
+ with st.chat_message(message["role"]):
369
+ st.markdown(message["content"])
370
+
371
+ # We initialize the quantized LLM from a local path.
372
+ # Currently most parameters are fixed but we can make them
373
+ # configurable.
374
+ #llm_chain = create_chain(retriever)
375
+
376
+ # We take questions/instructions from the chat input to pass to the LLM
377
+ if user_query := st.chat_input("Your message here", key="user_input"):
378
+ # remove source documents option from menu while query is running
379
+
380
+ st.session_state['source_documents'] = [f"User query: '{user_query}'"] # reset source documents list
381
+
382
+ formatted_user_query = f":blue[{user_query}]"
383
+ # Add our input to the session state
384
+ st.session_state.messages.append(
385
+ {"role": "user", "content": formatted_user_query}
386
+ )
387
 
388
+ # Add our input to the chat window
389
+ with st.chat_message("user"):
390
+ st.markdown(formatted_user_query)
391
 
392
+ # Let user know agent is planning the actions
393
+ action_plan_message = "Please wait while I plan out a best set of actions to obtain the necessary information to answer your query."
 
 
 
 
 
394
 
395
+ # Add the response to the session state
396
+ st.session_state.messages.append(
397
+ {"role": "assistant", "content": action_plan_message}
398
+ )
399
+ # Add the response to the chat window
400
+ with st.chat_message("assistant"):
401
+ st.markdown(action_plan_message)
402
 
403
+ # Pass our input to the llm chain and capture the final responses.
404
+ # It is worth noting that the Stream Handler is already receiving the
405
+ # streaming response as the llm is generating. We get our response
406
+ # here once the llm has finished generating the complete response.
407
+ results = agent(user_query)
408
+ response = f":blue[The answer to your query is:] {results['output']}"
409
 
410
+ # Add the response to the session state
411
+ st.session_state.messages.append(
412
+ {"role": "assistant", "content": response}
413
+ )
414
 
415
+ # Add the response to the chat window
416
+ with st.chat_message("assistant"):
417
+ st.markdown(response)
418
 
 
 
419
 
420
+ # with col2:
421
+ # st.write("hi")
422
 
423
 
424
+ if selected == "Source Documents \n (for Last Query, Click Only After Full Execution)":
425
+ st.header("Source Documents for Last Query")
426
+ try:
427
+ st.subheader(st.session_state['source_documents'][0])
428
+ for doc in st.session_state['source_documents'][1:]:
429
+ st.write(doc)
430
+ except:
431
+ st.write("No source documents retrieved yet. Please run a user query before coming back to this page.")