vinhnx90 commited on
Commit
14ea497
β€’
1 Parent(s): 4497a00

Refactor app to have a LLM models selection

Browse files
Files changed (4) hide show
  1. app.py +101 -48
  2. document_retriever.py +1 -0
  3. llm_provider.py +6 -0
  4. requirements.txt +1 -0
app.py CHANGED
@@ -1,22 +1,39 @@
 
1
  import streamlit as st
2
  from langchain.chains.conversational_retrieval.base import ConversationalRetrievalChain
3
- from langchain.chains.retrieval_qa.base import RetrievalQA
4
  from langchain.memory import ConversationBufferMemory
 
5
  from langchain_community.chat_message_histories.streamlit import (
6
  StreamlitChatMessageHistory,
7
  )
8
- from langchain_community.chat_models.openai import ChatOpenAI
9
 
10
  from calback_handler import PrintRetrievalHandler, StreamHandler
11
  from chat_profile import ChatProfileRoleEnum
12
  from document_retriever import configure_retriever
 
13
 
14
- LLM_MODEL = "gpt-3.5-turbo"
 
 
15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  st.set_page_config(
17
  page_title="InkChatGPT: Chat with Documents",
18
  page_icon="πŸ“š",
19
- initial_sidebar_state="collapsed",
20
  menu_items={
21
  "Get Help": "https://x.com/vinhnx",
22
  "Report a bug": "https://github.com/vinhnx/InkChatGPT/issues",
@@ -26,15 +43,6 @@ st.set_page_config(
26
  },
27
  )
28
 
29
- # Hide Header
30
- # st.markdown(
31
- # """<style>.stApp [data-testid="stToolbar"]{display:none;}</style>""",
32
- # unsafe_allow_html=True,
33
- # )
34
-
35
- # Setup memory for contextual conversation
36
- msgs = StreamlitChatMessageHistory()
37
-
38
  with st.sidebar:
39
  with st.container():
40
  col1, col2 = st.columns([0.2, 0.8])
@@ -47,40 +55,65 @@ with st.sidebar:
47
  with col2:
48
  st.header(":books: InkChatGPT")
49
 
50
- documents_tab, settings_tab = st.tabs(["Documents", "Settings"])
51
- with settings_tab:
52
- openai_api_key = st.text_input("OpenAI API Key", type="password")
 
 
 
 
 
 
 
53
 
54
- cohere_api_key = ""
55
- if st.toggle(
56
- label="Use Cohere's Rerank", help="https://txt.cohere.com/rerank/"
57
- ):
58
- cohere_api_key = st.text_input("Cohere API Key", type="password")
 
59
 
60
- if len(msgs.messages) == 0 or st.button("Clear message history"):
 
61
  msgs.clear()
62
  msgs.add_ai_message("""
63
  Hi, your uploaded document(s) had been analyzed.
64
 
65
- Feel free to ask me any questions. For example: you can start by asking me `What is this book about?` or `Tell me about the content of this book!`'
 
 
 
 
66
  """)
67
 
68
- with documents_tab:
69
- uploaded_files = st.file_uploader(
70
- label="Select files",
71
- type=["pdf", "txt", "docx"],
72
- accept_multiple_files=True,
73
- disabled=(not openai_api_key),
74
- )
 
75
 
76
- if not openai_api_key:
77
- st.info("πŸ”‘ Please open the `Settings` tab from side bar menu to get started.")
78
 
79
- if uploaded_files:
80
- result_retriever = configure_retriever(
81
- uploaded_files, cohere_api_key=cohere_api_key
 
82
  )
83
 
 
 
 
 
 
 
 
 
 
 
84
  if result_retriever is not None:
85
  memory = ConversationBufferMemory(
86
  memory_key="chat_history",
@@ -88,37 +121,57 @@ if uploaded_files:
88
  return_messages=True,
89
  )
90
 
91
- # Setup LLM and QA chain
92
- llm = ChatOpenAI(
93
- model=LLM_MODEL,
94
- api_key=openai_api_key,
95
- temperature=0,
96
- streaming=True,
97
- )
 
 
 
 
 
 
 
 
 
 
 
 
98
 
 
99
  chain = ConversationalRetrievalChain.from_llm(
100
- llm,
101
  retriever=result_retriever,
102
  memory=memory,
103
- verbose=False,
104
  max_tokens_limit=4000,
105
  )
106
 
107
  avatars = {
108
- ChatProfileRoleEnum.HUMAN: "user",
109
- ChatProfileRoleEnum.AI: "assistant",
110
  }
111
 
112
  for msg in msgs.messages:
113
  st.chat_message(avatars[msg.type]).write(msg.content)
114
 
 
115
  if user_query := st.chat_input(
116
  placeholder="Ask me anything!",
117
- disabled=(not openai_api_key),
118
  ):
119
  st.chat_message("user").write(user_query)
120
 
121
  with st.chat_message("assistant"):
122
  retrieval_handler = PrintRetrievalHandler(st.empty())
123
  stream_handler = StreamHandler(st.empty())
124
- response = chain.run(user_query, callbacks=[retrieval_handler, stream_handler])
 
 
 
 
 
 
 
1
+ from sklearn import model_selection
2
  import streamlit as st
3
  from langchain.chains.conversational_retrieval.base import ConversationalRetrievalChain
 
4
  from langchain.memory import ConversationBufferMemory
5
+ from langchain_cohere import ChatCohere
6
  from langchain_community.chat_message_histories.streamlit import (
7
  StreamlitChatMessageHistory,
8
  )
9
+ from langchain_openai import ChatOpenAI
10
 
11
  from calback_handler import PrintRetrievalHandler, StreamHandler
12
  from chat_profile import ChatProfileRoleEnum
13
  from document_retriever import configure_retriever
14
+ from llm_provider import LLMProviderEnum
15
 
16
+ # Constants
17
+ GPT_LLM_MODEL = "gpt-3.5-turbo"
18
+ COMMAND_R_LLM_MODEL = "command-r"
19
 
20
+ # Properties
21
+ uploaded_files = []
22
+ api_key = ""
23
+ result_retriever = None
24
+ chain = None
25
+ llm = None
26
+ model_name = ""
27
+
28
+ # Set up sidebar
29
+ if "sidebar_state" not in st.session_state:
30
+ st.session_state.sidebar_state = "expanded"
31
+
32
+ # Streamlit app configuration
33
  st.set_page_config(
34
  page_title="InkChatGPT: Chat with Documents",
35
  page_icon="πŸ“š",
36
+ initial_sidebar_state=st.session_state.sidebar_state,
37
  menu_items={
38
  "Get Help": "https://x.com/vinhnx",
39
  "Report a bug": "https://github.com/vinhnx/InkChatGPT/issues",
 
43
  },
44
  )
45
 
 
 
 
 
 
 
 
 
 
46
  with st.sidebar:
47
  with st.container():
48
  col1, col2 = st.columns([0.2, 0.8])
 
55
  with col2:
56
  st.header(":books: InkChatGPT")
57
 
58
+ # Model
59
+ selected_model = st.selectbox(
60
+ "Select a model",
61
+ options=[
62
+ LLMProviderEnum.OPEN_AI.value,
63
+ LLMProviderEnum.COHERE.value,
64
+ ],
65
+ index=None,
66
+ placeholder="Select a model...",
67
+ )
68
 
69
+ if selected_model:
70
+ api_key = st.text_input(f"{selected_model} API Key", type="password")
71
+ if selected_model == LLMProviderEnum.OPEN_AI:
72
+ model_name = GPT_LLM_MODEL
73
+ elif selected_model == LLMProviderEnum.COHERE:
74
+ model_name = COMMAND_R_LLM_MODEL
75
 
76
+ msgs = StreamlitChatMessageHistory()
77
+ if len(msgs.messages) == 0:
78
  msgs.clear()
79
  msgs.add_ai_message("""
80
  Hi, your uploaded document(s) had been analyzed.
81
 
82
+ Feel free to ask me any questions. For example: you can start by asking me something like:
83
+
84
+ `What is this context about?`
85
+
86
+ `Help me summarize this!`
87
  """)
88
 
89
+ if api_key:
90
+ # Documents
91
+ uploaded_files = st.file_uploader(
92
+ label="Select files",
93
+ type=["pdf", "txt", "docx"],
94
+ accept_multiple_files=True,
95
+ disabled=(not selected_model),
96
+ )
97
 
98
+ if api_key and not uploaded_files:
99
+ st.info("🌟 You can upload some documents to get started")
100
 
101
+ # Check if a model is selected
102
+ if not selected_model:
103
+ st.info(
104
+ "πŸ“Ί Please select a model first, open the `Settings` tab from side bar menu to get started"
105
  )
106
 
107
+ # Check if API key is provided
108
+ if selected_model and len(api_key.strip()) == 0:
109
+ st.warning(
110
+ f"πŸ”‘ API key for {selected_model} is missing or invalid. Please provide a valid API key."
111
+ )
112
+
113
+ # Process uploaded files
114
+ if uploaded_files:
115
+ result_retriever = configure_retriever(uploaded_files, cohere_api_key=api_key)
116
+
117
  if result_retriever is not None:
118
  memory = ConversationBufferMemory(
119
  memory_key="chat_history",
 
121
  return_messages=True,
122
  )
123
 
124
+ if selected_model == LLMProviderEnum.OPEN_AI:
125
+ llm = ChatOpenAI(
126
+ model=model_name,
127
+ api_key=api_key,
128
+ temperature=0,
129
+ streaming=True,
130
+ )
131
+ elif selected_model == LLMProviderEnum.COHERE:
132
+ llm = ChatCohere(
133
+ model=model_name,
134
+ temperature=0.3,
135
+ streaming=True,
136
+ cohere_api_key=api_key,
137
+ )
138
+
139
+ if llm is None:
140
+ st.error(
141
+ "Failed to initialize the language model. Please check your configuration."
142
+ )
143
 
144
+ # Create the ConversationalRetrievalChain instance using the llm instance
145
  chain = ConversationalRetrievalChain.from_llm(
146
+ llm=llm,
147
  retriever=result_retriever,
148
  memory=memory,
149
+ verbose=True,
150
  max_tokens_limit=4000,
151
  )
152
 
153
  avatars = {
154
+ ChatProfileRoleEnum.HUMAN.value: "user",
155
+ ChatProfileRoleEnum.AI.value: "assistant",
156
  }
157
 
158
  for msg in msgs.messages:
159
  st.chat_message(avatars[msg.type]).write(msg.content)
160
 
161
+ # Get user input and generate response
162
  if user_query := st.chat_input(
163
  placeholder="Ask me anything!",
164
+ disabled=(not uploaded_files),
165
  ):
166
  st.chat_message("user").write(user_query)
167
 
168
  with st.chat_message("assistant"):
169
  retrieval_handler = PrintRetrievalHandler(st.empty())
170
  stream_handler = StreamHandler(st.empty())
171
+ response = chain.run(
172
+ user_query,
173
+ callbacks=[retrieval_handler, stream_handler],
174
+ )
175
+
176
+ if selected_model and model_name:
177
+ st.sidebar.caption(f"πŸͺ„ Using `{model_name}` model")
document_retriever.py CHANGED
@@ -5,6 +5,7 @@ import streamlit as st
5
  from langchain.retrievers import ContextualCompressionRetriever
6
  from langchain.retrievers.document_compressors import EmbeddingsFilter
7
  from langchain_cohere import CohereRerank
 
8
  from langchain_community.document_loaders import Docx2txtLoader, PyPDFLoader, TextLoader
9
  from langchain_community.embeddings import HuggingFaceEmbeddings
10
  from langchain_community.vectorstores import DocArrayInMemorySearch
 
5
  from langchain.retrievers import ContextualCompressionRetriever
6
  from langchain.retrievers.document_compressors import EmbeddingsFilter
7
  from langchain_cohere import CohereRerank
8
+
9
  from langchain_community.document_loaders import Docx2txtLoader, PyPDFLoader, TextLoader
10
  from langchain_community.embeddings import HuggingFaceEmbeddings
11
  from langchain_community.vectorstores import DocArrayInMemorySearch
llm_provider.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ from enum import Enum
2
+
3
+
4
+ class LLMProviderEnum(str, Enum):
5
+ OPEN_AI = "OpenAI"
6
+ COHERE = "Cohere"
requirements.txt CHANGED
@@ -3,6 +3,7 @@ sentence-transformers
3
  docarray
4
  langchain
5
  langchain_cohere
 
6
  streamlit
7
  streamlit_chat
8
  streamlit-extras
 
3
  docarray
4
  langchain
5
  langchain_cohere
6
+ langchain_openai
7
  streamlit
8
  streamlit_chat
9
  streamlit-extras