import os
import numpy as np
import torch
import torch.nn as nn
import gradio as gr
import time
import traceback
import spaces
from torchvision.models import efficientnet_v2_m, EfficientNet_V2_M_Weights
from torchvision.ops import nms, box_iou
import torch.nn.functional as F
from torchvision import transforms
from PIL import Image, ImageDraw, ImageFont, ImageFilter
from breed_health_info import breed_health_info
from breed_noise_info import breed_noise_info
from dog_database import get_dog_description
from scoring_calculation_system import UserPreferences
from recommendation_html_format import format_recommendation_html, get_breed_recommendations
from history_manager import UserHistoryManager
from search_history import create_history_tab, create_history_component
from styles import get_css_styles
from breed_detection import create_detection_tab
from breed_comparison import create_comparison_tab
from breed_recommendation import create_recommendation_tab
from html_templates import (
format_description_html,
format_single_dog_result,
format_multiple_breeds_result,
format_unknown_breed_message,
format_not_dog_message,
format_hint_html,
format_multi_dog_container,
format_breed_details_html,
get_color_scheme,
get_akc_breeds_link
)
from urllib.parse import quote
from ultralytics import YOLO
from functools import wraps
history_manager = UserHistoryManager()
dog_breeds = ["Afghan_Hound", "African_Hunting_Dog", "Airedale", "American_Staffordshire_Terrier",
"Appenzeller", "Australian_Terrier", "Bedlington_Terrier", "Bernese_Mountain_Dog", "Bichon_Frise",
"Blenheim_Spaniel", "Border_Collie", "Border_Terrier", "Boston_Bull", "Bouvier_Des_Flandres",
"Brabancon_Griffon", "Brittany_Spaniel", "Cardigan", "Chesapeake_Bay_Retriever",
"Chihuahua", "Dachshund", "Dandie_Dinmont", "Doberman", "English_Foxhound", "English_Setter",
"English_Springer", "EntleBucher", "Eskimo_Dog", "French_Bulldog", "German_Shepherd",
"German_Short-Haired_Pointer", "Gordon_Setter", "Great_Dane", "Great_Pyrenees",
"Greater_Swiss_Mountain_Dog","Havanese", "Ibizan_Hound", "Irish_Setter", "Irish_Terrier",
"Irish_Water_Spaniel", "Irish_Wolfhound", "Italian_Greyhound", "Japanese_Spaniel",
"Kerry_Blue_Terrier", "Labrador_Retriever", "Lakeland_Terrier", "Leonberg", "Lhasa",
"Maltese_Dog", "Mexican_Hairless", "Newfoundland", "Norfolk_Terrier", "Norwegian_Elkhound",
"Norwich_Terrier", "Old_English_Sheepdog", "Pekinese", "Pembroke", "Pomeranian",
"Rhodesian_Ridgeback", "Rottweiler", "Saint_Bernard", "Saluki", "Samoyed",
"Scotch_Terrier", "Scottish_Deerhound", "Sealyham_Terrier", "Shetland_Sheepdog", "Shiba_Inu",
"Shih-Tzu", "Siberian_Husky", "Staffordshire_Bullterrier", "Sussex_Spaniel",
"Tibetan_Mastiff", "Tibetan_Terrier", "Walker_Hound", "Weimaraner",
"Welsh_Springer_Spaniel", "West_Highland_White_Terrier", "Yorkshire_Terrier",
"Affenpinscher", "Basenji", "Basset", "Beagle", "Black-and-Tan_Coonhound", "Bloodhound",
"Bluetick", "Borzoi", "Boxer", "Briard", "Bull_Mastiff", "Cairn", "Chow", "Clumber",
"Cocker_Spaniel", "Collie", "Curly-Coated_Retriever", "Dhole", "Dingo",
"Flat-Coated_Retriever", "Giant_Schnauzer", "Golden_Retriever", "Groenendael", "Keeshond",
"Kelpie", "Komondor", "Kuvasz", "Malamute", "Malinois", "Miniature_Pinscher",
"Miniature_Poodle", "Miniature_Schnauzer", "Otterhound", "Papillon", "Pug", "Redbone",
"Schipperke", "Silky_Terrier", "Soft-Coated_Wheaten_Terrier", "Standard_Poodle",
"Standard_Schnauzer", "Toy_Poodle", "Toy_Terrier", "Vizsla", "Whippet",
"Wire-Haired_Fox_Terrier"]
class MultiHeadAttention(nn.Module):
def __init__(self, in_dim, num_heads=8):
super().__init__()
self.num_heads = num_heads
self.head_dim = max(1, in_dim // num_heads)
self.scaled_dim = self.head_dim * num_heads
self.fc_in = nn.Linear(in_dim, self.scaled_dim)
self.query = nn.Linear(self.scaled_dim, self.scaled_dim)
self.key = nn.Linear(self.scaled_dim, self.scaled_dim)
self.value = nn.Linear(self.scaled_dim, self.scaled_dim)
self.fc_out = nn.Linear(self.scaled_dim, in_dim)
def forward(self, x):
N = x.shape[0]
x = self.fc_in(x)
q = self.query(x).view(N, self.num_heads, self.head_dim)
k = self.key(x).view(N, self.num_heads, self.head_dim)
v = self.value(x).view(N, self.num_heads, self.head_dim)
energy = torch.einsum("nqd,nkd->nqk", [q, k])
attention = F.softmax(energy / (self.head_dim ** 0.5), dim=2)
out = torch.einsum("nqk,nvd->nqd", [attention, v])
out = out.reshape(N, self.scaled_dim)
out = self.fc_out(out)
return out
class BaseModel(nn.Module):
def __init__(self, num_classes, device='cuda' if torch.cuda.is_available() else 'cpu'):
super().__init__()
self.device = device
self.backbone = efficientnet_v2_m(weights=EfficientNet_V2_M_Weights.IMAGENET1K_V1)
self.feature_dim = self.backbone.classifier[1].in_features
self.backbone.classifier = nn.Identity()
self.num_heads = max(1, min(8, self.feature_dim // 64))
self.attention = MultiHeadAttention(self.feature_dim, num_heads=self.num_heads)
self.classifier = nn.Sequential(
nn.LayerNorm(self.feature_dim),
nn.Dropout(0.3),
nn.Linear(self.feature_dim, num_classes)
)
self.to(device)
def forward(self, x):
x = x.to(self.device)
features = self.backbone(x)
attended_features = self.attention(features)
logits = self.classifier(attended_features)
return logits, attended_features
class ModelManager:
"""
模型管理器:負責AI模型的初始化、設備管理和資源控制
使用單例模式確保整個應用程序中只有一個實例
"""
_instance = None
_initialized = False
_yolo_model = None
_breed_model = None
_device = None
def __new__(cls):
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
def __init__(self):
# 避免重複初始化
if not ModelManager._initialized:
# 初始化設備,這會在第一次創建實例時執行
self._device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
ModelManager._initialized = True
@property
def device(self):
"""
提供對設備的訪問
確保在需要時設備已經被初始化
"""
if self._device is None:
self._device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
return self._device
@property
def yolo_model(self):
"""
延遲初始化YOLO模型
只有在第一次使用時才會創建實例
"""
if self._yolo_model is None:
self._yolo_model = YOLO('yolov8x.pt')
return self._yolo_model
@property
def breed_model(self):
"""
延遲初始化品種分類模型
只有在第一次使用時才會創建實例並移動到正確的設備上
"""
if self._breed_model is None:
self._breed_model = BaseModel(
num_classes=len(dog_breeds),
device=self.device
).to(self.device)
checkpoint = torch.load(
'124_best_model_dog.pth',
map_location=self.device # 確保checkpoint加載到正確的設備
)
self._breed_model.load_state_dict(checkpoint['base_model'], strict=False)
self._breed_model.eval()
return self._breed_model
model_manager = ModelManager()
# Image preprocessing function
def preprocess_image(image):
# If the image is numpy.ndarray turn into PIL.Image
if isinstance(image, np.ndarray):
image = Image.fromarray(image)
# Use torchvision.transforms to process images
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
return transform(image).unsqueeze(0)
@spaces.GPU
def predict_single_dog(image):
"""
Predicts the dog breed using only the classifier.
Args:
image: PIL Image or numpy array
Returns:
tuple: (top1_prob, topk_breeds, relative_probs)
"""
image_tensor = preprocess_image(image).to(model_manager.device)
with torch.no_grad():
# Get model outputs (只使用logits,不需要features)
logits = model_manager.breed_model(image_tensor)[0] # 如果model仍返回tuple,取第一個元素
probs = F.softmax(logits, dim=1)
# Classifier prediction
top5_prob, top5_idx = torch.topk(probs, k=5)
breeds = [dog_breeds[idx.item()] for idx in top5_idx[0]]
probabilities = [prob.item() for prob in top5_prob[0]]
# Calculate relative probabilities
sum_probs = sum(probabilities[:3]) # 只取前三個來計算相對概率
relative_probs = [f"{(prob/sum_probs * 100):.2f}%" for prob in probabilities[:3]]
# Debug output
print("\nClassifier Predictions:")
for breed, prob in zip(breeds[:5], probabilities[:5]):
print(f"{breed}: {prob:.4f}")
return probabilities[0], breeds[:3], relative_probs
@spaces.GPU
def detect_multiple_dogs(image, conf_threshold=0.3, iou_threshold=0.55):
"""
使用YOLO模型檢測圖片中的狗。
只保留被識別為狗(class 16)的物體,並標記它們的狀態。
Args:
image: PIL Image
conf_threshold: YOLO檢測的信心度閾值
iou_threshold: 非極大值抑制的IoU閾值
Returns:
list: 包含檢測到的狗的列表,每個元素是(cropped_image, confidence, box, is_dog)的元組
"""
results = model_manager.yolo_model(image, conf=conf_threshold,
iou=iou_threshold)[0]
dogs = []
boxes = []
# 只處理被識別為狗的物體
for box in results.boxes:
class_id = box.cls.item()
if class_id == 16: # COCO dataset中狗的類別是16
xyxy = box.xyxy[0].tolist()
confidence = box.conf.item()
boxes.append((xyxy, confidence, True)) # 加入is_dog標記
if not boxes:
# 如果沒有檢測到狗,返回整張圖片並標記為非狗
return [(image, 1.0, [0, 0, image.width, image.height], False)]
nms_boxes = non_max_suppression(boxes, iou_threshold)
detected_objects = []
# 處理每個檢測到的狗
for box, confidence, is_dog in nms_boxes:
x1, y1, x2, y2 = box
w, h = x2 - x1, y2 - y1
# 擴大檢測框範圍以包含完整的狗
x1 = max(0, x1 - w * 0.05)
y1 = max(0, y1 - h * 0.05)
x2 = min(image.width, x2 + w * 0.05)
y2 = min(image.height, y2 + h * 0.05)
cropped_image = image.crop((x1, y1, x2, y2))
detected_objects.append((cropped_image, confidence, [x1, y1, x2, y2], is_dog))
return detected_objects
def non_max_suppression(boxes, iou_threshold):
keep = []
boxes = sorted(boxes, key=lambda x: x[1], reverse=True)
while boxes:
current = boxes.pop(0)
keep.append(current)
boxes = [box for box in boxes if calculate_iou(current[0], box[0]) < iou_threshold]
return keep
def calculate_iou(box1, box2):
x1 = max(box1[0], box2[0])
y1 = max(box1[1], box2[1])
x2 = min(box1[2], box2[2])
y2 = min(box1[3], box2[3])
intersection = max(0, x2 - x1) * max(0, y2 - y1)
area1 = (box1[2] - box1[0]) * (box1[3] - box1[1])
area2 = (box2[2] - box2[0]) * (box2[3] - box2[1])
iou = intersection / float(area1 + area2 - intersection)
return iou
def create_breed_comparison(breed1: str, breed2: str) -> dict:
breed1_info = get_dog_description(breed1)
breed2_info = get_dog_description(breed2)
# 標準化數值轉換
value_mapping = {
'Size': {'Small': 1, 'Medium': 2, 'Large': 3, 'Giant': 4},
'Exercise_Needs': {'Low': 1, 'Moderate': 2, 'High': 3, 'Very High': 4},
'Care_Level': {'Low': 1, 'Moderate': 2, 'High': 3},
'Grooming_Needs': {'Low': 1, 'Moderate': 2, 'High': 3}
}
comparison_data = {
breed1: {},
breed2: {}
}
for breed, info in [(breed1, breed1_info), (breed2, breed2_info)]:
comparison_data[breed] = {
'Size': value_mapping['Size'].get(info['Size'], 2), # 預設 Medium
'Exercise_Needs': value_mapping['Exercise_Needs'].get(info['Exercise Needs'], 2), # 預設 Moderate
'Care_Level': value_mapping['Care_Level'].get(info['Care Level'], 2),
'Grooming_Needs': value_mapping['Grooming_Needs'].get(info['Grooming Needs'], 2),
'Good_with_Children': info['Good with Children'] == 'Yes',
'Original_Data': info
}
return comparison_data
@spaces.GPU
def predict(image):
"""
主要的預測函數,負責處理狗的檢測和品種辨識。
它整合了YOLO的物體檢測和專門的品種分類模型。
實施雙層檢測,非狗會直接忽略.
Args:
image: PIL Image 或 numpy array
Returns:
tuple: (html_output, annotated_image, initial_state)
"""
if image is None:
return format_hint_html("Please upload an image to start."), None, None
try:
if isinstance(image, np.ndarray):
image = Image.fromarray(image)
# 檢測圖片中的狗
dogs = detect_multiple_dogs(image)
color_scheme = get_color_scheme(len(dogs) == 1)
# 準備標註
annotated_image = image.copy()
draw = ImageDraw.Draw(annotated_image)
try:
font = ImageFont.truetype("arial.ttf", 24)
except:
font = ImageFont.load_default()
dogs_info = ""
# 處理每個檢測到的物體
for i, (cropped_image, detection_confidence, box, is_dog) in enumerate(dogs):
color = color_scheme if len(dogs) == 1 else color_scheme[i % len(color_scheme)]
# 繪製框和標籤
draw.rectangle(box, outline=color, width=4)
label = f"Dog {i+1}" if is_dog else f"Object {i+1}"
label_bbox = draw.textbbox((0, 0), label, font=font)
label_width = label_bbox[2] - label_bbox[0]
label_height = label_bbox[3] - label_bbox[1]
# 繪製標籤背景和文字
label_x = box[0] + 5
label_y = box[1] + 5
draw.rectangle(
[label_x - 2, label_y - 2, label_x + label_width + 4, label_y + label_height + 4],
fill='white',
outline=color,
width=2
)
draw.text((label_x, label_y), label, fill=color, font=font)
try:
# 首先檢查是否為狗
if not is_dog:
dogs_info += format_not_dog_message(color, i+1)
continue
# 如果是狗,進行品種預測
top1_prob, topk_breeds, relative_probs = predict_single_dog(cropped_image)
combined_confidence = detection_confidence * top1_prob
# 根據信心度決定輸出格式
if combined_confidence < 0.2:
dogs_info += format_unknown_breed_message(color, i+1)
elif top1_prob >= 0.45:
breed = topk_breeds[0]
description = get_dog_description(breed)
if description is None:
description = {
"Name": breed,
"Size": "Unknown",
"Exercise Needs": "Unknown",
"Grooming Needs": "Unknown",
"Care Level": "Unknown",
"Good with Children": "Unknown",
"Description": f"Identified as {breed.replace('_', ' ')}"
}
dogs_info += format_single_dog_result(breed, description, color)
else:
dogs_info += format_multiple_breeds_result(
topk_breeds,
relative_probs,
color,
i+1,
lambda breed: get_dog_description(breed) or {
"Name": breed,
"Size": "Unknown",
"Exercise Needs": "Unknown",
"Grooming Needs": "Unknown",
"Care Level": "Unknown",
"Good with Children": "Unknown",
"Description": f"Identified as {breed.replace('_', ' ')}"
}
)
except Exception as e:
print(f"Error formatting results for dog {i+1}: {str(e)}")
dogs_info += format_unknown_breed_message(color, i+1)
# 包裝最終的HTML輸出
html_output = format_multi_dog_container(dogs_info)
# 準備初始狀態
initial_state = {
"dogs_info": dogs_info,
"image": annotated_image,
"is_multi_dog": len(dogs) > 1,
"html_output": html_output
}
return html_output, annotated_image, initial_state
except Exception as e:
error_msg = f"An error occurred: {str(e)}\n\nTraceback:\n{traceback.format_exc()}"
print(error_msg)
return format_hint_html(error_msg), None, None
def show_details_html(choice, previous_output, initial_state):
"""
Generate detailed HTML view for a selected breed.
Args:
choice: str, Selected breed option
previous_output: str, Previous HTML output
initial_state: dict, Current state information
Returns:
tuple: (html_output, gradio_update, updated_state)
"""
if not choice:
return previous_output, gr.update(visible=True), initial_state
try:
breed = choice.split("More about ")[-1]
description = get_dog_description(breed)
html_output = format_breed_details_html(description, breed)
# Update state
initial_state["current_description"] = html_output
initial_state["original_buttons"] = initial_state.get("buttons", [])
return html_output, gr.update(visible=True), initial_state
except Exception as e:
error_msg = f"An error occurred while showing details: {e}"
print(error_msg)
return format_hint_html(error_msg), gr.update(visible=True), initial_state
def main():
with gr.Blocks(css=get_css_styles()) as iface:
gr.HTML("""
Powered by AI • Breed Recognition • Smart Matching • Companion Guide
🐾 PawMatch AI
Your Smart Dog Breed Guide