Update functions.py
Browse files- functions.py +5 -1
functions.py
CHANGED
@@ -32,6 +32,7 @@ from langchain.chat_models import ChatOpenAI
|
|
32 |
from langchain.callbacks.base import CallbackManager
|
33 |
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
|
34 |
from langchain.chains import ConversationalRetrievalChain, QAGenerationChain
|
|
|
35 |
|
36 |
from langchain.prompts.chat import (
|
37 |
ChatPromptTemplate,
|
@@ -57,6 +58,8 @@ time_str = time.strftime("%d%m%Y-%H%M%S")
|
|
57 |
HTML_WRAPPER = """<div style="overflow-x: auto; border: 1px solid #e6e9ef; border-radius: 0.25rem; padding: 1rem;
|
58 |
margin-bottom: 2.5rem">{}</div> """
|
59 |
|
|
|
|
|
60 |
#Stuff Chain Type Prompt template
|
61 |
|
62 |
@st.cache_resource
|
@@ -230,9 +233,10 @@ def embed_text(query,embedding_model,_docsearch):
|
|
230 |
chain = ConversationalRetrievalChain.from_llm(chat_llm,
|
231 |
retriever= _docsearch.as_retriever(),
|
232 |
qa_prompt = load_prompt(),
|
|
|
233 |
return_source_documents=True)
|
234 |
|
235 |
-
answer = chain({"question": query
|
236 |
|
237 |
return answer
|
238 |
|
|
|
32 |
from langchain.callbacks.base import CallbackManager
|
33 |
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
|
34 |
from langchain.chains import ConversationalRetrievalChain, QAGenerationChain
|
35 |
+
from langchain.memory import ConversationBufferMemory
|
36 |
|
37 |
from langchain.prompts.chat import (
|
38 |
ChatPromptTemplate,
|
|
|
58 |
HTML_WRAPPER = """<div style="overflow-x: auto; border: 1px solid #e6e9ef; border-radius: 0.25rem; padding: 1rem;
|
59 |
margin-bottom: 2.5rem">{}</div> """
|
60 |
|
61 |
+
memory = ConversationBufferMemory(memory_key='chat_history', return_messages=True)
|
62 |
+
|
63 |
#Stuff Chain Type Prompt template
|
64 |
|
65 |
@st.cache_resource
|
|
|
233 |
chain = ConversationalRetrievalChain.from_llm(chat_llm,
|
234 |
retriever= _docsearch.as_retriever(),
|
235 |
qa_prompt = load_prompt(),
|
236 |
+
memory = memory,
|
237 |
return_source_documents=True)
|
238 |
|
239 |
+
answer = chain({"question": query})
|
240 |
|
241 |
return answer
|
242 |
|