Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -177,6 +177,7 @@ def get_akc_breeds_link():
|
|
177 |
# except Exception as e:
|
178 |
# return f"An error occurred: {e}"
|
179 |
|
|
|
180 |
def predict(image):
|
181 |
try:
|
182 |
image_tensor = preprocess_image(image)
|
@@ -187,52 +188,39 @@ def predict(image):
|
|
187 |
else:
|
188 |
logits = output
|
189 |
|
190 |
-
# 計算預測的概率分佈
|
191 |
probabilities = F.softmax(logits, dim=1)
|
192 |
-
|
193 |
-
# 取得最高的預測分數以及對應的品種
|
194 |
top_confidence, top_index = torch.max(probabilities, 1)
|
195 |
-
top_confidence = top_confidence.item()
|
196 |
top_breed = dog_breeds[top_index.item()]
|
197 |
|
198 |
-
#
|
199 |
if top_confidence >= 0.60:
|
200 |
description = get_dog_description(top_breed)
|
201 |
akc_link = get_akc_breeds_link()
|
202 |
description_str = f"**Breed**: {top_breed}\n\n**Description**: {description}\n"
|
203 |
description_str += f"\n\n**Want to learn more about dog breeds?** [Visit the AKC dog breeds page]({akc_link}) and search for {top_breed}."
|
204 |
return description_str
|
205 |
-
|
206 |
-
#
|
207 |
else:
|
208 |
top3_confidences, top3_indices = torch.topk(probabilities, 3, dim=1)
|
209 |
top3_breeds = [dog_breeds[idx] for idx in top3_indices.squeeze().tolist()]
|
210 |
top3_confidences = top3_confidences.squeeze().tolist()
|
211 |
|
212 |
-
|
213 |
-
|
214 |
-
|
215 |
-
|
216 |
-
|
217 |
-
|
|
|
|
|
|
|
|
|
218 |
|
219 |
except Exception as e:
|
220 |
return f"An error occurred: {e}"
|
221 |
|
222 |
-
# 處理用戶選擇的結果
|
223 |
-
def handle_user_selection(top3_breeds, selected_breed):
|
224 |
-
if selected_breed in top3_breeds:
|
225 |
-
breed_index = top3_breeds.index(selected_breed)
|
226 |
-
description = get_dog_description(selected_breed)
|
227 |
-
|
228 |
-
akc_link = get_akc_breeds_link()
|
229 |
-
description_str = f"**Breed**: {selected_breed}\n\n**Description**: {description}\n"
|
230 |
-
description_str += f"\n\n**Want to learn more about dog breeds?** [Visit the AKC dog breeds page]({akc_link}) and search for {selected_breed}."
|
231 |
-
|
232 |
-
return description_str
|
233 |
-
else:
|
234 |
-
return "Sorry, the breed could not be identified. Please try uploading a clearer image or another breed."
|
235 |
-
|
236 |
|
237 |
iface = gr.Interface(
|
238 |
fn=predict,
|
|
|
177 |
# except Exception as e:
|
178 |
# return f"An error occurred: {e}"
|
179 |
|
180 |
+
# Prediction function
|
181 |
def predict(image):
|
182 |
try:
|
183 |
image_tensor = preprocess_image(image)
|
|
|
188 |
else:
|
189 |
logits = output
|
190 |
|
|
|
191 |
probabilities = F.softmax(logits, dim=1)
|
|
|
|
|
192 |
top_confidence, top_index = torch.max(probabilities, 1)
|
193 |
+
top_confidence = top_confidence.item()
|
194 |
top_breed = dog_breeds[top_index.item()]
|
195 |
|
196 |
+
# If confidence is higher than 60%, return the top prediction directly
|
197 |
if top_confidence >= 0.60:
|
198 |
description = get_dog_description(top_breed)
|
199 |
akc_link = get_akc_breeds_link()
|
200 |
description_str = f"**Breed**: {top_breed}\n\n**Description**: {description}\n"
|
201 |
description_str += f"\n\n**Want to learn more about dog breeds?** [Visit the AKC dog breeds page]({akc_link}) and search for {top_breed}."
|
202 |
return description_str
|
203 |
+
|
204 |
+
# If confidence is lower than 60%, return top 3 results and explain why
|
205 |
else:
|
206 |
top3_confidences, top3_indices = torch.topk(probabilities, 3, dim=1)
|
207 |
top3_breeds = [dog_breeds[idx] for idx in top3_indices.squeeze().tolist()]
|
208 |
top3_confidences = top3_confidences.squeeze().tolist()
|
209 |
|
210 |
+
top3_info = "\n\n".join([f"{i+1}. {breed} ({conf*100:.2f}% confidence)"
|
211 |
+
for i, (breed, conf) in enumerate(zip(top3_breeds, top3_confidences))])
|
212 |
+
|
213 |
+
# Return top 3 breeds and an explanation
|
214 |
+
message = (f"The model couldn't confidently identify the breed. Here are the top 3 possible breeds:\n\n"
|
215 |
+
f"{top3_info}\n\n"
|
216 |
+
"This can happen if the image quality is low or the breed is rare in the dataset. "
|
217 |
+
"Please try uploading a clearer image or a different angle of the dog. "
|
218 |
+
"For more accurate results, ensure the dog is the main subject of the photo.")
|
219 |
+
return message
|
220 |
|
221 |
except Exception as e:
|
222 |
return f"An error occurred: {e}"
|
223 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
224 |
|
225 |
iface = gr.Interface(
|
226 |
fn=predict,
|