LuckRafly's picture
Update app.py
2de941d
import streamlit as st
from PIL import Image
from function import bounding_box
from tempfile import NamedTemporaryFile
import os
from function import ImageCaptionTools, ObjectDetectionTool
from langchain.agents import initialize_agent, AgentType
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain.memory import ConversationBufferWindowMemory
from htmlTemplate import css, bot_template, user_template
import random
DIR = './temp'
if not os.path.exists(DIR):
os.mkdir(DIR)
if "image_processed" not in st.session_state:
DIR_PATH = os.path.join(DIR, str(random.randint(1,999999999)))
st.session_state.dirpath = DIR_PATH
if not os.path.exists(DIR_PATH):
os.mkdir(DIR_PATH)
def delete_temp_files():
for filename in os.listdir(st.session_state.dirpath):
file_path = os.path.join(st.session_state.dirpath, filename)
if os.path.isfile(file_path):
os.unlink(file_path)
# initialize Agent
def agent_init():
tools = [ImageCaptionTools(), ObjectDetectionTool()]
llm = ChatGoogleGenerativeAI(model="gemini-pro")
memory = ConversationBufferWindowMemory(memory_key='chat_history',
k=5,
return_messages=True)
agents = initialize_agent(
agent=AgentType.CONVERSATIONAL_REACT_DESCRIPTION,
llm=llm,
tools=tools,
max_iterations=5,
verbose=True,
memory=memory
)
return agents
def main():
st.set_page_config(
page_title="Chat with an Image",
page_icon="🖼️",
layout="wide"
)
st.write(css, unsafe_allow_html=True)
st.title("Chat with an Image 🖼️")
agent = agent_init()
# Check if the page has been reloaded
if 'reloaded' not in st.session_state:
st.session_state.reloaded = False
else:
st.session_state.reloaded = True
if "image_processed" not in st.session_state:
st.session_state.image_processed = None
if "result_bounding" not in st.session_state:
st.session_state.result_bounding = None
# image_path = 'documentation\photo_1.jpg'
col1, col2 = st.columns([1, 1])
with col1:
image_upload = st.file_uploader(label="Please Upload Your Image", type=['jpg', 'png', 'jpeg'])
if not image_upload:
st.warning("Please upload your image")
else:
st.image(
image_upload,
use_column_width=True
)
click_process = st.button("Process Image", disabled=not image_upload)
if click_process:
delete_temp_files()
with NamedTemporaryFile(dir=st.session_state.dirpath, delete=False) as f:
f.write(image_upload.getbuffer())
st.session_state.image_path = f.name
st.session_state.image_processed = True
if (st.session_state.image_processed and st.session_state.result_bounding is None) or click_process:
with st.spinner("Please Wait"):
result_bounding = bounding_box(st.session_state.image_path)
st.session_state.result_bounding = result_bounding
# Expander to show/hide image
if st.session_state.result_bounding is not None:
with st.expander("Show Image (Bounding Box)"):
st.image(st.session_state.result_bounding)
with col2:
user_question = st.text_area("Ask About your image",
disabled=not st.session_state.image_processed,
max_chars=150)
click_ask = st.button("Ask Question", disabled=not st.session_state.image_processed)
if click_ask:
st.write(user_template.replace("{{MSG}}", user_question), unsafe_allow_html=True)
with st.spinner("Doraemon Searching for Answer🔎"):
chat_history = agent.invoke({"input": f"{user_question}, this is the image path: {st.session_state.image_path}"})
response = chat_history['output']
st.write(bot_template.replace("{{MSG}}", response), unsafe_allow_html=True)
if __name__ == "__main__":
main()