import streamlit as st import pandas as pd from PIL import Image import torch from pipe import PlonkPipeline from pathlib import Path from streamlit_extras.colored_header import colored_header import plotly.express as px import requests from io import BytesIO # Set page config st.set_page_config( page_title="Around the World in 80 Timesteps", page_icon="πΊοΈ", layout="wide" ) device = "cuda" if torch.cuda.is_available() else "cpu" PROJECT_ROOT = Path(__file__).parent.parent.absolute() # Define checkpoint path CHECKPOINT_DIR = PROJECT_ROOT / "checkpoints" MODEL_NAMES = { "PLONK_YFCC": "nicolas-dufour/PLONK_YFCC", "PLONK_OSV_5M": "nicolas-dufour/PLONK_OSV_5M", "PLONK_iNaturalist": "nicolas-dufour/PLONK_iNaturalist", } @st.cache_resource def load_model(model_name): """Load the model and cache it to prevent reloading""" try: pipe = PlonkPipeline(model_path=model_name) return pipe except Exception as e: st.error(f"Error loading model: {str(e)}") st.stop() PIPES = {model_name: load_model(MODEL_NAMES[model_name]) for model_name in MODEL_NAMES} def predict_location(image, model_name, cfg=0.0, num_samples=256): with torch.no_grad(): batch = {"img": [], "emb": []} # If image is already a PIL Image, use it directly if isinstance(image, Image.Image): img = image.convert("RGB") else: img = Image.open(image).convert("RGB") pipe = PIPES[model_name] # Create a progress bar progress_bar = st.progress(0) status_text = st.empty() def update_progress(step, total_steps): progress = float(step) / float(total_steps) progress_bar.progress(progress) status_text.text(f"Sampling step {step + 1}/{total_steps}") # Get regular predictions with progress updates predicted_gps = pipe( img, batch_size=num_samples, cfg=cfg, num_steps=16, callback=update_progress ) # Get single high-confidence prediction status_text.text("Generating high-confidence prediction...") high_conf_gps = pipe(img, batch_size=1, cfg=2.0, num_steps=16) # Clear the status text and progress bar status_text.empty() progress_bar.empty() return { "lat": predicted_gps[:, 0].astype(float).tolist(), "lon": predicted_gps[:, 1].astype(float).tolist(), "high_conf_lat": high_conf_gps[0, 0].astype(float), "high_conf_lon": high_conf_gps[0, 1].astype(float), } def load_example_images(): """Load example images from the examples directory""" examples_dir = Path(__file__).parent / "examples" if not examples_dir.exists(): st.error( """ Examples directory not found. Please create the following structure: demo/ βββ examples/ βββ eiffel_tower.jpg βββ colosseum.jpg βββ taj_mahal.jpg βββ statue_liberty.jpg βββ sydney_opera.jpg """ ) return {} examples = {} for img_path in examples_dir.glob("*.jpg"): # Use filename without extension as the key name = img_path.stem.replace("_", " ").title() examples[name] = str(img_path) if not examples: st.warning("No example images found in the examples directory.") return examples def resize_image_for_display(image, max_size=400): """Resize image while maintaining aspect ratio""" # Get current size width, height = image.size # Calculate ratio to maintain aspect ratio if width > height: if width > max_size: ratio = max_size / width new_size = (max_size, int(height * ratio)) else: if height > max_size: ratio = max_size / height new_size = (int(width * ratio), max_size) # Only resize if image is larger than max_size if width > max_size or height > max_size: return image.resize(new_size, Image.Resampling.LANCZOS) return image def load_image_from_url(url): """Load an image from a URL""" try: response = requests.get(url) response.raise_for_status() # Raise an exception for bad status codes return Image.open(BytesIO(response.content)) except Exception as e: st.error(f"Error loading image from URL: {str(e)}") return None def main(): # Custom CSS st.markdown( """ """, unsafe_allow_html=True, ) # Header with custom styling colored_header( label="πΊοΈ Around the World in 80 Timesteps: A Generative Approach to Global Visual Geolocation", description="Upload an image and our model, PLONK, will predict possible locations! In red we will sample one point with guidance scale 2.0 for the best guess. Project page: https://nicolas-dufour.github.io/plonk", color_name="red-70", ) # Adjust column ratio to give 2/3 of the space to the map col1, col2 = st.columns([1, 2], gap="large") with col1: # Add model selection before the sliders model_name = st.selectbox( "π€ Select Model", options=MODEL_NAMES.keys(), index=0, # Default to YFCC help="Choose which PLONK model variant to use for prediction.", ) # Modify the slider columns to accommodate both controls col_slider1, col_slider2 = st.columns([0.5, 0.5]) with col_slider1: cfg_value = st.slider( "π― Guidance scale", min_value=0.0, max_value=5.0, value=0.0, step=0.1, help="Scale for classifier-free guidance during sampling. A small value makes the model predictions display the diversity of the model, while a large value makes the model predictions more conservative but potentially more accurate.", ) with col_slider2: num_samples = st.number_input( "π² Number of samples", min_value=1, max_value=5000, value=64, step=1, help="Number of location predictions to generate. More samples give better coverage but take longer to compute.", ) st.markdown("### πΈ Choose your image") tab1, tab2, tab3 = st.tabs(["Upload", "URL", "Examples"]) with tab1: uploaded_file = st.file_uploader( "Choose an image...", type=["png", "jpg", "jpeg"], help="Supported formats: PNG, JPG, JPEG", ) if uploaded_file is not None: st.markdown('
Number of sampled locations: {len(pred["lat"])}
Best guess location: {pred["high_conf_lat"]:.2f}Β°, {pred["high_conf_lon"]:.2f}Β°
The predicted locations will appear here on an interactive map.