Spaces:
Sleeping
Sleeping
import os | |
import tempfile | |
import streamlit as st | |
from langchain.chains import ConversationalRetrievalChain | |
from langchain.chat_models import ChatOpenAI | |
from langchain.embeddings import HuggingFaceEmbeddings | |
from langchain.memory import ConversationBufferMemory | |
from langchain.memory.chat_message_histories import StreamlitChatMessageHistory | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from langchain_community.document_loaders import ( | |
Docx2txtLoader, | |
PyPDFLoader, | |
TextLoader, | |
UnstructuredEPubLoader, | |
) | |
from langchain_community.vectorstores import DocArrayInMemorySearch | |
from calback_handler import PrintRetrievalHandler, StreamHandler | |
from chat_profile import ChatProfileRoleEnum | |
# configs | |
LLM_MODEL_NAME = "gpt-3.5-turbo" | |
EMBEDDING_MODEL_NAME = "all-MiniLM-L6-v2" | |
st.set_page_config( | |
page_title=":books: InkChatGPT: Chat with Documents", | |
page_icon="π", | |
initial_sidebar_state="collapsed", | |
menu_items={ | |
"Get Help": "https://x.com/vinhnx", | |
"Report a bug": "https://github.com/vinhnx/InkChatGPT/issues", | |
"About": "InkChatGPT is a Streamlit application that allows users to upload PDF documents and engage in a conversational Q&A with a language model (LLM) based on the content of those documents.", | |
}, | |
) | |
st.image("./assets/icon.jpg", width=100) | |
st.header( | |
":gray[:books: InkChatGPT]", | |
divider="blue", | |
) | |
st.write("**Chat** with Documents") | |
# Setup memory for contextual conversation | |
msgs = StreamlitChatMessageHistory() | |
def configure_retriever(files): | |
# Read documents | |
docs = [] | |
temp_dir = tempfile.TemporaryDirectory() | |
for file in files: | |
temp_filepath = os.path.join(temp_dir.name, file.name) | |
with open(temp_filepath, "wb") as f: | |
f.write(file.getvalue()) | |
_, extension = os.path.splitext(temp_filepath) | |
# Load the file using the appropriate loader | |
if extension == ".pdf": | |
loader = PyPDFLoader(temp_filepath) | |
elif extension == ".docx": | |
loader = Docx2txtLoader(temp_filepath) | |
elif extension == ".txt": | |
loader = TextLoader(temp_filepath) | |
elif extension == ".epub": | |
loader = UnstructuredEPubLoader(temp_filepath) | |
else: | |
st.write("This document format is not supported!") | |
return None | |
# loader = PyPDFLoader(temp_filepath) | |
docs.extend(loader.load()) | |
# Split documents | |
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1500, chunk_overlap=200) | |
splits = text_splitter.split_documents(docs) | |
# Create embeddings and store in vectordb | |
embeddings = HuggingFaceEmbeddings(model_name=EMBEDDING_MODEL_NAME) | |
vectordb = DocArrayInMemorySearch.from_documents(splits, embeddings) | |
# Define retriever | |
retriever = vectordb.as_retriever( | |
search_type="mmr", search_kwargs={"k": 2, "fetch_k": 4} | |
) | |
return retriever | |
with st.sidebar.expander("Documents"): | |
st.subheader("Files") | |
uploaded_files = st.file_uploader( | |
label="Select files", | |
type=["pdf", "txt", "docx", "epub"], | |
accept_multiple_files=True, | |
) | |
with st.sidebar.expander("Setup"): | |
st.subheader("API Key") | |
openai_api_key = st.text_input("OpenAI API Key", type="password") | |
is_empty_chat_messages = len(msgs.messages) == 0 | |
if is_empty_chat_messages or st.button("Clear message history"): | |
msgs.clear() | |
msgs.add_ai_message("How can I help you?") | |
if not openai_api_key: | |
st.info("Please add your OpenAI API key in the sidebar to continue.") | |
st.stop() | |
if uploaded_files: | |
result_retriever = configure_retriever(uploaded_files) | |
memory = ConversationBufferMemory( | |
memory_key="chat_history", chat_memory=msgs, return_messages=True | |
) | |
# Setup LLM and QA chain | |
llm = ChatOpenAI( | |
model_name=LLM_MODEL_NAME, | |
openai_api_key=openai_api_key, | |
temperature=0, | |
streaming=True, | |
) | |
chain = ConversationalRetrievalChain.from_llm( | |
llm, retriever=result_retriever, memory=memory, verbose=False | |
) | |
avatars = { | |
ChatProfileRoleEnum.Human: "user", | |
ChatProfileRoleEnum.AI: "assistant", | |
} | |
for msg in msgs.messages: | |
st.chat_message(avatars[msg.type]).write(msg.content) | |
if user_query := st.chat_input(placeholder="Ask me anything!"): | |
st.chat_message("user").write(user_query) | |
with st.chat_message("assistant"): | |
retrieval_handler = PrintRetrievalHandler(st.empty()) | |
stream_handler = StreamHandler(st.empty()) | |
response = chain.run( | |
user_query, callbacks=[retrieval_handler, stream_handler] | |
) | |