Spaces:
Runtime error
Runtime error
Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,290 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import base64
|
2 |
+
import os
|
3 |
+
from huggingface_hub import login
|
4 |
+
import PIL.Image
|
5 |
+
from byaldi import RAGMultiModalModel
|
6 |
+
import PIL.Image as PILImage
|
7 |
+
import io
|
8 |
+
import textwrap
|
9 |
+
import google.generativeai as genai
|
10 |
+
import gradio as gr # Add Gradio for UI
|
11 |
+
from PIL import Image as PILImage
|
12 |
+
|
13 |
+
# Initialize Google API and model
|
14 |
+
import torch
|
15 |
+
|
16 |
+
device = torch.device("cpu") # Force CPU
|
17 |
+
import os
|
18 |
+
GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY")
|
19 |
+
genai.configure(api_key=GOOGLE_API_KEY)
|
20 |
+
model = genai.GenerativeModel('models/gemini-1.5-flash-latest')
|
21 |
+
|
22 |
+
# Load the RAG multi-modal model
|
23 |
+
RAG = RAGMultiModalModel.from_pretrained("vidore/colpali-v1.2", verbose=1)
|
24 |
+
|
25 |
+
# Specify the index path where the index was saved during the first run
|
26 |
+
index_path = "/home/mohammadaqib/Desktop/project/research/Multi-Modal-RAG/Colpali/BCC"
|
27 |
+
RAG = RAGMultiModalModel.from_index(index_path)
|
28 |
+
|
29 |
+
# Initialize conversation history
|
30 |
+
conversation_history = []
|
31 |
+
|
32 |
+
def get_user_input(query):
|
33 |
+
"""Process user input."""
|
34 |
+
return query
|
35 |
+
|
36 |
+
def process_image_from_results(results):
|
37 |
+
"""Process images from the search results and merge them."""
|
38 |
+
image_list = []
|
39 |
+
for i in range(min(3, len(results))):
|
40 |
+
try:
|
41 |
+
# Ensure the result has a base64 attribute
|
42 |
+
image_bytes = base64.b64decode(results[i].base64)
|
43 |
+
image = PILImage.open(io.BytesIO(image_bytes)) # Open image directly from bytes
|
44 |
+
image_list.append(image)
|
45 |
+
except AttributeError:
|
46 |
+
print(f"Result {i} does not contain a 'base64' attribute.")
|
47 |
+
|
48 |
+
# Merge images if any
|
49 |
+
if image_list:
|
50 |
+
total_width = sum(img.width for img in image_list)
|
51 |
+
max_height = max(img.height for img in image_list)
|
52 |
+
|
53 |
+
merged_image = PILImage.new('RGB', (total_width, max_height))
|
54 |
+
x_offset = 0
|
55 |
+
for img in image_list:
|
56 |
+
merged_image.paste(img, (x_offset, 0))
|
57 |
+
x_offset += img.width
|
58 |
+
|
59 |
+
# Save the merged image
|
60 |
+
merged_image.save('merged_image.jpg')
|
61 |
+
return merged_image
|
62 |
+
else:
|
63 |
+
return None
|
64 |
+
|
65 |
+
def generate_answer(query, image):
|
66 |
+
"""Generate an answer using the Gemini model and the merged image."""
|
67 |
+
response = model.generate_content([f'Answer to the question asked using the image. Also mention the reference from image to support your answer. Example, Table Number or Statement number or any metadata. Question: {query}', image], stream=True)
|
68 |
+
response.resolve()
|
69 |
+
return response.text
|
70 |
+
|
71 |
+
def classify_system_question(query):
|
72 |
+
"""Check if the question is related to the system itself."""
|
73 |
+
response = model.generate_content([f"Determine if the question is about the system itself, like 'Who are you?' or 'What can you do?' or 'Introduce yourself' . Answer with 'yes' or 'no'. Question: {query}"], stream=True)
|
74 |
+
response.resolve()
|
75 |
+
return response.text.strip().lower() == "yes"
|
76 |
+
|
77 |
+
def classify_question(query):
|
78 |
+
"""Classify whether the question is general or domain-specific using Gemini."""
|
79 |
+
response = model.generate_content([f"Classify this question as 'general' or 'domain-specific'. Give one word answer i.e general or domain-specific. General questions are greetings and questions involving general knowledge like the capital of France. General questions also involve politics, geography, history, economics, cosmology, information about famous personalities, etc. Question: {query}"], stream=True)
|
80 |
+
response.resolve()
|
81 |
+
return response.text.strip().lower() # Assuming the response is either 'general' or 'domain-specific'
|
82 |
+
|
83 |
+
def chatbot(query, history):
|
84 |
+
max_history_length = 50 # Number of recent exchanges to keep
|
85 |
+
|
86 |
+
# Truncate the history to the last `max_history_length` exchanges
|
87 |
+
truncated_history = history[-max_history_length:]
|
88 |
+
|
89 |
+
# Add user input to the history
|
90 |
+
truncated_history.append(("You: " + query, "Model:"))
|
91 |
+
|
92 |
+
# Step 1: Check if the question is about the system
|
93 |
+
if classify_system_question(query):
|
94 |
+
text = "I am an AI assistant capable of answering queries related to the National Building Code of Canada and general questions. I was developed by the research group [SITE] at the University of Alberta. How can I assist you further?"
|
95 |
+
|
96 |
+
else:
|
97 |
+
# Step 2: Classify the question as general or domain-specific
|
98 |
+
question_type = classify_question(query)
|
99 |
+
|
100 |
+
# If the question is general, use Gemini to directly answer it
|
101 |
+
if question_type == "general":
|
102 |
+
text = model.generate_content([f"Answer this general question: {query}. If it is a greeting respond accordingly and if it is not greeting add a prefix saying that it is a general query."], stream=True)
|
103 |
+
text.resolve()
|
104 |
+
text = text.text
|
105 |
+
|
106 |
+
else:
|
107 |
+
# Step 3: Query the RAG model for domain-specific answers
|
108 |
+
results = RAG.search(query, k=3)
|
109 |
+
|
110 |
+
# Check if RAG found any results
|
111 |
+
if not results:
|
112 |
+
text = model.generate_content([f"Answer this question: {query}"], stream=True)
|
113 |
+
text.resolve()
|
114 |
+
text = text.text
|
115 |
+
text = "It is a general query. ANSWER:" + text
|
116 |
+
else:
|
117 |
+
# Process images from the results
|
118 |
+
image = process_image_from_results(results)
|
119 |
+
|
120 |
+
# Generate the answer using the Gemini model if an image is found
|
121 |
+
if image:
|
122 |
+
text = generate_answer(query, image)
|
123 |
+
text = "It is a query from NBCC. ANSWER:" + text
|
124 |
+
|
125 |
+
# Check if the answer is a fallback message (indicating no relevant answer)
|
126 |
+
if any(keyword in text.lower() for keyword in [
|
127 |
+
"does not provide",
|
128 |
+
"cannot answer",
|
129 |
+
"does not contain",
|
130 |
+
"no relevant answer",
|
131 |
+
"not found",
|
132 |
+
"information unavailable",
|
133 |
+
"not in the document",
|
134 |
+
"unable to provide",
|
135 |
+
"no data",
|
136 |
+
"missing information",
|
137 |
+
"no match",
|
138 |
+
"provided text does not describe",
|
139 |
+
"are not explicitly listed",
|
140 |
+
"are not explicitly mentioned",
|
141 |
+
"no results",
|
142 |
+
"not available",
|
143 |
+
"query not found"
|
144 |
+
]):
|
145 |
+
# Fallback to Gemini for answering
|
146 |
+
text = model.generate_content([f"Answer this general question in concise manner: {query}"], stream=True)
|
147 |
+
text.resolve()
|
148 |
+
text = text.text
|
149 |
+
text = "It is a general query. ANSWER: " + text
|
150 |
+
else:
|
151 |
+
text = model.generate_content([f"Answer this question: {query}"], stream=True)
|
152 |
+
text.resolve()
|
153 |
+
text = text.text
|
154 |
+
text = "It is a query from NBCC. ANSWER: " + text
|
155 |
+
|
156 |
+
# Add the model's response to the truncated history
|
157 |
+
truncated_history[-1] = (truncated_history[-1][0], "Model: " + text) # Update the most recent message with model's answer
|
158 |
+
|
159 |
+
# Return the output text, updated state, and chat history (as tuple pairs)
|
160 |
+
return text, truncated_history, truncated_history # Ensure all three outputs are returned in the correct order
|
161 |
+
|
162 |
+
|
163 |
+
|
164 |
+
import gradio as gr
|
165 |
+
|
166 |
+
# Define Gradio interface
|
167 |
+
with gr.Blocks() as iface:
|
168 |
+
# Set the conversation state as an empty list
|
169 |
+
state = gr.State([])
|
170 |
+
|
171 |
+
# Custom CSS to beautify the interface
|
172 |
+
iface.css = """
|
173 |
+
.gradio-container {
|
174 |
+
background-color: #f9f9f9;
|
175 |
+
border-radius: 15px;
|
176 |
+
padding: 20px;
|
177 |
+
box-shadow: 0 4px 8px rgba(0, 0, 0, 0.1);
|
178 |
+
}
|
179 |
+
.gr-chatbox {
|
180 |
+
background-color: #f0f0f0;
|
181 |
+
border-radius: 10px;
|
182 |
+
padding: 10px;
|
183 |
+
max-height: 1000px;
|
184 |
+
overflow-y: scroll;
|
185 |
+
margin-bottom: 10px;
|
186 |
+
}
|
187 |
+
.gr-textbox input {
|
188 |
+
border-radius: 10px;
|
189 |
+
padding: 12px;
|
190 |
+
font-size: 16px;
|
191 |
+
border: 1px solid #ccc;
|
192 |
+
width: 100%;
|
193 |
+
margin-top: 10px;
|
194 |
+
box-sizing: border-box;
|
195 |
+
}
|
196 |
+
.gr-textbox input:focus {
|
197 |
+
border-color: #4CAF50;
|
198 |
+
outline: none;
|
199 |
+
}
|
200 |
+
.gr-button {
|
201 |
+
background-color: #4CAF50;
|
202 |
+
color: white;
|
203 |
+
padding: 12px;
|
204 |
+
border-radius: 10px;
|
205 |
+
font-size: 16px;
|
206 |
+
border: none;
|
207 |
+
cursor: pointer;
|
208 |
+
}
|
209 |
+
.gr-button:hover {
|
210 |
+
background-color: #45a049;
|
211 |
+
}
|
212 |
+
.gr-chatbot {
|
213 |
+
font-family: "Arial", sans-serif;
|
214 |
+
font-size: 14px;
|
215 |
+
}
|
216 |
+
.gr-chatbot .gr-chatbot-user {
|
217 |
+
background-color: #e1f5fe;
|
218 |
+
border-radius: 10px;
|
219 |
+
padding: 8px;
|
220 |
+
margin-bottom: 10px;
|
221 |
+
max-width: 80%;
|
222 |
+
}
|
223 |
+
.gr-chatbot .gr-chatbot-model {
|
224 |
+
background-color: #ffffff;
|
225 |
+
border-radius: 10px;
|
226 |
+
padding: 8px;
|
227 |
+
margin-bottom: 10px;
|
228 |
+
max-width: 80%;
|
229 |
+
}
|
230 |
+
.gr-chatbot .gr-chatbot-user p,
|
231 |
+
.gr-chatbot .gr-chatbot-model p {
|
232 |
+
margin: 0;
|
233 |
+
}
|
234 |
+
#input_box {
|
235 |
+
position: fixed;
|
236 |
+
bottom: 20px;
|
237 |
+
width: 95%;
|
238 |
+
padding: 10px;
|
239 |
+
border-radius: 10px;
|
240 |
+
box-shadow: 0 0 5px rgba(0, 0, 0, 0.2);
|
241 |
+
}
|
242 |
+
"""
|
243 |
+
|
244 |
+
# Add an image at the top of the page
|
245 |
+
with gr.Column():
|
246 |
+
gr.Image("/home/mohammadaqib/Pictures/Screenshots/site.png",height = 300) # Use the image URL
|
247 |
+
gr.Markdown(
|
248 |
+
"# Question Answering System Over National Building Code of Canada"
|
249 |
+
)
|
250 |
+
|
251 |
+
# Chatbot UI
|
252 |
+
with gr.Row():
|
253 |
+
|
254 |
+
chat_history = gr.Chatbot(label="Chat History", height=250)
|
255 |
+
|
256 |
+
|
257 |
+
# Place input at the bottom
|
258 |
+
with gr.Row():
|
259 |
+
query = gr.Textbox(
|
260 |
+
label="Ask a Question",
|
261 |
+
placeholder="Enter your question here...",
|
262 |
+
lines=1,
|
263 |
+
interactive=True,
|
264 |
+
elem_id="input_box" # Custom ID for styling
|
265 |
+
)
|
266 |
+
|
267 |
+
# Output for the response
|
268 |
+
output_text = gr.Textbox(label="Answer", interactive=False, visible=False) # Optional to hide
|
269 |
+
|
270 |
+
# Define the interaction behavior
|
271 |
+
query.submit(
|
272 |
+
chatbot,
|
273 |
+
inputs=[query, state],
|
274 |
+
outputs=[output_text, state, chat_history],
|
275 |
+
show_progress=True
|
276 |
+
).then(
|
277 |
+
lambda _: "", # Clear the input after submission
|
278 |
+
inputs=None,
|
279 |
+
outputs=query
|
280 |
+
)
|
281 |
+
|
282 |
+
gr.Markdown("<p style='position: fixed; bottom:0; width: 100%; text-align: left; font-style: italic; margin-left: 15%; font-size: 18px;'>Developed by Mohammad Aqib, MSc Student at the University of Alberta, supervised by Dr. Qipei (Gavin) Mei.</p>", elem_id="footer")
|
283 |
+
|
284 |
+
|
285 |
+
|
286 |
+
|
287 |
+
# Launch the interface
|
288 |
+
iface.launch(share=True)
|
289 |
+
|
290 |
+
|