Spaces:
Running
Running
File size: 5,568 Bytes
9b889da afe2deb 8a4658f afe2deb 38a655b 4b52e6f afe2deb cc6f22c 4b52e6f dbf177b 4b67dd1 cc6f22c 9b889da dbf177b 1de1740 db673be dbf177b db673be 9b889da 955fc23 14085cb 4184b6d 52895be 955fc23 be60ccb 955fc23 1de1740 be60ccb 1de1740 955fc23 f07227d d497c1d f07227d 955fc23 6d49cf1 39fc524 cc6f22c 6d49cf1 a8f08c9 cc6f22c 05501c8 4345f8a a8f08c9 05501c8 03ec0c3 cc6f22c 05501c8 4345f8a cc6f22c 05501c8 cc6f22c 4345f8a cc6f22c 24d6812 4345f8a a8f08c9 4345f8a a8f08c9 4345f8a 24d6812 a8f08c9 24d6812 a8f08c9 24d6812 4345f8a 24d6812 4b52e6f 05501c8 f07227d afe2deb cc6f22c 24d6812 39fc524 a8f08c9 8a4658f 4345f8a a8f08c9 f07227d 24d6812 6ecb023 24d6812 52895be a8f08c9 dd76dcd a8f08c9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 |
import os
import torch
from torch.utils.data import Dataset, DataLoader
import pandas as pd
import numpy as np
import joblib
import gradio as gr
import plotly.graph_objects as go
from io import BytesIO
from PIL import Image
from torchvision import transforms, models
from sklearn.preprocessing import LabelEncoder, MinMaxScaler
from gradio import Interface, Image, Label, HTML
from huggingface_hub import snapshot_download
import torch_xla.utils.serialization as xser
import s2sphere
import folium
token = os.environ.get("token")
local_dir = snapshot_download(
repo_id="robocan/GeoG_23k",
repo_type="model",
local_dir="SVD",
token=token
)
device = 'cpu'
le = LabelEncoder()
le = joblib.load("SVD/le.gz")
len_classes = len(le.classes_) + 1
class ModelPre(torch.nn.Module):
def __init__(self):
super().__init__()
self.embedding = torch.nn.Sequential(
*list(models.convnext_small(weights=models.ConvNeXt_Small_Weights.IMAGENET1K_V1).children())[:-1],
torch.nn.Flatten(),
torch.nn.Linear(in_features=768, out_features=1024),
torch.nn.ReLU(),
torch.nn.Linear(in_features=1024, out_features=1024),
torch.nn.ReLU(),
torch.nn.Linear(in_features=1024, out_features=len_classes),
)
def forward(self, data):
return self.embedding(data)
# Load the pretrained model
model = ModelPre()
model_w = xser.load("SVD/GeoG.pth")
model.load_state_dict(model_w['model'])
cmp = transforms.Compose([
transforms.ToTensor(),
transforms.Resize(size=(224, 224), antialias=True),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
def predict(input_img):
with torch.inference_mode():
img = cmp(input_img).unsqueeze(0)
res = model(img.to(device))
probabilities = torch.softmax(res, dim=1).cpu().numpy().flatten()
top_10_indices = np.argsort(probabilities)[-10:][::-1]
top_10_probabilities = probabilities[top_10_indices]
top_10_predictions = le.inverse_transform(top_10_indices)
results = {top_10_predictions[i]: float(top_10_probabilities[i]) for i in range(10)}
return results, top_10_predictions
# Function to get S2 cell polygon
def get_s2_cell_polygon(cell_id):
cell = s2sphere.Cell(s2sphere.CellId(cell_id))
vertices = []
for i in range(4):
vertex = s2sphere.LatLng.from_point(cell.get_vertex(i))
vertices.append((vertex.lat().degrees, vertex.lng().degrees))
vertices.append(vertices[0]) # Close the polygon
return vertices
def create_map_figure(predictions, cell_ids, selected_index=None):
fig = go.Figure()
# Assign colors based on rank
colors = ['rgba(0, 255, 0, 0.2)'] * 3 + ['rgba(255, 255, 0, 0.2)'] * 7
zoom_level = 1 # Default zoom level
center_lat = None
center_lon = None
for rank, cell_id in enumerate(cell_ids):
cell_id = int(float(cell_id))
polygon = get_s2_cell_polygon(cell_id)
lats, lons = zip(*polygon)
color = colors[rank]
# Draw S2 cell polygon
fig.add_trace(go.Scattermapbox(
lat=lats,
lon=lons,
mode='lines',
fill='toself',
fillcolor=color,
line=dict(color='blue'),
name=f'Prediction {rank + 1}',
))
# Adjust zoom level if selected prediction is found
if selected_index is not None and rank == selected_index:
zoom_level = 10 # Adjust the zoom level to your liking
center_lat = np.mean(lats)
center_lon = np.mean(lons)
# Update map layout
fig.update_layout(
mapbox_style="open-street-map",
hovermode='closest',
mapbox=dict(
bearing=0,
center=go.layout.mapbox.Center(
lat=center_lat if center_lat else np.mean(lats),
lon=center_lon if center_lon else np.mean(lons)
),
pitch=0,
zoom=zoom_level # Zoom in based on selection
),
)
return fig
# Create label output function
def create_label_output(predictions):
results, cell_ids = predictions
fig = create_map_figure(results, cell_ids)
return fig
def predict_and_plot(input_img, selected_prediction):
predictions = predict(input_img)
# Convert dropdown selection into an index (Prediction 1 corresponds to index 0, etc.)
if selected_prediction is not None:
selected_index = int(selected_prediction.split()[-1]) - 1 # Extract index from "Prediction X"
else:
selected_index = None # No selection, default view
return create_map_figure(predictions, predictions[1], selected_index=selected_index)
# Gradio app definition
with gr.Blocks() as gradio_app:
with gr.Column():
input_image = gr.Image(label="Upload an Image", type="pil")
selected_prediction = gr.Dropdown(
choices=[f"Prediction {i+1}" for i in range(10)],
label="Select Prediction to Zoom",
value="Prediction 1" # Set default to "Prediction 1"
)
output_map = gr.Plot(label="Predicted Location on Map")
btn_predict = gr.Button("Predict")
# Update click function to include selected prediction
btn_predict.click(predict_and_plot, inputs=[input_image, selected_prediction], outputs=output_map)
examples = ["GB.PNG", "IT.PNG", "NL.PNG", "NZ.PNG"]
gr.Examples(examples=examples, inputs=input_image)
gradio_app.launch() |