DawnC commited on
Commit
c9e5868
1 Parent(s): 12f4776

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +93 -30
app.py CHANGED
@@ -9,6 +9,12 @@ from torchvision import transforms
9
  from PIL import Image
10
  from data_manager import get_dog_description
11
  from urllib.parse import quote
 
 
 
 
 
 
12
 
13
  dog_breeds = ["Afghan_Hound", "African_Hunting_Dog", "Airedale", "American_Staffordshire_Terrier",
14
  "Appenzeller", "Australian_Terrier", "Bedlington_Terrier", "Bernese_Mountain_Dog",
@@ -121,44 +127,101 @@ def preprocess_image(image):
121
  def get_akc_breeds_link():
122
  return "https://www.akc.org/dog-breeds/"
123
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
124
  def predict(image):
125
  if image is None:
126
  return "Please upload an image to get started.", gr.update(visible=False), gr.update(visible=False), gr.update(visible=False)
127
-
128
  try:
129
- image_tensor = preprocess_image(image)
130
- with torch.no_grad():
131
- output = model(image_tensor)
132
- logits = output[0] if isinstance(output, tuple) else output
133
-
134
- probabilities = F.softmax(logits, dim=1)
135
- topk_probs, topk_indices = torch.topk(probabilities, k=3)
136
-
137
- top1_prob = topk_probs[0][0].item()
138
- topk_breeds = [dog_breeds[idx.item()] for idx in topk_indices[0]]
139
- topk_probs_percent = [f"{prob.item() * 100:.2f}%" for prob in topk_probs[0]]
140
-
141
- if top1_prob >= 0.5:
142
- breed = topk_breeds[0]
143
- description = get_dog_description(breed)
144
- return format_description(description, breed), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False)
145
-
146
- elif top1_prob < 0.2:
147
- return ("The image is too unclear or the dog breed is not in the dataset. Please upload a clearer image of the dog.",
148
- gr.update(visible=False), gr.update(visible=False), gr.update(visible=False))
149
- else:
150
- explanation = (
151
- f"The model couldn't confidently identify the breed. Here are the top 3 possible breeds:\n\n"
152
- f"1. **{topk_breeds[0]}** ({topk_probs_percent[0]} confidence)\n"
153
- f"2. **{topk_breeds[1]}** ({topk_probs_percent[1]} confidence)\n"
154
- f"3. **{topk_breeds[2]}** ({topk_probs_percent[2]} confidence)\n\n"
155
- "Click on a button to view more information about the breed."
156
- )
157
- return explanation, gr.update(visible=True, value=f"More about {topk_breeds[0]}"), gr.update(visible=True, value=f"More about {topk_breeds[1]}"), gr.update(visible=True, value=f"More about {topk_breeds[2]}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
158
 
159
  except Exception as e:
160
  return f"An error occurred: {e}", gr.update(visible=False), gr.update(visible=False), gr.update(visible=False)
161
 
 
162
  def format_description(description, breed):
163
  if isinstance(description, dict):
164
  formatted_description = "\n\n".join([f"**{key}**: {value}" for key, value in description.items()])
 
9
  from PIL import Image
10
  from data_manager import get_dog_description
11
  from urllib.parse import quote
12
+ os.system('pip install ultralytics')
13
+ from ultralytics import YOLO
14
+
15
+ # 下載YOLOv5預訓練模型
16
+ model_yolo = YOLO('yolov5s.pt') # 使用 YOLOv5 預訓練模型
17
+
18
 
19
  dog_breeds = ["Afghan_Hound", "African_Hunting_Dog", "Airedale", "American_Staffordshire_Terrier",
20
  "Appenzeller", "Australian_Terrier", "Bedlington_Terrier", "Bernese_Mountain_Dog",
 
127
  def get_akc_breeds_link():
128
  return "https://www.akc.org/dog-breeds/"
129
 
130
+ # def predict(image):
131
+ # if image is None:
132
+ # return "Please upload an image to get started.", gr.update(visible=False), gr.update(visible=False), gr.update(visible=False)
133
+
134
+ # try:
135
+ # image_tensor = preprocess_image(image)
136
+ # with torch.no_grad():
137
+ # output = model(image_tensor)
138
+ # logits = output[0] if isinstance(output, tuple) else output
139
+
140
+ # probabilities = F.softmax(logits, dim=1)
141
+ # topk_probs, topk_indices = torch.topk(probabilities, k=3)
142
+
143
+ # top1_prob = topk_probs[0][0].item()
144
+ # topk_breeds = [dog_breeds[idx.item()] for idx in topk_indices[0]]
145
+ # topk_probs_percent = [f"{prob.item() * 100:.2f}%" for prob in topk_probs[0]]
146
+
147
+ # if top1_prob >= 0.5:
148
+ # breed = topk_breeds[0]
149
+ # description = get_dog_description(breed)
150
+ # return format_description(description, breed), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False)
151
+
152
+ # elif top1_prob < 0.2:
153
+ # return ("The image is too unclear or the dog breed is not in the dataset. Please upload a clearer image of the dog.",
154
+ # gr.update(visible=False), gr.update(visible=False), gr.update(visible=False))
155
+ # else:
156
+ # explanation = (
157
+ # f"The model couldn't confidently identify the breed. Here are the top 3 possible breeds:\n\n"
158
+ # f"1. **{topk_breeds[0]}** ({topk_probs_percent[0]} confidence)\n"
159
+ # f"2. **{topk_breeds[1]}** ({topk_probs_percent[1]} confidence)\n"
160
+ # f"3. **{topk_breeds[2]}** ({topk_probs_percent[2]} confidence)\n\n"
161
+ # "Click on a button to view more information about the breed."
162
+ # )
163
+ # return explanation, gr.update(visible=True, value=f"More about {topk_breeds[0]}"), gr.update(visible=True, value=f"More about {topk_breeds[1]}"), gr.update(visible=True, value=f"More about {topk_breeds[2]}")
164
+
165
+ # except Exception as e:
166
+ # return f"An error occurred: {e}", gr.update(visible=False), gr.update(visible=False), gr.update(visible=False)
167
+
168
  def predict(image):
169
  if image is None:
170
  return "Please upload an image to get started.", gr.update(visible=False), gr.update(visible=False), gr.update(visible=False)
171
+
172
  try:
173
+ # 使用 YOLO 偵測狗
174
+ results = model_yolo(image)
175
+ dogs = results.xyxy[0] # 提取偵測到的狗的邊界框
176
+
177
+ if len(dogs) == 0:
178
+ return "No dog detected in the image.", gr.update(visible=False), gr.update(visible=False), gr.update(visible=False)
179
+
180
+ explanations = []
181
+ visible_buttons = []
182
+
183
+ for i, box in enumerate(dogs):
184
+ x1, y1, x2, y2 = map(int, box[:4])
185
+ cropped_image = image.crop((x1, y1, x2, y2)) # 裁剪狗區域
186
+ image_tensor = preprocess_image(cropped_image)
187
+
188
+ with torch.no_grad():
189
+ output = model(image_tensor)
190
+ logits = output[0] if isinstance(output, tuple) else output
191
+
192
+ probabilities = F.softmax(logits, dim=1)
193
+ topk_probs, topk_indices = torch.topk(probabilities, k=3)
194
+
195
+ top1_prob = topk_probs[0][0].item()
196
+ topk_breeds = [dog_breeds[idx.item()] for idx in topk_indices[0]]
197
+ topk_probs_percent = [f"{prob.item() * 100:.2f}%" for prob in topk_probs[0]]
198
+
199
+ # 根據信心分數進行判斷
200
+ if top1_prob >= 0.5:
201
+ breed = topk_breeds[0]
202
+ description = get_dog_description(breed)
203
+ explanations.append(f"Detected a dog: **{breed}** with {topk_probs_percent[0]} confidence.")
204
+ elif 0.2 <= top1_prob < 0.5:
205
+ explanation = (
206
+ f"Detected a dog with moderate confidence. Here are the top 3 possible breeds:\n"
207
+ f"1. **{topk_breeds[0]}** ({topk_probs_percent[0]} confidence)\n"
208
+ f"2. **{topk_breeds[1]}** ({topk_probs_percent[1]} confidence)\n"
209
+ f"3. **{topk_breeds[2]}** ({topk_probs_percent[2]} confidence)\n"
210
+ )
211
+ explanations.append(explanation)
212
+ visible_buttons.extend([i+1 for _ in range(3)])
213
+ else:
214
+ explanations.append("The image is too unclear or the breed is not in the dataset. Please upload a clearer image.")
215
+
216
+ # 處理不同情境的結果
217
+ if len(explanations) > 0:
218
+ final_explanation = "\n\n".join(explanations)
219
+ return final_explanation, gr.update(visible=len(visible_buttons) >= 1), gr.update(visible=len(visible_buttons) >= 2), gr.update(visible=len(visible_buttons) >= 3)
220
 
221
  except Exception as e:
222
  return f"An error occurred: {e}", gr.update(visible=False), gr.update(visible=False), gr.update(visible=False)
223
 
224
+
225
  def format_description(description, breed):
226
  if isinstance(description, dict):
227
  formatted_description = "\n\n".join([f"**{key}**: {value}" for key, value in description.items()])