GeoGuessrRobot / app.py
robocan's picture
Update app.py
a8f08c9 verified
raw
history blame
5.15 kB
import os
import torch
from torch.utils.data import Dataset, DataLoader
import pandas as pd
import numpy as np
import joblib
import gradio as gr
import plotly.graph_objects as go
from io import BytesIO
from PIL import Image
from torchvision import transforms, models
from sklearn.preprocessing import LabelEncoder, MinMaxScaler
from gradio import Interface, Image, Label, HTML
from huggingface_hub import snapshot_download
import torch_xla.utils.serialization as xser
import s2sphere
import folium
token = os.environ.get("token")
local_dir = snapshot_download(
repo_id="robocan/GeoG_23k",
repo_type="model",
local_dir="SVD",
token=token
)
device = 'cpu'
le = LabelEncoder()
le = joblib.load("SVD/le.gz")
len_classes = len(le.classes_) + 1
class ModelPre(torch.nn.Module):
def __init__(self):
super().__init__()
self.embedding = torch.nn.Sequential(
*list(models.convnext_small(weights=models.ConvNeXt_Small_Weights.IMAGENET1K_V1).children())[:-1],
torch.nn.Flatten(),
torch.nn.Linear(in_features=768, out_features=1024),
torch.nn.ReLU(),
torch.nn.Linear(in_features=1024, out_features=1024),
torch.nn.ReLU(),
torch.nn.Linear(in_features=1024, out_features=len_classes),
)
def forward(self, data):
return self.embedding(data)
# Load the pretrained model
model = ModelPre()
model_w = xser.load("SVD/GeoG.pth")
model.load_state_dict(model_w['model'])
cmp = transforms.Compose([
transforms.ToTensor(),
transforms.Resize(size=(224, 224), antialias=True),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
def predict(input_img):
with torch.inference_mode():
img = cmp(input_img).unsqueeze(0)
res = model(img.to(device))
probabilities = torch.softmax(res, dim=1).cpu().numpy().flatten()
top_10_indices = np.argsort(probabilities)[-10:][::-1]
top_10_probabilities = probabilities[top_10_indices]
top_10_predictions = le.inverse_transform(top_10_indices)
results = {top_10_predictions[i]: float(top_10_probabilities[i]) for i in range(10)}
return results, top_10_predictions
# Function to get S2 cell polygon
def get_s2_cell_polygon(cell_id):
cell = s2sphere.Cell(s2sphere.CellId(cell_id))
vertices = []
for i in range(4):
vertex = s2sphere.LatLng.from_point(cell.get_vertex(i))
vertices.append((vertex.lat().degrees, vertex.lng().degrees))
vertices.append(vertices[0]) # Close the polygon
return vertices
def create_map_figure(predictions, cell_ids, selected_index=None):
fig = go.Figure()
# Assign colors based on rank
colors = ['rgba(0, 255, 0, 0.2)'] * 3 + ['rgba(255, 255, 0, 0.2)'] * 7
zoom_level = 1
center_lat = None
center_lon = None
for rank, cell_id in enumerate(cell_ids):
cell_id = int(float(cell_id))
polygon = get_s2_cell_polygon(cell_id)
lats, lons = zip(*polygon)
color = colors[rank]
fig.add_trace(go.Scattermapbox(
lat=lats,
lon=lons,
mode='lines',
fill='toself',
fillcolor=color,
line=dict(color='blue'),
name=f'Prediction {rank + 1}', # Updated label
))
# Set zoom based on the selected index
if selected_index is not None and rank == selected_index:
zoom_level = 10 # Adjust zoom level
center_lat = np.mean(lats)
center_lon = np.mean(lons)
fig.update_layout(
mapbox_style="open-street-map",
hovermode='closest',
mapbox=dict(
bearing=0,
center=go.layout.mapbox.Center(
lat=center_lat if center_lat else np.mean(lats),
lon=center_lon if center_lon else np.mean(lons)
),
pitch=0,
zoom=zoom_level # Zoom in if an index is selected
),
)
return fig
# Create label output function
def create_label_output(predictions):
results, cell_ids = predictions
fig = create_map_figure(results, cell_ids)
return fig
# Update the predict_and_plot function to handle zoom on selection
def predict_and_plot(input_img, selected_prediction):
predictions = predict(input_img)
return create_map_figure(predictions, predictions[1], selected_index=selected_prediction)
# Gradio app definition
with gr.Blocks() as gradio_app:
with gr.Column():
input_image = gr.Image(label="Upload an Image", type="pil")
selected_prediction = gr.Dropdown(choices=[f"Prediction {i+1}" for i in range(10)], label="Select Prediction to Zoom")
output_map = gr.Plot(label="Predicted Location on Map")
btn_predict = gr.Button("Predict")
# Update click function to include selected prediction
btn_predict.click(predict_and_plot, inputs=[input_image, selected_prediction], outputs=output_map)
examples = ["GB.PNG", "IT.PNG", "NL.PNG", "NZ.PNG"]
gr.Examples(examples=examples, inputs=input_image)
gradio_app.launch()