File size: 5,664 Bytes
db70198
21858a4
c65178d
db70198
14ea497
c65178d
 
 
14ea497
0c5e8e3
1ce831d
68eaa27
1ec5b20
14ea497
b5bc349
14ea497
 
 
6f96801
14ea497
 
 
 
 
 
 
 
 
 
 
 
 
18a32c9
1ec5b20
18a32c9
14ea497
18a32c9
 
 
a499b16
 
 
18a32c9
 
 
e698d82
 
 
 
 
 
 
 
 
 
 
 
14ea497
 
 
 
 
 
 
 
 
 
d1e97a4
14ea497
 
 
 
 
 
d1e97a4
14ea497
 
e698d82
212e943
 
 
14ea497
 
 
 
 
212e943
e698d82
14ea497
 
 
 
 
 
 
 
0e17e2d
14ea497
 
18a32c9
14ea497
 
 
 
d1e97a4
db70198
14ea497
 
 
 
 
 
 
 
 
 
e698d82
 
 
 
 
 
c65178d
14ea497
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c65178d
14ea497
e698d82
14ea497
e698d82
 
14ea497
e698d82
 
c65178d
e698d82
14ea497
 
e698d82
7713f97
e698d82
 
7713f97
14ea497
0e17e2d
1ec5b20
14ea497
0e17e2d
 
7713f97
0e17e2d
 
 
14ea497
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
import streamlit as st

from langchain.chains.conversational_retrieval.base import ConversationalRetrievalChain
from langchain.memory import ConversationBufferMemory
from langchain_cohere import ChatCohere
from langchain_community.chat_message_histories.streamlit import (
    StreamlitChatMessageHistory,
)
from langchain_openai import ChatOpenAI

from calback_handler import PrintRetrievalHandler, StreamHandler
from chat_profile import ChatProfileRoleEnum
from document_retriever import configure_retriever
from llm_provider import LLMProviderEnum

# Constants
GPT_LLM_MODEL = "gpt-3.5-turbo"
COMMAND_R_LLM_MODEL = "command-r"

# Properties
uploaded_files = []
api_key = ""
result_retriever = None
chain = None
llm = None
model_name = ""

# Set up sidebar
if "sidebar_state" not in st.session_state:
    st.session_state.sidebar_state = "expanded"

# Streamlit app configuration
st.set_page_config(
    page_title="InkChatGPT: Chat with Documents",
    page_icon="πŸ“š",
    initial_sidebar_state=st.session_state.sidebar_state,
    menu_items={
        "Get Help": "https://x.com/vinhnx",
        "Report a bug": "https://github.com/vinhnx/InkChatGPT/issues",
        "About": """InkChatGPT is a simple Retrieval Augmented Generation (RAG) application that allows users to upload PDF documents and engage in a conversational Q&A, with a language model (LLM) based on the content of those documents.
        
        GitHub: https://github.com/vinhnx/InkChatGPT""",
    },
)

with st.sidebar:
    with st.container():
        col1, col2 = st.columns([0.2, 0.8])
        with col1:
            st.image(
                "./assets/app_icon.png",
                use_column_width="always",
                output_format="PNG",
            )
        with col2:
            st.header(":books: InkChatGPT")

        # Model
        selected_model = st.selectbox(
            "Select a model",
            options=[
                LLMProviderEnum.OPEN_AI.value,
                LLMProviderEnum.COHERE.value,
            ],
            index=None,
            placeholder="Select a model...",
        )

        if selected_model:
            api_key = st.text_input(f"{selected_model} API Key", type="password")
            if selected_model == LLMProviderEnum.OPEN_AI:
                model_name = GPT_LLM_MODEL
            elif selected_model == LLMProviderEnum.COHERE:
                model_name = COMMAND_R_LLM_MODEL

        msgs = StreamlitChatMessageHistory()
        if len(msgs.messages) == 0:
            msgs.clear()
            msgs.add_ai_message("""
            Hi, your uploaded document(s) had been analyzed. 
            
            Feel free to ask me any questions. For example: you can start by asking me something like: 
            
            `What is this context about?`
            
            `Help me summarize this!`
            """)

        if api_key:
            # Documents
            uploaded_files = st.file_uploader(
                label="Select files",
                type=["pdf", "txt", "docx"],
                accept_multiple_files=True,
                disabled=(not selected_model),
            )

if api_key and not uploaded_files:
    st.info("🌟 You can upload some documents to get started")

# Check if a model is selected
if not selected_model:
    st.info(
        "πŸ“Ί Please select a model first, open the `Settings` tab from side bar menu to get started"
    )

# Check if API key is provided
if selected_model and len(api_key.strip()) == 0:
    st.warning(
        f"πŸ”‘ API key for {selected_model} is missing or invalid. Please provide a valid API key."
    )

# Process uploaded files
if uploaded_files:
    result_retriever = configure_retriever(uploaded_files, cohere_api_key=api_key)

    if result_retriever is not None:
        memory = ConversationBufferMemory(
            memory_key="chat_history",
            chat_memory=msgs,
            return_messages=True,
        )

        if selected_model == LLMProviderEnum.OPEN_AI:
            llm = ChatOpenAI(
                model=model_name,
                api_key=api_key,
                temperature=0,
                streaming=True,
            )
        elif selected_model == LLMProviderEnum.COHERE:
            llm = ChatCohere(
                model=model_name,
                temperature=0.3,
                streaming=True,
                cohere_api_key=api_key,
            )

        if llm is None:
            st.error(
                "Failed to initialize the language model. Please check your configuration."
            )

        # Create the ConversationalRetrievalChain instance using the llm instance
        chain = ConversationalRetrievalChain.from_llm(
            llm=llm,
            retriever=result_retriever,
            memory=memory,
            verbose=True,
            max_tokens_limit=4000,
        )

        avatars = {
            ChatProfileRoleEnum.HUMAN.value: "user",
            ChatProfileRoleEnum.AI.value: "assistant",
        }

        for msg in msgs.messages:
            st.chat_message(avatars[msg.type]).write(msg.content)

# Get user input and generate response
if user_query := st.chat_input(
    placeholder="Ask me anything!",
    disabled=(not uploaded_files),
):
    st.chat_message("user").write(user_query)

    with st.chat_message("assistant"):
        retrieval_handler = PrintRetrievalHandler(st.empty())
        stream_handler = StreamHandler(st.empty())
        response = chain.run(
            user_query,
            callbacks=[retrieval_handler, stream_handler],
        )

if selected_model and model_name:
    st.sidebar.caption(f"πŸͺ„ Using `{model_name}` model")