multimodal / app.py
NEXAS's picture
Update app.py
76924c2 verified
import streamlit as st
import tempfile
import base64
import os
from src.utils.ingest_text import create_vector_database
from src.utils.ingest_image import extract_and_store_images
from src.utils.text_qa import qa_bot
from src.utils.image_qa import query_and_print_results
import nest_asyncio
from langchain.memory import ConversationBufferWindowMemory
from langchain_community.chat_message_histories import StreamlitChatMessageHistory
from dotenv import load_dotenv
nest_asyncio.apply()
load_dotenv()
st.set_page_config(layout='wide', page_title="InsightFusion Chat")
memory_storage = StreamlitChatMessageHistory(key="chat_messages")
memory = ConversationBufferWindowMemory(memory_key="chat_history", human_prefix="User", chat_memory=memory_storage, k=3)
image_bg = r"data/pexels-andreea-ch-371539-1166644.jpg"
def add_bg_from_local(image_file):
with open(image_file, "rb") as image_file:
encoded_string = base64.b64encode(image_file.read())
st.markdown(f"""<style>.stApp {{background-image: url(data:image/{"png"};base64,{encoded_string.decode()});
background-size: cover}}</style>""", unsafe_allow_html=True)
add_bg_from_local(image_bg)
st.markdown("""
<svg width="600" height="100">
<text x="50%" y="50%" font-family="San serif" font-size="42px" fill="Black" text-anchor="middle" stroke="white"
stroke-width="0.3" stroke-linejoin="round">InsightFusion Chat
</text>
</svg>
""", unsafe_allow_html=True)
def get_answer(query, chain):
try:
response = chain.invoke(query)
return response['result']
except Exception as e:
st.error(f"Error in get_answer: {e}")
return None
uploaded_file = st.file_uploader("File upload", type="pdf")
if uploaded_file is not None:
temp_file_path = os.path.join("temp", uploaded_file.name)
os.makedirs("temp", exist_ok=True)
with open(temp_file_path, "wb") as f:
f.write(uploaded_file.getbuffer())
path = os.path.abspath(temp_file_path)
st.write(f"File saved to: {path}")
st.write("Document uploaded successfully!")
if st.button("Start Processing"):
if uploaded_file is not None:
with st.spinner("Processing"):
try:
client = create_vector_database(path)
image_vdb = extract_and_store_images(path)
chain = qa_bot(client)
st.session_state['chain'] = chain
st.session_state['image_vdb'] = image_vdb
st.success("Processing complete.")
except Exception as e:
st.error(f"Error during processing: {e}")
else:
st.error("Please upload a file before starting processing.")
st.markdown("""
<style>
.stChatInputContainer > div {
background-color: #000000;
}
</style>
""", unsafe_allow_html=True)
if user_input := st.chat_input("User Input"):
if 'chain' in st.session_state and 'image_vdb' in st.session_state:
chain = st.session_state['chain']
image_vdb = st.session_state['image_vdb']
with st.chat_message("user"):
st.markdown(user_input)
with st.spinner("Generating Response..."):
response = get_answer(user_input, chain)
if response:
with st.chat_message("assistant"):
st.markdown(response)
# Save context in memory
memory.save_context(
{"input": user_input},
{"output": response}
)
# Append messages to session state for display
st.session_state.messages.append({"role": "user", "content": user_input})
st.session_state.messages.append({"role": "assistant", "content": response})
try:
query_and_print_results(image_vdb, user_input)
except Exception as e:
st.error(f"Error querying image database: {e}")
else:
st.error("Failed to generate response.")
else:
st.error("Please start processing before entering user input.")
if "messages" not in st.session_state:
st.session_state.messages = []
for message in st.session_state.messages:
with st.chat_message(message["role"]):
st.write(message["content"])
for i, msg in enumerate(memory_storage.messages):
name = "user" if i % 2 == 0 else "assistant"
st.chat_message(name).markdown(msg.content)