nicolas-dufour commited on
Commit
68bc627
Β·
1 Parent(s): 3648fa8

initial commit

Browse files
Files changed (2) hide show
  1. app.py +388 -0
  2. requirements.txt +6 -0
app.py ADDED
@@ -0,0 +1,388 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import pandas as pd
3
+ from PIL import Image
4
+ import torch
5
+ from plonk.pipe import PlonkPipeline
6
+ from pathlib import Path
7
+ from streamlit_extras.colored_header import colored_header
8
+ import plotly.express as px
9
+ import requests
10
+ from io import BytesIO
11
+
12
+ # Set page config
13
+ st.set_page_config(
14
+ page_title="Around the World in 80 Timesteps", page_icon="πŸ—ΊοΈ", layout="wide"
15
+ )
16
+
17
+ device = "cuda" if torch.cuda.is_available() else "cpu"
18
+ PROJECT_ROOT = Path(__file__).parent.parent.absolute()
19
+ # Define checkpoint path
20
+ CHECKPOINT_DIR = PROJECT_ROOT / "checkpoints"
21
+
22
+ MODEL_NAMES = {
23
+ "PLONK_YFCC": "nicolas-dufour/PLONK_YFCC",
24
+ "PLONK_OSV_5M": "nicolas-dufour/PLONK_OSV_5M",
25
+ "PLONK_iNaturalist": "nicolas-dufour/PLONK_iNaturalist",
26
+ }
27
+
28
+
29
+ @st.cache_resource
30
+ def load_model(model_name):
31
+ """Load the model and cache it to prevent reloading"""
32
+ try:
33
+ pipe = PlonkPipeline(model_path=model_name)
34
+ return pipe
35
+ except Exception as e:
36
+ st.error(f"Error loading model: {str(e)}")
37
+ st.stop()
38
+
39
+
40
+ PIPES = {model_name: load_model(MODEL_NAMES[model_name]) for model_name in MODEL_NAMES}
41
+
42
+
43
+ def predict_location(image, model_name, cfg=0.0, num_samples=256):
44
+ with torch.no_grad():
45
+ batch = {"img": [], "emb": []}
46
+
47
+ # If image is already a PIL Image, use it directly
48
+ if isinstance(image, Image.Image):
49
+ img = image.convert("RGB")
50
+ else:
51
+ img = Image.open(image).convert("RGB")
52
+
53
+ pipe = PIPES[model_name]
54
+
55
+ # Get regular predictions
56
+ predicted_gps = pipe(img, batch_size=num_samples, cfg=cfg, num_steps=32)
57
+
58
+ # Get single high-confidence prediction
59
+ high_conf_gps = pipe(img, batch_size=1, cfg=2.0, num_steps=32)
60
+ return {
61
+ "lat": predicted_gps[:, 0].astype(float).tolist(),
62
+ "lon": predicted_gps[:, 1].astype(float).tolist(),
63
+ "high_conf_lat": high_conf_gps[0, 0].astype(float),
64
+ "high_conf_lon": high_conf_gps[0, 1].astype(float),
65
+ }
66
+
67
+
68
+ def load_example_images():
69
+ """Load example images from the examples directory"""
70
+ examples_dir = Path(__file__).parent / "examples"
71
+ if not examples_dir.exists():
72
+ st.error(
73
+ """
74
+ Examples directory not found. Please create the following structure:
75
+ demo/
76
+ └── examples/
77
+ β”œβ”€β”€ eiffel_tower.jpg
78
+ β”œβ”€β”€ colosseum.jpg
79
+ β”œβ”€β”€ taj_mahal.jpg
80
+ β”œβ”€β”€ statue_liberty.jpg
81
+ └── sydney_opera.jpg
82
+ """
83
+ )
84
+ return {}
85
+
86
+ examples = {}
87
+ for img_path in examples_dir.glob("*.jpg"):
88
+ # Use filename without extension as the key
89
+ name = img_path.stem.replace("_", " ").title()
90
+ examples[name] = str(img_path)
91
+
92
+ if not examples:
93
+ st.warning("No example images found in the examples directory.")
94
+
95
+ return examples
96
+
97
+
98
+ def resize_image_for_display(image, max_size=400):
99
+ """Resize image while maintaining aspect ratio"""
100
+ # Get current size
101
+ width, height = image.size
102
+
103
+ # Calculate ratio to maintain aspect ratio
104
+ if width > height:
105
+ if width > max_size:
106
+ ratio = max_size / width
107
+ new_size = (max_size, int(height * ratio))
108
+ else:
109
+ if height > max_size:
110
+ ratio = max_size / height
111
+ new_size = (int(width * ratio), max_size)
112
+
113
+ # Only resize if image is larger than max_size
114
+ if width > max_size or height > max_size:
115
+ return image.resize(new_size, Image.Resampling.LANCZOS)
116
+ return image
117
+
118
+
119
+ def load_image_from_url(url):
120
+ """Load an image from a URL"""
121
+ try:
122
+ response = requests.get(url)
123
+ response.raise_for_status() # Raise an exception for bad status codes
124
+ return Image.open(BytesIO(response.content))
125
+ except Exception as e:
126
+ st.error(f"Error loading image from URL: {str(e)}")
127
+ return None
128
+
129
+
130
+ def main():
131
+ # Custom CSS
132
+ st.markdown(
133
+ """
134
+ <style>
135
+ .main {
136
+ padding: 0rem 1rem;
137
+ }
138
+ .stButton>button {
139
+ width: 100%;
140
+ background-color: #FF4B4B;
141
+ color: white;
142
+ border: none;
143
+ padding: 0.5rem 1rem;
144
+ border-radius: 0.5rem;
145
+ }
146
+ .stButton>button:hover {
147
+ background-color: #FF6B6B;
148
+ }
149
+ .prediction-box {
150
+ background-color: #f0f2f6;
151
+ padding: 1.5rem;
152
+ border-radius: 0.5rem;
153
+ margin: 1rem 0;
154
+ }
155
+ /* New styles for image containers */
156
+ .upload-container {
157
+ max-height: 300px;
158
+ overflow-y: auto;
159
+ margin-bottom: 1rem;
160
+ }
161
+ .examples-container {
162
+ max-height: 200px;
163
+ display: flex;
164
+ gap: 10px;
165
+ }
166
+ .stTabs [data-baseweb="tab-panel"] {
167
+ padding-top: 1rem;
168
+ }
169
+ </style>
170
+ """,
171
+ unsafe_allow_html=True,
172
+ )
173
+
174
+ # Header with custom styling
175
+ colored_header(
176
+ label="πŸ—ΊοΈ Around the World in 80 Timesteps: A Generative Approach to Global Visual Geolocation",
177
+ description="Upload an image and our model, PLONK, will predict possible locations! In red we will sample one point with guidance scale 2.0 for the best guess. <br> <br> Project page: https://nicolas-dufour.github.io/plonk",
178
+ color_name="red-70",
179
+ )
180
+
181
+ # Adjust column ratio to give 2/3 of the space to the map
182
+ col1, col2 = st.columns([1, 2], gap="large")
183
+
184
+ with col1:
185
+ # Add model selection before the sliders
186
+ model_name = st.selectbox(
187
+ "πŸ€– Select Model",
188
+ options=MODEL_NAMES.keys(),
189
+ index=0, # Default to YFCC
190
+ help="Choose which PLONK model variant to use for prediction.",
191
+ )
192
+
193
+ # Modify the slider columns to accommodate both controls
194
+ col_slider1, col_slider2 = st.columns([0.5, 0.5])
195
+ with col_slider1:
196
+ cfg_value = st.slider(
197
+ "🎯 Guidance scale",
198
+ min_value=0.0,
199
+ max_value=5.0,
200
+ value=0.0,
201
+ step=0.1,
202
+ help="Scale for classifier-free guidance during sampling. A small value makes the model predictions display the diversity of the model, while a large value makes the model predictions more conservative but potentially more accurate.",
203
+ )
204
+
205
+ with col_slider2:
206
+ num_samples = st.number_input(
207
+ "🎲 Number of samples",
208
+ min_value=1,
209
+ max_value=5000,
210
+ value=1000,
211
+ step=1,
212
+ help="Number of location predictions to generate. More samples give better coverage but take longer to compute.",
213
+ )
214
+
215
+ st.markdown("### πŸ“Έ Choose your image")
216
+ tab1, tab2, tab3 = st.tabs(["Upload", "URL", "Examples"])
217
+
218
+ with tab1:
219
+ uploaded_file = st.file_uploader(
220
+ "Choose an image...",
221
+ type=["png", "jpg", "jpeg"],
222
+ help="Supported formats: PNG, JPG, JPEG",
223
+ )
224
+
225
+ if uploaded_file is not None:
226
+ st.markdown('<div class="upload-container">', unsafe_allow_html=True)
227
+ original_image = Image.open(uploaded_file)
228
+ display_image = resize_image_for_display(
229
+ original_image.copy(), max_size=300
230
+ )
231
+ st.image(
232
+ display_image, caption="Uploaded Image", use_container_width=True
233
+ )
234
+ st.markdown("</div>", unsafe_allow_html=True)
235
+
236
+ if st.button("πŸ” Predict Location", key="predict_upload"):
237
+ with st.spinner("🌍 Analyzing image and predicting locations..."):
238
+ predictions = predict_location(
239
+ original_image,
240
+ model_name=model_name,
241
+ cfg=cfg_value,
242
+ num_samples=num_samples,
243
+ )
244
+ st.session_state["predictions"] = predictions
245
+
246
+ with tab2:
247
+ url = st.text_input("Enter image URL:", key="image_url")
248
+
249
+ if url:
250
+ image = load_image_from_url(url)
251
+ if image:
252
+ st.markdown(
253
+ '<div class="upload-container">', unsafe_allow_html=True
254
+ )
255
+ display_image = resize_image_for_display(image.copy(), max_size=300)
256
+ st.image(
257
+ display_image,
258
+ caption="Image from URL",
259
+ use_container_width=True,
260
+ )
261
+ st.markdown("</div>", unsafe_allow_html=True)
262
+
263
+ if st.button("πŸ” Predict Location", key="predict_url"):
264
+ with st.spinner(
265
+ "🌍 Analyzing image and predicting locations..."
266
+ ):
267
+ predictions = predict_location(
268
+ image,
269
+ model_name=model_name,
270
+ cfg=cfg_value,
271
+ num_samples=num_samples,
272
+ )
273
+ st.session_state["predictions"] = predictions
274
+
275
+ with tab3:
276
+ examples = load_example_images()
277
+ st.markdown('<div class="examples-container">', unsafe_allow_html=True)
278
+ example_cols = st.columns(len(examples))
279
+
280
+ for idx, (name, path) in enumerate(examples.items()):
281
+ with example_cols[idx]:
282
+ original_image = Image.open(path)
283
+ display_image = resize_image_for_display(
284
+ original_image.copy(), max_size=150
285
+ )
286
+
287
+ if st.container().button(
288
+ "πŸ“Έ",
289
+ key=f"img_{name}",
290
+ help=f"Click to predict location for {name}",
291
+ use_container_width=True,
292
+ ):
293
+ with st.spinner(
294
+ "🌍 Analyzing image and predicting locations..."
295
+ ):
296
+ predictions = predict_location(
297
+ original_image,
298
+ model_name=model_name,
299
+ cfg=cfg_value,
300
+ num_samples=num_samples,
301
+ )
302
+ st.session_state["predictions"] = predictions
303
+ st.rerun()
304
+
305
+ st.image(display_image, caption=name, use_container_width=True)
306
+ st.markdown("</div>", unsafe_allow_html=True)
307
+
308
+ with col2:
309
+ st.markdown("### 🌍 Predicted Locations")
310
+
311
+ if "predictions" in st.session_state:
312
+ pred = st.session_state["predictions"]
313
+
314
+ # Create DataFrame for all predictions
315
+ df = pd.DataFrame(
316
+ {
317
+ "lat": pred["lat"],
318
+ "lon": pred["lon"],
319
+ "type": ["Sample"] * len(pred["lat"]),
320
+ }
321
+ )
322
+
323
+ # Add high-confidence prediction
324
+ df = pd.concat(
325
+ [
326
+ df,
327
+ pd.DataFrame(
328
+ {
329
+ "lat": [pred["high_conf_lat"]],
330
+ "lon": [pred["high_conf_lon"]],
331
+ "type": ["Best Guess"],
332
+ }
333
+ ),
334
+ ]
335
+ )
336
+
337
+ # Create a more interactive map using Plotly
338
+ fig = px.scatter_mapbox(
339
+ df,
340
+ lat="lat",
341
+ lon="lon",
342
+ zoom=2,
343
+ opacity=0.6,
344
+ color="type",
345
+ color_discrete_map={"Sample": "blue", "Best Guess": "red"},
346
+ mapbox_style="carto-positron",
347
+ )
348
+
349
+ fig.update_traces(selector=dict(name="Best Guess"), marker_size=15)
350
+
351
+ fig.update_layout(
352
+ margin={"r": 0, "t": 0, "l": 0, "b": 0},
353
+ height=500,
354
+ showlegend=True,
355
+ legend=dict(yanchor="top", y=0.99, xanchor="left", x=0.01),
356
+ )
357
+
358
+ # Display map in a container
359
+ with st.container():
360
+ st.plotly_chart(fig, use_container_width=True)
361
+
362
+ # Display stats in a styled container
363
+ with st.container():
364
+ st.markdown(
365
+ f"""
366
+ <div class="prediction-box">
367
+ <h4>πŸ“Š Prediction Statistics</h4>
368
+ <p>Number of sampled locations: {len(pred["lat"])}</p>
369
+ <p>Best guess location: {pred["high_conf_lat"]:.2f}Β°, {pred["high_conf_lon"]:.2f}Β°</p>
370
+ </div>
371
+ """,
372
+ unsafe_allow_html=True,
373
+ )
374
+ else:
375
+ # Empty state with better styling
376
+ st.markdown(
377
+ """
378
+ <div class="prediction-box" style="text-align: center;">
379
+ <h4>πŸ‘† Upload an image and click 'Predict Location'</h4>
380
+ <p>The predicted locations will appear here on an interactive map.</p>
381
+ </div>
382
+ """,
383
+ unsafe_allow_html=True,
384
+ )
385
+
386
+
387
+ if __name__ == "__main__":
388
+ main()
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ git+https://github.com/nicolas-dufour/plonk.git@master
2
+ pandas
3
+ torch
4
+ torchvision
5
+ streamlit_extras
6
+ plotly