robocan commited on
Commit
a8f08c9
1 Parent(s): 3e8fcf7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +24 -10
app.py CHANGED
@@ -80,17 +80,21 @@ def get_s2_cell_polygon(cell_id):
80
  vertices.append(vertices[0]) # Close the polygon
81
  return vertices
82
 
83
- def create_map_figure(predictions, cell_ids):
84
  fig = go.Figure()
85
 
86
  # Assign colors based on rank
87
  colors = ['rgba(0, 255, 0, 0.2)'] * 3 + ['rgba(255, 255, 0, 0.2)'] * 7
 
 
 
88
 
89
  for rank, cell_id in enumerate(cell_ids):
90
  cell_id = int(float(cell_id))
91
  polygon = get_s2_cell_polygon(cell_id)
92
  lats, lons = zip(*polygon)
93
  color = colors[rank]
 
94
  fig.add_trace(go.Scattermapbox(
95
  lat=lats,
96
  lon=lons,
@@ -101,17 +105,23 @@ def create_map_figure(predictions, cell_ids):
101
  name=f'Prediction {rank + 1}', # Updated label
102
  ))
103
 
 
 
 
 
 
 
104
  fig.update_layout(
105
- mapbox_style="open-street-map", # Change this line to use 'light' style
106
  hovermode='closest',
107
  mapbox=dict(
108
  bearing=0,
109
  center=go.layout.mapbox.Center(
110
- lat=np.mean(lats),
111
- lon=np.mean(lons)
112
  ),
113
  pitch=0,
114
- zoom=1
115
  ),
116
  )
117
 
@@ -124,20 +134,24 @@ def create_label_output(predictions):
124
  fig = create_map_figure(results, cell_ids)
125
  return fig
126
 
127
- # Predict and plot function
128
- def predict_and_plot(input_img):
129
  predictions = predict(input_img)
130
- return create_label_output(predictions)
 
131
 
132
 
133
  # Gradio app definition
134
  with gr.Blocks() as gradio_app:
135
  with gr.Column():
136
  input_image = gr.Image(label="Upload an Image", type="pil")
 
137
  output_map = gr.Plot(label="Predicted Location on Map")
138
  btn_predict = gr.Button("Predict")
139
 
140
- btn_predict.click(predict_and_plot, inputs=input_image, outputs=output_map)
 
 
141
  examples = ["GB.PNG", "IT.PNG", "NL.PNG", "NZ.PNG"]
142
  gr.Examples(examples=examples, inputs=input_image)
143
- gradio_app.launch()
 
80
  vertices.append(vertices[0]) # Close the polygon
81
  return vertices
82
 
83
+ def create_map_figure(predictions, cell_ids, selected_index=None):
84
  fig = go.Figure()
85
 
86
  # Assign colors based on rank
87
  colors = ['rgba(0, 255, 0, 0.2)'] * 3 + ['rgba(255, 255, 0, 0.2)'] * 7
88
+ zoom_level = 1
89
+ center_lat = None
90
+ center_lon = None
91
 
92
  for rank, cell_id in enumerate(cell_ids):
93
  cell_id = int(float(cell_id))
94
  polygon = get_s2_cell_polygon(cell_id)
95
  lats, lons = zip(*polygon)
96
  color = colors[rank]
97
+
98
  fig.add_trace(go.Scattermapbox(
99
  lat=lats,
100
  lon=lons,
 
105
  name=f'Prediction {rank + 1}', # Updated label
106
  ))
107
 
108
+ # Set zoom based on the selected index
109
+ if selected_index is not None and rank == selected_index:
110
+ zoom_level = 10 # Adjust zoom level
111
+ center_lat = np.mean(lats)
112
+ center_lon = np.mean(lons)
113
+
114
  fig.update_layout(
115
+ mapbox_style="open-street-map",
116
  hovermode='closest',
117
  mapbox=dict(
118
  bearing=0,
119
  center=go.layout.mapbox.Center(
120
+ lat=center_lat if center_lat else np.mean(lats),
121
+ lon=center_lon if center_lon else np.mean(lons)
122
  ),
123
  pitch=0,
124
+ zoom=zoom_level # Zoom in if an index is selected
125
  ),
126
  )
127
 
 
134
  fig = create_map_figure(results, cell_ids)
135
  return fig
136
 
137
+ # Update the predict_and_plot function to handle zoom on selection
138
+ def predict_and_plot(input_img, selected_prediction):
139
  predictions = predict(input_img)
140
+ return create_map_figure(predictions, predictions[1], selected_index=selected_prediction)
141
+
142
 
143
 
144
  # Gradio app definition
145
  with gr.Blocks() as gradio_app:
146
  with gr.Column():
147
  input_image = gr.Image(label="Upload an Image", type="pil")
148
+ selected_prediction = gr.Dropdown(choices=[f"Prediction {i+1}" for i in range(10)], label="Select Prediction to Zoom")
149
  output_map = gr.Plot(label="Predicted Location on Map")
150
  btn_predict = gr.Button("Predict")
151
 
152
+ # Update click function to include selected prediction
153
+ btn_predict.click(predict_and_plot, inputs=[input_image, selected_prediction], outputs=output_map)
154
+
155
  examples = ["GB.PNG", "IT.PNG", "NL.PNG", "NZ.PNG"]
156
  gr.Examples(examples=examples, inputs=input_image)
157
+ gradio_app.launch()