ReGeo – A Direct Regression Approach for Global Image Geo-Localization
This paper presents a novel approach to Geo-Localization, a task that aims to predict geographic coordinates, i.e., latitude and longitude of an image based on its visual content. Traditional methods in this domain often rely on databases, complex pipelines or large-scale image classification networks. In contrast, we propose a direct regression approach that simplifies the process by predicting the geographic coordinates directly from the image features. We leverage a pre-trained Vision Transformer (ViT) model, specifically a pre-trained CLIP model, for feature extraction and introduce a regression head for coordinate prediction. Various configurations, including pre- training and task-specific adaptations, are tested and evaluated resulting in our model called ReGeo. Experimental results show that ReGeo offers competitive performance compared to existing SOTA approaches, despite being simpler and needing minimal supporting code pipelines.
- Demo: Coming soon
Model Details
- Developed by: Tobias Rothlin, tobias.rothlin@ost.ch
- Supervisor: Mitra Purandare, mitra.purandare@ost.ch
- Model Card author: Kevin Löffler, kevin.loeffler@ost.ch
How to Get Started with the Model
Example inference:
# imports
import torch
from PIL import Image
from model import LocationDecoder # ReGeo model class: https://github.com/TobiasRothlin/GeoLocalization/blob/main/src/DGX1/src/RegressionPretraining/Model.py
from transformers import CLIPProcessor
# load custom config (do not use AutoConfig), an example can be found in this repo
config = { ... }
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
preprocessor = CLIPProcessor.from_pretrained('openai/clip-vit-large-patch14-336')
model = LocationDecoder.from_pretrained('OSTswiss/ReGeo', config=config)
# Load model for inference
model.to(device)
model.eval()
# load image
image_path = 'path/to/your/image.jpg' # can be any size
image = Image.open(image_path)
model_input = preprocessor(images=image, return_tensors="pt")
pixel_values = model_input['pixel_values'].to(device)
# run inference
with torch.no_grad():
output = model(pixel_values)
normal_coordinates = output.squeeze().tolist()
latitude = normal_coordinates[0] * 90
longitude = normal_coordinates[1] * 180
- Downloads last month
- 157
Model tree for OSTswiss/ReGeo
Base model
openai/clip-vit-large-patch14