robocan commited on
Commit
39fc524
1 Parent(s): 52895be

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -24
app.py CHANGED
@@ -30,7 +30,6 @@ le = LabelEncoder()
30
  le = joblib.load("SVD/le.gz")
31
  len_classes = len(le.classes_) + 1
32
 
33
-
34
  class ModelPre(torch.nn.Module):
35
  def __init__(self):
36
  super().__init__()
@@ -60,26 +59,16 @@ cmp = transforms.Compose([
60
  ])
61
 
62
  def predict(input_img):
63
- global global_predictions
64
- results, cell_ids = predict(input_img)
65
- global_predictions = (results, cell_ids) # Store predictions globally for zoom functionality
66
- return create_map_figure(global_predictions, global_predictions[1])
67
-
68
- def zoom_on_prediction(selected_prediction):
69
- global global_predictions
70
- if global_predictions is None:
71
- return None # No prediction made yet
72
-
73
- # Convert dropdown selection into an index (Prediction 1 corresponds to index 0, etc.)
74
- if selected_prediction is not None:
75
- selected_index = int(selected_prediction.split()[-1]) - 1 # Extract index from "Prediction X"
76
- else:
77
- selected_index = None # No selection, default view
78
-
79
- # Return the updated map with zoom
80
- return create_map_figure(global_predictions, global_predictions[1], selected_index=selected_index)
81
-
82
-
83
 
84
  # Function to get S2 cell polygon
85
  def get_s2_cell_polygon(cell_id):
@@ -146,7 +135,7 @@ def create_label_output(predictions):
146
  results, cell_ids = predictions
147
  fig = create_map_figure(results, cell_ids)
148
  return fig
149
-
150
  def predict_and_plot(input_img, selected_prediction):
151
  predictions = predict(input_img)
152
 
@@ -159,8 +148,6 @@ def predict_and_plot(input_img, selected_prediction):
159
  return create_map_figure(predictions, predictions[1], selected_index=selected_index)
160
 
161
 
162
-
163
-
164
  # Gradio app definition
165
  with gr.Blocks() as gradio_app:
166
  with gr.Column():
 
30
  le = joblib.load("SVD/le.gz")
31
  len_classes = len(le.classes_) + 1
32
 
 
33
  class ModelPre(torch.nn.Module):
34
  def __init__(self):
35
  super().__init__()
 
59
  ])
60
 
61
  def predict(input_img):
62
+ with torch.inference_mode():
63
+ img = cmp(input_img).unsqueeze(0)
64
+ res = model(img.to(device))
65
+ probabilities = torch.softmax(res, dim=1).cpu().numpy().flatten()
66
+ top_10_indices = np.argsort(probabilities)[-10:][::-1]
67
+ top_10_probabilities = probabilities[top_10_indices]
68
+ top_10_predictions = le.inverse_transform(top_10_indices)
69
+
70
+ results = {top_10_predictions[i]: float(top_10_probabilities[i]) for i in range(10)}
71
+ return results, top_10_predictions
 
 
 
 
 
 
 
 
 
 
72
 
73
  # Function to get S2 cell polygon
74
  def get_s2_cell_polygon(cell_id):
 
135
  results, cell_ids = predictions
136
  fig = create_map_figure(results, cell_ids)
137
  return fig
138
+
139
  def predict_and_plot(input_img, selected_prediction):
140
  predictions = predict(input_img)
141
 
 
148
  return create_map_figure(predictions, predictions[1], selected_index=selected_index)
149
 
150
 
 
 
151
  # Gradio app definition
152
  with gr.Blocks() as gradio_app:
153
  with gr.Column():