DawnC commited on
Commit
5045da7
1 Parent(s): 9f7a41d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +69 -48
app.py CHANGED
@@ -6,7 +6,7 @@ import gradio as gr
6
  import time
7
  import traceback
8
  import spaces
9
- from torchvision.models import convnext_base, ConvNeXt_Base_Weights
10
  from torchvision.ops import nms, box_iou
11
  import torch.nn.functional as F
12
  from torchvision import transforms
@@ -72,56 +72,77 @@ dog_breeds = ["Afghan_Hound", "African_Hunting_Dog", "Airedale", "American_Staff
72
  class MultiHeadAttention(nn.Module):
73
 
74
  def __init__(self, in_dim, num_heads=8):
 
 
 
 
 
 
75
  super().__init__()
76
  self.num_heads = num_heads
77
- self.head_dim = max(1, in_dim // num_heads)
78
- self.scaled_dim = self.head_dim * num_heads
79
- self.fc_in = nn.Linear(in_dim, self.scaled_dim)
80
- self.query = nn.Linear(self.scaled_dim, self.scaled_dim)
81
- self.key = nn.Linear(self.scaled_dim, self.scaled_dim)
82
- self.value = nn.Linear(self.scaled_dim, self.scaled_dim)
83
- self.fc_out = nn.Linear(self.scaled_dim, in_dim)
84
 
85
  def forward(self, x):
86
- N = x.shape[0]
87
- x = self.fc_in(x)
88
- q = self.query(x).view(N, self.num_heads, self.head_dim)
89
- k = self.key(x).view(N, self.num_heads, self.head_dim)
90
- v = self.value(x).view(N, self.num_heads, self.head_dim)
 
 
 
 
 
 
 
91
 
92
- energy = torch.einsum("nqd,nkd->nqk", [q, k])
93
- attention = F.softmax(energy / (self.head_dim ** 0.5), dim=2)
 
94
 
 
95
  out = torch.einsum("nqk,nvd->nqd", [attention, v])
96
- out = out.reshape(N, self.scaled_dim)
97
- out = self.fc_out(out)
98
  return out
99
 
 
100
  class BaseModel(nn.Module):
101
 
102
  def __init__(self, num_classes, device='cuda' if torch.cuda.is_available() else 'cpu'):
103
  super().__init__()
104
  self.device = device
105
 
106
- # 1. 初始化 backbone
107
- self.backbone = convnext_base(weights=ConvNeXt_Base_Weights.IMAGENET1K_V1)
108
- self.backbone.classifier = nn.Identity() # 移除原始分類器
 
 
 
109
 
110
- # 2. 使用測試數據確定實際的特徵維度
111
- with torch.no_grad(): # 不需要計算梯度
112
- dummy_input = torch.randn(1, 3, 224, 224) # 創建示例輸入
113
  features = self.backbone(dummy_input)
114
- if len(features.shape) > 2: # 如果特徵是多維的
115
- features = features.mean([-2, -1]) # 進行全局平均池化
116
- self.feature_dim = features.shape[1] # 獲取正確的特徵維度
117
 
118
- print(f"Feature Dim: {self.feature_dim}") # 幫助調試
 
 
 
 
 
119
 
120
- # 3. 設置多頭注意力層
121
  self.num_heads = max(1, min(8, self.feature_dim // 64))
122
  self.attention = MultiHeadAttention(self.feature_dim, num_heads=self.num_heads)
123
 
124
- # 4. 設置分類器
125
  self.classifier = nn.Sequential(
126
  nn.LayerNorm(self.feature_dim),
127
  nn.Dropout(0.3),
@@ -130,27 +151,27 @@ class BaseModel(nn.Module):
130
 
131
  def forward(self, x):
132
  """
133
- 模型的前向傳播過程
134
  Args:
135
- x (Tensor): 輸入圖像張量,形狀為 [batch_size, channels, height, width]
136
  Returns:
137
- Tuple[Tensor, Tensor]: 分類邏輯值和注意力特徵
138
  """
139
  x = x.to(self.device)
140
 
141
- # 1. 提取基礎特徵
142
  features = self.backbone(x)
143
 
144
- # 2. 處理特徵維度
145
  if len(features.shape) > 2:
146
- # 如果特徵維度是 [batch_size, channels, height, width]
147
- # 轉換為 [batch_size, channels]
148
- features = features.mean([-2, -1]) # 使用全局平均池化
149
 
150
- # 3. 應用注意力機制
151
  attended_features = self.attention(features)
152
 
153
- # 4. 最終分類
154
  logits = self.classifier(attended_features)
155
 
156
  return logits, attended_features
@@ -211,7 +232,7 @@ class ModelManager:
211
  ).to(self.device)
212
 
213
  checkpoint = torch.load(
214
- 'ConvNextBase_best_model_dog.pth',
215
  map_location=self.device # 確保checkpoint加載到正確的設備
216
  )
217
  self._breed_model.load_state_dict(checkpoint['base_model'], strict=False)
@@ -271,7 +292,7 @@ def predict_single_dog(image):
271
  return probabilities[0], breeds[:3], relative_probs
272
 
273
  @spaces.GPU
274
- def detect_multiple_dogs(image, conf_threshold=0.3, iou_threshold=0.55):
275
  """
276
  使用YOLO模型檢測圖片中的狗。
277
  只保留被識別為狗(class 16)的物體,並標記它們的狀態。
@@ -310,10 +331,10 @@ def detect_multiple_dogs(image, conf_threshold=0.3, iou_threshold=0.55):
310
  x1, y1, x2, y2 = box
311
  w, h = x2 - x1, y2 - y1
312
  # 擴大檢測框範圍以包含完整的狗
313
- x1 = max(0, x1 - w * 0.05)
314
- y1 = max(0, y1 - h * 0.05)
315
- x2 = min(image.width, x2 + w * 0.05)
316
- y2 = min(image.height, y2 + h * 0.05)
317
  cropped_image = image.crop((x1, y1, x2, y2))
318
  detected_objects.append((cropped_image, confidence, [x1, y1, x2, y2], is_dog))
319
 
@@ -442,9 +463,9 @@ def predict(image):
442
  combined_confidence = detection_confidence * top1_prob
443
 
444
  # 根據信心度決定輸出格式
445
- if combined_confidence < 0.2:
446
  dogs_info += format_unknown_breed_message(color, i+1)
447
- elif top1_prob >= 0.45:
448
  breed = topk_breeds[0]
449
  description = get_dog_description(breed)
450
  if description is None:
@@ -555,7 +576,7 @@ def main():
555
  'Border_Collie.jpg',
556
  'Golden_Retriever.jpeg',
557
  'Saint_Bernard.jpeg',
558
- 'Samoyed.jpg',
559
  'French_Bulldog.jpeg'
560
  ]
561
  detection_components = create_detection_tab(predict, example_images)
 
6
  import time
7
  import traceback
8
  import spaces
9
+ import timm
10
  from torchvision.ops import nms, box_iou
11
  import torch.nn.functional as F
12
  from torchvision import transforms
 
72
  class MultiHeadAttention(nn.Module):
73
 
74
  def __init__(self, in_dim, num_heads=8):
75
+ """
76
+ Initializes the MultiHeadAttention module.
77
+ Args:
78
+ in_dim (int): Dimension of the input features.
79
+ num_heads (int): Number of attention heads. Defaults to 8.
80
+ """
81
  super().__init__()
82
  self.num_heads = num_heads
83
+ self.head_dim = max(1, in_dim // num_heads) # Compute dimension per head
84
+ self.scaled_dim = self.head_dim * num_heads # Scaled dimension after splitting into heads
85
+ self.fc_in = nn.Linear(in_dim, self.scaled_dim) # Linear layer to project input to scaled_dim
86
+ self.query = nn.Linear(self.scaled_dim, self.scaled_dim) # Query projection
87
+ self.key = nn.Linear(self.scaled_dim, self.scaled_dim) # Key projection
88
+ self.value = nn.Linear(self.scaled_dim, self.scaled_dim) # Value projection
89
+ self.fc_out = nn.Linear(self.scaled_dim, in_dim) # Linear layer to project output back to in_dim
90
 
91
  def forward(self, x):
92
+ """
93
+ Forward pass for multi-head attention mechanism.
94
+ Args:
95
+ x (Tensor): Input tensor of shape (batch_size, input_dim).
96
+ Returns:
97
+ Tensor: Output tensor after applying attention mechanism.
98
+ """
99
+ N = x.shape[0] # Batch size
100
+ x = self.fc_in(x) # Project input to scaled_dim
101
+ q = self.query(x).view(N, self.num_heads, self.head_dim) # Compute queries
102
+ k = self.key(x).view(N, self.num_heads, self.head_dim) # Compute keys
103
+ v = self.value(x).view(N, self.num_heads, self.head_dim) # Compute values
104
 
105
+ # Calculate attention scores
106
+ energy = torch.einsum("nqd,nkd->nqk", [q, k]) # Dot product between queries and keys
107
+ attention = F.softmax(energy / (self.head_dim ** 0.5), dim=2) # Apply softmax with scaling
108
 
109
+ # Compute weighted sum of values based on attention scores
110
  out = torch.einsum("nqk,nvd->nqd", [attention, v])
111
+ out = out.reshape(N, self.scaled_dim) # Concatenate all heads
112
+ out = self.fc_out(out) # Project back to original input dimension
113
  return out
114
 
115
+
116
  class BaseModel(nn.Module):
117
 
118
  def __init__(self, num_classes, device='cuda' if torch.cuda.is_available() else 'cpu'):
119
  super().__init__()
120
  self.device = device
121
 
122
+ # 1. Initialize backbone, num_classes=0 to remove classifier layer
123
+ self.backbone = timm.create_model(
124
+ 'convnextv2_base',
125
+ pretrained=True,
126
+ num_classes=0
127
+ )
128
 
129
+ # 2. Use test data to determine actual feature dimensions
130
+ with torch.no_grad(): # No need to compute gradients
131
+ dummy_input = torch.randn(1, 3, 224, 224) # Create example input
132
  features = self.backbone(dummy_input)
 
 
 
133
 
134
+ if len(features.shape) > 2: # If features are multi-dimensional
135
+ features = features.mean([-2, -1]) # Apply global average pooling
136
+
137
+ self.feature_dim = features.shape[1] # Get correct feature dimension
138
+
139
+ print(f"Feature Dimension from V2 backbone: {self.feature_dim}")
140
 
141
+ # 3. Setup multi-head attention layer
142
  self.num_heads = max(1, min(8, self.feature_dim // 64))
143
  self.attention = MultiHeadAttention(self.feature_dim, num_heads=self.num_heads)
144
 
145
+ # 4. Setup classifier
146
  self.classifier = nn.Sequential(
147
  nn.LayerNorm(self.feature_dim),
148
  nn.Dropout(0.3),
 
151
 
152
  def forward(self, x):
153
  """
154
+ The forward propagation process combines V2's FCCA and the multi-head attention mechanism.
155
  Args:
156
+ x (Tensor): Input image tensor with shape [batch_size, channels, height, width]
157
  Returns:
158
+ Tuple[Tensor, Tensor]: Classification logits and attention features.
159
  """
160
  x = x.to(self.device)
161
 
162
+ # 1. Extract base features
163
  features = self.backbone(x)
164
 
165
+ # 2. Process feature dimensions
166
  if len(features.shape) > 2:
167
+ # If feature dimensions are [batch_size, channels, height, width]
168
+ # Convert to [batch_size, channels]
169
+ features = features.mean([-2, -1]) # Use global average pooling
170
 
171
+ # 3. Apply attention mechanism
172
  attended_features = self.attention(features)
173
 
174
+ # 4. Final classification
175
  logits = self.classifier(attended_features)
176
 
177
  return logits, attended_features
 
232
  ).to(self.device)
233
 
234
  checkpoint = torch.load(
235
+ 'ConvNextV2Base_best_model_dog.pth',
236
  map_location=self.device # 確保checkpoint加載到正確的設備
237
  )
238
  self._breed_model.load_state_dict(checkpoint['base_model'], strict=False)
 
292
  return probabilities[0], breeds[:3], relative_probs
293
 
294
  @spaces.GPU
295
+ def detect_multiple_dogs(image, conf_threshold=0.3, iou_threshold=0.3):
296
  """
297
  使用YOLO模型檢測圖片中的狗。
298
  只保留被識別為狗(class 16)的物體,並標記它們的狀態。
 
331
  x1, y1, x2, y2 = box
332
  w, h = x2 - x1, y2 - y1
333
  # 擴大檢測框範圍以包含完整的狗
334
+ x1 = max(0, x1 - w * 0.02)
335
+ y1 = max(0, y1 - h * 0.02)
336
+ x2 = min(image.width, x2 + w * 0.02)
337
+ y2 = min(image.height, y2 + h * 0.02)
338
  cropped_image = image.crop((x1, y1, x2, y2))
339
  detected_objects.append((cropped_image, confidence, [x1, y1, x2, y2], is_dog))
340
 
 
463
  combined_confidence = detection_confidence * top1_prob
464
 
465
  # 根據信心度決定輸出格式
466
+ if combined_confidence < 0.15:
467
  dogs_info += format_unknown_breed_message(color, i+1)
468
+ elif top1_prob >= 0.4:
469
  breed = topk_breeds[0]
470
  description = get_dog_description(breed)
471
  if description is None:
 
576
  'Border_Collie.jpg',
577
  'Golden_Retriever.jpeg',
578
  'Saint_Bernard.jpeg',
579
+ 'Samoyed.jpeg',
580
  'French_Bulldog.jpeg'
581
  ]
582
  detection_components = create_detection_tab(predict, example_images)