Aqib2489 commited on
Commit
164d1e2
·
verified ·
1 Parent(s): 046cb0e

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +290 -0
app.py ADDED
@@ -0,0 +1,290 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ import os
3
+ from huggingface_hub import login
4
+ import PIL.Image
5
+ from byaldi import RAGMultiModalModel
6
+ import PIL.Image as PILImage
7
+ import io
8
+ import textwrap
9
+ import google.generativeai as genai
10
+ import gradio as gr # Add Gradio for UI
11
+ from PIL import Image as PILImage
12
+
13
+ # Initialize Google API and model
14
+ import torch
15
+
16
+ device = torch.device("cpu") # Force CPU
17
+ import os
18
+ GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY")
19
+ genai.configure(api_key=GOOGLE_API_KEY)
20
+ model = genai.GenerativeModel('models/gemini-1.5-flash-latest')
21
+
22
+ # Load the RAG multi-modal model
23
+ RAG = RAGMultiModalModel.from_pretrained("vidore/colpali-v1.2", verbose=1)
24
+
25
+ # Specify the index path where the index was saved during the first run
26
+ index_path = "/home/mohammadaqib/Desktop/project/research/Multi-Modal-RAG/Colpali/BCC"
27
+ RAG = RAGMultiModalModel.from_index(index_path)
28
+
29
+ # Initialize conversation history
30
+ conversation_history = []
31
+
32
+ def get_user_input(query):
33
+ """Process user input."""
34
+ return query
35
+
36
+ def process_image_from_results(results):
37
+ """Process images from the search results and merge them."""
38
+ image_list = []
39
+ for i in range(min(3, len(results))):
40
+ try:
41
+ # Ensure the result has a base64 attribute
42
+ image_bytes = base64.b64decode(results[i].base64)
43
+ image = PILImage.open(io.BytesIO(image_bytes)) # Open image directly from bytes
44
+ image_list.append(image)
45
+ except AttributeError:
46
+ print(f"Result {i} does not contain a 'base64' attribute.")
47
+
48
+ # Merge images if any
49
+ if image_list:
50
+ total_width = sum(img.width for img in image_list)
51
+ max_height = max(img.height for img in image_list)
52
+
53
+ merged_image = PILImage.new('RGB', (total_width, max_height))
54
+ x_offset = 0
55
+ for img in image_list:
56
+ merged_image.paste(img, (x_offset, 0))
57
+ x_offset += img.width
58
+
59
+ # Save the merged image
60
+ merged_image.save('merged_image.jpg')
61
+ return merged_image
62
+ else:
63
+ return None
64
+
65
+ def generate_answer(query, image):
66
+ """Generate an answer using the Gemini model and the merged image."""
67
+ response = model.generate_content([f'Answer to the question asked using the image. Also mention the reference from image to support your answer. Example, Table Number or Statement number or any metadata. Question: {query}', image], stream=True)
68
+ response.resolve()
69
+ return response.text
70
+
71
+ def classify_system_question(query):
72
+ """Check if the question is related to the system itself."""
73
+ response = model.generate_content([f"Determine if the question is about the system itself, like 'Who are you?' or 'What can you do?' or 'Introduce yourself' . Answer with 'yes' or 'no'. Question: {query}"], stream=True)
74
+ response.resolve()
75
+ return response.text.strip().lower() == "yes"
76
+
77
+ def classify_question(query):
78
+ """Classify whether the question is general or domain-specific using Gemini."""
79
+ response = model.generate_content([f"Classify this question as 'general' or 'domain-specific'. Give one word answer i.e general or domain-specific. General questions are greetings and questions involving general knowledge like the capital of France. General questions also involve politics, geography, history, economics, cosmology, information about famous personalities, etc. Question: {query}"], stream=True)
80
+ response.resolve()
81
+ return response.text.strip().lower() # Assuming the response is either 'general' or 'domain-specific'
82
+
83
+ def chatbot(query, history):
84
+ max_history_length = 50 # Number of recent exchanges to keep
85
+
86
+ # Truncate the history to the last `max_history_length` exchanges
87
+ truncated_history = history[-max_history_length:]
88
+
89
+ # Add user input to the history
90
+ truncated_history.append(("You: " + query, "Model:"))
91
+
92
+ # Step 1: Check if the question is about the system
93
+ if classify_system_question(query):
94
+ text = "I am an AI assistant capable of answering queries related to the National Building Code of Canada and general questions. I was developed by the research group [SITE] at the University of Alberta. How can I assist you further?"
95
+
96
+ else:
97
+ # Step 2: Classify the question as general or domain-specific
98
+ question_type = classify_question(query)
99
+
100
+ # If the question is general, use Gemini to directly answer it
101
+ if question_type == "general":
102
+ text = model.generate_content([f"Answer this general question: {query}. If it is a greeting respond accordingly and if it is not greeting add a prefix saying that it is a general query."], stream=True)
103
+ text.resolve()
104
+ text = text.text
105
+
106
+ else:
107
+ # Step 3: Query the RAG model for domain-specific answers
108
+ results = RAG.search(query, k=3)
109
+
110
+ # Check if RAG found any results
111
+ if not results:
112
+ text = model.generate_content([f"Answer this question: {query}"], stream=True)
113
+ text.resolve()
114
+ text = text.text
115
+ text = "It is a general query. ANSWER:" + text
116
+ else:
117
+ # Process images from the results
118
+ image = process_image_from_results(results)
119
+
120
+ # Generate the answer using the Gemini model if an image is found
121
+ if image:
122
+ text = generate_answer(query, image)
123
+ text = "It is a query from NBCC. ANSWER:" + text
124
+
125
+ # Check if the answer is a fallback message (indicating no relevant answer)
126
+ if any(keyword in text.lower() for keyword in [
127
+ "does not provide",
128
+ "cannot answer",
129
+ "does not contain",
130
+ "no relevant answer",
131
+ "not found",
132
+ "information unavailable",
133
+ "not in the document",
134
+ "unable to provide",
135
+ "no data",
136
+ "missing information",
137
+ "no match",
138
+ "provided text does not describe",
139
+ "are not explicitly listed",
140
+ "are not explicitly mentioned",
141
+ "no results",
142
+ "not available",
143
+ "query not found"
144
+ ]):
145
+ # Fallback to Gemini for answering
146
+ text = model.generate_content([f"Answer this general question in concise manner: {query}"], stream=True)
147
+ text.resolve()
148
+ text = text.text
149
+ text = "It is a general query. ANSWER: " + text
150
+ else:
151
+ text = model.generate_content([f"Answer this question: {query}"], stream=True)
152
+ text.resolve()
153
+ text = text.text
154
+ text = "It is a query from NBCC. ANSWER: " + text
155
+
156
+ # Add the model's response to the truncated history
157
+ truncated_history[-1] = (truncated_history[-1][0], "Model: " + text) # Update the most recent message with model's answer
158
+
159
+ # Return the output text, updated state, and chat history (as tuple pairs)
160
+ return text, truncated_history, truncated_history # Ensure all three outputs are returned in the correct order
161
+
162
+
163
+
164
+ import gradio as gr
165
+
166
+ # Define Gradio interface
167
+ with gr.Blocks() as iface:
168
+ # Set the conversation state as an empty list
169
+ state = gr.State([])
170
+
171
+ # Custom CSS to beautify the interface
172
+ iface.css = """
173
+ .gradio-container {
174
+ background-color: #f9f9f9;
175
+ border-radius: 15px;
176
+ padding: 20px;
177
+ box-shadow: 0 4px 8px rgba(0, 0, 0, 0.1);
178
+ }
179
+ .gr-chatbox {
180
+ background-color: #f0f0f0;
181
+ border-radius: 10px;
182
+ padding: 10px;
183
+ max-height: 1000px;
184
+ overflow-y: scroll;
185
+ margin-bottom: 10px;
186
+ }
187
+ .gr-textbox input {
188
+ border-radius: 10px;
189
+ padding: 12px;
190
+ font-size: 16px;
191
+ border: 1px solid #ccc;
192
+ width: 100%;
193
+ margin-top: 10px;
194
+ box-sizing: border-box;
195
+ }
196
+ .gr-textbox input:focus {
197
+ border-color: #4CAF50;
198
+ outline: none;
199
+ }
200
+ .gr-button {
201
+ background-color: #4CAF50;
202
+ color: white;
203
+ padding: 12px;
204
+ border-radius: 10px;
205
+ font-size: 16px;
206
+ border: none;
207
+ cursor: pointer;
208
+ }
209
+ .gr-button:hover {
210
+ background-color: #45a049;
211
+ }
212
+ .gr-chatbot {
213
+ font-family: "Arial", sans-serif;
214
+ font-size: 14px;
215
+ }
216
+ .gr-chatbot .gr-chatbot-user {
217
+ background-color: #e1f5fe;
218
+ border-radius: 10px;
219
+ padding: 8px;
220
+ margin-bottom: 10px;
221
+ max-width: 80%;
222
+ }
223
+ .gr-chatbot .gr-chatbot-model {
224
+ background-color: #ffffff;
225
+ border-radius: 10px;
226
+ padding: 8px;
227
+ margin-bottom: 10px;
228
+ max-width: 80%;
229
+ }
230
+ .gr-chatbot .gr-chatbot-user p,
231
+ .gr-chatbot .gr-chatbot-model p {
232
+ margin: 0;
233
+ }
234
+ #input_box {
235
+ position: fixed;
236
+ bottom: 20px;
237
+ width: 95%;
238
+ padding: 10px;
239
+ border-radius: 10px;
240
+ box-shadow: 0 0 5px rgba(0, 0, 0, 0.2);
241
+ }
242
+ """
243
+
244
+ # Add an image at the top of the page
245
+ with gr.Column():
246
+ gr.Image("/home/mohammadaqib/Pictures/Screenshots/site.png",height = 300) # Use the image URL
247
+ gr.Markdown(
248
+ "# Question Answering System Over National Building Code of Canada"
249
+ )
250
+
251
+ # Chatbot UI
252
+ with gr.Row():
253
+
254
+ chat_history = gr.Chatbot(label="Chat History", height=250)
255
+
256
+
257
+ # Place input at the bottom
258
+ with gr.Row():
259
+ query = gr.Textbox(
260
+ label="Ask a Question",
261
+ placeholder="Enter your question here...",
262
+ lines=1,
263
+ interactive=True,
264
+ elem_id="input_box" # Custom ID for styling
265
+ )
266
+
267
+ # Output for the response
268
+ output_text = gr.Textbox(label="Answer", interactive=False, visible=False) # Optional to hide
269
+
270
+ # Define the interaction behavior
271
+ query.submit(
272
+ chatbot,
273
+ inputs=[query, state],
274
+ outputs=[output_text, state, chat_history],
275
+ show_progress=True
276
+ ).then(
277
+ lambda _: "", # Clear the input after submission
278
+ inputs=None,
279
+ outputs=query
280
+ )
281
+
282
+ gr.Markdown("<p style='position: fixed; bottom:0; width: 100%; text-align: left; font-style: italic; margin-left: 15%; font-size: 18px;'>Developed by Mohammad Aqib, MSc Student at the University of Alberta, supervised by Dr. Qipei (Gavin) Mei.</p>", elem_id="footer")
283
+
284
+
285
+
286
+
287
+ # Launch the interface
288
+ iface.launch(share=True)
289
+
290
+