LuckRafly commited on
Commit
3aba7d8
1 Parent(s): b4f02d2

Upload 4 files

Browse files
Files changed (4) hide show
  1. app.py +102 -0
  2. function.py +88 -0
  3. htmlTemplate.py +89 -0
  4. requirements.txt +6 -0
app.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from PIL import Image
3
+ from function import bounding_box, captioning_query
4
+ from tempfile import NamedTemporaryFile
5
+ import os
6
+ from function import ImageCaptionTools, ObjectDetectionTool
7
+ from langchain.agents import initialize_agent, AgentType
8
+ from langchain_google_genai import ChatGoogleGenerativeAI
9
+ from langchain.memory import ConversationBufferWindowMemory
10
+ from htmlTemplate import css, bot_template, user_template
11
+
12
+ DIR_PATH = './temp'
13
+ if not os.path.exists(DIR_PATH):
14
+ os.mkdir(DIR_PATH)
15
+
16
+ # initialize Agent
17
+ def agent_init():
18
+ tools = [ImageCaptionTools(), ObjectDetectionTool()]
19
+ llm = ChatGoogleGenerativeAI(model="gemini-pro")
20
+ memory = ConversationBufferWindowMemory(memory_key='chat_history',
21
+ k=5,
22
+ return_messages=True)
23
+ agents = initialize_agent(
24
+ agent=AgentType.CONVERSATIONAL_REACT_DESCRIPTION,
25
+ llm=llm,
26
+ tools=tools,
27
+ max_iterations=5,
28
+ verbose=True,
29
+ memory=memory
30
+ )
31
+ return agents
32
+
33
+ def delete_temp_files():
34
+ for filename in os.listdir(DIR_PATH):
35
+ file_path = os.path.join(DIR_PATH, filename)
36
+ if os.path.isfile(file_path):
37
+ os.unlink(file_path)
38
+
39
+ def main():
40
+ st.set_page_config(
41
+ page_title="Chat with an Image",
42
+ page_icon="🖼️",
43
+ layout="wide"
44
+ )
45
+ st.write(css, unsafe_allow_html=True)
46
+
47
+ agent = agent_init()
48
+
49
+ if "image_processed" not in st.session_state:
50
+ st.session_state.image_processed = None
51
+
52
+ if "result_bounding" not in st.session_state:
53
+ st.session_state.result_bounding = None
54
+
55
+ # Delete temp files when session state changes
56
+ if st.session_state.image_processed is None:
57
+ delete_temp_files()
58
+
59
+ # image_path = 'documentation\photo_1.jpg'
60
+
61
+ col1, col2 = st.columns([1, 1])
62
+ with col1:
63
+ image_upload = st.file_uploader(label="Please Upload Your Image", type=['jpg', 'png', 'jpeg'])
64
+ if not image_upload:
65
+ st.warning("Please upload your image")
66
+ else:
67
+ st.image(
68
+ image_upload,
69
+ use_column_width=True
70
+ )
71
+ click_process = st.button("Process Image", disabled=not image_upload)
72
+ if click_process:
73
+ delete_temp_files()
74
+ with NamedTemporaryFile(dir=DIR_PATH, delete=False) as f:
75
+ f.write(image_upload.getbuffer())
76
+ st.session_state.image_path = f.name
77
+ st.session_state.image_processed = True
78
+
79
+ if (st.session_state.image_processed and st.session_state.result_bounding is None) or click_process:
80
+ with st.spinner("Please Wait"):
81
+ result_bounding = bounding_box(st.session_state.image_path)
82
+ st.session_state.result_bounding = result_bounding
83
+
84
+ # Expander to show/hide image
85
+ if st.session_state.result_bounding is not None:
86
+ with st.expander("Show Image (Bounding Box)"):
87
+ st.image(st.session_state.result_bounding)
88
+
89
+ with col2:
90
+ user_question = st.text_area("Ask About your image",
91
+ disabled=not st.session_state.image_processed,
92
+ max_chars=150)
93
+ click_ask = st.button("Ask Question", disabled=not st.session_state.image_processed)
94
+ if click_ask:
95
+ st.write(user_template.replace("{{MSG}}", user_question), unsafe_allow_html=True)
96
+ with st.spinner("AI Searching for Answer🔎"):
97
+ chat_history = agent.invoke({"input": f"{user_question}, this is the image path: {st.session_state.image_path}"})
98
+ response = chat_history['output']
99
+ st.write(bot_template.replace("{{MSG}}", response), unsafe_allow_html=True)
100
+
101
+ if __name__ == "__main__":
102
+ main()
function.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain.tools import BaseTool
2
+ from PIL import Image, ImageDraw
3
+ import requests
4
+ from dotenv import load_dotenv
5
+ import os
6
+ load_dotenv()
7
+
8
+
9
+ def object_detection_query(filepath):
10
+ API_URL = "https://api-inference.huggingface.co/models/facebook/detr-resnet-50"
11
+ headers = {"Authorization": "Bearer " + os.environ['HUGGINGFACEHUB_API_TOKEN']}
12
+ with open(filepath, "rb") as f:
13
+ data = f.read()
14
+ response = requests.post(API_URL, headers=headers, data=data)
15
+ return response.json()
16
+
17
+ def bounding_box(filepath):
18
+ # Generate an output
19
+ output = object_detection_query(filepath)
20
+
21
+ # load the image
22
+ image = Image.open(filepath).convert('RGB')
23
+
24
+ # create a drawing object
25
+ draw = ImageDraw.Draw(image)
26
+
27
+ # Draw boxes and labels on the image
28
+ for detection in output:
29
+ label = detection['label']
30
+ score = detection['score']
31
+ box = detection['box']
32
+
33
+ # Draw the box
34
+ draw.rectangle([box['xmin'], box['ymin'], box['xmax'], box['ymax']], outline="red", width=2)
35
+
36
+ # Draw the label and score
37
+ text = f"{label} ({score:.2f})"
38
+ draw.text((box['xmin'], box['ymin']-10), text, fill='red')
39
+
40
+ return image
41
+
42
+ def captioning_query(filepath):
43
+ API_URL = "https://api-inference.huggingface.co/models/Salesforce/blip-image-captioning-large"
44
+ headers = {"Authorization": "Bearer " + os.environ['HUGGINGFACEHUB_API_TOKEN']}
45
+ with open(filepath, "rb") as f:
46
+ data = f.read()
47
+ response = requests.post(API_URL, headers=headers, data=data)
48
+ return response.json()
49
+
50
+ class ImageCaptionTools(BaseTool):
51
+ name = "Image_Caption_Tools"
52
+ description = "Use this tool with any given image path to receive a personalized description, poem, story, or more. "\
53
+ "Ideal for agents seeking tailored insights. "\
54
+ "Let the tool craft content based on your image for a unique perspective."
55
+
56
+ def _run(self, image_path) -> str:
57
+ """Use the tool."""
58
+ result = captioning_query(image_path)
59
+ text = result[0]['generated_text']
60
+ return text
61
+
62
+ async def _arun(self, query: str) -> str:
63
+ """Use the tool asynchronously."""
64
+ raise NotImplementedError("custom_search does not support async")
65
+
66
+
67
+ class ObjectDetectionTool(BaseTool):
68
+ name = "Object_Detection_Tool"
69
+ description = "Object Detection Tool: Use this tool to detect objects in an image. Provide the image path, " \
70
+ "and it will return a list of detected objects. Each element in the list is in the format: " \
71
+ "[x1, y1, x2, y2] class_name confidence_score. This tool focuses on object detection, providing " \
72
+ "locations of objects in the image. For image descriptions or other insights, explore additional tools."
73
+
74
+ def _run(self, image_path) -> str:
75
+ """Use the tool."""
76
+ results = object_detection_query(image_path)
77
+ detections = ""
78
+ for result in results:
79
+ box = result['box']
80
+ detections += '[{}, {}, {}, {}]'.format(int(box['xmin']), int(box['ymin']), int(box['xmax']), int(box['ymax']))
81
+ detections += ' {}'.format(result['label'])
82
+ detections += ' {}\n'.format(result['score'])
83
+ return detections
84
+
85
+ async def _arun(self, query: str) -> str:
86
+ """Use the tool asynchronously."""
87
+ raise NotImplementedError("custom_search does not support async")
88
+
htmlTemplate.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Updated CSS
2
+ # CSS Styles
3
+ css = '''
4
+ <style>
5
+ /* Styling for the body of the Streamlit app */
6
+ body {
7
+ background-color: #f2f7ff; /* Soft blue background */
8
+ margin: 0; /* Remove default margin */
9
+ padding: 0; /* Remove default padding */
10
+ }
11
+
12
+ /* Styling for the chat container */
13
+ .chat-container {
14
+ max-width: 600px; /* Adjust the maximum width as needed */
15
+ margin: 0 auto; /* Center the chat container */
16
+ background-color: #ffffff; /* White background */
17
+ padding: 1rem; /* Add padding to the chat container */
18
+ border-radius: 1rem; /* Rounded corners for the chat container */
19
+ box-shadow: 0 0 10px rgba(0, 0, 0, 0.1); /* Add a subtle box shadow */
20
+ }
21
+
22
+ /* Styling for the chat messages */
23
+ .chat-message {
24
+ padding: 1rem;
25
+ border-radius: 0.5rem;
26
+ margin-bottom: 1rem;
27
+ display: flex;
28
+ border: 1px solid #d3d3d3; /* Add a subtle border */
29
+ }
30
+
31
+ /* Styling for user messages */
32
+ .chat-message.user {
33
+ background-color: #ffffff; /* White background for user messages */
34
+ }
35
+
36
+ /* Styling for bot messages */
37
+ .chat-message.bot {
38
+ background-color: #9dc8e5; /* Soft blue background for bot messages */
39
+ }
40
+
41
+ /* Styling for the avatar */
42
+ .chat-message .avatar {
43
+ width: 15%; /* Adjust avatar size */
44
+ }
45
+
46
+ /* Styling for the avatar image */
47
+ .chat-message .avatar img {
48
+ max-width: 60px;
49
+ max-height: 60px;
50
+ border-radius: 50%;
51
+ object-fit: cover;
52
+ }
53
+
54
+ /* Styling for the message content */
55
+ .chat-message .message {
56
+ flex: 1; /* Allow the message to take up remaining space */
57
+ padding: 0.75rem;
58
+ color: #495057; /* Dark text color for better readability */
59
+ }
60
+
61
+ /* Styling for strong (name) in the message */
62
+ .chat-message .message strong {
63
+ margin-right: 0.25rem; /* Adjust the margin as needed */
64
+ }
65
+ </style>
66
+ '''
67
+
68
+ # HTML Templates for Bot and User Messages
69
+ bot_template = '''
70
+ <div class="chat-message bot">
71
+ <div class="avatar">
72
+ <img src="https://i.ibb.co/dp2yyWP/bot.jpg">
73
+ </div>
74
+ <div class="message">
75
+ <strong>Doraemon:</strong> {{MSG}}
76
+ </div>
77
+ </div>
78
+ '''
79
+
80
+ user_template = '''
81
+ <div class="chat-message user">
82
+ <div class="avatar">
83
+ <img src="https://i.ibb.co/JB2sps1/human.jpg">
84
+ </div>
85
+ <div class="message">
86
+ <strong>Nobita:</strong> {{MSG}}
87
+ </div>
88
+ </div>
89
+ '''
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ langchain
2
+ langchain
3
+ streamlit
4
+ langchain-google-genai
5
+ transformers
6
+ python-dotenv