vinhnx90 commited on
Commit
68eaa27
1 Parent(s): 18a32c9
Files changed (4) hide show
  1. app.py +3 -41
  2. calback_handler.py +31 -0
  3. requirements.txt +2 -1
  4. token_stream_handler.py +0 -13
app.py CHANGED
@@ -2,24 +2,18 @@ import os
2
  import tempfile
3
 
4
  import streamlit as st
5
- from chat_profile import ChatProfileRoleEnum
6
-
7
  from langchain.callbacks.base import BaseCallbackHandler
8
  from langchain.chains import ConversationalRetrievalChain
9
  from langchain.chat_models import ChatOpenAI
10
- from langchain_community.document_loaders import Docx2txtLoader, PyPDFLoader, TextLoader
11
  from langchain.embeddings import HuggingFaceEmbeddings
12
  from langchain.memory import ConversationBufferMemory
13
  from langchain.memory.chat_message_histories import StreamlitChatMessageHistory
14
  from langchain.text_splitter import RecursiveCharacterTextSplitter
 
15
  from langchain_community.vectorstores import DocArrayInMemorySearch
16
- from streamlit_extras.add_vertical_space import add_vertical_space
17
 
18
- # TODO: refactor
19
- # TODO: extract class
20
- # TODO: modularize
21
- # TODO: hide side bar
22
- # TODO: make the page attactive
23
 
24
  # configs
25
  LLM_MODEL_NAME = "gpt-3.5-turbo"
@@ -89,38 +83,6 @@ def configure_retriever(uploaded_files):
89
  return retriever
90
 
91
 
92
- class StreamHandler(BaseCallbackHandler):
93
- def __init__(
94
- self, container: st.delta_generator.DeltaGenerator, initial_text: str = ""
95
- ):
96
- self.container = container
97
- self.text = initial_text
98
- self.run_id_ignore_token = None
99
-
100
- def on_llm_start(self, serialized: dict, prompts: list, **kwargs):
101
- # Workaround to prevent showing the rephrased question as output
102
- if prompts[0].startswith("Human"):
103
- self.run_id_ignore_token = kwargs.get("run_id")
104
-
105
- def on_llm_new_token(self, token: str, **kwargs) -> None:
106
- if self.run_id_ignore_token == kwargs.get("run_id", False):
107
- return
108
- self.text += token
109
- self.container.markdown(self.text)
110
-
111
-
112
- class PrintRetrievalHandler(BaseCallbackHandler):
113
- def __init__(self, container):
114
- self.status = container.status("**Thinking...**")
115
- self.container = container
116
-
117
- def on_retriever_start(self, serialized: dict, query: str, **kwargs):
118
- self.status.write(f"**Checking document for query:** `{query}`. Please wait...")
119
-
120
- def on_retriever_end(self, documents, **kwargs):
121
- self.container.empty()
122
-
123
-
124
  with st.sidebar.expander("Documents"):
125
  st.subheader("Files")
126
  uploaded_files = st.file_uploader(
 
2
  import tempfile
3
 
4
  import streamlit as st
 
 
5
  from langchain.callbacks.base import BaseCallbackHandler
6
  from langchain.chains import ConversationalRetrievalChain
7
  from langchain.chat_models import ChatOpenAI
 
8
  from langchain.embeddings import HuggingFaceEmbeddings
9
  from langchain.memory import ConversationBufferMemory
10
  from langchain.memory.chat_message_histories import StreamlitChatMessageHistory
11
  from langchain.text_splitter import RecursiveCharacterTextSplitter
12
+ from langchain_community.document_loaders import Docx2txtLoader, PyPDFLoader, TextLoader
13
  from langchain_community.vectorstores import DocArrayInMemorySearch
 
14
 
15
+ from chat_profile import ChatProfileRoleEnum
16
+ from calback_handler import StreamHandler, PrintRetrievalHandler
 
 
 
17
 
18
  # configs
19
  LLM_MODEL_NAME = "gpt-3.5-turbo"
 
83
  return retriever
84
 
85
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
  with st.sidebar.expander("Documents"):
87
  st.subheader("Files")
88
  uploaded_files = st.file_uploader(
calback_handler.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain.callbacks.base import BaseCallbackHandler
2
+
3
+
4
+ class StreamHandler(BaseCallbackHandler):
5
+ def __init__(self, container, initial_text: str = ""):
6
+ self.container = container
7
+ self.text = initial_text
8
+ self.run_id_ignore_token = None
9
+
10
+ def on_llm_start(self, serialized: dict, prompts: list, **kwargs):
11
+ # Workaround to prevent showing the rephrased question as output
12
+ if prompts[0].startswith("Human"):
13
+ self.run_id_ignore_token = kwargs.get("run_id")
14
+
15
+ def on_llm_new_token(self, token: str, **kwargs) -> None:
16
+ if self.run_id_ignore_token == kwargs.get("run_id", False):
17
+ return
18
+ self.text += token
19
+ self.container.markdown(self.text)
20
+
21
+
22
+ class PrintRetrievalHandler(BaseCallbackHandler):
23
+ def __init__(self, container):
24
+ self.status = container.status("**Thinking...**")
25
+ self.container = container
26
+
27
+ def on_retriever_start(self, serialized: dict, query: str, **kwargs):
28
+ self.status.write(f"**Checking document for query:** `{query}`. Please wait...")
29
+
30
+ def on_retriever_end(self, documents, **kwargs):
31
+ self.container.empty()
requirements.txt CHANGED
@@ -5,4 +5,5 @@ langchain
5
  streamlit
6
  streamlit_chat
7
  streamlit-extras
8
- pypdf
 
 
5
  streamlit
6
  streamlit_chat
7
  streamlit-extras
8
+ pypdf
9
+ docx2txt
token_stream_handler.py DELETED
@@ -1,13 +0,0 @@
1
- import os
2
-
3
- from langchain.callbacks.base import BaseCallbackHandler
4
-
5
-
6
- class StreamHandler(BaseCallbackHandler):
7
- def __init__(self, container, initial_text=""):
8
- self.container = container
9
- self.text = initial_text
10
-
11
- def on_llm_new_token(self, token: str, **kwargs) -> None:
12
- self.text += token
13
- self.container.markdown(self.text)