vinhnx90 commited on
Commit
e698d82
β€’
1 Parent(s): 3cd35af

Use Cohere's Rerank to improve search retrieval performance

Browse files
Files changed (3) hide show
  1. app.py +67 -68
  2. document_retriever.py +5 -6
  3. requirements.txt +1 -0
app.py CHANGED
@@ -1,5 +1,6 @@
1
  import streamlit as st
2
  from langchain.chains.conversational_retrieval.base import ConversationalRetrievalChain
 
3
  from langchain.memory import ConversationBufferMemory
4
  from langchain_community.chat_message_histories.streamlit import (
5
  StreamlitChatMessageHistory,
@@ -34,86 +35,84 @@ st.set_page_config(
34
  # Setup memory for contextual conversation
35
  msgs = StreamlitChatMessageHistory()
36
 
37
- with st.container():
38
- col1, col2 = st.columns([0.2, 0.8])
39
- with col1:
40
- st.image(
41
- "./assets/app_icon.png",
42
- use_column_width="always",
43
- output_format="PNG",
44
- )
45
- with col2:
46
- st.header(":books: InkChatGPT")
47
- st.caption(
48
- """
49
- Simple Retrieval Augmented Generation (RAG) application that allows users to upload PDF documents and engage in a conversational Q&A, with a language model (LLM) based on the content of those documents. Built with LangChain as Streamlit.
 
 
 
 
 
 
 
 
 
 
 
 
 
50
 
51
- Supports PDF, TXT, DOCX β€’ Limit 200MB per file.
52
- * GitHub: https://github.com/vinhnx/InkChatGPT
53
- * Twitter: https://x.com/vinhnx
54
- """
 
 
 
 
 
55
  )
56
 
57
- chat_tab, documents_tab, settings_tab = st.tabs(["Chat", "Documents", "Settings"])
58
- with settings_tab:
59
- openai_api_key = st.text_input("OpenAI API Key", type="password")
60
- if len(msgs.messages) == 0 or st.button("Clear message history"):
61
- msgs.clear()
62
- msgs.add_ai_message("""
63
- Hi, your uploaded document(s) had been analyzed.
64
-
65
- Feel free to ask me any questions. For example: you can start by asking me `'What is this book about?` or `Tell me about the content of this book!`'
66
- """)
67
-
68
- with documents_tab:
69
- uploaded_files = st.file_uploader(
70
- label="Select files",
71
- type=["pdf", "txt", "docx"],
72
- accept_multiple_files=True,
73
- disabled=(not openai_api_key),
74
- )
75
 
76
- with chat_tab:
77
- if uploaded_files:
78
- result_retriever = configure_retriever(uploaded_files)
79
 
80
- if result_retriever is not None:
81
- memory = ConversationBufferMemory(
82
- memory_key="chat_history",
83
- chat_memory=msgs,
84
- return_messages=True,
85
- )
86
 
87
- # Setup LLM and QA chain
88
- llm = ChatOpenAI(
89
- model=LLM_MODEL,
90
- api_key=openai_api_key,
91
- temperature=0,
92
- streaming=True,
93
- )
94
 
95
- chain = ConversationalRetrievalChain.from_llm(
96
- llm,
97
- retriever=result_retriever,
98
- memory=memory,
99
- verbose=False,
100
- max_tokens_limit=4000,
101
- )
102
 
103
- avatars = {
104
- ChatProfileRoleEnum.HUMAN: "user",
105
- ChatProfileRoleEnum.AI: "assistant",
106
- }
107
-
108
- for msg in msgs.messages:
109
- st.chat_message(avatars[msg.type]).write(msg.content)
110
 
111
- if not openai_api_key:
112
- st.caption("πŸ”‘ Add your **OpenAI API key** on the `Settings` to continue.")
113
 
114
  if user_query := st.chat_input(
115
  placeholder="Ask me anything!",
116
- disabled=(not openai_api_key),
117
  ):
118
  st.chat_message("user").write(user_query)
119
 
 
1
  import streamlit as st
2
  from langchain.chains.conversational_retrieval.base import ConversationalRetrievalChain
3
+ from langchain.chains.retrieval_qa.base import RetrievalQA
4
  from langchain.memory import ConversationBufferMemory
5
  from langchain_community.chat_message_histories.streamlit import (
6
  StreamlitChatMessageHistory,
 
35
  # Setup memory for contextual conversation
36
  msgs = StreamlitChatMessageHistory()
37
 
38
+ with st.sidebar:
39
+ with st.container():
40
+ col1, col2 = st.columns([0.2, 0.8])
41
+ with col1:
42
+ st.image(
43
+ "./assets/app_icon.png",
44
+ use_column_width="always",
45
+ output_format="PNG",
46
+ )
47
+ with col2:
48
+ st.header(":books: InkChatGPT")
49
+
50
+ # chat_tab,
51
+ documents_tab, settings_tab = st.tabs(
52
+ [
53
+ # "Chat",
54
+ "Documents",
55
+ "Settings",
56
+ ]
57
+ )
58
+ with settings_tab:
59
+ openai_api_key = st.text_input("OpenAI API Key", type="password")
60
+ if len(msgs.messages) == 0 or st.button("Clear message history"):
61
+ msgs.clear()
62
+ msgs.add_ai_message("""
63
+ Hi, your uploaded document(s) had been analyzed.
64
 
65
+ Feel free to ask me any questions. For example: you can start by asking me `'What is this book about?` or `Tell me about the content of this book!`'
66
+ """)
67
+
68
+ with documents_tab:
69
+ uploaded_files = st.file_uploader(
70
+ label="Select files",
71
+ type=["pdf", "txt", "docx"],
72
+ accept_multiple_files=True,
73
+ disabled=(not openai_api_key),
74
  )
75
 
76
+ if not openai_api_key:
77
+ st.info("πŸ”‘ Please Add your **OpenAI API key** on the `Settings` to continue.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
 
79
+ if uploaded_files:
80
+ result_retriever = configure_retriever(uploaded_files)
 
81
 
82
+ if result_retriever is not None:
83
+ memory = ConversationBufferMemory(
84
+ memory_key="chat_history",
85
+ chat_memory=msgs,
86
+ return_messages=True,
87
+ )
88
 
89
+ # Setup LLM and QA chain
90
+ llm = ChatOpenAI(
91
+ model=LLM_MODEL,
92
+ api_key=openai_api_key,
93
+ temperature=0,
94
+ streaming=True,
95
+ )
96
 
97
+ chain = ConversationalRetrievalChain.from_llm(
98
+ llm,
99
+ retriever=result_retriever,
100
+ memory=memory,
101
+ verbose=False,
102
+ max_tokens_limit=4000,
103
+ )
104
 
105
+ avatars = {
106
+ ChatProfileRoleEnum.HUMAN: "user",
107
+ ChatProfileRoleEnum.AI: "assistant",
108
+ }
 
 
 
109
 
110
+ for msg in msgs.messages:
111
+ st.chat_message(avatars[msg.type]).write(msg.content)
112
 
113
  if user_query := st.chat_input(
114
  placeholder="Ask me anything!",
115
+ disabled=(not openai_api_key and not result_retriever),
116
  ):
117
  st.chat_message("user").write(user_query)
118
 
document_retriever.py CHANGED
@@ -3,7 +3,8 @@ import tempfile
3
 
4
  import streamlit as st
5
  from langchain.retrievers import ContextualCompressionRetriever
6
- from langchain.retrievers.document_compressors import EmbeddingsFilter
 
7
  from langchain_community.document_loaders import Docx2txtLoader, PyPDFLoader, TextLoader
8
  from langchain_community.embeddings import HuggingFaceEmbeddings
9
  from langchain_community.vectorstores import DocArrayInMemorySearch
@@ -53,10 +54,8 @@ def configure_retriever(files, use_compression=False):
53
  if not use_compression:
54
  return retriever
55
 
56
- embeddings_filter = EmbeddingsFilter(
57
- embeddings=embeddings, similarity_threshold=0.76
58
- )
59
-
60
  return ContextualCompressionRetriever(
61
- base_compressor=embeddings_filter, base_retriever=retriever
 
62
  )
 
3
 
4
  import streamlit as st
5
  from langchain.retrievers import ContextualCompressionRetriever
6
+
7
+ from langchain_cohere import CohereRerank
8
  from langchain_community.document_loaders import Docx2txtLoader, PyPDFLoader, TextLoader
9
  from langchain_community.embeddings import HuggingFaceEmbeddings
10
  from langchain_community.vectorstores import DocArrayInMemorySearch
 
54
  if not use_compression:
55
  return retriever
56
 
57
+ compressor = CohereRerank()
 
 
 
58
  return ContextualCompressionRetriever(
59
+ base_compressor=compressor,
60
+ base_retriever=retriever,
61
  )
requirements.txt CHANGED
@@ -2,6 +2,7 @@ openai
2
  sentence-transformers
3
  docarray
4
  langchain
 
5
  streamlit
6
  streamlit_chat
7
  streamlit-extras
 
2
  sentence-transformers
3
  docarray
4
  langchain
5
+ langchain_cohere
6
  streamlit
7
  streamlit_chat
8
  streamlit-extras