File size: 5,568 Bytes
9b889da
afe2deb
 
8a4658f
afe2deb
 
38a655b
 
4b52e6f
afe2deb
cc6f22c
 
4b52e6f
dbf177b
4b67dd1
cc6f22c
 
9b889da
 
 
dbf177b
1de1740
db673be
dbf177b
db673be
 
9b889da
955fc23
 
14085cb
4184b6d
52895be
955fc23
 
 
 
be60ccb
955fc23
1de1740
be60ccb
1de1740
 
 
955fc23
 
 
 
f07227d
 
 
 
d497c1d
f07227d
955fc23
 
 
 
 
 
6d49cf1
 
39fc524
 
 
 
 
 
 
 
 
 
cc6f22c
 
 
 
 
 
 
 
 
 
6d49cf1
a8f08c9
cc6f22c
 
05501c8
 
4345f8a
a8f08c9
 
05501c8
 
03ec0c3
cc6f22c
 
05501c8
4345f8a
 
cc6f22c
 
 
 
 
05501c8
cc6f22c
4345f8a
cc6f22c
24d6812
4345f8a
a8f08c9
4345f8a
a8f08c9
 
 
4345f8a
24d6812
a8f08c9
24d6812
 
 
 
a8f08c9
 
24d6812
 
4345f8a
24d6812
 
 
 
4b52e6f
05501c8
f07227d
afe2deb
cc6f22c
 
24d6812
39fc524
a8f08c9
8a4658f
4345f8a
 
 
 
 
 
 
 
 
a8f08c9
f07227d
24d6812
 
 
6ecb023
 
 
 
 
24d6812
 
 
52895be
 
a8f08c9
dd76dcd
 
a8f08c9
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
import os
import torch
from torch.utils.data import Dataset, DataLoader
import pandas as pd
import numpy as np
import joblib
import gradio as gr
import plotly.graph_objects as go
from io import BytesIO
from PIL import Image
from torchvision import transforms, models
from sklearn.preprocessing import LabelEncoder, MinMaxScaler
from gradio import Interface, Image, Label, HTML
from huggingface_hub import snapshot_download
import torch_xla.utils.serialization as xser
import s2sphere
import folium

token = os.environ.get("token")

local_dir = snapshot_download(
    repo_id="robocan/GeoG_23k",
    repo_type="model",
    local_dir="SVD",
    token=token
)

device = 'cpu'
le = LabelEncoder()
le = joblib.load("SVD/le.gz")
len_classes = len(le.classes_) + 1

class ModelPre(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.embedding = torch.nn.Sequential(
            *list(models.convnext_small(weights=models.ConvNeXt_Small_Weights.IMAGENET1K_V1).children())[:-1],
            torch.nn.Flatten(),
            torch.nn.Linear(in_features=768, out_features=1024),
            torch.nn.ReLU(),
            torch.nn.Linear(in_features=1024, out_features=1024),
            torch.nn.ReLU(),
            torch.nn.Linear(in_features=1024, out_features=len_classes),
        )

    def forward(self, data):
        return self.embedding(data)
    
# Load the pretrained model 
model = ModelPre()

model_w = xser.load("SVD/GeoG.pth")
model.load_state_dict(model_w['model'])

cmp = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize(size=(224, 224), antialias=True),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

def predict(input_img):
    with torch.inference_mode():
        img = cmp(input_img).unsqueeze(0)
        res = model(img.to(device))
        probabilities = torch.softmax(res, dim=1).cpu().numpy().flatten()
        top_10_indices = np.argsort(probabilities)[-10:][::-1]
        top_10_probabilities = probabilities[top_10_indices]
        top_10_predictions = le.inverse_transform(top_10_indices)
        
        results = {top_10_predictions[i]: float(top_10_probabilities[i]) for i in range(10)}
        return results, top_10_predictions

# Function to get S2 cell polygon
def get_s2_cell_polygon(cell_id):
    cell = s2sphere.Cell(s2sphere.CellId(cell_id))
    vertices = []
    for i in range(4):
        vertex = s2sphere.LatLng.from_point(cell.get_vertex(i))
        vertices.append((vertex.lat().degrees, vertex.lng().degrees))
    vertices.append(vertices[0])  # Close the polygon
    return vertices

def create_map_figure(predictions, cell_ids, selected_index=None):
    fig = go.Figure()

    # Assign colors based on rank
    colors = ['rgba(0, 255, 0, 0.2)'] * 3 + ['rgba(255, 255, 0, 0.2)'] * 7
    zoom_level = 1  # Default zoom level
    center_lat = None
    center_lon = None

    for rank, cell_id in enumerate(cell_ids):
        cell_id = int(float(cell_id))
        polygon = get_s2_cell_polygon(cell_id)
        lats, lons = zip(*polygon)
        color = colors[rank]

        # Draw S2 cell polygon
        fig.add_trace(go.Scattermapbox(
            lat=lats,
            lon=lons,
            mode='lines',
            fill='toself',
            fillcolor=color,
            line=dict(color='blue'),
            name=f'Prediction {rank + 1}',
        ))

        # Adjust zoom level if selected prediction is found
        if selected_index is not None and rank == selected_index:
            zoom_level = 10  # Adjust the zoom level to your liking
            center_lat = np.mean(lats)
            center_lon = np.mean(lons)

    # Update map layout
    fig.update_layout(
        mapbox_style="open-street-map",
        hovermode='closest',
        mapbox=dict(
            bearing=0,
            center=go.layout.mapbox.Center(
                lat=center_lat if center_lat else np.mean(lats),
                lon=center_lon if center_lon else np.mean(lons)
            ),
            pitch=0,
            zoom=zoom_level  # Zoom in based on selection
        ),
    )

    return fig


# Create label output function
def create_label_output(predictions):
    results, cell_ids = predictions
    fig = create_map_figure(results, cell_ids)
    return fig

def predict_and_plot(input_img, selected_prediction):
    predictions = predict(input_img)

    # Convert dropdown selection into an index (Prediction 1 corresponds to index 0, etc.)
    if selected_prediction is not None:
        selected_index = int(selected_prediction.split()[-1]) - 1  # Extract index from "Prediction X"
    else:
        selected_index = None  # No selection, default view

    return create_map_figure(predictions, predictions[1], selected_index=selected_index)


# Gradio app definition
with gr.Blocks() as gradio_app:
    with gr.Column():
        input_image = gr.Image(label="Upload an Image", type="pil")
        selected_prediction = gr.Dropdown(
            choices=[f"Prediction {i+1}" for i in range(10)], 
            label="Select Prediction to Zoom", 
            value="Prediction 1"  # Set default to "Prediction 1"
        )
        output_map = gr.Plot(label="Predicted Location on Map")
        btn_predict = gr.Button("Predict")

        # Update click function to include selected prediction
        btn_predict.click(predict_and_plot, inputs=[input_image, selected_prediction], outputs=output_map)

        examples = ["GB.PNG", "IT.PNG", "NL.PNG", "NZ.PNG"]
        gr.Examples(examples=examples, inputs=input_image)
    gradio_app.launch()