robocan commited on
Commit
24d6812
·
verified ·
1 Parent(s): a3e98ec

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +37 -17
app.py CHANGED
@@ -87,19 +87,40 @@ def predict(input_img):
87
  prediction = MMS.inverse_transform(res.cpu().numpy()).flatten()
88
  return prediction
89
 
90
- # Function to generate HTML for map
91
- def create_map_html(lat, lon):
92
- m = folium.Map(location=[lat, lon], zoom_start=12)
93
- folium.Marker([lat, lon]).add_to(m)
94
- data = BytesIO()
95
- m.save(data, close_file=False)
96
- return data.getvalue().decode()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
 
98
  # Create label output function
99
  def create_label_output(predictions):
100
  lat, lon = predictions
101
- map_html = create_map_html(lat, lon)
102
- return f"<div><h3>Predicted coordinates: ({lat:.6f}, {lon:.6f})</h3>{map_html}</div>"
103
 
104
  # Predict and plot function
105
  def predict_and_plot(input_img):
@@ -107,13 +128,12 @@ def predict_and_plot(input_img):
107
  return create_label_output(predictions)
108
 
109
  # Gradio app definition
110
- gradio_app = Interface(
111
- fn=predict_and_plot,
112
- inputs=Image(label="Upload an Image", type="pil"),
113
- examples=["GB.PNG", "IT.PNG", "NL.PNG", "NZ.PNG"],
114
- outputs=HTML(),
115
- title="Predict the Location of this Image"
116
- )
117
 
118
- if __name__ == "__main__":
119
  gradio_app.launch()
 
87
  prediction = MMS.inverse_transform(res.cpu().numpy()).flatten()
88
  return prediction
89
 
90
+ # Function to generate Plotly map figure
91
+ def create_map_figure(lat, lon):
92
+ fig = go.Figure(go.Scattermapbox(
93
+ lat=[lat],
94
+ lon=[lon],
95
+ mode='markers',
96
+ marker=go.scattermapbox.Marker(
97
+ size=14
98
+ ),
99
+ text=[f'Lat: {lat}, Lon: {lon}'],
100
+ hoverinfo='text'
101
+ ))
102
+
103
+ fig.update_layout(
104
+ mapbox_style="open-street-map",
105
+ hovermode='closest',
106
+ mapbox=dict(
107
+ bearing=0,
108
+ center=go.layout.mapbox.Center(
109
+ lat=lat,
110
+ lon=lon
111
+ ),
112
+ pitch=0,
113
+ zoom=10
114
+ ),
115
+ )
116
+
117
+ return fig
118
 
119
  # Create label output function
120
  def create_label_output(predictions):
121
  lat, lon = predictions
122
+ fig = create_map_figure(lat, lon)
123
+ return fig
124
 
125
  # Predict and plot function
126
  def predict_and_plot(input_img):
 
128
  return create_label_output(predictions)
129
 
130
  # Gradio app definition
131
+ with gr.Blocks() as gradio_app:
132
+ with gr.Column():
133
+ input_image = gr.Image(label="Upload an Image", type="pil")
134
+ output_map = gr.Plot(label="Predicted Location on Map")
135
+ btn_predict = gr.Button("Predict")
136
+
137
+ btn_predict.click(predict_and_plot, inputs=input_image, outputs=output_map)
138
 
 
139
  gradio_app.launch()