File size: 3,441 Bytes
db70198
 
 
1ec5b20
1ce831d
68eaa27
1ec5b20
9caad80
b5bc349
18a32c9
1ec5b20
18a32c9
 
 
 
 
 
 
 
 
4ddc82f
 
 
 
 
 
7713f97
 
3373c54
0e17e2d
9c5fb2e
0e17e2d
 
1ec5b20
 
 
0e17e2d
 
 
 
 
 
 
 
 
 
 
 
db70198
0e17e2d
18a32c9
1ce831d
cee7091
1ce831d
1ec5b20
18a32c9
 
0e17e2d
 
 
db70198
0e17e2d
1ec5b20
 
 
0e17e2d
3373c54
0e17e2d
 
1ec5b20
0e17e2d
 
 
 
db70198
0e17e2d
1ec5b20
 
 
 
9caad80
0e17e2d
18a32c9
0e17e2d
1ec5b20
0e17e2d
 
18a32c9
0e17e2d
 
7713f97
0e17e2d
 
7713f97
0e17e2d
1ec5b20
 
0e17e2d
 
7713f97
0e17e2d
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import streamlit as st
from langchain.memory import ConversationBufferMemory
from langchain.memory.chat_message_histories import StreamlitChatMessageHistory
from langchain_community.chat_models import ChatOpenAI
from calback_handler import PrintRetrievalHandler, StreamHandler
from chat_profile import ChatProfileRoleEnum
from document_retriever import configure_retriever
from langchain.chains import ConversationalRetrievalChain

st.set_page_config(
    page_title="InkChatGPT: Chat with Documents",
    page_icon="πŸ“š",
    initial_sidebar_state="collapsed",
    menu_items={
        "Get Help": "https://x.com/vinhnx",
        "Report a bug": "https://github.com/vinhnx/InkChatGPT/issues",
        "About": "InkChatGPT is a Streamlit application that allows users to upload PDF documents and engage in a conversational Q&A with a language model (LLM) based on the content of those documents.",
    },
)

# Hide Header
st.markdown(
    """<style>.stApp [data-testid="stToolbar"]{display:none;}</style>""",
    unsafe_allow_html=True,
)

# Setup memory for contextual conversation
msgs = StreamlitChatMessageHistory()

with st.container():
    col1, col2 = st.columns([0.3, 0.8])
    with col1:
        st.image(
            "./assets/app_icon.png",
            use_column_width="always",
            output_format="PNG",
        )
    with col2:
        st.header(":books: InkChatGPT")
        st.write("**Chat** with Documents")
        st.caption("Supports PDF, TXT, DOCX, EPUB β€’ Limit 200MB per file")

chat_tab, documents_tab, settings_tab = st.tabs(["Chat", "Documents", "Settings"])
with settings_tab:
    openai_api_key = st.text_input("OpenAI API Key", type="password")
    if len(msgs.messages) == 0 or st.button("Clear message history"):
        msgs.clear()
        msgs.add_ai_message("How can I help you?")

with documents_tab:
    uploaded_files = st.file_uploader(
        label="Select files",
        type=["pdf", "txt", "docx"],
        accept_multiple_files=True,
        disabled=(not openai_api_key),
    )

with chat_tab:
    if uploaded_files:
        result_retriever = configure_retriever(uploaded_files)

        memory = ConversationBufferMemory(
            memory_key="chat_history",
            chat_memory=msgs,
            return_messages=True,
        )

        # Setup LLM and QA chain
        llm = ChatOpenAI(
            model_name="gpt-3.5-turbo",
            openai_api_key=openai_api_key,
            temperature=0,
            streaming=True,
        )

        chain = ConversationalRetrievalChain.from_llm(
            llm,
            retriever=result_retriever,
            memory=memory,
            verbose=False,
            max_tokens_limit=4000,
        )

        avatars = {
            ChatProfileRoleEnum.HUMAN: "user",
            ChatProfileRoleEnum.AI: "assistant",
        }

        for msg in msgs.messages:
            st.chat_message(avatars[msg.type]).write(msg.content)

if not openai_api_key:
    st.caption("πŸ”‘ Add your **OpenAI API key** on the `Settings` to continue.")

if user_query := st.chat_input(
    placeholder="Ask me anything!",
    disabled=(not openai_api_key),
):
    st.chat_message("user").write(user_query)

    with st.chat_message("assistant"):
        retrieval_handler = PrintRetrievalHandler(st.empty())
        stream_handler = StreamHandler(st.empty())
        response = chain.run(user_query, callbacks=[retrieval_handler, stream_handler])