DawnC commited on
Commit
2196d2b
1 Parent(s): 6196e20

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +214 -41
app.py CHANGED
@@ -27,6 +27,8 @@ from html_templates import (
27
  format_single_dog_result,
28
  format_multiple_breeds_result,
29
  format_error_message,
 
 
30
  format_warning_html,
31
  format_multi_dog_container,
32
  format_breed_details_html,
@@ -238,36 +240,85 @@ def predict_single_dog(image):
238
 
239
  return probabilities[0], breeds[:3], relative_probs
240
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
241
  @spaces.GPU
242
  def detect_multiple_dogs(image, conf_threshold=0.3, iou_threshold=0.55):
 
 
 
243
 
 
 
 
 
 
 
 
 
244
  results = model_manager.yolo_model(image, conf=conf_threshold,
245
  iou=iou_threshold)[0]
246
 
247
  dogs = []
248
  boxes = []
 
 
249
  for box in results.boxes:
250
- if box.cls == 16: # COCO dataset class for dog is 16
 
251
  xyxy = box.xyxy[0].tolist()
252
  confidence = box.conf.item()
253
- boxes.append((xyxy, confidence))
254
 
255
  if not boxes:
256
- dogs.append((image, 1.0, [0, 0, image.width, image.height]))
257
- else:
258
- nms_boxes = non_max_suppression(boxes, iou_threshold)
259
-
260
- for box, confidence in nms_boxes:
261
- x1, y1, x2, y2 = box
262
- w, h = x2 - x1, y2 - y1
263
- x1 = max(0, x1 - w * 0.05)
264
- y1 = max(0, y1 - h * 0.05)
265
- x2 = min(image.width, x2 + w * 0.05)
266
- y2 = min(image.height, y2 + h * 0.05)
267
- cropped_image = image.crop((x1, y1, x2, y2))
268
- dogs.append((cropped_image, confidence, [x1, y1, x2, y2]))
269
-
270
- return dogs
 
 
 
 
271
 
272
  def non_max_suppression(boxes, iou_threshold):
273
  keep = []
@@ -324,17 +375,137 @@ def create_breed_comparison(breed1: str, breed2: str) -> dict:
324
  return comparison_data
325
 
326
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
327
  def predict(image):
328
  """
329
- Main prediction function that handles both single and multiple dog detection.
330
-
 
331
  Args:
332
- image: PIL Image or numpy array
333
-
334
  Returns:
335
  tuple: (html_output, annotated_image, initial_state)
336
  """
337
-
338
  if image is None:
339
  return format_warning_html("Please upload an image to start."), None, None
340
 
@@ -342,11 +513,11 @@ def predict(image):
342
  if isinstance(image, np.ndarray):
343
  image = Image.fromarray(image)
344
 
345
- # Detect dogs in the image
346
  dogs = detect_multiple_dogs(image)
347
  color_scheme = get_color_scheme(len(dogs) == 1)
348
 
349
- # Prepare for annotation
350
  annotated_image = image.copy()
351
  draw = ImageDraw.Draw(annotated_image)
352
 
@@ -357,18 +528,18 @@ def predict(image):
357
 
358
  dogs_info = ""
359
 
360
- # Process each detected dog
361
- for i, (cropped_image, detection_confidence, box) in enumerate(dogs):
362
  color = color_scheme if len(dogs) == 1 else color_scheme[i % len(color_scheme)]
363
 
364
- # Draw box and label on image
365
  draw.rectangle(box, outline=color, width=4)
366
- label = f"Dog {i+1}"
367
  label_bbox = draw.textbbox((0, 0), label, font=font)
368
  label_width = label_bbox[2] - label_bbox[0]
369
  label_height = label_bbox[3] - label_bbox[1]
370
 
371
- # Draw label background and text
372
  label_x = box[0] + 5
373
  label_y = box[1] + 5
374
  draw.rectangle(
@@ -379,20 +550,23 @@ def predict(image):
379
  )
380
  draw.text((label_x, label_y), label, fill=color, font=font)
381
 
382
- # Predict breed
383
- top1_prob, topk_breeds, relative_probs = predict_single_dog(cropped_image)
384
- combined_confidence = detection_confidence * top1_prob
385
-
386
- # Format results based on confidence with error handling
387
  try:
 
 
 
 
 
 
 
 
 
 
388
  if combined_confidence < 0.2:
389
- dogs_info += format_error_message(color, i+1)
390
  elif top1_prob >= 0.45:
391
  breed = topk_breeds[0]
392
  description = get_dog_description(breed)
393
- # Handle missing breed description
394
  if description is None:
395
- # 如果沒有描述,創建一個基本描述
396
  description = {
397
  "Name": breed,
398
  "Size": "Unknown",
@@ -404,7 +578,6 @@ def predict(image):
404
  }
405
  dogs_info += format_single_dog_result(breed, description, color)
406
  else:
407
- # 修改format_multiple_breeds_result的調用,包含錯誤處理
408
  dogs_info += format_multiple_breeds_result(
409
  topk_breeds,
410
  relative_probs,
@@ -422,12 +595,12 @@ def predict(image):
422
  )
423
  except Exception as e:
424
  print(f"Error formatting results for dog {i+1}: {str(e)}")
425
- dogs_info += format_error_message(color, i+1)
426
 
427
- # Wrap final HTML output
428
  html_output = format_multi_dog_container(dogs_info)
429
 
430
- # Prepare initial state
431
  initial_state = {
432
  "dogs_info": dogs_info,
433
  "image": annotated_image,
 
27
  format_single_dog_result,
28
  format_multiple_breeds_result,
29
  format_error_message,
30
+ format_unknown_breed_message,
31
+ format_not_dog_message,
32
  format_warning_html,
33
  format_multi_dog_container,
34
  format_breed_details_html,
 
240
 
241
  return probabilities[0], breeds[:3], relative_probs
242
 
243
+ # @spaces.GPU
244
+ # def detect_multiple_dogs(image, conf_threshold=0.3, iou_threshold=0.55):
245
+
246
+ # results = model_manager.yolo_model(image, conf=conf_threshold,
247
+ # iou=iou_threshold)[0]
248
+
249
+ # dogs = []
250
+ # boxes = []
251
+ # for box in results.boxes:
252
+ # if box.cls == 16: # COCO dataset class for dog is 16
253
+ # xyxy = box.xyxy[0].tolist()
254
+ # confidence = box.conf.item()
255
+ # boxes.append((xyxy, confidence))
256
+
257
+ # if not boxes:
258
+ # dogs.append((image, 1.0, [0, 0, image.width, image.height]))
259
+ # else:
260
+ # nms_boxes = non_max_suppression(boxes, iou_threshold)
261
+
262
+ # for box, confidence in nms_boxes:
263
+ # x1, y1, x2, y2 = box
264
+ # w, h = x2 - x1, y2 - y1
265
+ # x1 = max(0, x1 - w * 0.05)
266
+ # y1 = max(0, y1 - h * 0.05)
267
+ # x2 = min(image.width, x2 + w * 0.05)
268
+ # y2 = min(image.height, y2 + h * 0.05)
269
+ # cropped_image = image.crop((x1, y1, x2, y2))
270
+ # dogs.append((cropped_image, confidence, [x1, y1, x2, y2]))
271
+
272
+ # return dogs
273
+
274
  @spaces.GPU
275
  def detect_multiple_dogs(image, conf_threshold=0.3, iou_threshold=0.55):
276
+ """
277
+ 使用YOLO模型檢測圖片中的狗。
278
+ 只保留被識別為狗(class 16)的物體,並標記它們的狀態。
279
 
280
+ Args:
281
+ image: PIL Image
282
+ conf_threshold: YOLO檢測的信心度閾值
283
+ iou_threshold: 非極大值抑制的IoU閾值
284
+
285
+ Returns:
286
+ list: 包含檢測到的狗的列表,每個元素是(cropped_image, confidence, box, is_dog)的元組
287
+ """
288
  results = model_manager.yolo_model(image, conf=conf_threshold,
289
  iou=iou_threshold)[0]
290
 
291
  dogs = []
292
  boxes = []
293
+
294
+ # 只處理被識別為狗的物體
295
  for box in results.boxes:
296
+ class_id = box.cls.item()
297
+ if class_id == 16: # COCO dataset中狗的類別是16
298
  xyxy = box.xyxy[0].tolist()
299
  confidence = box.conf.item()
300
+ boxes.append((xyxy, confidence, True)) # 加入is_dog標記
301
 
302
  if not boxes:
303
+ # 如果沒有檢測到狗,返回整張圖片並標記為非狗
304
+ return [(image, 1.0, [0, 0, image.width, image.height], False)]
305
+
306
+ nms_boxes = non_max_suppression(boxes, iou_threshold)
307
+ detected_objects = []
308
+
309
+ # 處理每個檢測到的狗
310
+ for box, confidence, is_dog in nms_boxes:
311
+ x1, y1, x2, y2 = box
312
+ w, h = x2 - x1, y2 - y1
313
+ # 擴大檢測框範圍以包含完整的狗
314
+ x1 = max(0, x1 - w * 0.05)
315
+ y1 = max(0, y1 - h * 0.05)
316
+ x2 = min(image.width, x2 + w * 0.05)
317
+ y2 = min(image.height, y2 + h * 0.05)
318
+ cropped_image = image.crop((x1, y1, x2, y2))
319
+ detected_objects.append((cropped_image, confidence, [x1, y1, x2, y2], is_dog))
320
+
321
+ return detected_objects
322
 
323
  def non_max_suppression(boxes, iou_threshold):
324
  keep = []
 
375
  return comparison_data
376
 
377
 
378
+ # def predict(image):
379
+ # """
380
+ # Main prediction function that handles both single and multiple dog detection.
381
+
382
+ # Args:
383
+ # image: PIL Image or numpy array
384
+
385
+ # Returns:
386
+ # tuple: (html_output, annotated_image, initial_state)
387
+ # """
388
+
389
+ # if image is None:
390
+ # return format_warning_html("Please upload an image to start."), None, None
391
+
392
+ # try:
393
+ # if isinstance(image, np.ndarray):
394
+ # image = Image.fromarray(image)
395
+
396
+ # # Detect dogs in the image
397
+ # dogs = detect_multiple_dogs(image)
398
+ # color_scheme = get_color_scheme(len(dogs) == 1)
399
+
400
+ # # Prepare for annotation
401
+ # annotated_image = image.copy()
402
+ # draw = ImageDraw.Draw(annotated_image)
403
+
404
+ # try:
405
+ # font = ImageFont.truetype("arial.ttf", 24)
406
+ # except:
407
+ # font = ImageFont.load_default()
408
+
409
+ # dogs_info = ""
410
+
411
+ # # Process each detected dog
412
+ # for i, (cropped_image, detection_confidence, box) in enumerate(dogs):
413
+ # color = color_scheme if len(dogs) == 1 else color_scheme[i % len(color_scheme)]
414
+
415
+ # # Draw box and label on image
416
+ # draw.rectangle(box, outline=color, width=4)
417
+ # label = f"Dog {i+1}"
418
+ # label_bbox = draw.textbbox((0, 0), label, font=font)
419
+ # label_width = label_bbox[2] - label_bbox[0]
420
+ # label_height = label_bbox[3] - label_bbox[1]
421
+
422
+ # # Draw label background and text
423
+ # label_x = box[0] + 5
424
+ # label_y = box[1] + 5
425
+ # draw.rectangle(
426
+ # [label_x - 2, label_y - 2, label_x + label_width + 4, label_y + label_height + 4],
427
+ # fill='white',
428
+ # outline=color,
429
+ # width=2
430
+ # )
431
+ # draw.text((label_x, label_y), label, fill=color, font=font)
432
+
433
+ # # Predict breed
434
+ # top1_prob, topk_breeds, relative_probs = predict_single_dog(cropped_image)
435
+ # combined_confidence = detection_confidence * top1_prob
436
+
437
+ # # Format results based on confidence with error handling
438
+ # try:
439
+ # if combined_confidence < 0.2:
440
+ # dogs_info += format_error_message(color, i+1)
441
+ # elif top1_prob >= 0.45:
442
+ # breed = topk_breeds[0]
443
+ # description = get_dog_description(breed)
444
+ # # Handle missing breed description
445
+ # if description is None:
446
+ # # 如果沒有描述,創建一個基本描述
447
+ # description = {
448
+ # "Name": breed,
449
+ # "Size": "Unknown",
450
+ # "Exercise Needs": "Unknown",
451
+ # "Grooming Needs": "Unknown",
452
+ # "Care Level": "Unknown",
453
+ # "Good with Children": "Unknown",
454
+ # "Description": f"Identified as {breed.replace('_', ' ')}"
455
+ # }
456
+ # dogs_info += format_single_dog_result(breed, description, color)
457
+ # else:
458
+ # # 修改format_multiple_breeds_result的調用,包含錯誤處理
459
+ # dogs_info += format_multiple_breeds_result(
460
+ # topk_breeds,
461
+ # relative_probs,
462
+ # color,
463
+ # i+1,
464
+ # lambda breed: get_dog_description(breed) or {
465
+ # "Name": breed,
466
+ # "Size": "Unknown",
467
+ # "Exercise Needs": "Unknown",
468
+ # "Grooming Needs": "Unknown",
469
+ # "Care Level": "Unknown",
470
+ # "Good with Children": "Unknown",
471
+ # "Description": f"Identified as {breed.replace('_', ' ')}"
472
+ # }
473
+ # )
474
+ # except Exception as e:
475
+ # print(f"Error formatting results for dog {i+1}: {str(e)}")
476
+ # dogs_info += format_error_message(color, i+1)
477
+
478
+ # # Wrap final HTML output
479
+ # html_output = format_multi_dog_container(dogs_info)
480
+
481
+ # # Prepare initial state
482
+ # initial_state = {
483
+ # "dogs_info": dogs_info,
484
+ # "image": annotated_image,
485
+ # "is_multi_dog": len(dogs) > 1,
486
+ # "html_output": html_output
487
+ # }
488
+
489
+ # return html_output, annotated_image, initial_state
490
+
491
+ # except Exception as e:
492
+ # error_msg = f"An error occurred: {str(e)}\n\nTraceback:\n{traceback.format_exc()}"
493
+ # print(error_msg)
494
+ # return format_warning_html(error_msg), None, None
495
+
496
+
497
+ @spaces.GPU
498
  def predict(image):
499
  """
500
+ 主要的預測函數,負責處理狗的檢測和品種辨識。
501
+ 它整合了YOLO的物體檢測和專門的品種分類模型。
502
+
503
  Args:
504
+ image: PIL Image numpy array
505
+
506
  Returns:
507
  tuple: (html_output, annotated_image, initial_state)
508
  """
 
509
  if image is None:
510
  return format_warning_html("Please upload an image to start."), None, None
511
 
 
513
  if isinstance(image, np.ndarray):
514
  image = Image.fromarray(image)
515
 
516
+ # 檢測圖片中的狗
517
  dogs = detect_multiple_dogs(image)
518
  color_scheme = get_color_scheme(len(dogs) == 1)
519
 
520
+ # 準備標註
521
  annotated_image = image.copy()
522
  draw = ImageDraw.Draw(annotated_image)
523
 
 
528
 
529
  dogs_info = ""
530
 
531
+ # 處理每個檢測到的物體
532
+ for i, (cropped_image, detection_confidence, box, is_dog) in enumerate(dogs):
533
  color = color_scheme if len(dogs) == 1 else color_scheme[i % len(color_scheme)]
534
 
535
+ # 繪製框和標籤
536
  draw.rectangle(box, outline=color, width=4)
537
+ label = f"Dog {i+1}" if is_dog else f"Object {i+1}"
538
  label_bbox = draw.textbbox((0, 0), label, font=font)
539
  label_width = label_bbox[2] - label_bbox[0]
540
  label_height = label_bbox[3] - label_bbox[1]
541
 
542
+ # 繪製標籤背景和文字
543
  label_x = box[0] + 5
544
  label_y = box[1] + 5
545
  draw.rectangle(
 
550
  )
551
  draw.text((label_x, label_y), label, fill=color, font=font)
552
 
 
 
 
 
 
553
  try:
554
+ # 首先檢查是否為狗
555
+ if not is_dog:
556
+ dogs_info += format_not_dog_message(color, i+1)
557
+ continue
558
+
559
+ # 如果是狗,進行品種預測
560
+ top1_prob, topk_breeds, relative_probs = predict_single_dog(cropped_image)
561
+ combined_confidence = detection_confidence * top1_prob
562
+
563
+ # 根據信心度決定輸出格式
564
  if combined_confidence < 0.2:
565
+ dogs_info += format_unknown_breed_message(color, i+1)
566
  elif top1_prob >= 0.45:
567
  breed = topk_breeds[0]
568
  description = get_dog_description(breed)
 
569
  if description is None:
 
570
  description = {
571
  "Name": breed,
572
  "Size": "Unknown",
 
578
  }
579
  dogs_info += format_single_dog_result(breed, description, color)
580
  else:
 
581
  dogs_info += format_multiple_breeds_result(
582
  topk_breeds,
583
  relative_probs,
 
595
  )
596
  except Exception as e:
597
  print(f"Error formatting results for dog {i+1}: {str(e)}")
598
+ dogs_info += format_unknown_breed_message(color, i+1)
599
 
600
+ # 包裝最終的HTML輸出
601
  html_output = format_multi_dog_container(dogs_info)
602
 
603
+ # 準備初始狀態
604
  initial_state = {
605
  "dogs_info": dogs_info,
606
  "image": annotated_image,