robocan commited on
Commit
4b52e6f
1 Parent(s): b8d71d7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -3
app.py CHANGED
@@ -4,10 +4,12 @@ from torch.utils.data import Dataset, DataLoader
4
  import pandas as pd
5
  import numpy as np
6
  import joblib
 
 
7
  from PIL import Image
8
  from torchvision import transforms,models
9
  from sklearn.preprocessing import LabelEncoder,MinMaxScaler
10
- from gradio import Interface, Image, Label
11
  from huggingface_hub import snapshot_download
12
 
13
  # Retrieve the token from the environment variables
@@ -85,9 +87,19 @@ def predict(input_img):
85
  prediction = MMS.inverse_transform(res.cpu().numpy()).flatten()
86
  return prediction
87
 
 
 
 
 
 
 
 
 
88
  # Create label output function
89
  def create_label_output(predictions):
90
- return f"Predicted values: {predictions}"
 
 
91
 
92
  # Predict and plot function
93
  def predict_and_plot(input_img):
@@ -99,7 +111,7 @@ gradio_app = Interface(
99
  fn=predict_and_plot,
100
  inputs=Image(label="Upload an Image", type="pil"),
101
  examples=["GB.PNG", "IT.PNG", "NL.PNG", "NZ.PNG"],
102
- outputs="text",
103
  title="Predict the Location of this Image"
104
  )
105
 
 
4
  import pandas as pd
5
  import numpy as np
6
  import joblib
7
+ from io import BytesIO
8
+ import folium
9
  from PIL import Image
10
  from torchvision import transforms,models
11
  from sklearn.preprocessing import LabelEncoder,MinMaxScaler
12
+ from gradio import Interface, Image, Label, HTML
13
  from huggingface_hub import snapshot_download
14
 
15
  # Retrieve the token from the environment variables
 
87
  prediction = MMS.inverse_transform(res.cpu().numpy()).flatten()
88
  return prediction
89
 
90
+ # Function to generate HTML for map
91
+ def create_map_html(lat, lon):
92
+ m = folium.Map(location=[lat, lon], zoom_start=12)
93
+ folium.Marker([lat, lon]).add_to(m)
94
+ data = BytesIO()
95
+ m.save(data, close_file=False)
96
+ return data.getvalue().decode()
97
+
98
  # Create label output function
99
  def create_label_output(predictions):
100
+ lat, lon = predictions
101
+ map_html = create_map_html(lat, lon)
102
+ return f"<div><h3>Predicted coordinates: ({lat:.6f}, {lon:.6f})</h3>{map_html}</div>"
103
 
104
  # Predict and plot function
105
  def predict_and_plot(input_img):
 
111
  fn=predict_and_plot,
112
  inputs=Image(label="Upload an Image", type="pil"),
113
  examples=["GB.PNG", "IT.PNG", "NL.PNG", "NZ.PNG"],
114
+ outputs=HTML(),
115
  title="Predict the Location of this Image"
116
  )
117