Spaces:
Running
Running
import streamlit as st | |
from PIL import Image | |
import torch | |
from torchvision import transforms | |
import pydeck as pdk | |
from geopy.geocoders import Nominatim | |
import time | |
import requests | |
from io import BytesIO | |
import reverse_geocoder as rg | |
from bs4 import BeautifulSoup | |
from urllib.parse import urljoin | |
from models.huggingface import Geolocalizer | |
import spacy | |
from collections import Counter | |
from spacy.cli import download | |
from typing import Tuple, List, Optional, Union, Dict | |
def load_spacy_model(model_name: str = "en_core_web_md") -> spacy.Language: | |
""" | |
Load the specified spaCy model. | |
Args: | |
model_name (str): Name of the spaCy model to load. | |
Returns: | |
spacy.Language: Loaded spaCy model. | |
""" | |
try: | |
return spacy.load(model_name) | |
except IOError: | |
print(f"Model {model_name} not found, downloading...") | |
download(model_name) | |
return spacy.load(model_name) | |
nlp = load_spacy_model() | |
IMAGE_SIZE = (224, 224) | |
GEOLOC_MODEL_NAME = "osv5m/baseline" | |
def load_geoloc_model() -> Optional[Geolocalizer]: | |
""" | |
Load the geolocation model. | |
Returns: | |
Optional[Geolocalizer]: Loaded geolocation model or None if loading fails. | |
""" | |
with st.spinner('Loading model...'): | |
try: | |
model = Geolocalizer.from_pretrained(GEOLOC_MODEL_NAME) | |
model.eval() | |
return model | |
except Exception as e: | |
st.error(f"Failed to load the model: {e}") | |
return None | |
def most_frequent_locations(text: str) -> Tuple[str, List[str]]: | |
""" | |
Find the most frequent locations mentioned in the text. | |
Args: | |
text (str): Input text to analyze. | |
Returns: | |
Tuple[str, List[str]]: Description of the most mentioned locations and a list of those locations. | |
""" | |
doc = nlp(text) | |
locations = [] | |
for ent in doc.ents: | |
if ent.label_ in ['LOC', 'GPE']: | |
print(f"Entity: {ent.text} | Label: {ent.label_} | Sentence: {ent.sent}") | |
locations.append(ent.text) | |
if locations: | |
location_counts = Counter(locations) | |
most_common_locations = location_counts.most_common(2) | |
common_locations_str = ', '.join([f"{loc[0]} ({loc[1]} occurrences)" for loc in most_common_locations]) | |
return f"Most Mentioned Locations: {common_locations_str}", [loc[0] for loc in most_common_locations] | |
else: | |
return "No locations found", [] | |
def transform_image(image: Image) -> torch.Tensor: | |
""" | |
Transform the input image for model prediction. | |
Args: | |
image (Image): Input image. | |
Returns: | |
torch.Tensor: Transformed image tensor. | |
""" | |
transform = transforms.Compose([ | |
transforms.Resize(IMAGE_SIZE), | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), | |
]) | |
return transform(image).unsqueeze(0) | |
def check_location_match(location_query: dict, most_common_locations: List[str]) -> bool: | |
""" | |
Check if the predicted location matches any of the most common locations. | |
Args: | |
location_query (dict): Predicted location details. | |
most_common_locations (List[str]): List of most common locations. | |
Returns: | |
bool: True if a match is found, False otherwise. | |
""" | |
name = location_query['name'] | |
admin1 = location_query['admin1'] | |
cc = location_query['cc'] | |
for loc in most_common_locations: | |
if name in loc and admin1 in loc and cc in loc: | |
return True | |
return False | |
def get_city_geojson(location_name: str) -> Optional[dict]: | |
""" | |
Fetch the GeoJSON data for the specified city. | |
Args: | |
location_name (str): Name of the city. | |
Returns: | |
Optional[dict]: GeoJSON data of the city or None if fetching fails. | |
""" | |
geolocator = Nominatim(user_agent="predictGeolocforImage") | |
try: | |
location = geolocator.geocode(location_name, geometry='geojson') | |
return location.raw['geojson'] if location else None | |
except Exception as e: | |
st.error(f"Failed to geocode location: {e}") | |
return None | |
def get_media(url: str) -> Optional[List[Tuple[str, str]]]: | |
""" | |
Fetch media URLs and associated text from the specified URL. | |
Args: | |
url (str): URL to fetch media from. | |
Returns: | |
Optional[List[Tuple[str, str]]]: List of tuples containing media URLs and associated text or None if fetching fails. | |
""" | |
try: | |
response = requests.get(url) | |
response.raise_for_status() | |
data = response.json() | |
return [(media['media_url'], entry['full_text']) | |
for entry in data for media in entry.get('media', []) if 'media_url' in media] | |
except requests.RequestException as e: | |
st.error(f"Failed to fetch media URL: {e}") | |
return None | |
def predict_location(image: Image, model: Geolocalizer) -> Optional[Tuple[List[float], dict, Optional[dict], float]]: | |
""" | |
Predict the location from the input image using the specified model. | |
Args: | |
image (Image): Input image. | |
model (Geolocalizer): Geolocation model. | |
Returns: | |
Optional[Tuple[List[float], dict, Optional[dict], float]]: Predicted GPS coordinates, location query, city GeoJSON data, and processing time or None if prediction fails. | |
""" | |
with st.spinner('Processing image and predicting location...'): | |
start_time = time.time() | |
try: | |
img_tensor = transform_image(image) | |
gps_radians = model(img_tensor) | |
gps_degrees = torch.rad2deg(gps_radians).squeeze(0).cpu().tolist() | |
location_query = rg.search((gps_degrees[0], gps_degrees[1]))[0] | |
location_name = f"{location_query['name']}, {location_query['admin1']}, {location_query['cc']}" | |
city_geojson = get_city_geojson(location_name) | |
processing_time = time.time() - start_time | |
return gps_degrees, location_query, city_geojson, processing_time | |
except Exception as e: | |
st.error(f"Failed to predict the location: {e}") | |
return None | |
def display_map(city_geojson: dict, gps_degrees: List[float]) -> None: | |
""" | |
Display a map with the specified city GeoJSON data and GPS coordinates. | |
Args: | |
city_geojson (dict): GeoJSON data of the city. | |
gps_degrees (List[float]): GPS coordinates. | |
""" | |
map_view = pdk.Deck( | |
map_style='mapbox://styles/mapbox/light-v9', | |
initial_view_state=pdk.ViewState( | |
latitude=gps_degrees[0], | |
longitude=gps_degrees[1], | |
zoom=8, | |
pitch=0, | |
), | |
layers=[ | |
pdk.Layer( | |
'GeoJsonLayer', | |
data=city_geojson, | |
get_fill_color=[255, 180, 0, 140], | |
pickable=True, | |
stroked=True, | |
filled=True, | |
extruded=False, | |
line_width_min_pixels=1, | |
), | |
], | |
) | |
st.pydeck_chart(map_view) | |
def display_image(image_url: str) -> None: | |
""" | |
Display an image from the specified URL. | |
Args: | |
image_url (str): URL of the image. | |
""" | |
try: | |
response = requests.get(image_url) | |
response.raise_for_status() | |
image_bytes = BytesIO(response.content) | |
st.image(image_bytes, caption=f'Image from URL: {image_url}', use_column_width=True) | |
except requests.RequestException as e: | |
st.error(f"Failed to fetch image at URL {image_url}: {e}") | |
except Exception as e: | |
st.error(f"An error occurred: {e}") | |
def scrape_webpage(url: str) -> Union[Tuple[Optional[str], Optional[List[str]]], Tuple[None, None]]: | |
""" | |
Scrape the specified webpage for text and images. | |
Args: | |
url (str): URL of the webpage to scrape. | |
Returns: | |
Union[Tuple[Optional[str], Optional[List[str]]], Tuple[None, None]]: Extracted text and list of image URLs or None if scraping fails. | |
""" | |
with st.spinner('Scraping web page...'): | |
try: | |
response = requests.get(url) | |
response.raise_for_status() | |
soup = BeautifulSoup(response.content, 'html.parser') | |
base_url = url # Adjust based on <base> tags or other HTML clues | |
text = ''.join(p.text for p in soup.find_all('p')) | |
images = [urljoin(base_url, img['src']) for img in soup.find_all('img') if 'src' in img.attrs] | |
return text, images | |
except requests.RequestException as e: | |
st.error(f"Failed to fetch and parse the URL: {e}") | |
return None, None | |
def main() -> None: | |
""" | |
Main function to run the Streamlit app. | |
""" | |
st.title('Welcome to Geolocation Guesstimation Demo 👋') | |
page = st.sidebar.selectbox( | |
"Choose your action:", | |
("Home", "Images", "Social Media", "Web Pages"), | |
index=0 | |
) | |
st.sidebar.success("Select a demo above.") | |
st.sidebar.info( | |
""" | |
- Web App URL: <https://yunusserhat-guesstimatelocation.hf.space/> | |
""" | |
) | |
st.sidebar.title("Contact") | |
st.sidebar.info( | |
""" | |
Yunus Serhat Bıçakçı at [yunusserhat.com](https://yunusserhat.com) | [GitHub](https://github.com/yunusserhat) | [Twitter](https://twitter.com/yunusserhat) | [LinkedIn](https://www.linkedin.com/in/yunusserhat) | |
""" | |
) | |
if page == "Home": | |
st.write("Welcome to the Geolocation Predictor. Please select an action from the sidebar dropdown.") | |
elif page == "Images": | |
upload_images_page() | |
elif page == "Social Media": | |
social_media_page() | |
elif page == "Web Pages": | |
web_page_url_page() | |
def upload_images_page() -> None: | |
""" | |
Display the image upload page for geolocation prediction. | |
""" | |
st.header("Image Upload for Geolocation Prediction") | |
uploaded_files = st.file_uploader("Choose images...", type=["jpg", "jpeg", "png"], accept_multiple_files=True) | |
if uploaded_files: | |
for idx, file in enumerate(uploaded_files, start=1): | |
with st.spinner(f"Processing {file.name}..."): | |
image = Image.open(file).convert('RGB') | |
st.image(image, caption=f'Uploaded Image: {file.name}', use_column_width=True) | |
model = load_geoloc_model() | |
if model: | |
result = predict_location(image, model) | |
if result: | |
gps_degrees, location_query, city_geojson, processing_time = result | |
st.write( | |
f"City: {location_query['name']}, Region: {location_query['admin1']}, Country: {location_query['cc']}") | |
if city_geojson: | |
display_map(city_geojson, gps_degrees) | |
st.write(f"Processing Time (seconds): {processing_time}") | |
def social_media_page() -> None: | |
""" | |
Display the social media analysis page. | |
""" | |
st.header("Social Media Analyser") | |
social_media_url = st.text_input("Enter a social media URL to analyse:", key='social_media_url_input') | |
if social_media_url: | |
media_data = get_media(social_media_url) | |
if media_data: | |
full_text = media_data[0][1] | |
st.subheader("Full Text") | |
st.write(full_text) | |
most_used_location, most_common_locations = most_frequent_locations(full_text) | |
st.subheader("Most Frequent Location") | |
st.write(most_used_location) | |
for idx, (media_url, _) in enumerate(media_data, start=1): | |
st.subheader(f"Image {idx}") | |
response = requests.get(media_url) | |
if response.status_code == 200: | |
image = Image.open(BytesIO(response.content)).convert('RGB') | |
st.image(image, caption=f'Image from URL: {media_url}', use_column_width=True) | |
model = load_geoloc_model() | |
if model: | |
result = predict_location(image, model) | |
if result: | |
gps_degrees, location_query, city_geojson, processing_time = result | |
location_name = f"{location_query['name']}, {location_query['admin1']}, {location_query['cc']}" | |
st.write( | |
f"City: {location_query['name']}, Region: {location_query['admin1']}, Country: {location_query['cc']}") | |
if city_geojson: | |
display_map(city_geojson, gps_degrees) | |
st.write(f"Processing Time (seconds): {processing_time}") | |
if check_location_match(location_query, most_common_locations): | |
st.success( | |
f"The predicted location {location_name} matches one of the most frequently mentioned locations!") | |
else: | |
st.error(f"Failed to fetch image at URL {media_url}: HTTP {response.status_code}") | |
def web_page_url_page() -> None: | |
""" | |
Display the web page URL analysis page. | |
""" | |
st.header("Web Page Analyser") | |
web_page_url = st.text_input("Enter a web page URL to scrape:", key='web_page_url_input') | |
if web_page_url: | |
text, images = scrape_webpage(web_page_url) | |
if text: | |
st.subheader("Extracted Text First 500 Characters:") | |
st.write(text[:500]) | |
most_used_location, most_common_locations = most_frequent_locations(text) | |
st.subheader("Most Frequent Location") | |
st.write(most_used_location) | |
if images: | |
selected_image_url = st.selectbox("Select an image to predict location:", images) | |
if selected_image_url: | |
response = requests.get(selected_image_url) | |
if response.status_code == 200: | |
image = Image.open(BytesIO(response.content)).convert('RGB') | |
st.image(image, caption=f'Selected Image from URL: {selected_image_url}', use_column_width=True) | |
model = load_geoloc_model() | |
if model: | |
result = predict_location(image, model) | |
if result: | |
gps_degrees, location_query, city_geojson, processing_time = result | |
location_name = f"{location_query['name']}, {location_query['admin1']}, {location_query['cc']}" | |
st.write( | |
f"City: {location_query['name']}, Region: {location_query['admin1']}, Country: {location_query['cc']}") | |
if city_geojson: | |
display_map(city_geojson, gps_degrees) | |
st.write(f"Processing Time (seconds): {processing_time}") | |
if check_location_match(location_query, most_common_locations): | |
st.success( | |
f"The predicted location {location_name} matches one of the most frequently mentioned locations!") | |
if __name__ == '__main__': | |
main() | |