import os import torch from torch.utils.data import Dataset, DataLoader import pandas as pd import numpy as np import joblib import gradio as gr import plotly.graph_objects as go from io import BytesIO from PIL import Image from torchvision import transforms, models from sklearn.preprocessing import LabelEncoder, MinMaxScaler from gradio import Interface, Image, Label, HTML from huggingface_hub import snapshot_download import torch_xla.utils.serialization as xser import s2sphere import folium token = os.environ.get("token") local_dir = snapshot_download( repo_id="robocan/GeoG_23k", repo_type="model", local_dir="SVD", token=token ) device = 'cpu' le = LabelEncoder() le = joblib.load("SVD/le.gz") len_classes = len(le.classes_) + 1 class ModelPre(torch.nn.Module): def __init__(self): super().__init__() self.embedding = torch.nn.Sequential( *list(models.convnext_small(weights=models.ConvNeXt_Small_Weights.IMAGENET1K_V1).children())[:-1], torch.nn.Flatten(), torch.nn.Linear(in_features=768, out_features=1024), torch.nn.ReLU(), torch.nn.Linear(in_features=1024, out_features=1024), torch.nn.ReLU(), torch.nn.Linear(in_features=1024, out_features=len_classes), ) def forward(self, data): return self.embedding(data) # Load the pretrained model model = ModelPre() model_w = xser.load("SVD/GeoG.pth") model.load_state_dict(model_w['model']) cmp = transforms.Compose([ transforms.ToTensor(), transforms.Resize(size=(224, 224), antialias=True), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) def predict(input_img): with torch.inference_mode(): img = cmp(input_img).unsqueeze(0) res = model(img.to(device)) probabilities = torch.softmax(res, dim=1).cpu().numpy().flatten() top_10_indices = np.argsort(probabilities)[-10:][::-1] top_10_probabilities = probabilities[top_10_indices] top_10_predictions = le.inverse_transform(top_10_indices) results = {top_10_predictions[i]: float(top_10_probabilities[i]) for i in range(10)} return results, top_10_predictions # Function to get S2 cell polygon def get_s2_cell_polygon(cell_id): cell = s2sphere.Cell(s2sphere.CellId(cell_id)) vertices = [] for i in range(4): vertex = s2sphere.LatLng.from_point(cell.get_vertex(i)) vertices.append((vertex.lat().degrees, vertex.lng().degrees)) vertices.append(vertices[0]) # Close the polygon return vertices def create_map_figure(predictions, cell_ids, selected_index=None): fig = go.Figure() # Assign colors based on rank colors = ['rgba(0, 255, 0, 0.2)'] * 3 + ['rgba(255, 255, 0, 0.2)'] * 7 zoom_level = 1 # Default zoom level center_lat = None center_lon = None for rank, cell_id in enumerate(cell_ids): cell_id = int(float(cell_id)) polygon = get_s2_cell_polygon(cell_id) lats, lons = zip(*polygon) color = colors[rank] # Draw S2 cell polygon fig.add_trace(go.Scattermapbox( lat=lats, lon=lons, mode='lines', fill='toself', fillcolor=color, line=dict(color='blue'), name=f'Prediction {rank + 1}', )) # Adjust zoom level if selected prediction is found if selected_index is not None and rank == selected_index: zoom_level = 10 # Adjust the zoom level to your liking center_lat = np.mean(lats) center_lon = np.mean(lons) # Update map layout fig.update_layout( mapbox_style="open-street-map", hovermode='closest', mapbox=dict( bearing=0, center=go.layout.mapbox.Center( lat=center_lat if center_lat else np.mean(lats), lon=center_lon if center_lon else np.mean(lons) ), pitch=0, zoom=zoom_level # Zoom in based on selection ), ) return fig # Create label output function def create_label_output(predictions): results, cell_ids = predictions fig = create_map_figure(results, cell_ids) return fig def predict_and_plot(input_img, selected_prediction): predictions = predict(input_img) # Convert dropdown selection into an index (Prediction 1 corresponds to index 0, etc.) if selected_prediction is not None: selected_index = int(selected_prediction.split()[-1]) - 1 # Extract index from "Prediction X" else: selected_index = None # No selection, default view return create_map_figure(predictions, predictions[1], selected_index=selected_index) # Gradio app definition with gr.Blocks() as gradio_app: with gr.Column(): input_image = gr.Image(label="Upload an Image", type="pil") selected_prediction = gr.Dropdown( choices=[f"Prediction {i+1}" for i in range(10)], label="Select Prediction to Zoom", value="Prediction 1" # Set default to "Prediction 1" ) output_map = gr.Plot(label="Predicted Location on Map") btn_predict = gr.Button("Predict") # Update click function to include selected prediction btn_predict.click(predict_and_plot, inputs=[input_image, selected_prediction], outputs=output_map) examples = ["GB.PNG", "IT.PNG", "NL.PNG", "NZ.PNG"] gr.Examples(examples=examples, inputs=input_image) gradio_app.launch()