yunusserhat commited on
Commit
cbdb152
·
verified ·
1 Parent(s): d93bc09

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +126 -34
app.py CHANGED
@@ -14,9 +14,19 @@ from models.huggingface import Geolocalizer
14
  import spacy
15
  from collections import Counter
16
  from spacy.cli import download
 
17
 
18
 
19
- def load_spacy_model(model_name="en_core_web_md"):
 
 
 
 
 
 
 
 
 
20
  try:
21
  return spacy.load(model_name)
22
  except IOError:
@@ -31,9 +41,14 @@ IMAGE_SIZE = (224, 224)
31
  GEOLOC_MODEL_NAME = "osv5m/baseline"
32
 
33
 
34
- # Load geolocation model
35
  @st.cache_resource(show_spinner=True)
36
- def load_geoloc_model() -> Geolocalizer:
 
 
 
 
 
 
37
  with st.spinner('Loading model...'):
38
  try:
39
  model = Geolocalizer.from_pretrained(GEOLOC_MODEL_NAME)
@@ -44,31 +59,43 @@ def load_geoloc_model() -> Geolocalizer:
44
  return None
45
 
46
 
47
- # Function to find the most frequent location
48
- def most_frequent_locations(text: str):
 
 
 
 
 
 
 
 
49
  doc = nlp(text)
50
  locations = []
51
 
52
- # Collect all identified location entities
53
  for ent in doc.ents:
54
  if ent.label_ in ['LOC', 'GPE']:
55
  print(f"Entity: {ent.text} | Label: {ent.label_} | Sentence: {ent.sent}")
56
  locations.append(ent.text)
57
 
58
- # Count occurrences and extract the most common locations
59
  if locations:
60
  location_counts = Counter(locations)
61
- most_common_locations = location_counts.most_common(2) # Adjust the number as needed
62
- # Format the output to show location names along with their counts
63
  common_locations_str = ', '.join([f"{loc[0]} ({loc[1]} occurrences)" for loc in most_common_locations])
64
-
65
  return f"Most Mentioned Locations: {common_locations_str}", [loc[0] for loc in most_common_locations]
66
  else:
67
  return "No locations found", []
68
 
69
 
70
- # Transform image for model prediction
71
  def transform_image(image: Image) -> torch.Tensor:
 
 
 
 
 
 
 
 
 
72
  transform = transforms.Compose([
73
  transforms.Resize(IMAGE_SIZE),
74
  transforms.ToTensor(),
@@ -77,7 +104,17 @@ def transform_image(image: Image) -> torch.Tensor:
77
  return transform(image).unsqueeze(0)
78
 
79
 
80
- def check_location_match(location_query, most_common_locations):
 
 
 
 
 
 
 
 
 
 
81
  name = location_query['name']
82
  admin1 = location_query['admin1']
83
  cc = location_query['cc']
@@ -88,8 +125,16 @@ def check_location_match(location_query, most_common_locations):
88
  return False
89
 
90
 
91
- # Fetch city GeoJSON data
92
- def get_city_geojson(location_name: str) -> dict:
 
 
 
 
 
 
 
 
93
  geolocator = Nominatim(user_agent="predictGeolocforImage")
94
  try:
95
  location = geolocator.geocode(location_name, geometry='geojson')
@@ -99,8 +144,16 @@ def get_city_geojson(location_name: str) -> dict:
99
  return None
100
 
101
 
102
- # Fetch media from URL
103
- def get_media(url: str) -> list:
 
 
 
 
 
 
 
 
104
  try:
105
  response = requests.get(url)
106
  response.raise_for_status()
@@ -112,8 +165,17 @@ def get_media(url: str) -> list:
112
  return None
113
 
114
 
115
- # Predict location from image
116
- def predict_location(image: Image, model: Geolocalizer) -> tuple:
 
 
 
 
 
 
 
 
 
117
  with st.spinner('Processing image and predicting location...'):
118
  start_time = time.time()
119
  try:
@@ -130,8 +192,14 @@ def predict_location(image: Image, model: Geolocalizer) -> tuple:
130
  return None
131
 
132
 
133
- # Display map in Streamlit
134
- def display_map(city_geojson: dict, gps_degrees: list) -> None:
 
 
 
 
 
 
135
  map_view = pdk.Deck(
136
  map_style='mapbox://styles/mapbox/light-v9',
137
  initial_view_state=pdk.ViewState(
@@ -156,8 +224,13 @@ def display_map(city_geojson: dict, gps_degrees: list) -> None:
156
  st.pydeck_chart(map_view)
157
 
158
 
159
- # Display image
160
  def display_image(image_url: str) -> None:
 
 
 
 
 
 
161
  try:
162
  response = requests.get(image_url)
163
  response.raise_for_status()
@@ -169,8 +242,16 @@ def display_image(image_url: str) -> None:
169
  st.error(f"An error occurred: {e}")
170
 
171
 
172
- # Scrape webpage for text and images
173
- def scrape_webpage(url: str) -> tuple:
 
 
 
 
 
 
 
 
174
  with st.spinner('Scraping web page...'):
175
  try:
176
  response = requests.get(url)
@@ -185,27 +266,31 @@ def scrape_webpage(url: str) -> tuple:
185
  return None, None
186
 
187
 
188
- def main():
 
 
 
189
  st.title('Welcome to Geolocation Guesstimation Demo 👋')
190
 
191
- # Define page navigation using the sidebar
192
  page = st.sidebar.selectbox(
193
  "Choose your action:",
194
  ("Home", "Images", "Social Media", "Web Pages"),
195
- index=0 # Default to Home
196
  )
197
 
198
  st.sidebar.success("Select a demo above.")
199
  st.sidebar.info(
200
  """
201
  - Web App URL: <https://yunusserhat-guesstimatelocation.hf.space/>
202
- """)
 
203
 
204
  st.sidebar.title("Contact")
205
  st.sidebar.info(
206
  """
207
  Yunus Serhat Bıçakçı at [yunusserhat.com](https://yunusserhat.com) | [GitHub](https://github.com/yunusserhat) | [Twitter](https://twitter.com/yunusserhat) | [LinkedIn](https://www.linkedin.com/in/yunusserhat)
208
- """)
 
209
 
210
  if page == "Home":
211
  st.write("Welcome to the Geolocation Predictor. Please select an action from the sidebar dropdown.")
@@ -220,7 +305,10 @@ def main():
220
  web_page_url_page()
221
 
222
 
223
- def upload_images_page():
 
 
 
224
  st.header("Image Upload for Geolocation Prediction")
225
  uploaded_files = st.file_uploader("Choose images...", type=["jpg", "jpeg", "png"], accept_multiple_files=True)
226
  if uploaded_files:
@@ -230,7 +318,7 @@ def upload_images_page():
230
  st.image(image, caption=f'Uploaded Image: {file.name}', use_column_width=True)
231
  model = load_geoloc_model()
232
  if model:
233
- result = predict_location(image, model) # Assume this function is defined elsewhere
234
  if result:
235
  gps_degrees, location_query, city_geojson, processing_time = result
236
  st.write(
@@ -240,7 +328,10 @@ def upload_images_page():
240
  st.write(f"Processing Time (seconds): {processing_time}")
241
 
242
 
243
- def social_media_page():
 
 
 
244
  st.header("Social Media Analyser")
245
  social_media_url = st.text_input("Enter a social media URL to analyse:", key='social_media_url_input')
246
  if social_media_url:
@@ -270,7 +361,6 @@ def social_media_page():
270
  if city_geojson:
271
  display_map(city_geojson, gps_degrees)
272
  st.write(f"Processing Time (seconds): {processing_time}")
273
- # Check for match and notify
274
  if check_location_match(location_query, most_common_locations):
275
  st.success(
276
  f"The predicted location {location_name} matches one of the most frequently mentioned locations!")
@@ -278,7 +368,10 @@ def social_media_page():
278
  st.error(f"Failed to fetch image at URL {media_url}: HTTP {response.status_code}")
279
 
280
 
281
- def web_page_url_page():
 
 
 
282
  st.header("Web Page Analyser")
283
  web_page_url = st.text_input("Enter a web page URL to scrape:", key='web_page_url_input')
284
  if web_page_url:
@@ -307,7 +400,6 @@ def web_page_url_page():
307
  if city_geojson:
308
  display_map(city_geojson, gps_degrees)
309
  st.write(f"Processing Time (seconds): {processing_time}")
310
- # Check for match and notify
311
  if check_location_match(location_query, most_common_locations):
312
  st.success(
313
  f"The predicted location {location_name} matches one of the most frequently mentioned locations!")
 
14
  import spacy
15
  from collections import Counter
16
  from spacy.cli import download
17
+ from typing import Tuple, List, Optional, Union, Dict
18
 
19
 
20
+ def load_spacy_model(model_name: str = "en_core_web_md") -> spacy.Language:
21
+ """
22
+ Load the specified spaCy model.
23
+
24
+ Args:
25
+ model_name (str): Name of the spaCy model to load.
26
+
27
+ Returns:
28
+ spacy.Language: Loaded spaCy model.
29
+ """
30
  try:
31
  return spacy.load(model_name)
32
  except IOError:
 
41
  GEOLOC_MODEL_NAME = "osv5m/baseline"
42
 
43
 
 
44
  @st.cache_resource(show_spinner=True)
45
+ def load_geoloc_model() -> Optional[Geolocalizer]:
46
+ """
47
+ Load the geolocation model.
48
+
49
+ Returns:
50
+ Optional[Geolocalizer]: Loaded geolocation model or None if loading fails.
51
+ """
52
  with st.spinner('Loading model...'):
53
  try:
54
  model = Geolocalizer.from_pretrained(GEOLOC_MODEL_NAME)
 
59
  return None
60
 
61
 
62
+ def most_frequent_locations(text: str) -> Tuple[str, List[str]]:
63
+ """
64
+ Find the most frequent locations mentioned in the text.
65
+
66
+ Args:
67
+ text (str): Input text to analyze.
68
+
69
+ Returns:
70
+ Tuple[str, List[str]]: Description of the most mentioned locations and a list of those locations.
71
+ """
72
  doc = nlp(text)
73
  locations = []
74
 
 
75
  for ent in doc.ents:
76
  if ent.label_ in ['LOC', 'GPE']:
77
  print(f"Entity: {ent.text} | Label: {ent.label_} | Sentence: {ent.sent}")
78
  locations.append(ent.text)
79
 
 
80
  if locations:
81
  location_counts = Counter(locations)
82
+ most_common_locations = location_counts.most_common(2)
 
83
  common_locations_str = ', '.join([f"{loc[0]} ({loc[1]} occurrences)" for loc in most_common_locations])
 
84
  return f"Most Mentioned Locations: {common_locations_str}", [loc[0] for loc in most_common_locations]
85
  else:
86
  return "No locations found", []
87
 
88
 
 
89
  def transform_image(image: Image) -> torch.Tensor:
90
+ """
91
+ Transform the input image for model prediction.
92
+
93
+ Args:
94
+ image (Image): Input image.
95
+
96
+ Returns:
97
+ torch.Tensor: Transformed image tensor.
98
+ """
99
  transform = transforms.Compose([
100
  transforms.Resize(IMAGE_SIZE),
101
  transforms.ToTensor(),
 
104
  return transform(image).unsqueeze(0)
105
 
106
 
107
+ def check_location_match(location_query: dict, most_common_locations: List[str]) -> bool:
108
+ """
109
+ Check if the predicted location matches any of the most common locations.
110
+
111
+ Args:
112
+ location_query (dict): Predicted location details.
113
+ most_common_locations (List[str]): List of most common locations.
114
+
115
+ Returns:
116
+ bool: True if a match is found, False otherwise.
117
+ """
118
  name = location_query['name']
119
  admin1 = location_query['admin1']
120
  cc = location_query['cc']
 
125
  return False
126
 
127
 
128
+ def get_city_geojson(location_name: str) -> Optional[dict]:
129
+ """
130
+ Fetch the GeoJSON data for the specified city.
131
+
132
+ Args:
133
+ location_name (str): Name of the city.
134
+
135
+ Returns:
136
+ Optional[dict]: GeoJSON data of the city or None if fetching fails.
137
+ """
138
  geolocator = Nominatim(user_agent="predictGeolocforImage")
139
  try:
140
  location = geolocator.geocode(location_name, geometry='geojson')
 
144
  return None
145
 
146
 
147
+ def get_media(url: str) -> Optional[List[Tuple[str, str]]]:
148
+ """
149
+ Fetch media URLs and associated text from the specified URL.
150
+
151
+ Args:
152
+ url (str): URL to fetch media from.
153
+
154
+ Returns:
155
+ Optional[List[Tuple[str, str]]]: List of tuples containing media URLs and associated text or None if fetching fails.
156
+ """
157
  try:
158
  response = requests.get(url)
159
  response.raise_for_status()
 
165
  return None
166
 
167
 
168
+ def predict_location(image: Image, model: Geolocalizer) -> Optional[Tuple[List[float], dict, Optional[dict], float]]:
169
+ """
170
+ Predict the location from the input image using the specified model.
171
+
172
+ Args:
173
+ image (Image): Input image.
174
+ model (Geolocalizer): Geolocation model.
175
+
176
+ Returns:
177
+ Optional[Tuple[List[float], dict, Optional[dict], float]]: Predicted GPS coordinates, location query, city GeoJSON data, and processing time or None if prediction fails.
178
+ """
179
  with st.spinner('Processing image and predicting location...'):
180
  start_time = time.time()
181
  try:
 
192
  return None
193
 
194
 
195
+ def display_map(city_geojson: dict, gps_degrees: List[float]) -> None:
196
+ """
197
+ Display a map with the specified city GeoJSON data and GPS coordinates.
198
+
199
+ Args:
200
+ city_geojson (dict): GeoJSON data of the city.
201
+ gps_degrees (List[float]): GPS coordinates.
202
+ """
203
  map_view = pdk.Deck(
204
  map_style='mapbox://styles/mapbox/light-v9',
205
  initial_view_state=pdk.ViewState(
 
224
  st.pydeck_chart(map_view)
225
 
226
 
 
227
  def display_image(image_url: str) -> None:
228
+ """
229
+ Display an image from the specified URL.
230
+
231
+ Args:
232
+ image_url (str): URL of the image.
233
+ """
234
  try:
235
  response = requests.get(image_url)
236
  response.raise_for_status()
 
242
  st.error(f"An error occurred: {e}")
243
 
244
 
245
+ def scrape_webpage(url: str) -> Union[Tuple[Optional[str], Optional[List[str]]], Tuple[None, None]]:
246
+ """
247
+ Scrape the specified webpage for text and images.
248
+
249
+ Args:
250
+ url (str): URL of the webpage to scrape.
251
+
252
+ Returns:
253
+ Union[Tuple[Optional[str], Optional[List[str]]], Tuple[None, None]]: Extracted text and list of image URLs or None if scraping fails.
254
+ """
255
  with st.spinner('Scraping web page...'):
256
  try:
257
  response = requests.get(url)
 
266
  return None, None
267
 
268
 
269
+ def main() -> None:
270
+ """
271
+ Main function to run the Streamlit app.
272
+ """
273
  st.title('Welcome to Geolocation Guesstimation Demo 👋')
274
 
 
275
  page = st.sidebar.selectbox(
276
  "Choose your action:",
277
  ("Home", "Images", "Social Media", "Web Pages"),
278
+ index=0
279
  )
280
 
281
  st.sidebar.success("Select a demo above.")
282
  st.sidebar.info(
283
  """
284
  - Web App URL: <https://yunusserhat-guesstimatelocation.hf.space/>
285
+ """
286
+ )
287
 
288
  st.sidebar.title("Contact")
289
  st.sidebar.info(
290
  """
291
  Yunus Serhat Bıçakçı at [yunusserhat.com](https://yunusserhat.com) | [GitHub](https://github.com/yunusserhat) | [Twitter](https://twitter.com/yunusserhat) | [LinkedIn](https://www.linkedin.com/in/yunusserhat)
292
+ """
293
+ )
294
 
295
  if page == "Home":
296
  st.write("Welcome to the Geolocation Predictor. Please select an action from the sidebar dropdown.")
 
305
  web_page_url_page()
306
 
307
 
308
+ def upload_images_page() -> None:
309
+ """
310
+ Display the image upload page for geolocation prediction.
311
+ """
312
  st.header("Image Upload for Geolocation Prediction")
313
  uploaded_files = st.file_uploader("Choose images...", type=["jpg", "jpeg", "png"], accept_multiple_files=True)
314
  if uploaded_files:
 
318
  st.image(image, caption=f'Uploaded Image: {file.name}', use_column_width=True)
319
  model = load_geoloc_model()
320
  if model:
321
+ result = predict_location(image, model)
322
  if result:
323
  gps_degrees, location_query, city_geojson, processing_time = result
324
  st.write(
 
328
  st.write(f"Processing Time (seconds): {processing_time}")
329
 
330
 
331
+ def social_media_page() -> None:
332
+ """
333
+ Display the social media analysis page.
334
+ """
335
  st.header("Social Media Analyser")
336
  social_media_url = st.text_input("Enter a social media URL to analyse:", key='social_media_url_input')
337
  if social_media_url:
 
361
  if city_geojson:
362
  display_map(city_geojson, gps_degrees)
363
  st.write(f"Processing Time (seconds): {processing_time}")
 
364
  if check_location_match(location_query, most_common_locations):
365
  st.success(
366
  f"The predicted location {location_name} matches one of the most frequently mentioned locations!")
 
368
  st.error(f"Failed to fetch image at URL {media_url}: HTTP {response.status_code}")
369
 
370
 
371
+ def web_page_url_page() -> None:
372
+ """
373
+ Display the web page URL analysis page.
374
+ """
375
  st.header("Web Page Analyser")
376
  web_page_url = st.text_input("Enter a web page URL to scrape:", key='web_page_url_input')
377
  if web_page_url:
 
400
  if city_geojson:
401
  display_map(city_geojson, gps_degrees)
402
  st.write(f"Processing Time (seconds): {processing_time}")
 
403
  if check_location_match(location_query, most_common_locations):
404
  st.success(
405
  f"The predicted location {location_name} matches one of the most frequently mentioned locations!")