DawnC commited on
Commit
b985ec3
1 Parent(s): 7921180

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -27
app.py CHANGED
@@ -177,6 +177,7 @@ def get_akc_breeds_link():
177
  # except Exception as e:
178
  # return f"An error occurred: {e}"
179
 
 
180
  def predict(image):
181
  try:
182
  image_tensor = preprocess_image(image)
@@ -187,52 +188,39 @@ def predict(image):
187
  else:
188
  logits = output
189
 
190
- # 計算預測的概率分佈
191
  probabilities = F.softmax(logits, dim=1)
192
-
193
- # 取得最高的預測分數以及對應的品種
194
  top_confidence, top_index = torch.max(probabilities, 1)
195
- top_confidence = top_confidence.item() # 轉成 Python 數值
196
  top_breed = dog_breeds[top_index.item()]
197
 
198
- # 如果最高預測分數大於等於 60%,直接返回該品種的資訊
199
  if top_confidence >= 0.60:
200
  description = get_dog_description(top_breed)
201
  akc_link = get_akc_breeds_link()
202
  description_str = f"**Breed**: {top_breed}\n\n**Description**: {description}\n"
203
  description_str += f"\n\n**Want to learn more about dog breeds?** [Visit the AKC dog breeds page]({akc_link}) and search for {top_breed}."
204
  return description_str
205
-
206
- # 如果預測分數小於 60%,返回 Top-3 預測並讓用戶選擇
207
  else:
208
  top3_confidences, top3_indices = torch.topk(probabilities, 3, dim=1)
209
  top3_breeds = [dog_breeds[idx] for idx in top3_indices.squeeze().tolist()]
210
  top3_confidences = top3_confidences.squeeze().tolist()
211
 
212
- return {
213
- "top3_breeds": top3_breeds,
214
- "top3_confidences": [f"{conf * 100:.2f}%" for conf in top3_confidences],
215
- "selected_breed": None,
216
- "message": "The confidence score is low. Please select the correct breed from the options or select 'None of the above' if none are correct."
217
- }
 
 
 
 
218
 
219
  except Exception as e:
220
  return f"An error occurred: {e}"
221
 
222
- # 處理用戶選擇的結果
223
- def handle_user_selection(top3_breeds, selected_breed):
224
- if selected_breed in top3_breeds:
225
- breed_index = top3_breeds.index(selected_breed)
226
- description = get_dog_description(selected_breed)
227
-
228
- akc_link = get_akc_breeds_link()
229
- description_str = f"**Breed**: {selected_breed}\n\n**Description**: {description}\n"
230
- description_str += f"\n\n**Want to learn more about dog breeds?** [Visit the AKC dog breeds page]({akc_link}) and search for {selected_breed}."
231
-
232
- return description_str
233
- else:
234
- return "Sorry, the breed could not be identified. Please try uploading a clearer image or another breed."
235
-
236
 
237
  iface = gr.Interface(
238
  fn=predict,
 
177
  # except Exception as e:
178
  # return f"An error occurred: {e}"
179
 
180
+ # Prediction function
181
  def predict(image):
182
  try:
183
  image_tensor = preprocess_image(image)
 
188
  else:
189
  logits = output
190
 
 
191
  probabilities = F.softmax(logits, dim=1)
 
 
192
  top_confidence, top_index = torch.max(probabilities, 1)
193
+ top_confidence = top_confidence.item()
194
  top_breed = dog_breeds[top_index.item()]
195
 
196
+ # If confidence is higher than 60%, return the top prediction directly
197
  if top_confidence >= 0.60:
198
  description = get_dog_description(top_breed)
199
  akc_link = get_akc_breeds_link()
200
  description_str = f"**Breed**: {top_breed}\n\n**Description**: {description}\n"
201
  description_str += f"\n\n**Want to learn more about dog breeds?** [Visit the AKC dog breeds page]({akc_link}) and search for {top_breed}."
202
  return description_str
203
+
204
+ # If confidence is lower than 60%, return top 3 results and explain why
205
  else:
206
  top3_confidences, top3_indices = torch.topk(probabilities, 3, dim=1)
207
  top3_breeds = [dog_breeds[idx] for idx in top3_indices.squeeze().tolist()]
208
  top3_confidences = top3_confidences.squeeze().tolist()
209
 
210
+ top3_info = "\n\n".join([f"{i+1}. {breed} ({conf*100:.2f}% confidence)"
211
+ for i, (breed, conf) in enumerate(zip(top3_breeds, top3_confidences))])
212
+
213
+ # Return top 3 breeds and an explanation
214
+ message = (f"The model couldn't confidently identify the breed. Here are the top 3 possible breeds:\n\n"
215
+ f"{top3_info}\n\n"
216
+ "This can happen if the image quality is low or the breed is rare in the dataset. "
217
+ "Please try uploading a clearer image or a different angle of the dog. "
218
+ "For more accurate results, ensure the dog is the main subject of the photo.")
219
+ return message
220
 
221
  except Exception as e:
222
  return f"An error occurred: {e}"
223
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
224
 
225
  iface = gr.Interface(
226
  fn=predict,