Vinh Nguyen commited on
Commit
7713f97
β€’
1 Parent(s): 36e3301

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +64 -60
app.py CHANGED
@@ -2,8 +2,6 @@ import os
2
  import tempfile
3
 
4
  import streamlit as st
5
- from streamlit_extras.add_vertical_space import add_vertical_space
6
- from streamlit_extras.colored_header import colored_header
7
 
8
  from langchain.callbacks.base import BaseCallbackHandler
9
  from langchain.chains import ConversationalRetrievalChain
@@ -14,15 +12,22 @@ from langchain.memory import ConversationBufferMemory
14
  from langchain.memory.chat_message_histories import StreamlitChatMessageHistory
15
  from langchain.text_splitter import RecursiveCharacterTextSplitter
16
  from langchain_community.vectorstores import DocArrayInMemorySearch
 
 
 
 
 
 
 
17
 
18
- st.set_page_config(page_title="πŸ“š InkChatGPT: Chat with Documents", page_icon="πŸ“š")
19
 
20
- add_vertical_space(30)
21
- colored_header(
22
- label="πŸ“š InkChatGPT",
23
- description="Chat with Documents",
24
- color_name="light-blue-70",
25
- )
26
 
27
 
28
  @st.cache_resource(ttl="1h")
@@ -75,66 +80,65 @@ class StreamHandler(BaseCallbackHandler):
75
 
76
  class PrintRetrievalHandler(BaseCallbackHandler):
77
  def __init__(self, container):
78
- self.status = container.status("**Context Retrieval**")
 
79
 
80
  def on_retriever_start(self, serialized: dict, query: str, **kwargs):
81
- self.status.write(f"**Question:** {query}")
82
- self.status.update(label=f"**Context Retrieval:** {query}")
83
 
84
  def on_retriever_end(self, documents, **kwargs):
85
- for idx, doc in enumerate(documents):
86
- source = os.path.basename(doc.metadata["source"])
87
- self.status.write(f"**Document {idx} from {source}**")
88
- self.status.markdown(doc.page_content)
89
- self.status.update(state="complete")
 
 
 
 
90
 
 
 
 
91
 
92
- openai_api_key = st.sidebar.text_input("OpenAI API Key", type="password")
 
 
 
93
 
94
  if not openai_api_key:
95
- st.info("Please add your OpenAI API key to continue.")
96
  st.stop()
97
 
98
- uploaded_files = st.sidebar.file_uploader(
99
- label="Upload PDF files", type=["pdf"], accept_multiple_files=True
100
- )
101
- if not uploaded_files:
102
- st.info("Please upload PDF documents to continue.")
103
- st.stop()
104
 
105
- retriever = configure_retriever(uploaded_files)
 
 
106
 
107
- # Setup memory for contextual conversation
108
- msgs = StreamlitChatMessageHistory()
109
- memory = ConversationBufferMemory(
110
- memory_key="chat_history", chat_memory=msgs, return_messages=True
111
- )
112
-
113
- # Setup LLM and QA chain
114
- llm = ChatOpenAI(
115
- model_name="gpt-3.5-turbo",
116
- openai_api_key=openai_api_key,
117
- temperature=0,
118
- streaming=True,
119
- )
120
- qa_chain = ConversationalRetrievalChain.from_llm(
121
- llm, retriever=retriever, memory=memory, verbose=True
122
- )
123
-
124
- if len(msgs.messages) == 0 or st.sidebar.button("Clear message history"):
125
- msgs.clear()
126
- msgs.add_ai_message("How can I help you?")
127
-
128
- avatars = {"human": "user", "ai": "assistant"}
129
- for msg in msgs.messages:
130
- st.chat_message(avatars[msg.type]).write(msg.content)
131
-
132
- if user_query := st.chat_input(placeholder="Ask me anything!"):
133
- st.chat_message("user").write(user_query)
134
-
135
- with st.chat_message("assistant"):
136
- retrieval_handler = PrintRetrievalHandler(st.container())
137
- stream_handler = StreamHandler(st.empty())
138
- response = qa_chain.run(
139
- user_query, callbacks=[retrieval_handler, stream_handler]
140
- )
 
2
  import tempfile
3
 
4
  import streamlit as st
 
 
5
 
6
  from langchain.callbacks.base import BaseCallbackHandler
7
  from langchain.chains import ConversationalRetrievalChain
 
12
  from langchain.memory.chat_message_histories import StreamlitChatMessageHistory
13
  from langchain.text_splitter import RecursiveCharacterTextSplitter
14
  from langchain_community.vectorstores import DocArrayInMemorySearch
15
+ from streamlit_extras.add_vertical_space import add_vertical_space
16
+
17
+ # TODO: refactor
18
+ # TODO: extract class
19
+ # TODO: modularize
20
+ # TODO: hide side bar
21
+ # TODO: make the page attactive
22
 
23
+ st.set_page_config(page_title=":books: InkChatGPT: Chat with Documents", page_icon="πŸ“š")
24
 
25
+ st.image("./assets/icon.jpg", width=150)
26
+ st.header(":gray[:books: InkChatGPT]", divider="blue")
27
+ st.write("**Chat** with Documents")
28
+
29
+ # Setup memory for contextual conversation
30
+ msgs = StreamlitChatMessageHistory()
31
 
32
 
33
  @st.cache_resource(ttl="1h")
 
80
 
81
  class PrintRetrievalHandler(BaseCallbackHandler):
82
  def __init__(self, container):
83
+ self.status = container.status("**Thinking...**")
84
+ self.container = container
85
 
86
  def on_retriever_start(self, serialized: dict, query: str, **kwargs):
87
+ self.status.write(f"**Checking document for query:** `{query}`. Please wait...")
 
88
 
89
  def on_retriever_end(self, documents, **kwargs):
90
+ self.container.empty()
91
+
92
+
93
+ with st.sidebar.expander("Documents"):
94
+ st.subheader("Files")
95
+ uploaded_files = st.file_uploader(
96
+ label="Select PDF files", type=["pdf"], accept_multiple_files=True
97
+ )
98
+
99
 
100
+ with st.sidebar.expander("Setup"):
101
+ st.subheader("API Key")
102
+ openai_api_key = st.text_input("OpenAI API Key", type="password")
103
 
104
+ is_empty_chat_messages = len(msgs.messages) == 0
105
+ if is_empty_chat_messages or st.button("Clear message history"):
106
+ msgs.clear()
107
+ msgs.add_ai_message("How can I help you?")
108
 
109
  if not openai_api_key:
110
+ st.info("Please add your OpenAI API key in the sidebar to continue.")
111
  st.stop()
112
 
113
+ if uploaded_files:
114
+ retriever = configure_retriever(uploaded_files)
 
 
 
 
115
 
116
+ memory = ConversationBufferMemory(
117
+ memory_key="chat_history", chat_memory=msgs, return_messages=True
118
+ )
119
 
120
+ # Setup LLM and QA chain
121
+ llm = ChatOpenAI(
122
+ model_name="gpt-3.5-turbo",
123
+ openai_api_key=openai_api_key,
124
+ temperature=0,
125
+ streaming=True,
126
+ )
127
+
128
+ chain = ConversationalRetrievalChain.from_llm(
129
+ llm, retriever=retriever, memory=memory, verbose=False
130
+ )
131
+
132
+ avatars = {"human": "user", "ai": "assistant"}
133
+ for msg in msgs.messages:
134
+ st.chat_message(avatars[msg.type]).write(msg.content)
135
+
136
+ if user_query := st.chat_input(placeholder="Ask me anything!"):
137
+ st.chat_message("user").write(user_query)
138
+
139
+ with st.chat_message("assistant"):
140
+ retrieval_handler = PrintRetrievalHandler(st.empty())
141
+ stream_handler = StreamHandler(st.empty())
142
+ response = chain.run(
143
+ user_query, callbacks=[retrieval_handler, stream_handler]
144
+ )