inkchatgpt / app.py
vinhnx90's picture
Handle epub loader
1ce831d
raw
history blame
4.75 kB
import os
import tempfile
import streamlit as st
from langchain.chains import ConversationalRetrievalChain
from langchain.chat_models import ChatOpenAI
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.memory import ConversationBufferMemory
from langchain.memory.chat_message_histories import StreamlitChatMessageHistory
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.document_loaders import (
Docx2txtLoader,
PyPDFLoader,
TextLoader,
UnstructuredEPubLoader,
)
from langchain_community.vectorstores import DocArrayInMemorySearch
from calback_handler import PrintRetrievalHandler, StreamHandler
from chat_profile import ChatProfileRoleEnum
# configs
LLM_MODEL_NAME = "gpt-3.5-turbo"
EMBEDDING_MODEL_NAME = "all-MiniLM-L6-v2"
st.set_page_config(
page_title=":books: InkChatGPT: Chat with Documents",
page_icon="πŸ“š",
initial_sidebar_state="collapsed",
menu_items={
"Get Help": "https://x.com/vinhnx",
"Report a bug": "https://github.com/vinhnx/InkChatGPT/issues",
"About": "InkChatGPT is a Streamlit application that allows users to upload PDF documents and engage in a conversational Q&A with a language model (LLM) based on the content of those documents.",
},
)
st.image("./assets/icon.jpg", width=100)
st.header(
":gray[:books: InkChatGPT]",
divider="blue",
)
st.write("**Chat** with Documents")
# Setup memory for contextual conversation
msgs = StreamlitChatMessageHistory()
@st.cache_resource(ttl="1h")
def configure_retriever(files):
# Read documents
docs = []
temp_dir = tempfile.TemporaryDirectory()
for file in files:
temp_filepath = os.path.join(temp_dir.name, file.name)
with open(temp_filepath, "wb") as f:
f.write(file.getvalue())
_, extension = os.path.splitext(temp_filepath)
# Load the file using the appropriate loader
if extension == ".pdf":
loader = PyPDFLoader(temp_filepath)
elif extension == ".docx":
loader = Docx2txtLoader(temp_filepath)
elif extension == ".txt":
loader = TextLoader(temp_filepath)
elif extension == ".epub":
loader = UnstructuredEPubLoader(temp_filepath)
else:
st.write("This document format is not supported!")
return None
# loader = PyPDFLoader(temp_filepath)
docs.extend(loader.load())
# Split documents
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1500, chunk_overlap=200)
splits = text_splitter.split_documents(docs)
# Create embeddings and store in vectordb
embeddings = HuggingFaceEmbeddings(model_name=EMBEDDING_MODEL_NAME)
vectordb = DocArrayInMemorySearch.from_documents(splits, embeddings)
# Define retriever
retriever = vectordb.as_retriever(
search_type="mmr", search_kwargs={"k": 2, "fetch_k": 4}
)
return retriever
with st.sidebar.expander("Documents"):
st.subheader("Files")
uploaded_files = st.file_uploader(
label="Select files",
type=["pdf", "txt", "docx", "epub"],
accept_multiple_files=True,
)
with st.sidebar.expander("Setup"):
st.subheader("API Key")
openai_api_key = st.text_input("OpenAI API Key", type="password")
is_empty_chat_messages = len(msgs.messages) == 0
if is_empty_chat_messages or st.button("Clear message history"):
msgs.clear()
msgs.add_ai_message("How can I help you?")
if not openai_api_key:
st.info("Please add your OpenAI API key in the sidebar to continue.")
st.stop()
if uploaded_files:
result_retriever = configure_retriever(uploaded_files)
memory = ConversationBufferMemory(
memory_key="chat_history", chat_memory=msgs, return_messages=True
)
# Setup LLM and QA chain
llm = ChatOpenAI(
model_name=LLM_MODEL_NAME,
openai_api_key=openai_api_key,
temperature=0,
streaming=True,
)
chain = ConversationalRetrievalChain.from_llm(
llm, retriever=result_retriever, memory=memory, verbose=False
)
avatars = {
ChatProfileRoleEnum.Human: "user",
ChatProfileRoleEnum.AI: "assistant",
}
for msg in msgs.messages:
st.chat_message(avatars[msg.type]).write(msg.content)
if user_query := st.chat_input(placeholder="Ask me anything!"):
st.chat_message("user").write(user_query)
with st.chat_message("assistant"):
retrieval_handler = PrintRetrievalHandler(st.empty())
stream_handler = StreamHandler(st.empty())
response = chain.run(
user_query, callbacks=[retrieval_handler, stream_handler]
)