Spaces:
Running
Running
import streamlit as st | |
from langchain.chains.conversational_retrieval.base import ConversationalRetrievalChain | |
from langchain.memory import ConversationBufferMemory | |
from langchain_cohere import ChatCohere | |
from langchain_community.chat_message_histories.streamlit import ( | |
StreamlitChatMessageHistory, | |
) | |
from langchain_openai import ChatOpenAI | |
from calback_handler import PrintRetrievalHandler, StreamHandler | |
from chat_profile import ChatProfileRoleEnum | |
from document_retriever import configure_retriever | |
from llm_provider import LLMProviderEnum | |
# Constants | |
GPT_LLM_MODEL = "gpt-3.5-turbo" | |
COMMAND_R_LLM_MODEL = "command-r" | |
# Properties | |
uploaded_files = [] | |
api_key = "" | |
result_retriever = None | |
chain = None | |
llm = None | |
model_name = "" | |
# Set up sidebar | |
if "sidebar_state" not in st.session_state: | |
st.session_state.sidebar_state = "expanded" | |
# Streamlit app configuration | |
st.set_page_config( | |
page_title="InkChatGPT: Chat with Documents", | |
page_icon="π", | |
initial_sidebar_state=st.session_state.sidebar_state, | |
menu_items={ | |
"Get Help": "https://x.com/vinhnx", | |
"Report a bug": "https://github.com/vinhnx/InkChatGPT/issues", | |
"About": """InkChatGPT is a simple Retrieval Augmented Generation (RAG) application that allows users to upload PDF documents and engage in a conversational Q&A, with a language model (LLM) based on the content of those documents. | |
GitHub: https://github.com/vinhnx/InkChatGPT""", | |
}, | |
) | |
with st.sidebar: | |
with st.container(): | |
col1, col2 = st.columns([0.2, 0.8]) | |
with col1: | |
st.image( | |
"./assets/app_icon.png", | |
use_column_width="always", | |
output_format="PNG", | |
) | |
with col2: | |
st.header(":books: InkChatGPT") | |
# Model | |
selected_model = st.selectbox( | |
"Select a model", | |
options=[ | |
LLMProviderEnum.OPEN_AI.value, | |
LLMProviderEnum.COHERE.value, | |
], | |
index=None, | |
placeholder="Select a model...", | |
) | |
if selected_model: | |
api_key = st.text_input(f"{selected_model} API Key", type="password") | |
if selected_model == LLMProviderEnum.OPEN_AI: | |
model_name = GPT_LLM_MODEL | |
elif selected_model == LLMProviderEnum.COHERE: | |
model_name = COMMAND_R_LLM_MODEL | |
msgs = StreamlitChatMessageHistory() | |
if len(msgs.messages) == 0: | |
msgs.clear() | |
msgs.add_ai_message(""" | |
Hi, your uploaded document(s) had been analyzed. | |
Feel free to ask me any questions. For example: you can start by asking me something like: | |
`What is this context about?` | |
`Help me summarize this!` | |
""") | |
if api_key: | |
# Documents | |
uploaded_files = st.file_uploader( | |
label="Select files", | |
type=["pdf", "txt", "docx"], | |
accept_multiple_files=True, | |
disabled=(not selected_model), | |
) | |
if api_key and not uploaded_files: | |
st.info("π You can upload some documents to get started") | |
# Check if a model is selected | |
if not selected_model: | |
st.info( | |
"πΊ Please select a model first, open the `Settings` tab from side bar menu to get started" | |
) | |
# Check if API key is provided | |
if selected_model and len(api_key.strip()) == 0: | |
st.warning( | |
f"π API key for {selected_model} is missing or invalid. Please provide a valid API key." | |
) | |
# Process uploaded files | |
if uploaded_files: | |
result_retriever = configure_retriever(uploaded_files, cohere_api_key=api_key) | |
if result_retriever is not None: | |
memory = ConversationBufferMemory( | |
memory_key="chat_history", | |
chat_memory=msgs, | |
return_messages=True, | |
) | |
if selected_model == LLMProviderEnum.OPEN_AI: | |
llm = ChatOpenAI( | |
model=model_name, | |
api_key=api_key, | |
temperature=0, | |
streaming=True, | |
) | |
elif selected_model == LLMProviderEnum.COHERE: | |
llm = ChatCohere( | |
model=model_name, | |
temperature=0.3, | |
streaming=True, | |
cohere_api_key=api_key, | |
) | |
if llm is None: | |
st.error( | |
"Failed to initialize the language model. Please check your configuration." | |
) | |
# Create the ConversationalRetrievalChain instance using the llm instance | |
chain = ConversationalRetrievalChain.from_llm( | |
llm=llm, | |
retriever=result_retriever, | |
memory=memory, | |
verbose=True, | |
max_tokens_limit=4000, | |
) | |
avatars = { | |
ChatProfileRoleEnum.HUMAN.value: "user", | |
ChatProfileRoleEnum.AI.value: "assistant", | |
} | |
for msg in msgs.messages: | |
st.chat_message(avatars[msg.type]).write(msg.content) | |
# Get user input and generate response | |
if user_query := st.chat_input( | |
placeholder="Ask me anything!", | |
disabled=(not uploaded_files), | |
): | |
st.chat_message("user").write(user_query) | |
with st.chat_message("assistant"): | |
retrieval_handler = PrintRetrievalHandler(st.empty()) | |
stream_handler = StreamHandler(st.empty()) | |
response = chain.run( | |
user_query, | |
callbacks=[retrieval_handler, stream_handler], | |
) | |
if selected_model and model_name: | |
st.sidebar.caption(f"πͺ Using `{model_name}` model") | |