Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -30,7 +30,6 @@ le = LabelEncoder()
|
|
30 |
le = joblib.load("SVD/le.gz")
|
31 |
len_classes = len(le.classes_) + 1
|
32 |
|
33 |
-
|
34 |
class ModelPre(torch.nn.Module):
|
35 |
def __init__(self):
|
36 |
super().__init__()
|
@@ -60,26 +59,16 @@ cmp = transforms.Compose([
|
|
60 |
])
|
61 |
|
62 |
def predict(input_img):
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
# Convert dropdown selection into an index (Prediction 1 corresponds to index 0, etc.)
|
74 |
-
if selected_prediction is not None:
|
75 |
-
selected_index = int(selected_prediction.split()[-1]) - 1 # Extract index from "Prediction X"
|
76 |
-
else:
|
77 |
-
selected_index = None # No selection, default view
|
78 |
-
|
79 |
-
# Return the updated map with zoom
|
80 |
-
return create_map_figure(global_predictions, global_predictions[1], selected_index=selected_index)
|
81 |
-
|
82 |
-
|
83 |
|
84 |
# Function to get S2 cell polygon
|
85 |
def get_s2_cell_polygon(cell_id):
|
@@ -146,7 +135,7 @@ def create_label_output(predictions):
|
|
146 |
results, cell_ids = predictions
|
147 |
fig = create_map_figure(results, cell_ids)
|
148 |
return fig
|
149 |
-
|
150 |
def predict_and_plot(input_img, selected_prediction):
|
151 |
predictions = predict(input_img)
|
152 |
|
@@ -159,8 +148,6 @@ def predict_and_plot(input_img, selected_prediction):
|
|
159 |
return create_map_figure(predictions, predictions[1], selected_index=selected_index)
|
160 |
|
161 |
|
162 |
-
|
163 |
-
|
164 |
# Gradio app definition
|
165 |
with gr.Blocks() as gradio_app:
|
166 |
with gr.Column():
|
|
|
30 |
le = joblib.load("SVD/le.gz")
|
31 |
len_classes = len(le.classes_) + 1
|
32 |
|
|
|
33 |
class ModelPre(torch.nn.Module):
|
34 |
def __init__(self):
|
35 |
super().__init__()
|
|
|
59 |
])
|
60 |
|
61 |
def predict(input_img):
|
62 |
+
with torch.inference_mode():
|
63 |
+
img = cmp(input_img).unsqueeze(0)
|
64 |
+
res = model(img.to(device))
|
65 |
+
probabilities = torch.softmax(res, dim=1).cpu().numpy().flatten()
|
66 |
+
top_10_indices = np.argsort(probabilities)[-10:][::-1]
|
67 |
+
top_10_probabilities = probabilities[top_10_indices]
|
68 |
+
top_10_predictions = le.inverse_transform(top_10_indices)
|
69 |
+
|
70 |
+
results = {top_10_predictions[i]: float(top_10_probabilities[i]) for i in range(10)}
|
71 |
+
return results, top_10_predictions
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
72 |
|
73 |
# Function to get S2 cell polygon
|
74 |
def get_s2_cell_polygon(cell_id):
|
|
|
135 |
results, cell_ids = predictions
|
136 |
fig = create_map_figure(results, cell_ids)
|
137 |
return fig
|
138 |
+
|
139 |
def predict_and_plot(input_img, selected_prediction):
|
140 |
predictions = predict(input_img)
|
141 |
|
|
|
148 |
return create_map_figure(predictions, predictions[1], selected_index=selected_index)
|
149 |
|
150 |
|
|
|
|
|
151 |
# Gradio app definition
|
152 |
with gr.Blocks() as gradio_app:
|
153 |
with gr.Column():
|