import streamlit as st
import pandas as pd
from PIL import Image
import torch
from pipe import PlonkPipeline
from pathlib import Path
from streamlit_extras.colored_header import colored_header
import plotly.express as px
import requests
from io import BytesIO
# Set page config
st.set_page_config(
page_title="Around the World in 80 Timesteps", page_icon="πΊοΈ", layout="wide"
)
device = "cuda" if torch.cuda.is_available() else "cpu"
PROJECT_ROOT = Path(__file__).parent.parent.absolute()
# Define checkpoint path
CHECKPOINT_DIR = PROJECT_ROOT / "checkpoints"
MODEL_NAMES = {
"PLONK_YFCC": "nicolas-dufour/PLONK_YFCC",
"PLONK_OSV_5M": "nicolas-dufour/PLONK_OSV_5M",
"PLONK_iNaturalist": "nicolas-dufour/PLONK_iNaturalist",
}
@st.cache_resource
def load_model(model_name):
"""Load the model and cache it to prevent reloading"""
try:
pipe = PlonkPipeline(model_path=model_name)
return pipe
except Exception as e:
st.error(f"Error loading model: {str(e)}")
st.stop()
PIPES = {model_name: load_model(MODEL_NAMES[model_name]) for model_name in MODEL_NAMES}
def predict_location(image, model_name, cfg=0.0, num_samples=256):
with torch.no_grad():
batch = {"img": [], "emb": []}
# If image is already a PIL Image, use it directly
if isinstance(image, Image.Image):
img = image.convert("RGB")
else:
img = Image.open(image).convert("RGB")
pipe = PIPES[model_name]
# Get regular predictions
predicted_gps = pipe(img, batch_size=num_samples, cfg=cfg, num_steps=32)
# Get single high-confidence prediction
high_conf_gps = pipe(img, batch_size=1, cfg=2.0, num_steps=32)
return {
"lat": predicted_gps[:, 0].astype(float).tolist(),
"lon": predicted_gps[:, 1].astype(float).tolist(),
"high_conf_lat": high_conf_gps[0, 0].astype(float),
"high_conf_lon": high_conf_gps[0, 1].astype(float),
}
def load_example_images():
"""Load example images from the examples directory"""
examples_dir = Path(__file__).parent / "examples"
if not examples_dir.exists():
st.error(
"""
Examples directory not found. Please create the following structure:
demo/
βββ examples/
βββ eiffel_tower.jpg
βββ colosseum.jpg
βββ taj_mahal.jpg
βββ statue_liberty.jpg
βββ sydney_opera.jpg
"""
)
return {}
examples = {}
for img_path in examples_dir.glob("*.jpg"):
# Use filename without extension as the key
name = img_path.stem.replace("_", " ").title()
examples[name] = str(img_path)
if not examples:
st.warning("No example images found in the examples directory.")
return examples
def resize_image_for_display(image, max_size=400):
"""Resize image while maintaining aspect ratio"""
# Get current size
width, height = image.size
# Calculate ratio to maintain aspect ratio
if width > height:
if width > max_size:
ratio = max_size / width
new_size = (max_size, int(height * ratio))
else:
if height > max_size:
ratio = max_size / height
new_size = (int(width * ratio), max_size)
# Only resize if image is larger than max_size
if width > max_size or height > max_size:
return image.resize(new_size, Image.Resampling.LANCZOS)
return image
def load_image_from_url(url):
"""Load an image from a URL"""
try:
response = requests.get(url)
response.raise_for_status() # Raise an exception for bad status codes
return Image.open(BytesIO(response.content))
except Exception as e:
st.error(f"Error loading image from URL: {str(e)}")
return None
def main():
# Custom CSS
st.markdown(
"""
""",
unsafe_allow_html=True,
)
# Header with custom styling
colored_header(
label="πΊοΈ Around the World in 80 Timesteps: A Generative Approach to Global Visual Geolocation",
description="Upload an image and our model, PLONK, will predict possible locations! In red we will sample one point with guidance scale 2.0 for the best guess.
Project page: https://nicolas-dufour.github.io/plonk",
color_name="red-70",
)
# Adjust column ratio to give 2/3 of the space to the map
col1, col2 = st.columns([1, 2], gap="large")
with col1:
# Add model selection before the sliders
model_name = st.selectbox(
"π€ Select Model",
options=MODEL_NAMES.keys(),
index=0, # Default to YFCC
help="Choose which PLONK model variant to use for prediction.",
)
# Modify the slider columns to accommodate both controls
col_slider1, col_slider2 = st.columns([0.5, 0.5])
with col_slider1:
cfg_value = st.slider(
"π― Guidance scale",
min_value=0.0,
max_value=5.0,
value=0.0,
step=0.1,
help="Scale for classifier-free guidance during sampling. A small value makes the model predictions display the diversity of the model, while a large value makes the model predictions more conservative but potentially more accurate.",
)
with col_slider2:
num_samples = st.number_input(
"π² Number of samples",
min_value=1,
max_value=5000,
value=1000,
step=1,
help="Number of location predictions to generate. More samples give better coverage but take longer to compute.",
)
st.markdown("### πΈ Choose your image")
tab1, tab2, tab3 = st.tabs(["Upload", "URL", "Examples"])
with tab1:
uploaded_file = st.file_uploader(
"Choose an image...",
type=["png", "jpg", "jpeg"],
help="Supported formats: PNG, JPG, JPEG",
)
if uploaded_file is not None:
st.markdown('
Number of sampled locations: {len(pred["lat"])}
Best guess location: {pred["high_conf_lat"]:.2f}Β°, {pred["high_conf_lon"]:.2f}Β°
The predicted locations will appear here on an interactive map.