Carlosito16 commited on
Commit
f3517b6
β€’
1 Parent(s): 74d04be

Update pages/3_chat.py

Browse files
Files changed (1) hide show
  1. pages/3_chat.py +119 -13
pages/3_chat.py CHANGED
@@ -1,7 +1,108 @@
1
  import streamlit as st
2
- from streamlit_extras.stateful_chat import chat, add_message
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  from langchain.chains import RetrievalQA
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
  with st.expander("key information"):
7
  st.write( st.session_state['chunked_df'], unsafe_allow_html=True)
@@ -10,19 +111,24 @@ with st.expander("key information"):
10
  st.markdown(st.session_state['repetition_penalty'])
11
 
12
 
13
- # qa_retriever = RetrievalQA.from_chain_type(llm=st.session_state['llm_model'] , chain_type="stuff",
14
- # retriever=st.session_state['faiss_db'].as_retriever())
15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
 
18
- # with chat(key="my_chat"):
19
- # if prompt := st.chat_input():
20
- # add_message("user", prompt, avatar="πŸ§‘β€πŸ’»")
21
- # # def stream_echo():
22
- # # for word in prompt.split():
23
- # # yield word + " "
24
- # # time.sleep(0.15)
25
- # add_message("assistant", "Echo: ", qa_retriever.run(prompt), avatar="🦜")
26
 
27
- # query = "How to process docuemnts about HR"
28
- # docs = st.session_state['faiss_db'].similarity_search(query)
 
1
  import streamlit as st
2
+ from streamlit_chat import message as st_message
3
+ import pandas as pd
4
+ import numpy as np
5
+ import datetime
6
+ import gspread
7
+ import torch
8
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
9
+
10
+
11
+ # from langchain.vectorstores import Chroma
12
+ from langchain.vectorstores import FAISS
13
+ from langchain.embeddings import HuggingFaceInstructEmbeddings
14
+
15
+
16
+ from langchain import HuggingFacePipeline
17
  from langchain.chains import RetrievalQA
18
+ from langchain.prompts import PromptTemplate
19
+ from langchain.memory import ConversationBufferWindowMemory
20
+
21
+
22
+ from langchain.chains import LLMChain
23
+ from langchain.chains import ConversationalRetrievalChain
24
+ from langchain.chains.question_answering import load_qa_chain
25
+ from langchain.chains.conversational_retrieval.prompts import CONDENSE_QUESTION_PROMPT
26
+
27
+
28
+ prompt_template = """
29
+ You are the chatbot and your job is to give answers.
30
+ MUST only use the following pieces of context to answer the question at the end. If the answers are not in the context or you are not sure of the answer, just say that you don't know, don't try to make up an answer.
31
+ {context}
32
+ Question: {question}
33
+ When encountering abusive, offensive, or harmful language, such as fuck, bitch,etc, just politely ask the users to maintain appropriate behaviours.
34
+ Always make sure to elaborate your response.
35
+ Never answer with any unfinished response
36
+ Answer:
37
+ """
38
+ PROMPT = PromptTemplate(
39
+ template=prompt_template, input_variables=["context", "question"]
40
+ )
41
+ # chain_type_kwargs = {"prompt": PROMPT}
42
+
43
+ @st.cache_resource
44
+ def load_conversational_qa_memory_retriever(llm_model, vector_database):
45
+
46
+ question_generator = LLMChain(llm=llm_model, prompt=CONDENSE_QUESTION_PROMPT)
47
+ doc_chain = load_qa_chain(llm_model, chain_type="stuff", prompt = PROMPT)
48
+ memory = ConversationBufferWindowMemory(k = 3, memory_key="chat_history", return_messages=True, output_key='answer')
49
+
50
+
51
+
52
+ conversational_qa_memory_retriever = ConversationalRetrievalChain(
53
+ retriever=vector_database.as_retriever(),
54
+ question_generator=question_generator,
55
+ combine_docs_chain=doc_chain,
56
+ return_source_documents=True,
57
+ memory = memory,
58
+ get_chat_history=lambda h :h)
59
+ return conversational_qa_memory_retriever, question_generator
60
+
61
+ def new_retrieve_answer():
62
+ prompt_answer= st.session_state.my_text_input + ". Try to be elaborate and informative in your answer."
63
+ answer = conversational_qa_memory_retriever({"question": prompt_answer })
64
+
65
+ print(f"condensed quesion : {question_generator.run({'chat_history': answer['chat_history'], 'question' : prompt_answer})}")
66
+
67
+ print(answer["chat_history"])
68
+
69
+ st.session_state.chat_history.append({"message": st.session_state.my_text_input, "is_user": True})
70
+ st.session_state.chat_history.append({"message": answer['answer'][6:] , "is_user": False})
71
+
72
+ st.session_state.my_text_input = ""
73
+
74
+ return answer['answer'][6:] #this positional slicing helps remove "<pad> " at the beginning
75
+
76
+ def clean_chat_history():
77
+ st.session_state.chat_history = []
78
+ conversational_qa_memory_retriever.memory.chat_memory.clear() #add this to remove
79
+
80
+
81
+ if "history" not in st.session_state: #this one is for the google sheet logging
82
+ st.session_state.history = []
83
+
84
 
85
+ if "chat_history" not in st.session_state: #this one is to pass previous messages into chat flow
86
+ st.session_state.chat_history = []
87
+
88
+
89
+
90
+
91
+ conversational_qa_memory_retriever, question_generator = load_conversational_qa_memory_retriever(llm_model = st.session_state['model'],
92
+ vector_database = st.session_state['faiss_db'])
93
+
94
+
95
+
96
+ print("all load done")
97
+
98
+
99
+ # Try adding this to set to clear the memory in each session
100
+ if st.session_state.chat_history == []:
101
+ conversational_qa_memory_retriever.memory.chat_memory.clear()
102
+
103
+
104
+
105
+ st.write("# extraGPT πŸ€– ")
106
 
107
  with st.expander("key information"):
108
  st.write( st.session_state['chunked_df'], unsafe_allow_html=True)
 
111
  st.markdown(st.session_state['repetition_penalty'])
112
 
113
 
 
 
114
 
115
+ st.write(' ⚠️ Please expect to wait **~ 10 - 20 seconds per question** as thi app is running on CPU against 3-billion-parameter LLM')
116
+
117
+ st.markdown("---")
118
+ st.write(" ")
119
+ st.write("""
120
+ ### ❔ Ask a question
121
+ """)
122
+
123
+
124
+
125
+
126
+ for chat in st.session_state.chat_history:
127
+ st_message(**chat)
128
+
129
+ query_input = st.text_input(label= 'Type a question' , key = 'my_text_input', on_change= new_retrieve_answer )
130
 
131
 
 
 
 
 
 
 
 
 
132
 
133
+ clear_button = st.button("Start new convo",
134
+ on_click=clean_chat_history)