robocan commited on
Commit
940a536
1 Parent(s): 4345f8a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +28 -13
app.py CHANGED
@@ -29,6 +29,8 @@ device = 'cpu'
29
  le = LabelEncoder()
30
  le = joblib.load("SVD/le.gz")
31
  len_classes = len(le.classes_) + 1
 
 
32
 
33
  class ModelPre(torch.nn.Module):
34
  def __init__(self):
@@ -59,16 +61,26 @@ cmp = transforms.Compose([
59
  ])
60
 
61
  def predict(input_img):
62
- with torch.inference_mode():
63
- img = cmp(input_img).unsqueeze(0)
64
- res = model(img.to(device))
65
- probabilities = torch.softmax(res, dim=1).cpu().numpy().flatten()
66
- top_10_indices = np.argsort(probabilities)[-10:][::-1]
67
- top_10_probabilities = probabilities[top_10_indices]
68
- top_10_predictions = le.inverse_transform(top_10_indices)
69
-
70
- results = {top_10_predictions[i]: float(top_10_probabilities[i]) for i in range(10)}
71
- return results, top_10_predictions
 
 
 
 
 
 
 
 
 
 
72
 
73
  # Function to get S2 cell polygon
74
  def get_s2_cell_polygon(cell_id):
@@ -154,12 +166,15 @@ def predict_and_plot(input_img, selected_prediction):
154
  with gr.Blocks() as gradio_app:
155
  with gr.Column():
156
  input_image = gr.Image(label="Upload an Image", type="pil")
157
- selected_prediction = gr.Dropdown(choices=[f"Prediction {i+1}" for i in range(10)], label="Select Prediction to Zoom", value=None)
158
  output_map = gr.Plot(label="Predicted Location on Map")
159
  btn_predict = gr.Button("Predict")
 
 
 
 
160
 
161
- # Update click function to include selected prediction
162
- btn_predict.click(predict_and_plot, inputs=[input_image, selected_prediction], outputs=output_map)
163
 
164
  examples = ["GB.PNG", "IT.PNG", "NL.PNG", "NZ.PNG"]
165
  gr.Examples(examples=examples, inputs=input_image)
 
29
  le = LabelEncoder()
30
  le = joblib.load("SVD/le.gz")
31
  len_classes = len(le.classes_) + 1
32
+ # Global variable to store predictions for dynamic zoom
33
+ global_predictions = None
34
 
35
  class ModelPre(torch.nn.Module):
36
  def __init__(self):
 
61
  ])
62
 
63
  def predict(input_img):
64
+ global global_predictions
65
+ results, cell_ids = predict(input_img)
66
+ global_predictions = (results, cell_ids) # Store predictions globally for zoom functionality
67
+ return create_map_figure(global_predictions, global_predictions[1])
68
+
69
+ def zoom_on_prediction(selected_prediction):
70
+ global global_predictions
71
+ if global_predictions is None:
72
+ return None # No prediction made yet
73
+
74
+ # Convert dropdown selection into an index (Prediction 1 corresponds to index 0, etc.)
75
+ if selected_prediction is not None:
76
+ selected_index = int(selected_prediction.split()[-1]) - 1 # Extract index from "Prediction X"
77
+ else:
78
+ selected_index = None # No selection, default view
79
+
80
+ # Return the updated map with zoom
81
+ return create_map_figure(global_predictions, global_predictions[1], selected_index=selected_index)
82
+
83
+
84
 
85
  # Function to get S2 cell polygon
86
  def get_s2_cell_polygon(cell_id):
 
166
  with gr.Blocks() as gradio_app:
167
  with gr.Column():
168
  input_image = gr.Image(label="Upload an Image", type="pil")
 
169
  output_map = gr.Plot(label="Predicted Location on Map")
170
  btn_predict = gr.Button("Predict")
171
+ selected_prediction = gr.Dropdown(choices=[f"Prediction {i+1}" for i in range(10)], label="Select Prediction to Zoom", value=None)
172
+
173
+ # Perform the prediction and plot the initial map
174
+ btn_predict.click(predict, inputs=input_image, outputs=output_map)
175
 
176
+ # Allow the user to zoom in on a selected prediction after the prediction is made
177
+ selected_prediction.change(zoom_on_prediction, inputs=selected_prediction, outputs=output_map)
178
 
179
  examples = ["GB.PNG", "IT.PNG", "NL.PNG", "NZ.PNG"]
180
  gr.Examples(examples=examples, inputs=input_image)