inkchatgpt / app.py
vinhnx90's picture
Remove unused import
21858a4
raw
history blame
5.66 kB
import streamlit as st
from langchain.chains.conversational_retrieval.base import ConversationalRetrievalChain
from langchain.memory import ConversationBufferMemory
from langchain_cohere import ChatCohere
from langchain_community.chat_message_histories.streamlit import (
StreamlitChatMessageHistory,
)
from langchain_openai import ChatOpenAI
from calback_handler import PrintRetrievalHandler, StreamHandler
from chat_profile import ChatProfileRoleEnum
from document_retriever import configure_retriever
from llm_provider import LLMProviderEnum
# Constants
GPT_LLM_MODEL = "gpt-3.5-turbo"
COMMAND_R_LLM_MODEL = "command-r"
# Properties
uploaded_files = []
api_key = ""
result_retriever = None
chain = None
llm = None
model_name = ""
# Set up sidebar
if "sidebar_state" not in st.session_state:
st.session_state.sidebar_state = "expanded"
# Streamlit app configuration
st.set_page_config(
page_title="InkChatGPT: Chat with Documents",
page_icon="πŸ“š",
initial_sidebar_state=st.session_state.sidebar_state,
menu_items={
"Get Help": "https://x.com/vinhnx",
"Report a bug": "https://github.com/vinhnx/InkChatGPT/issues",
"About": """InkChatGPT is a simple Retrieval Augmented Generation (RAG) application that allows users to upload PDF documents and engage in a conversational Q&A, with a language model (LLM) based on the content of those documents.
GitHub: https://github.com/vinhnx/InkChatGPT""",
},
)
with st.sidebar:
with st.container():
col1, col2 = st.columns([0.2, 0.8])
with col1:
st.image(
"./assets/app_icon.png",
use_column_width="always",
output_format="PNG",
)
with col2:
st.header(":books: InkChatGPT")
# Model
selected_model = st.selectbox(
"Select a model",
options=[
LLMProviderEnum.OPEN_AI.value,
LLMProviderEnum.COHERE.value,
],
index=None,
placeholder="Select a model...",
)
if selected_model:
api_key = st.text_input(f"{selected_model} API Key", type="password")
if selected_model == LLMProviderEnum.OPEN_AI:
model_name = GPT_LLM_MODEL
elif selected_model == LLMProviderEnum.COHERE:
model_name = COMMAND_R_LLM_MODEL
msgs = StreamlitChatMessageHistory()
if len(msgs.messages) == 0:
msgs.clear()
msgs.add_ai_message("""
Hi, your uploaded document(s) had been analyzed.
Feel free to ask me any questions. For example: you can start by asking me something like:
`What is this context about?`
`Help me summarize this!`
""")
if api_key:
# Documents
uploaded_files = st.file_uploader(
label="Select files",
type=["pdf", "txt", "docx"],
accept_multiple_files=True,
disabled=(not selected_model),
)
if api_key and not uploaded_files:
st.info("🌟 You can upload some documents to get started")
# Check if a model is selected
if not selected_model:
st.info(
"πŸ“Ί Please select a model first, open the `Settings` tab from side bar menu to get started"
)
# Check if API key is provided
if selected_model and len(api_key.strip()) == 0:
st.warning(
f"πŸ”‘ API key for {selected_model} is missing or invalid. Please provide a valid API key."
)
# Process uploaded files
if uploaded_files:
result_retriever = configure_retriever(uploaded_files, cohere_api_key=api_key)
if result_retriever is not None:
memory = ConversationBufferMemory(
memory_key="chat_history",
chat_memory=msgs,
return_messages=True,
)
if selected_model == LLMProviderEnum.OPEN_AI:
llm = ChatOpenAI(
model=model_name,
api_key=api_key,
temperature=0,
streaming=True,
)
elif selected_model == LLMProviderEnum.COHERE:
llm = ChatCohere(
model=model_name,
temperature=0.3,
streaming=True,
cohere_api_key=api_key,
)
if llm is None:
st.error(
"Failed to initialize the language model. Please check your configuration."
)
# Create the ConversationalRetrievalChain instance using the llm instance
chain = ConversationalRetrievalChain.from_llm(
llm=llm,
retriever=result_retriever,
memory=memory,
verbose=True,
max_tokens_limit=4000,
)
avatars = {
ChatProfileRoleEnum.HUMAN.value: "user",
ChatProfileRoleEnum.AI.value: "assistant",
}
for msg in msgs.messages:
st.chat_message(avatars[msg.type]).write(msg.content)
# Get user input and generate response
if user_query := st.chat_input(
placeholder="Ask me anything!",
disabled=(not uploaded_files),
):
st.chat_message("user").write(user_query)
with st.chat_message("assistant"):
retrieval_handler = PrintRetrievalHandler(st.empty())
stream_handler = StreamHandler(st.empty())
response = chain.run(
user_query,
callbacks=[retrieval_handler, stream_handler],
)
if selected_model and model_name:
st.sidebar.caption(f"πŸͺ„ Using `{model_name}` model")