ReGeo – A Direct Regression Approach for Global Image Geo-Localization

This paper presents a novel approach to Geo-Localization, a task that aims to predict geographic coordinates, i.e., latitude and longitude of an image based on its visual content. Traditional methods in this domain often rely on databases, complex pipelines or large-scale image classification networks. In contrast, we propose a direct regression approach that simplifies the process by predicting the geographic coordinates directly from the image features. We leverage a pre-trained Vision Transformer (ViT) model, specifically a pre-trained CLIP model, for feature extraction and introduce a regression head for coordinate prediction. Various configurations, including pre- training and task-specific adaptations, are tested and evaluated resulting in our model called ReGeo. Experimental results show that ReGeo offers competitive performance compared to existing SOTA approaches, despite being simpler and needing minimal supporting code pipelines.

  • Demo: Coming soon

Model Details

How to Get Started with the Model

Example inference:

# imports
import torch
from PIL import Image
from model import LocationDecoder  # ReGeo model class: https://github.com/TobiasRothlin/GeoLocalization/blob/main/src/DGX1/src/RegressionPretraining/Model.py
from transformers import CLIPProcessor

# load custom config (do not use AutoConfig), an example can be found in this repo
config = { ... }

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
preprocessor = CLIPProcessor.from_pretrained('openai/clip-vit-large-patch14-336')
model = LocationDecoder.from_pretrained('OSTswiss/ReGeo', config=config)

# Load model for inference
model.to(device)
model.eval()

# load image
image_path = 'path/to/your/image.jpg'  # can be any size
image = Image.open(image_path)
model_input = preprocessor(images=image, return_tensors="pt")
pixel_values = model_input['pixel_values'].to(device)

# run inference
with torch.no_grad():
    output = model(pixel_values)
    normal_coordinates = output.squeeze().tolist()
    latitude = normal_coordinates[0] * 90
    longitude = normal_coordinates[1] * 180
Downloads last month
157
Inference Providers NEW
This model is not currently available via any of the supported Inference Providers.

Model tree for OSTswiss/ReGeo

Finetuned
(54)
this model

Datasets used to train OSTswiss/ReGeo