Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -27,6 +27,8 @@ from html_templates import (
|
|
27 |
format_single_dog_result,
|
28 |
format_multiple_breeds_result,
|
29 |
format_error_message,
|
|
|
|
|
30 |
format_warning_html,
|
31 |
format_multi_dog_container,
|
32 |
format_breed_details_html,
|
@@ -238,36 +240,85 @@ def predict_single_dog(image):
|
|
238 |
|
239 |
return probabilities[0], breeds[:3], relative_probs
|
240 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
241 |
@spaces.GPU
|
242 |
def detect_multiple_dogs(image, conf_threshold=0.3, iou_threshold=0.55):
|
|
|
|
|
|
|
243 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
244 |
results = model_manager.yolo_model(image, conf=conf_threshold,
|
245 |
iou=iou_threshold)[0]
|
246 |
|
247 |
dogs = []
|
248 |
boxes = []
|
|
|
|
|
249 |
for box in results.boxes:
|
250 |
-
|
|
|
251 |
xyxy = box.xyxy[0].tolist()
|
252 |
confidence = box.conf.item()
|
253 |
-
boxes.append((xyxy, confidence))
|
254 |
|
255 |
if not boxes:
|
256 |
-
|
257 |
-
|
258 |
-
|
259 |
-
|
260 |
-
|
261 |
-
|
262 |
-
|
263 |
-
|
264 |
-
|
265 |
-
|
266 |
-
|
267 |
-
|
268 |
-
|
269 |
-
|
270 |
-
|
|
|
|
|
|
|
|
|
271 |
|
272 |
def non_max_suppression(boxes, iou_threshold):
|
273 |
keep = []
|
@@ -324,17 +375,137 @@ def create_breed_comparison(breed1: str, breed2: str) -> dict:
|
|
324 |
return comparison_data
|
325 |
|
326 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
327 |
def predict(image):
|
328 |
"""
|
329 |
-
|
330 |
-
|
|
|
331 |
Args:
|
332 |
-
image: PIL Image
|
333 |
-
|
334 |
Returns:
|
335 |
tuple: (html_output, annotated_image, initial_state)
|
336 |
"""
|
337 |
-
|
338 |
if image is None:
|
339 |
return format_warning_html("Please upload an image to start."), None, None
|
340 |
|
@@ -342,11 +513,11 @@ def predict(image):
|
|
342 |
if isinstance(image, np.ndarray):
|
343 |
image = Image.fromarray(image)
|
344 |
|
345 |
-
#
|
346 |
dogs = detect_multiple_dogs(image)
|
347 |
color_scheme = get_color_scheme(len(dogs) == 1)
|
348 |
|
349 |
-
#
|
350 |
annotated_image = image.copy()
|
351 |
draw = ImageDraw.Draw(annotated_image)
|
352 |
|
@@ -357,18 +528,18 @@ def predict(image):
|
|
357 |
|
358 |
dogs_info = ""
|
359 |
|
360 |
-
#
|
361 |
-
for i, (cropped_image, detection_confidence, box) in enumerate(dogs):
|
362 |
color = color_scheme if len(dogs) == 1 else color_scheme[i % len(color_scheme)]
|
363 |
|
364 |
-
#
|
365 |
draw.rectangle(box, outline=color, width=4)
|
366 |
-
label = f"Dog {i+1}"
|
367 |
label_bbox = draw.textbbox((0, 0), label, font=font)
|
368 |
label_width = label_bbox[2] - label_bbox[0]
|
369 |
label_height = label_bbox[3] - label_bbox[1]
|
370 |
|
371 |
-
#
|
372 |
label_x = box[0] + 5
|
373 |
label_y = box[1] + 5
|
374 |
draw.rectangle(
|
@@ -379,20 +550,23 @@ def predict(image):
|
|
379 |
)
|
380 |
draw.text((label_x, label_y), label, fill=color, font=font)
|
381 |
|
382 |
-
# Predict breed
|
383 |
-
top1_prob, topk_breeds, relative_probs = predict_single_dog(cropped_image)
|
384 |
-
combined_confidence = detection_confidence * top1_prob
|
385 |
-
|
386 |
-
# Format results based on confidence with error handling
|
387 |
try:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
388 |
if combined_confidence < 0.2:
|
389 |
-
dogs_info +=
|
390 |
elif top1_prob >= 0.45:
|
391 |
breed = topk_breeds[0]
|
392 |
description = get_dog_description(breed)
|
393 |
-
# Handle missing breed description
|
394 |
if description is None:
|
395 |
-
# 如果沒有描述,創建一個基本描述
|
396 |
description = {
|
397 |
"Name": breed,
|
398 |
"Size": "Unknown",
|
@@ -404,7 +578,6 @@ def predict(image):
|
|
404 |
}
|
405 |
dogs_info += format_single_dog_result(breed, description, color)
|
406 |
else:
|
407 |
-
# 修改format_multiple_breeds_result的調用,包含錯誤處理
|
408 |
dogs_info += format_multiple_breeds_result(
|
409 |
topk_breeds,
|
410 |
relative_probs,
|
@@ -422,12 +595,12 @@ def predict(image):
|
|
422 |
)
|
423 |
except Exception as e:
|
424 |
print(f"Error formatting results for dog {i+1}: {str(e)}")
|
425 |
-
dogs_info +=
|
426 |
|
427 |
-
#
|
428 |
html_output = format_multi_dog_container(dogs_info)
|
429 |
|
430 |
-
#
|
431 |
initial_state = {
|
432 |
"dogs_info": dogs_info,
|
433 |
"image": annotated_image,
|
|
|
27 |
format_single_dog_result,
|
28 |
format_multiple_breeds_result,
|
29 |
format_error_message,
|
30 |
+
format_unknown_breed_message,
|
31 |
+
format_not_dog_message,
|
32 |
format_warning_html,
|
33 |
format_multi_dog_container,
|
34 |
format_breed_details_html,
|
|
|
240 |
|
241 |
return probabilities[0], breeds[:3], relative_probs
|
242 |
|
243 |
+
# @spaces.GPU
|
244 |
+
# def detect_multiple_dogs(image, conf_threshold=0.3, iou_threshold=0.55):
|
245 |
+
|
246 |
+
# results = model_manager.yolo_model(image, conf=conf_threshold,
|
247 |
+
# iou=iou_threshold)[0]
|
248 |
+
|
249 |
+
# dogs = []
|
250 |
+
# boxes = []
|
251 |
+
# for box in results.boxes:
|
252 |
+
# if box.cls == 16: # COCO dataset class for dog is 16
|
253 |
+
# xyxy = box.xyxy[0].tolist()
|
254 |
+
# confidence = box.conf.item()
|
255 |
+
# boxes.append((xyxy, confidence))
|
256 |
+
|
257 |
+
# if not boxes:
|
258 |
+
# dogs.append((image, 1.0, [0, 0, image.width, image.height]))
|
259 |
+
# else:
|
260 |
+
# nms_boxes = non_max_suppression(boxes, iou_threshold)
|
261 |
+
|
262 |
+
# for box, confidence in nms_boxes:
|
263 |
+
# x1, y1, x2, y2 = box
|
264 |
+
# w, h = x2 - x1, y2 - y1
|
265 |
+
# x1 = max(0, x1 - w * 0.05)
|
266 |
+
# y1 = max(0, y1 - h * 0.05)
|
267 |
+
# x2 = min(image.width, x2 + w * 0.05)
|
268 |
+
# y2 = min(image.height, y2 + h * 0.05)
|
269 |
+
# cropped_image = image.crop((x1, y1, x2, y2))
|
270 |
+
# dogs.append((cropped_image, confidence, [x1, y1, x2, y2]))
|
271 |
+
|
272 |
+
# return dogs
|
273 |
+
|
274 |
@spaces.GPU
|
275 |
def detect_multiple_dogs(image, conf_threshold=0.3, iou_threshold=0.55):
|
276 |
+
"""
|
277 |
+
使用YOLO模型檢測圖片中的狗。
|
278 |
+
只保留被識別為狗(class 16)的物體,並標記它們的狀態。
|
279 |
|
280 |
+
Args:
|
281 |
+
image: PIL Image
|
282 |
+
conf_threshold: YOLO檢測的信心度閾值
|
283 |
+
iou_threshold: 非極大值抑制的IoU閾值
|
284 |
+
|
285 |
+
Returns:
|
286 |
+
list: 包含檢測到的狗的列表,每個元素是(cropped_image, confidence, box, is_dog)的元組
|
287 |
+
"""
|
288 |
results = model_manager.yolo_model(image, conf=conf_threshold,
|
289 |
iou=iou_threshold)[0]
|
290 |
|
291 |
dogs = []
|
292 |
boxes = []
|
293 |
+
|
294 |
+
# 只處理被識別為狗的物體
|
295 |
for box in results.boxes:
|
296 |
+
class_id = box.cls.item()
|
297 |
+
if class_id == 16: # COCO dataset中狗的類別是16
|
298 |
xyxy = box.xyxy[0].tolist()
|
299 |
confidence = box.conf.item()
|
300 |
+
boxes.append((xyxy, confidence, True)) # 加入is_dog標記
|
301 |
|
302 |
if not boxes:
|
303 |
+
# 如果沒有檢測到狗,返回整張圖片並標記為非狗
|
304 |
+
return [(image, 1.0, [0, 0, image.width, image.height], False)]
|
305 |
+
|
306 |
+
nms_boxes = non_max_suppression(boxes, iou_threshold)
|
307 |
+
detected_objects = []
|
308 |
+
|
309 |
+
# 處理每個檢測到的狗
|
310 |
+
for box, confidence, is_dog in nms_boxes:
|
311 |
+
x1, y1, x2, y2 = box
|
312 |
+
w, h = x2 - x1, y2 - y1
|
313 |
+
# 擴大檢測框範圍以包含完整的狗
|
314 |
+
x1 = max(0, x1 - w * 0.05)
|
315 |
+
y1 = max(0, y1 - h * 0.05)
|
316 |
+
x2 = min(image.width, x2 + w * 0.05)
|
317 |
+
y2 = min(image.height, y2 + h * 0.05)
|
318 |
+
cropped_image = image.crop((x1, y1, x2, y2))
|
319 |
+
detected_objects.append((cropped_image, confidence, [x1, y1, x2, y2], is_dog))
|
320 |
+
|
321 |
+
return detected_objects
|
322 |
|
323 |
def non_max_suppression(boxes, iou_threshold):
|
324 |
keep = []
|
|
|
375 |
return comparison_data
|
376 |
|
377 |
|
378 |
+
# def predict(image):
|
379 |
+
# """
|
380 |
+
# Main prediction function that handles both single and multiple dog detection.
|
381 |
+
|
382 |
+
# Args:
|
383 |
+
# image: PIL Image or numpy array
|
384 |
+
|
385 |
+
# Returns:
|
386 |
+
# tuple: (html_output, annotated_image, initial_state)
|
387 |
+
# """
|
388 |
+
|
389 |
+
# if image is None:
|
390 |
+
# return format_warning_html("Please upload an image to start."), None, None
|
391 |
+
|
392 |
+
# try:
|
393 |
+
# if isinstance(image, np.ndarray):
|
394 |
+
# image = Image.fromarray(image)
|
395 |
+
|
396 |
+
# # Detect dogs in the image
|
397 |
+
# dogs = detect_multiple_dogs(image)
|
398 |
+
# color_scheme = get_color_scheme(len(dogs) == 1)
|
399 |
+
|
400 |
+
# # Prepare for annotation
|
401 |
+
# annotated_image = image.copy()
|
402 |
+
# draw = ImageDraw.Draw(annotated_image)
|
403 |
+
|
404 |
+
# try:
|
405 |
+
# font = ImageFont.truetype("arial.ttf", 24)
|
406 |
+
# except:
|
407 |
+
# font = ImageFont.load_default()
|
408 |
+
|
409 |
+
# dogs_info = ""
|
410 |
+
|
411 |
+
# # Process each detected dog
|
412 |
+
# for i, (cropped_image, detection_confidence, box) in enumerate(dogs):
|
413 |
+
# color = color_scheme if len(dogs) == 1 else color_scheme[i % len(color_scheme)]
|
414 |
+
|
415 |
+
# # Draw box and label on image
|
416 |
+
# draw.rectangle(box, outline=color, width=4)
|
417 |
+
# label = f"Dog {i+1}"
|
418 |
+
# label_bbox = draw.textbbox((0, 0), label, font=font)
|
419 |
+
# label_width = label_bbox[2] - label_bbox[0]
|
420 |
+
# label_height = label_bbox[3] - label_bbox[1]
|
421 |
+
|
422 |
+
# # Draw label background and text
|
423 |
+
# label_x = box[0] + 5
|
424 |
+
# label_y = box[1] + 5
|
425 |
+
# draw.rectangle(
|
426 |
+
# [label_x - 2, label_y - 2, label_x + label_width + 4, label_y + label_height + 4],
|
427 |
+
# fill='white',
|
428 |
+
# outline=color,
|
429 |
+
# width=2
|
430 |
+
# )
|
431 |
+
# draw.text((label_x, label_y), label, fill=color, font=font)
|
432 |
+
|
433 |
+
# # Predict breed
|
434 |
+
# top1_prob, topk_breeds, relative_probs = predict_single_dog(cropped_image)
|
435 |
+
# combined_confidence = detection_confidence * top1_prob
|
436 |
+
|
437 |
+
# # Format results based on confidence with error handling
|
438 |
+
# try:
|
439 |
+
# if combined_confidence < 0.2:
|
440 |
+
# dogs_info += format_error_message(color, i+1)
|
441 |
+
# elif top1_prob >= 0.45:
|
442 |
+
# breed = topk_breeds[0]
|
443 |
+
# description = get_dog_description(breed)
|
444 |
+
# # Handle missing breed description
|
445 |
+
# if description is None:
|
446 |
+
# # 如果沒有描述,創建一個基本描述
|
447 |
+
# description = {
|
448 |
+
# "Name": breed,
|
449 |
+
# "Size": "Unknown",
|
450 |
+
# "Exercise Needs": "Unknown",
|
451 |
+
# "Grooming Needs": "Unknown",
|
452 |
+
# "Care Level": "Unknown",
|
453 |
+
# "Good with Children": "Unknown",
|
454 |
+
# "Description": f"Identified as {breed.replace('_', ' ')}"
|
455 |
+
# }
|
456 |
+
# dogs_info += format_single_dog_result(breed, description, color)
|
457 |
+
# else:
|
458 |
+
# # 修改format_multiple_breeds_result的調用,包含錯誤處理
|
459 |
+
# dogs_info += format_multiple_breeds_result(
|
460 |
+
# topk_breeds,
|
461 |
+
# relative_probs,
|
462 |
+
# color,
|
463 |
+
# i+1,
|
464 |
+
# lambda breed: get_dog_description(breed) or {
|
465 |
+
# "Name": breed,
|
466 |
+
# "Size": "Unknown",
|
467 |
+
# "Exercise Needs": "Unknown",
|
468 |
+
# "Grooming Needs": "Unknown",
|
469 |
+
# "Care Level": "Unknown",
|
470 |
+
# "Good with Children": "Unknown",
|
471 |
+
# "Description": f"Identified as {breed.replace('_', ' ')}"
|
472 |
+
# }
|
473 |
+
# )
|
474 |
+
# except Exception as e:
|
475 |
+
# print(f"Error formatting results for dog {i+1}: {str(e)}")
|
476 |
+
# dogs_info += format_error_message(color, i+1)
|
477 |
+
|
478 |
+
# # Wrap final HTML output
|
479 |
+
# html_output = format_multi_dog_container(dogs_info)
|
480 |
+
|
481 |
+
# # Prepare initial state
|
482 |
+
# initial_state = {
|
483 |
+
# "dogs_info": dogs_info,
|
484 |
+
# "image": annotated_image,
|
485 |
+
# "is_multi_dog": len(dogs) > 1,
|
486 |
+
# "html_output": html_output
|
487 |
+
# }
|
488 |
+
|
489 |
+
# return html_output, annotated_image, initial_state
|
490 |
+
|
491 |
+
# except Exception as e:
|
492 |
+
# error_msg = f"An error occurred: {str(e)}\n\nTraceback:\n{traceback.format_exc()}"
|
493 |
+
# print(error_msg)
|
494 |
+
# return format_warning_html(error_msg), None, None
|
495 |
+
|
496 |
+
|
497 |
+
@spaces.GPU
|
498 |
def predict(image):
|
499 |
"""
|
500 |
+
主要的預測函數,負責處理狗的檢測和品種辨識。
|
501 |
+
它整合了YOLO的物體檢測和專門的品種分類模型。
|
502 |
+
|
503 |
Args:
|
504 |
+
image: PIL Image 或 numpy array
|
505 |
+
|
506 |
Returns:
|
507 |
tuple: (html_output, annotated_image, initial_state)
|
508 |
"""
|
|
|
509 |
if image is None:
|
510 |
return format_warning_html("Please upload an image to start."), None, None
|
511 |
|
|
|
513 |
if isinstance(image, np.ndarray):
|
514 |
image = Image.fromarray(image)
|
515 |
|
516 |
+
# 檢測圖片中的狗
|
517 |
dogs = detect_multiple_dogs(image)
|
518 |
color_scheme = get_color_scheme(len(dogs) == 1)
|
519 |
|
520 |
+
# 準備標註
|
521 |
annotated_image = image.copy()
|
522 |
draw = ImageDraw.Draw(annotated_image)
|
523 |
|
|
|
528 |
|
529 |
dogs_info = ""
|
530 |
|
531 |
+
# 處理每個檢測到的物體
|
532 |
+
for i, (cropped_image, detection_confidence, box, is_dog) in enumerate(dogs):
|
533 |
color = color_scheme if len(dogs) == 1 else color_scheme[i % len(color_scheme)]
|
534 |
|
535 |
+
# 繪製框和標籤
|
536 |
draw.rectangle(box, outline=color, width=4)
|
537 |
+
label = f"Dog {i+1}" if is_dog else f"Object {i+1}"
|
538 |
label_bbox = draw.textbbox((0, 0), label, font=font)
|
539 |
label_width = label_bbox[2] - label_bbox[0]
|
540 |
label_height = label_bbox[3] - label_bbox[1]
|
541 |
|
542 |
+
# 繪製標籤背景和文字
|
543 |
label_x = box[0] + 5
|
544 |
label_y = box[1] + 5
|
545 |
draw.rectangle(
|
|
|
550 |
)
|
551 |
draw.text((label_x, label_y), label, fill=color, font=font)
|
552 |
|
|
|
|
|
|
|
|
|
|
|
553 |
try:
|
554 |
+
# 首先檢查是否為狗
|
555 |
+
if not is_dog:
|
556 |
+
dogs_info += format_not_dog_message(color, i+1)
|
557 |
+
continue
|
558 |
+
|
559 |
+
# 如果是狗,進行品種預測
|
560 |
+
top1_prob, topk_breeds, relative_probs = predict_single_dog(cropped_image)
|
561 |
+
combined_confidence = detection_confidence * top1_prob
|
562 |
+
|
563 |
+
# 根據信心度決定輸出格式
|
564 |
if combined_confidence < 0.2:
|
565 |
+
dogs_info += format_unknown_breed_message(color, i+1)
|
566 |
elif top1_prob >= 0.45:
|
567 |
breed = topk_breeds[0]
|
568 |
description = get_dog_description(breed)
|
|
|
569 |
if description is None:
|
|
|
570 |
description = {
|
571 |
"Name": breed,
|
572 |
"Size": "Unknown",
|
|
|
578 |
}
|
579 |
dogs_info += format_single_dog_result(breed, description, color)
|
580 |
else:
|
|
|
581 |
dogs_info += format_multiple_breeds_result(
|
582 |
topk_breeds,
|
583 |
relative_probs,
|
|
|
595 |
)
|
596 |
except Exception as e:
|
597 |
print(f"Error formatting results for dog {i+1}: {str(e)}")
|
598 |
+
dogs_info += format_unknown_breed_message(color, i+1)
|
599 |
|
600 |
+
# 包裝最終的HTML輸出
|
601 |
html_output = format_multi_dog_container(dogs_info)
|
602 |
|
603 |
+
# 準備初始狀態
|
604 |
initial_state = {
|
605 |
"dogs_info": dogs_info,
|
606 |
"image": annotated_image,
|