MrSimple07 commited on
Commit
e4c3943
·
verified ·
1 Parent(s): 47ca3d8

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +542 -0
app.py ADDED
@@ -0,0 +1,542 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import logging
3
+ import gradio as gr
4
+ import faiss
5
+ import numpy as np
6
+ import pandas as pd
7
+ import requests
8
+ from geopy.geocoders import Nominatim
9
+ from sentence_transformers import SentenceTransformer
10
+ from typing import Tuple, Optional
11
+ import os
12
+ from huggingface_hub import hf_hub_download
13
+ import geonamescache
14
+
15
+ logging.basicConfig(level=logging.INFO)
16
+
17
+ from huggingface_hub import login
18
+
19
+ token = os.getenv('HF_TOKEN')
20
+
21
+ df_path = hf_hub_download(
22
+ repo_id='MrSimple07/raggg',
23
+ filename='15_rag_data.csv',
24
+ repo_type='dataset',
25
+ token = token
26
+ )
27
+ embeddings_path = hf_hub_download(
28
+ repo_id='MrSimple07/raggg',
29
+ filename='rag_embeddings.npy',
30
+ repo_type='dataset',
31
+ token = token
32
+ )
33
+
34
+ df = pd.read_csv(df_path)
35
+ embeddings = np.load(embeddings_path)
36
+
37
+ MISTRAL_API_KEY = "TeX7Cs30zMCAi0A90w4pGhPbOGrYzQkj"
38
+ MISTRAL_API_URL = "https://api.mistral.ai/v1/chat/completions"
39
+
40
+ category_synonyms = {
41
+ "museum": [
42
+ "museums", "art galleries", "natural museums", "modern art museums"
43
+ ],
44
+ "cafe": [
45
+ "coffee shops", ""
46
+ ],
47
+ "restaurant": [
48
+ "local dining spots", "fine dining", "casual eateries",
49
+ "family-friendly restaurants", "street food places"
50
+ ],
51
+ "parks": [
52
+ "national parks", "urban green spaces", "botanical gardens",
53
+ "recreational parks", "wildlife reserves"
54
+ ],
55
+ "park": [
56
+ "national parks", "urban green spaces", "botanical gardens",
57
+ "recreational parks", "wildlife reserves"
58
+ ],
59
+ "spa": ['bath', 'swimming', 'pool']
60
+ }
61
+
62
+ def extract_location_geonames(query: str) -> dict:
63
+ gc = geonamescache.GeonamesCache()
64
+ countries = {c['name'].lower(): c['name'] for c in gc.get_countries().values()}
65
+ cities = {c['name'].lower(): c['name'] for c in gc.get_cities().values()}
66
+
67
+ words = query.split()
68
+
69
+ for i in range(len(words)):
70
+ for j in range(i+1, len(words)+1):
71
+ potential_location = ' '.join(words[i:j]).lower()
72
+
73
+ # Check if it's a city first
74
+ if potential_location in cities:
75
+ return {
76
+ 'city': cities[potential_location],
77
+ }
78
+
79
+ # Then check if it's a country
80
+ if potential_location in countries:
81
+ return {
82
+ 'city': ' '.join(words[:i] + words[j:]) if i+j < len(words) else None,
83
+ 'country': countries[potential_location]
84
+ }
85
+
86
+ return {'city': query}
87
+
88
+
89
+
90
+ def expand_category_once(query, target_category):
91
+ """
92
+ Expand the target category term in absthe query only once with synonyms and related phrases.
93
+ """
94
+ target_lower = target_category.lower()
95
+ if target_lower in query.lower():
96
+ synonyms = category_synonyms.get(target_lower, [])
97
+ if synonyms:
98
+ expanded_term = f"{target_category} ({', '.join(synonyms)})"
99
+ query = query.replace(target_category, expanded_term, 1) # Replace only the first occurrence
100
+ return query
101
+
102
+ CATEGORY_FILTER_WORDS = [
103
+ 'museum', 'art', 'gallery', 'tourism', 'historical',
104
+ 'bar', 'cafe', 'restaurant', 'park', 'landmark',
105
+ 'beach', 'mountain', 'theater', 'church', 'monument',
106
+ 'garden', 'library', 'university', 'shopping', 'market',
107
+ 'hotel', 'resort', 'cultural', 'natural', 'science',
108
+ 'educational', 'entertainment', 'sports', 'memorial', 'historic',
109
+ 'spa', 'landmarks', 'sleep', 'coffee shops', 'shops', 'buildings',
110
+ 'gothic', 'castle', 'fortress', 'aquarium', 'zoo', 'wildlife',
111
+ 'adventure', 'hiking', 'lighthouse', 'vineyard', 'brewery',
112
+ 'winery', 'pub', 'nightclub', 'observatory', 'theme park',
113
+ 'botanical', 'sanctuary', 'heritage', 'island', 'waterfall',
114
+ 'canyon', 'valley', 'desert', 'artisans', 'crafts', 'music hall',
115
+ 'dance clubs', 'opera house', 'skyscraper', 'bridge', 'fountain',
116
+ 'temple', 'shrine', 'archaeological', 'planetarium', 'marketplace',
117
+ 'street art', 'local cuisine', 'eco-tourism', 'carnival', 'festival', 'film'
118
+ ]
119
+
120
+
121
+ def extract_category_from_query(query: str) -> Optional[str]:
122
+ query_lower = query.lower()
123
+ for word in CATEGORY_FILTER_WORDS:
124
+ if word in query_lower:
125
+ return word
126
+
127
+ return None
128
+
129
+ def get_location_details(min_lat, max_lat, min_lon, max_lon):
130
+ """Get detailed location information for a bounding box with improved city detection and error handling"""
131
+ geolocator = Nominatim(user_agent="location_finder", timeout=10)
132
+
133
+ try:
134
+ # Strategy 1: Try multiple points within the bounding box
135
+ sample_points = [
136
+ ((float(min_lat) + float(max_lat)) / 2,
137
+ (float(min_lon) + float(max_lon)) / 2),
138
+ (float(min_lat), float(min_lon)),
139
+ (float(max_lat), float(min_lon)),
140
+ (float(min_lat), float(max_lon)),
141
+ (float(max_lat), float(max_lon))
142
+ ]
143
+
144
+ # Collect unique cities from all points
145
+ cities = set()
146
+ full_addresses = []
147
+
148
+ for lat, lon in sample_points:
149
+ try:
150
+ # Add multiple retry attempts with exponential backoff
151
+ for attempt in range(3):
152
+ try:
153
+ location = geolocator.reverse(f"{lat}, {lon}", language='en')
154
+ break
155
+ except Exception as retry_error:
156
+ if attempt == 2: # Last attempt
157
+ print(f"Failed to retrieve location for {lat}, {lon} after 3 attempts")
158
+ continue
159
+ time.sleep(2 ** attempt) # Exponential backoff
160
+
161
+ if location:
162
+ address = location.raw.get('address', {})
163
+
164
+ # Extract city with multiple fallback options
165
+ city = (
166
+ address.get('city') or
167
+ address.get('town') or
168
+ address.get('municipality') or
169
+ address.get('county') or
170
+ address.get('state')
171
+ )
172
+
173
+ if city:
174
+ cities.add(city)
175
+ full_addresses.append(location.address)
176
+
177
+ except Exception as point_error:
178
+ print(f"Error processing point {lat}, {lon}: {point_error}")
179
+ continue
180
+
181
+ # If no cities found, try alternative geocoding service or return default
182
+ if not cities:
183
+ print("No cities detected. Returning default location information.")
184
+ return {
185
+ 'location_parts': [],
186
+ 'full_address_parts': '',
187
+ 'full_address': '',
188
+ 'city': [],
189
+ 'state': '',
190
+ 'country': '',
191
+ 'cities_or_query': ''
192
+ }
193
+
194
+ # Prioritize cities, keeping all detected cities
195
+ city_list = list(cities)
196
+
197
+ # Use the last processed address for state and country
198
+ state = address.get('state', '')
199
+ country = address.get('country', '')
200
+
201
+ # Create a formatted list of cities for query
202
+ cities_or_query = " or ".join(city_list)
203
+
204
+ location_parts = [part for part in [cities_or_query, state, country] if part]
205
+ full_address_parts = ', '.join(location_parts)
206
+
207
+ print(f"Detected Cities: {cities}")
208
+ print(f"Cities for Query: {cities_or_query}")
209
+ print(f"Full Address Parts: {full_address_parts}")
210
+
211
+ return {
212
+ 'location_parts': city_list,
213
+ 'full_address_parts': full_address_parts,
214
+ 'full_address': full_addresses[0] if full_addresses else '',
215
+ 'city': city_list,
216
+ 'state': state,
217
+ 'country': country,
218
+ 'cities_or_query': cities_or_query
219
+ }
220
+
221
+ except Exception as e:
222
+ print(f"Comprehensive error in location details retrieval: {e}")
223
+ import traceback
224
+ traceback.print_exc()
225
+
226
+ return None
227
+
228
+ def rag_query(
229
+ query: str,
230
+ df: pd.DataFrame,
231
+ model: SentenceTransformer,
232
+ precomputed_embeddings: np.ndarray,
233
+ index: faiss.IndexFlatL2,
234
+ min_lat: str = None,
235
+ max_lat: str = None,
236
+ min_lon: str = None,
237
+ max_lon: str = None,
238
+ category: str = None,
239
+ city: str = None,
240
+ ) -> Tuple[str, str]:
241
+ """Enhanced RAG function with prioritized location extraction"""
242
+ print("\n=== Starting RAG Query ===")
243
+ print(f"Initial DataFrame size: {len(df)}")
244
+
245
+ # Prioritized location extraction
246
+ location_info = None
247
+ location_names = []
248
+
249
+ # Priority 1: Explicitly provided city name
250
+ if city:
251
+ location_names = [city]
252
+ print(f"Using explicitly provided city: {city}")
253
+
254
+ # Priority 2: Coordinates (Nominatim)
255
+ elif all(coord is not None and coord != "" for coord in [min_lat, max_lat, min_lon, max_lon]):
256
+ try:
257
+ location_info = get_location_details(
258
+ float(min_lat),
259
+ float(max_lat),
260
+ float(min_lon),
261
+ float(max_lon)
262
+ )
263
+
264
+ # Extract location names from Nominatim result
265
+ if location_info:
266
+ if location_info.get('city'):
267
+ location_names.extend(location_info['city'] if isinstance(location_info['city'], list) else [location_info['city']])
268
+ if location_info.get('state'):
269
+ location_names.append(location_info['state'])
270
+ if location_info.get('country'):
271
+ location_names.append(location_info['country'])
272
+
273
+ print(f"Using coordinates-based location: {location_names}")
274
+ except Exception as e:
275
+ print(f"Location details error: {e}")
276
+
277
+ # Priority 3: Extract from query using GeoNames only if no previous methods worked
278
+ if not location_names:
279
+ geonames_info = extract_location_geonames(query)
280
+ if geonames_info.get('city'):
281
+ location_names = [geonames_info['city']]
282
+ print(f"Using GeoNames-extracted city: {location_names}")
283
+
284
+ # Start with a copy of the original DataFrame
285
+ filtered_df = df.copy()
286
+
287
+ # Filter DataFrame by location names
288
+ if location_names:
289
+ # Create a case-insensitive filter
290
+ location_filter = (
291
+ filtered_df['city'].str.lower().isin([name.lower() for name in location_names]) |
292
+ filtered_df['city'].apply(lambda x: any(name.lower() in str(x).lower() for name in location_names)) |
293
+ filtered_df['combined_field'].apply(lambda x: any(name.lower() in str(x).lower() for name in location_names))
294
+ )
295
+
296
+ filtered_df = filtered_df[location_filter]
297
+
298
+ print(f"Location Names Used for Filtering: {location_names}")
299
+ print(f"Results after location filtering: {len(filtered_df)}")
300
+
301
+
302
+
303
+ enhanced_query_parts = []
304
+ if query:
305
+ enhanced_query_parts.append(query)
306
+ if category:
307
+ enhanced_query_parts.append(f"{category} category")
308
+ if city:
309
+ enhanced_query_parts.append(f" in {city}")
310
+
311
+ if min_lat is not None and max_lat is not None and min_lon is not None and max_lon is not None:
312
+ enhanced_query_parts.append(f"within latitudes {min_lat} to {max_lat} and longitudes {min_lon} to {max_lon}")
313
+
314
+ # Add location context
315
+ if location_info:
316
+ location_context = " ".join(filter(None, [
317
+ ", ".join(location_info.get('city', [])),
318
+ location_info.get('state', ''),
319
+ # location_info.get('country', '')
320
+ ]))
321
+ if location_context:
322
+ enhanced_query_parts.append(f"in {location_context}")
323
+
324
+
325
+
326
+ enhanced_query = " ".join(enhanced_query_parts)
327
+
328
+ if enhanced_query:
329
+ enhanced_query = expand_category_once(enhanced_query, category)
330
+ print(f"Filtered by city '{city}': {len(filtered_df)} results")
331
+
332
+ print(f"Enhanced Query: {enhanced_query}")
333
+
334
+ detected_category = extract_category_from_query(enhanced_query)
335
+ if detected_category:
336
+ category_filter = (
337
+ filtered_df['category'].str.contains(detected_category, case=False, na=False) |
338
+ filtered_df['combined_field'].str.contains(detected_category, case=False, na=False)
339
+ )
340
+ filtered_df = filtered_df[category_filter]
341
+
342
+ print(f"Filtered by query words '{detected_category}': {len(filtered_df)} results")
343
+
344
+
345
+ try:
346
+ query_vector = model.encode([enhanced_query])[0]
347
+
348
+ # Compute embeddings for the filtered DataFrame
349
+ filtered_embeddings = precomputed_embeddings[filtered_df.index]
350
+
351
+ # Create FAISS index with filtered embeddings
352
+ filtered_index = faiss.IndexFlatL2(filtered_embeddings.shape[1])
353
+ filtered_index.add(filtered_embeddings.astype(np.float32))
354
+
355
+ # Perform semantic search on filtered results
356
+ k = min(20, len(filtered_df))
357
+ distances, local_indices = filtered_index.search(
358
+ np.array([query_vector]).astype(np.float32),
359
+ k
360
+ )
361
+
362
+ # Get the top results
363
+ results_df = filtered_df.iloc[local_indices[0]]
364
+
365
+ # Format results
366
+ formatted_results = []
367
+ for i, (_, row) in enumerate(results_df.iterrows(), 1):
368
+ formatted_results.append(
369
+ f"\n=== Result {i} ===\n"
370
+ f"Name: {row['name']}\n"
371
+ f"Category: {row['category']}\n"
372
+ f"City: {row['city']}\n"
373
+ f"Address: {row['address']}\n"
374
+ f"Description: {row['description']}\n"
375
+ f"Latitude: {row['latitude']}\n"
376
+ f"Longitude: {row['longitude']}\n"
377
+ )
378
+
379
+ search_results = "\n".join(formatted_results) if formatted_results else "No matching locations found."
380
+
381
+ # Optional: Use Mistral for further refinement
382
+ try:
383
+ answer = query_mistral(enhanced_query, search_results)
384
+ except Exception as e:
385
+ print(f"Error in Mistral query: {e}")
386
+ answer = "Unable to generate additional insights."
387
+
388
+ return search_results, answer
389
+
390
+ except Exception as e:
391
+ print(f"Error in semantic search: {e}")
392
+ return f"Error performing search: {str(e)}", ""
393
+
394
+
395
+ def query_mistral(prompt: str, context: str, max_retries: int = 3) -> str:
396
+ """
397
+ Robust Mistral verification with exponential backoff
398
+ """
399
+ import time
400
+
401
+ # Early return if no context
402
+ if not context or context.strip() == "No matching locations found.":
403
+ return context
404
+
405
+ verification_prompt = f"""Precise Location Curation Task:
406
+ REQUIREMENTS:
407
+ - Source Query: {prompt}
408
+ - Current Context: {context}
409
+
410
+ DETAILED INSTRUCTIONS:
411
+ 1. Write the min, max latitude and min, max longitude in the beginning taking from the query
412
+ 2. Curate a comprehensive list of 15 locations inside of these coordinates and strictly relevant to place.
413
+ 3. Take STRICTLY ONLY relevant places to Source Query.
414
+ 4. Add a short description about the place (2-3 sentences)
415
+ 5. Add coordinates (lat and long).
416
+ 6. Add address for the place
417
+ 7. Remove any duplicate entries in the list
418
+ 8. If places > 10, quick generation a new places relevant to Source Query and inside of the coordinates
419
+
420
+
421
+ CRITICAL: Do NOT use placeholder. Quick and fast response required
422
+ """
423
+
424
+ for attempt in range(max_retries):
425
+ try:
426
+ # Robust API configuration
427
+ response = requests.post(
428
+ MISTRAL_API_URL,
429
+ headers={
430
+ "Authorization": f"Bearer {MISTRAL_API_KEY}",
431
+ "Content-Type": "application/json"
432
+ },
433
+ json={
434
+ "model": "mistral-large-latest",
435
+ "messages": [
436
+ {"role": "system", "content": "You are a precise location curator specializing in comprehensive travel information."},
437
+ {"role": "user", "content": verification_prompt}
438
+ ],
439
+ "temperature": 0.1,
440
+ "max_tokens": 5000
441
+ },
442
+ timeout=100 # Increased timeout
443
+ )
444
+
445
+ # Enhanced error handling
446
+ response.raise_for_status()
447
+
448
+ # Extract verified response
449
+ verified_response = response.json()['choices'][0]['message']['content']
450
+
451
+ # Validate response length and complexity
452
+ if len(verified_response.strip()) < 100:
453
+ if attempt == max_retries - 1:
454
+ return context
455
+ time.sleep(2 ** attempt) # Exponential backoff
456
+ continue
457
+
458
+ return verified_response
459
+
460
+ except requests.Timeout:
461
+ logging.warning(f"Mistral API timeout (Attempt {attempt + 1}/{max_retries})")
462
+ if attempt < max_retries - 1:
463
+ time.sleep(2 ** attempt) # Exponential backoff
464
+ else:
465
+ logging.error("Mistral API consistently timing out")
466
+ return context
467
+
468
+ except requests.RequestException as e:
469
+ logging.error(f"Mistral API request error: {e}")
470
+ if attempt < max_retries - 1:
471
+ time.sleep(2 ** attempt)
472
+ else:
473
+ return context
474
+
475
+ except Exception as e:
476
+ logging.error(f"Unexpected error in Mistral verification: {e}")
477
+ if attempt < max_retries - 1:
478
+ time.sleep(2 ** attempt)
479
+ else:
480
+ return context
481
+
482
+ return context
483
+
484
+
485
+
486
+ def create_interface(
487
+ df: pd.DataFrame,
488
+ model: SentenceTransformer,
489
+ precomputed_embeddings: np.ndarray,
490
+ index: faiss.IndexFlatL2
491
+ ):
492
+ """Create Gradio interface with 4 bounding box inputs"""
493
+ return gr.Interface(
494
+ fn=lambda q, min_lat, max_lat, min_lon, max_lon, city, cat: rag_query(
495
+ query=q,
496
+ df=df,
497
+ model=model,
498
+ precomputed_embeddings=precomputed_embeddings,
499
+ index=index,
500
+ min_lat=min_lat,
501
+ max_lat=max_lat,
502
+ min_lon=min_lon,
503
+ max_lon=max_lon,
504
+ city=city,
505
+ category=cat
506
+ )[1],
507
+ inputs=[
508
+ gr.Textbox(lines=2, label="Question"),
509
+ gr.Textbox(label="Min Latitude"),
510
+ gr.Textbox(label="Max Latitude"),
511
+ gr.Textbox(label="Min Longitude"),
512
+ gr.Textbox(label="Max Longitude"),
513
+ gr.Textbox(label="City"),
514
+ gr.Textbox(label="Category")
515
+ ],
516
+ outputs=[
517
+ gr.Textbox(label="Locations Found"),
518
+ ],
519
+ title="Tourist Information System with Bounding Box Search",
520
+ examples=[
521
+ ["Museums in area", "40.71", "40.86", "-74.0", "-74.1", "", "museum"],
522
+ ["Restaurants", "48.8575", "48.9", "2.3514", "2.4", "Paris", "restaurant"],
523
+ ["Coffee shops", "51.5", "51.6", "-0.2", "-0.1", "London", "cafe"],
524
+ ["Spa places", "", "", "", "", "Budapest", ""],
525
+ ["Lambic brewery", "50.84211068618749", "50.849274898691244","4.339536387173865", "4.361188801802462", "", ""],
526
+ ["Art nouveau architecture buildings", "44.42563381188614", "44.43347927669681","26.008709832230608", "26.181744493414488", "", ""],
527
+ ["Harry Potter filming locations", "51.52428877891333", "51.54738884423489", "-0.1955164690977472", "-0.05082973945560466", "", ""]
528
+
529
+ ]
530
+ )
531
+ if __name__ == "__main__":
532
+ try:
533
+ model = SentenceTransformer('all-MiniLM-L6-v2')
534
+ precomputed_embeddings = embeddings
535
+ index = faiss.IndexFlatL2(precomputed_embeddings.shape[1])
536
+ index.add(precomputed_embeddings.astype(np.float32))
537
+
538
+ iface = create_interface(df, model, precomputed_embeddings, index)
539
+ iface.launch(share=True, debug=True)
540
+ except Exception as e:
541
+ logging.error(f"Startup error: {e}")
542
+ sys.exit(1)