Spaces:
Sleeping
Sleeping
import sys | |
import logging | |
import gradio as gr | |
import faiss | |
import numpy as np | |
import pandas as pd | |
import requests | |
from geopy.geocoders import Nominatim | |
from sentence_transformers import SentenceTransformer | |
from typing import Tuple, Optional | |
import os | |
from huggingface_hub import hf_hub_download | |
import geonamescache | |
logging.basicConfig(level=logging.INFO) | |
from huggingface_hub import login | |
token = os.getenv('HF_TOKEN') | |
df_path = hf_hub_download( | |
repo_id='MrSimple07/raggg', | |
filename='15_rag_data.csv', | |
repo_type='dataset', | |
token = token | |
) | |
embeddings_path = hf_hub_download( | |
repo_id='MrSimple07/raggg', | |
filename='rag_embeddings.npy', | |
repo_type='dataset', | |
token = token | |
) | |
df = pd.read_csv(df_path) | |
embeddings = np.load(embeddings_path, mmap_mode='r') | |
MISTRAL_API_KEY = "TeX7Cs30zMCAi0A90w4pGhPbOGrYzQkj" | |
MISTRAL_API_URL = "https://api.mistral.ai/v1/chat/completions" | |
category_synonyms = { | |
"museum": [ | |
"museums", "art galleries", "natural museums", "modern art museums" | |
], | |
"cafe": [ | |
"coffee shops", "" | |
], | |
"restaurant": [ | |
"local dining spots", "fine dining", "casual eateries", | |
"family-friendly restaurants", "street food places" | |
], | |
"parks": [ | |
"national parks", "urban green spaces", "botanical gardens", | |
"recreational parks", "wildlife reserves" | |
], | |
"park": [ | |
"national parks", "urban green spaces", "botanical gardens", | |
"recreational parks", "wildlife reserves" | |
], | |
"spa": ['bath', 'swimming', 'pool'] | |
} | |
def extract_location_geonames(query: str) -> dict: | |
gc = geonamescache.GeonamesCache() | |
countries = {c['name'].lower(): c['name'] for c in gc.get_countries().values()} | |
cities = {c['name'].lower(): c['name'] for c in gc.get_cities().values()} | |
words = query.split() | |
for i in range(len(words)): | |
for j in range(i+1, len(words)+1): | |
potential_location = ' '.join(words[i:j]).lower() | |
# Check if it's a city first | |
if potential_location in cities: | |
return { | |
'city': cities[potential_location], | |
} | |
# Then check if it's a country | |
if potential_location in countries: | |
return { | |
'city': ' '.join(words[:i] + words[j:]) if i+j < len(words) else None, | |
'country': countries[potential_location] | |
} | |
return {'city': query} | |
def expand_category_once(query, target_category): | |
""" | |
Expand the target category term in absthe query only once with synonyms and related phrases. | |
""" | |
target_lower = target_category.lower() | |
if target_lower in query.lower(): | |
synonyms = category_synonyms.get(target_lower, []) | |
if synonyms: | |
expanded_term = f"{target_category} ({', '.join(synonyms)})" | |
query = query.replace(target_category, expanded_term, 1) # Replace only the first occurrence | |
return query | |
CATEGORY_FILTER_WORDS = [ | |
'museum', 'art', 'gallery', 'tourism', 'historical', | |
'bar', 'cafe', 'restaurant', 'park', 'landmark', | |
'beach', 'mountain', 'theater', 'church', 'monument', | |
'garden', 'library', 'university', 'shopping', 'market', | |
'hotel', 'resort', 'cultural', 'natural', 'science', | |
'educational', 'entertainment', 'sports', 'memorial', 'historic', | |
'spa', 'landmarks', 'sleep', 'coffee shops', 'shops', 'buildings', | |
'gothic', 'castle', 'fortress', 'aquarium', 'zoo', 'wildlife', | |
'adventure', 'hiking', 'lighthouse', 'vineyard', 'brewery', | |
'winery', 'pub', 'nightclub', 'observatory', 'theme park', | |
'botanical', 'sanctuary', 'heritage', 'island', 'waterfall', | |
'canyon', 'valley', 'desert', 'artisans', 'crafts', 'music hall', | |
'dance clubs', 'opera house', 'skyscraper', 'bridge', 'fountain', | |
'temple', 'shrine', 'archaeological', 'planetarium', 'marketplace', | |
'street art', 'local cuisine', 'eco-tourism', 'carnival', 'festival', 'film' | |
] | |
def extract_category_from_query(query: str) -> Optional[str]: | |
query_lower = query.lower() | |
for word in CATEGORY_FILTER_WORDS: | |
if word in query_lower: | |
return word | |
return None | |
def get_location_details(min_lat, max_lat, min_lon, max_lon): | |
"""Get detailed location information for a bounding box with improved city detection and error handling""" | |
geolocator = Nominatim(user_agent="location_finder", timeout=10) | |
try: | |
# Strategy 1: Try multiple points within the bounding box | |
sample_points = [ | |
((float(min_lat) + float(max_lat)) / 2, | |
(float(min_lon) + float(max_lon)) / 2), | |
(float(min_lat), float(min_lon)), | |
(float(max_lat), float(min_lon)), | |
(float(min_lat), float(max_lon)), | |
(float(max_lat), float(max_lon)) | |
] | |
# Collect unique cities from all points | |
cities = set() | |
full_addresses = [] | |
for lat, lon in sample_points: | |
try: | |
# Add multiple retry attempts with exponential backoff | |
for attempt in range(3): | |
try: | |
location = geolocator.reverse(f"{lat}, {lon}", language='en') | |
break | |
except Exception as retry_error: | |
if attempt == 2: # Last attempt | |
print(f"Failed to retrieve location for {lat}, {lon} after 3 attempts") | |
continue | |
time.sleep(2 ** attempt) # Exponential backoff | |
if location: | |
address = location.raw.get('address', {}) | |
# Extract city with multiple fallback options | |
city = ( | |
address.get('city') or | |
address.get('town') or | |
address.get('municipality') or | |
address.get('county') or | |
address.get('state') | |
) | |
if city: | |
cities.add(city) | |
full_addresses.append(location.address) | |
except Exception as point_error: | |
print(f"Error processing point {lat}, {lon}: {point_error}") | |
continue | |
# If no cities found, try alternative geocoding service or return default | |
if not cities: | |
print("No cities detected. Returning default location information.") | |
return { | |
'location_parts': [], | |
'full_address_parts': '', | |
'full_address': '', | |
'city': [], | |
'state': '', | |
'country': '', | |
'cities_or_query': '' | |
} | |
# Prioritize cities, keeping all detected cities | |
city_list = list(cities) | |
# Use the last processed address for state and country | |
state = address.get('state', '') | |
country = address.get('country', '') | |
# Create a formatted list of cities for query | |
cities_or_query = " or ".join(city_list) | |
location_parts = [part for part in [cities_or_query, state, country] if part] | |
full_address_parts = ', '.join(location_parts) | |
print(f"Detected Cities: {cities}") | |
print(f"Cities for Query: {cities_or_query}") | |
print(f"Full Address Parts: {full_address_parts}") | |
return { | |
'location_parts': city_list, | |
'full_address_parts': full_address_parts, | |
'full_address': full_addresses[0] if full_addresses else '', | |
'city': city_list, | |
'state': state, | |
'country': country, | |
'cities_or_query': cities_or_query | |
} | |
except Exception as e: | |
print(f"Comprehensive error in location details retrieval: {e}") | |
import traceback | |
traceback.print_exc() | |
return None | |
def rag_query( | |
query: str, | |
df: pd.DataFrame, | |
model: SentenceTransformer, | |
precomputed_embeddings: np.ndarray, | |
index: faiss.IndexFlatL2, | |
min_lat: str = None, | |
max_lat: str = None, | |
min_lon: str = None, | |
max_lon: str = None, | |
category: str = None, | |
city: str = None, | |
) -> Tuple[str, str]: | |
"""Enhanced RAG function with prioritized location extraction""" | |
print("\n=== Starting RAG Query ===") | |
print(f"Initial DataFrame size: {len(df)}") | |
# Prioritized location extraction | |
location_info = None | |
location_names = [] | |
# Priority 1: Explicitly provided city name | |
if city: | |
location_names = [city] | |
print(f"Using explicitly provided city: {city}") | |
# Priority 2: Coordinates (Nominatim) | |
elif all(coord is not None and coord != "" for coord in [min_lat, max_lat, min_lon, max_lon]): | |
try: | |
location_info = get_location_details( | |
float(min_lat), | |
float(max_lat), | |
float(min_lon), | |
float(max_lon) | |
) | |
# Extract location names from Nominatim result | |
if location_info: | |
if location_info.get('city'): | |
location_names.extend(location_info['city'] if isinstance(location_info['city'], list) else [location_info['city']]) | |
if location_info.get('state'): | |
location_names.append(location_info['state']) | |
if location_info.get('country'): | |
location_names.append(location_info['country']) | |
print(f"Using coordinates-based location: {location_names}") | |
except Exception as e: | |
print(f"Location details error: {e}") | |
# Priority 3: Extract from query using GeoNames only if no previous methods worked | |
if not location_names: | |
geonames_info = extract_location_geonames(query) | |
if geonames_info.get('city'): | |
location_names = [geonames_info['city']] | |
print(f"Using GeoNames-extracted city: {location_names}") | |
# Start with a copy of the original DataFrame | |
filtered_df = df.copy() | |
# Filter DataFrame by location names | |
if location_names: | |
# Create a case-insensitive filter | |
location_filter = ( | |
filtered_df['city'].str.lower().isin([name.lower() for name in location_names]) | | |
filtered_df['city'].apply(lambda x: any(name.lower() in str(x).lower() for name in location_names)) | | |
filtered_df['combined_field'].apply(lambda x: any(name.lower() in str(x).lower() for name in location_names)) | |
) | |
filtered_df = filtered_df[location_filter] | |
print(f"Location Names Used for Filtering: {location_names}") | |
print(f"Results after location filtering: {len(filtered_df)}") | |
enhanced_query_parts = [] | |
if query: | |
enhanced_query_parts.append(query) | |
if category: | |
enhanced_query_parts.append(f"{category} category") | |
if city: | |
enhanced_query_parts.append(f" in {city}") | |
if min_lat is not None and max_lat is not None and min_lon is not None and max_lon is not None: | |
enhanced_query_parts.append(f"within latitudes {min_lat} to {max_lat} and longitudes {min_lon} to {max_lon}") | |
# Add location context | |
if location_info: | |
location_context = " ".join(filter(None, [ | |
", ".join(location_info.get('city', [])), | |
location_info.get('state', ''), | |
# location_info.get('country', '') | |
])) | |
if location_context: | |
enhanced_query_parts.append(f"in {location_context}") | |
enhanced_query = " ".join(enhanced_query_parts) | |
if enhanced_query: | |
enhanced_query = expand_category_once(enhanced_query, category) | |
print(f"Filtered by city '{city}': {len(filtered_df)} results") | |
print(f"Enhanced Query: {enhanced_query}") | |
detected_category = extract_category_from_query(enhanced_query) | |
if detected_category: | |
category_filter = ( | |
filtered_df['category'].str.contains(detected_category, case=False, na=False) | | |
filtered_df['combined_field'].str.contains(detected_category, case=False, na=False) | |
) | |
filtered_df = filtered_df[category_filter] | |
print(f"Filtered by query words '{detected_category}': {len(filtered_df)} results") | |
try: | |
query_vector = model.encode([enhanced_query])[0] | |
# Compute embeddings for the filtered DataFrame | |
filtered_embeddings = precomputed_embeddings[filtered_df.index] | |
# Create FAISS index with filtered embeddings | |
filtered_index = faiss.IndexFlatL2(filtered_embeddings.shape[1]) | |
filtered_index.add(filtered_embeddings.astype(np.float32)) | |
# Perform semantic search on filtered results | |
k = min(20, len(filtered_df)) | |
distances, local_indices = filtered_index.search( | |
np.array([query_vector]).astype(np.float32), | |
k | |
) | |
# Get the top results | |
results_df = filtered_df.iloc[local_indices[0]] | |
# Format results | |
formatted_results = [] | |
for i, (_, row) in enumerate(results_df.iterrows(), 1): | |
formatted_results.append( | |
f"\n=== Result {i} ===\n" | |
f"Name: {row['name']}\n" | |
f"Category: {row['category']}\n" | |
f"City: {row['city']}\n" | |
f"Address: {row['address']}\n" | |
f"Description: {row['description']}\n" | |
f"Latitude: {row['latitude']}\n" | |
f"Longitude: {row['longitude']}\n" | |
) | |
search_results = "\n".join(formatted_results) if formatted_results else "No matching locations found." | |
# Optional: Use Mistral for further refinement | |
try: | |
answer = query_mistral(enhanced_query, search_results) | |
except Exception as e: | |
print(f"Error in Mistral query: {e}") | |
answer = "Unable to generate additional insights." | |
return search_results, answer | |
except Exception as e: | |
print(f"Error in semantic search: {e}") | |
return f"Error performing search: {str(e)}", "" | |
def query_mistral(prompt: str, context: str, max_retries: int = 3) -> str: | |
""" | |
Robust Mistral verification with exponential backoff | |
""" | |
import time | |
# Early return if no context | |
if not context or context.strip() == "No matching locations found.": | |
return context | |
verification_prompt = f"""Precise Location Curation Task: | |
REQUIREMENTS: | |
- Source Query: {prompt} | |
- Current Context: {context} | |
DETAILED INSTRUCTIONS: | |
1. Write the min, max latitude and min, max longitude in the beginning taking from the query | |
2. Curate a comprehensive list of 15 locations inside of these coordinates and strictly relevant to place. | |
3. Take STRICTLY ONLY relevant places to Source Query. | |
4. Add a short description about the place (2-3 sentences) | |
5. Add coordinates (lat and long) if there are in the Current Context. | |
6. If no coordinates in the Current Context, then give only name and description | |
7. Add address for the place | |
8. Remove any duplicate entries in the list | |
9. If places > 10, quick generation a new places relevant to Source Query and inside of the coordinates | |
CRITICAL: Do NOT use placeholder. Quick and fast response required | |
""" | |
for attempt in range(max_retries): | |
try: | |
# Robust API configuration | |
response = requests.post( | |
MISTRAL_API_URL, | |
headers={ | |
"Authorization": f"Bearer {MISTRAL_API_KEY}", | |
"Content-Type": "application/json" | |
}, | |
json={ | |
"model": "mistral-large-latest", | |
"messages": [ | |
{"role": "system", "content": "You are a precise location curator specializing in comprehensive travel information."}, | |
{"role": "user", "content": verification_prompt} | |
], | |
"temperature": 0.1, | |
"max_tokens": 5000 | |
}, | |
timeout=100 # Increased timeout | |
) | |
# Enhanced error handling | |
response.raise_for_status() | |
# Extract verified response | |
verified_response = response.json()['choices'][0]['message']['content'] | |
# Validate response length and complexity | |
if len(verified_response.strip()) < 100: | |
if attempt == max_retries - 1: | |
return context | |
time.sleep(2 ** attempt) # Exponential backoff | |
continue | |
return verified_response | |
except requests.Timeout: | |
logging.warning(f"Mistral API timeout (Attempt {attempt + 1}/{max_retries})") | |
if attempt < max_retries - 1: | |
time.sleep(2 ** attempt) # Exponential backoff | |
else: | |
logging.error("Mistral API consistently timing out") | |
return context | |
except requests.RequestException as e: | |
logging.error(f"Mistral API request error: {e}") | |
if attempt < max_retries - 1: | |
time.sleep(2 ** attempt) | |
else: | |
return context | |
except Exception as e: | |
logging.error(f"Unexpected error in Mistral verification: {e}") | |
if attempt < max_retries - 1: | |
time.sleep(2 ** attempt) | |
else: | |
return context | |
return context | |
def create_interface( | |
df: pd.DataFrame, | |
model: SentenceTransformer, | |
precomputed_embeddings: np.ndarray, | |
index: faiss.IndexFlatL2 | |
): | |
"""Create Gradio interface with 4 bounding box inputs""" | |
return gr.Interface( | |
fn=lambda q, min_lat, max_lat, min_lon, max_lon, city, cat: rag_query( | |
query=q, | |
df=df, | |
model=model, | |
precomputed_embeddings=precomputed_embeddings, | |
index=index, | |
min_lat=min_lat, | |
max_lat=max_lat, | |
min_lon=min_lon, | |
max_lon=max_lon, | |
city=city, | |
category=cat | |
)[1], | |
inputs=[ | |
gr.Textbox(lines=2, label="Question"), | |
gr.Textbox(label="Min Latitude"), | |
gr.Textbox(label="Max Latitude"), | |
gr.Textbox(label="Min Longitude"), | |
gr.Textbox(label="Max Longitude"), | |
gr.Textbox(label="City"), | |
gr.Textbox(label="Category") | |
], | |
outputs=[ | |
gr.Textbox(label="Locations Found"), | |
], | |
title="Tourist Information System with Bounding Box Search", | |
examples=[ | |
["Museums in area", "40.71", "40.86", "-74.0", "-74.1", "", "museum"], | |
["Restaurants", "48.8575", "48.9", "2.3514", "2.4", "Paris", "restaurant"], | |
["Coffee shops", "51.5", "51.6", "-0.2", "-0.1", "London", "cafe"], | |
["Spa places", "", "", "", "", "Budapest", ""], | |
["Lambic brewery", "50.84211068618749", "50.849274898691244","4.339536387173865", "4.361188801802462", "", ""], | |
["Art nouveau architecture buildings", "44.42563381188614", "44.43347927669681","26.008709832230608", "26.181744493414488", "", ""], | |
["Harry Potter filming locations", "51.52428877891333", "51.54738884423489", "-0.1955164690977472", "-0.05082973945560466", "", ""] | |
] | |
) | |
if __name__ == "__main__": | |
try: | |
model = SentenceTransformer('all-MiniLM-L6-v2') | |
precomputed_embeddings = embeddings | |
index = faiss.IndexFlatL2(precomputed_embeddings.shape[1]) | |
index.add(precomputed_embeddings.astype(np.float32)) | |
iface = create_interface(df, model, precomputed_embeddings, index) | |
iface.launch(share=True, debug=True) | |
except Exception as e: | |
logging.error(f"Startup error: {e}") | |
sys.exit(1) |