DawnC commited on
Commit
1eb7f90
1 Parent(s): 4de47b1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +135 -10
app.py CHANGED
@@ -4,6 +4,7 @@ import torch
4
  import torch.nn as nn
5
  import gradio as gr
6
  from torchvision.models import efficientnet_v2_m, EfficientNet_V2_M_Weights
 
7
  import torch.nn.functional as F
8
  from torchvision import transforms
9
  from PIL import Image, ImageDraw, ImageFont, ImageFilter
@@ -163,17 +164,138 @@ def _predict_single_dog(image):
163
  return top1_prob, topk_breeds, topk_probs_percent
164
 
165
 
166
- async def detect_multiple_dogs(image, conf_threshold=0.25, iou_threshold=0.3):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
167
  results = model_yolo(image, conf=conf_threshold, iou=iou_threshold)[0]
168
  dogs = []
 
 
 
169
  for box in results.boxes:
170
- if box.cls == 16: # COCO 資料集中狗的類別是 16
171
  xyxy = box.xyxy[0].tolist()
172
  confidence = box.conf.item()
 
 
 
 
 
 
 
 
 
 
 
 
 
173
  cropped_image = image.crop((xyxy[0], xyxy[1], xyxy[2], xyxy[3]))
174
  dogs.append((cropped_image, confidence, xyxy))
175
- return dogs
176
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
177
 
178
  async def predict(image):
179
  if image is None:
@@ -183,12 +305,15 @@ async def predict(image):
183
  if isinstance(image, np.ndarray):
184
  image = Image.fromarray(image)
185
 
186
- dogs = await detect_multiple_dogs(image, conf_threshold=0.15, iou_threshold=0.3)
187
 
188
- if len(dogs) <= 1:
189
- return await process_single_dog(image)
 
 
 
190
 
191
- # 多狗情境
192
  color_list = ['#FF0000', '#00FF00', '#0000FF', '#FFFF00', '#00FFFF', '#FF00FF', '#800080', '#FFA500']
193
  explanations = []
194
  buttons = []
@@ -196,7 +321,7 @@ async def predict(image):
196
  draw = ImageDraw.Draw(annotated_image)
197
  font = ImageFont.load_default()
198
 
199
- for i, (cropped_image, _, box) in enumerate(dogs):
200
  top1_prob, topk_breeds, topk_probs_percent = await predict_single_dog(cropped_image)
201
  color = color_list[i % len(color_list)]
202
  draw.rectangle(box, outline=color, width=3)
@@ -239,7 +364,7 @@ async def predict(image):
239
 
240
  except Exception as e:
241
  error_msg = f"An error occurred: {str(e)}"
242
- print(error_msg) # 添加日誌輸出
243
  return error_msg, None, gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), None
244
 
245
 
 
4
  import torch.nn as nn
5
  import gradio as gr
6
  from torchvision.models import efficientnet_v2_m, EfficientNet_V2_M_Weights
7
+ from torchvision.ops import nms
8
  import torch.nn.functional as F
9
  from torchvision import transforms
10
  from PIL import Image, ImageDraw, ImageFont, ImageFilter
 
164
  return top1_prob, topk_breeds, topk_probs_percent
165
 
166
 
167
+ # async def detect_multiple_dogs(image, conf_threshold=0.25, iou_threshold=0.4):
168
+ # results = model_yolo(image, conf=conf_threshold, iou=iou_threshold)[0]
169
+ # dogs = []
170
+ # for box in results.boxes:
171
+ # if box.cls == 16: # COCO 資料集中狗的類別是 16
172
+ # xyxy = box.xyxy[0].tolist()
173
+ # confidence = box.conf.item()
174
+ # cropped_image = image.crop((xyxy[0], xyxy[1], xyxy[2], xyxy[3]))
175
+ # dogs.append((cropped_image, confidence, xyxy))
176
+ # return dogs
177
+
178
+
179
+ # async def predict(image):
180
+ # if image is None:
181
+ # return "Please upload an image to start.", None, gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), None
182
+
183
+ # try:
184
+ # if isinstance(image, np.ndarray):
185
+ # image = Image.fromarray(image)
186
+
187
+ # dogs = await detect_multiple_dogs(image, conf_threshold=0.25, iou_threshold=0.4)
188
+
189
+ # if len(dogs) <= 1:
190
+ # return await process_single_dog(image)
191
+
192
+ # # 多狗情境
193
+ # color_list = ['#FF0000', '#00FF00', '#0000FF', '#FFFF00', '#00FFFF', '#FF00FF', '#800080', '#FFA500']
194
+ # explanations = []
195
+ # buttons = []
196
+ # annotated_image = image.copy()
197
+ # draw = ImageDraw.Draw(annotated_image)
198
+ # font = ImageFont.load_default()
199
+
200
+ # for i, (cropped_image, _, box) in enumerate(dogs):
201
+ # top1_prob, topk_breeds, topk_probs_percent = await predict_single_dog(cropped_image)
202
+ # color = color_list[i % len(color_list)]
203
+ # draw.rectangle(box, outline=color, width=3)
204
+ # draw.text((box[0], box[1]), f"Dog {i+1}", fill=color, font=font)
205
+
206
+ # breed = topk_breeds[0]
207
+ # if top1_prob >= 0.5:
208
+ # description = get_dog_description(breed)
209
+ # formatted_description = format_description(description, breed)
210
+ # explanations.append(f"Dog {i+1}: {formatted_description}")
211
+ # elif top1_prob >= 0.2:
212
+ # dog_explanation = f"Dog {i+1}: Top 3 possible breeds:\n"
213
+ # dog_explanation += "\n".join([f"{j+1}. **{breed}** ({prob} confidence)" for j, (breed, prob) in enumerate(zip(topk_breeds[:3], topk_probs_percent[:3]))])
214
+ # explanations.append(dog_explanation)
215
+ # buttons.extend([gr.update(visible=True, value=f"Dog {i+1}: More about {breed}") for breed in topk_breeds[:3]])
216
+ # else:
217
+ # explanations.append(f"Dog {i+1}: The image is unclear or the breed is not in the dataset.")
218
+
219
+ # final_explanation = "\n\n".join(explanations)
220
+ # if buttons:
221
+ # final_explanation += "\n\nClick on a button to view more information about the breed."
222
+ # initial_state = {
223
+ # "explanation": final_explanation,
224
+ # "buttons": buttons,
225
+ # "show_back": True
226
+ # }
227
+ # return (final_explanation, annotated_image,
228
+ # buttons[0] if len(buttons) > 0 else gr.update(visible=False),
229
+ # buttons[1] if len(buttons) > 1 else gr.update(visible=False),
230
+ # buttons[2] if len(buttons) > 2 else gr.update(visible=False),
231
+ # gr.update(visible=True),
232
+ # initial_state)
233
+ # else:
234
+ # initial_state = {
235
+ # "explanation": final_explanation,
236
+ # "buttons": [],
237
+ # "show_back": False
238
+ # }
239
+ # return final_explanation, annotated_image, gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), initial_state
240
+
241
+ # except Exception as e:
242
+ # error_msg = f"An error occurred: {str(e)}"
243
+ # print(error_msg) # 添加日誌輸出
244
+ # return error_msg, None, gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), None
245
+
246
+ async def detect_multiple_dogs(image, conf_threshold=0.25, iou_threshold=0.4, merge_threshold=0.3):
247
  results = model_yolo(image, conf=conf_threshold, iou=iou_threshold)[0]
248
  dogs = []
249
+ boxes = []
250
+ confidences = []
251
+
252
  for box in results.boxes:
253
+ if box.cls == 16: # COCO dataset class for dog is 16
254
  xyxy = box.xyxy[0].tolist()
255
  confidence = box.conf.item()
256
+ boxes.append(torch.tensor(xyxy))
257
+ confidences.append(confidence)
258
+
259
+ if boxes:
260
+ boxes = torch.stack(boxes)
261
+ confidences = torch.tensor(confidences)
262
+
263
+ # Apply NMS
264
+ keep = nms(boxes, confidences, iou_threshold)
265
+
266
+ for i in keep:
267
+ xyxy = boxes[i].tolist()
268
+ confidence = confidences[i].item()
269
  cropped_image = image.crop((xyxy[0], xyxy[1], xyxy[2], xyxy[3]))
270
  dogs.append((cropped_image, confidence, xyxy))
271
+
272
+ # Merge nearby boxes
273
+ merged_dogs = []
274
+ while dogs:
275
+ base_dog = dogs.pop(0)
276
+ base_box = torch.tensor(base_dog[2])
277
+ to_merge = [base_dog]
278
+
279
+ i = 0
280
+ while i < len(dogs):
281
+ compare_box = torch.tensor(dogs[i][2])
282
+ iou = box_iou(base_box.unsqueeze(0), compare_box.unsqueeze(0)).item()
283
+ if iou > merge_threshold:
284
+ to_merge.append(dogs.pop(i))
285
+ else:
286
+ i += 1
287
+
288
+ if len(to_merge) == 1:
289
+ merged_dogs.append(base_dog)
290
+ else:
291
+ merged_box = torch.cat([torch.tensor(dog[2]).unsqueeze(0) for dog in to_merge]).mean(0)
292
+ merged_confidence = max(dog[1] for dog in to_merge)
293
+ merged_image = image.crop(merged_box.tolist())
294
+ merged_dogs.append((merged_image, merged_confidence, merged_box.tolist()))
295
+
296
+ return merged_dogs
297
+
298
+ return []
299
 
300
  async def predict(image):
301
  if image is None:
 
305
  if isinstance(image, np.ndarray):
306
  image = Image.fromarray(image)
307
 
308
+ dogs = await detect_multiple_dogs(image, conf_threshold=0.25, iou_threshold=0.4, merge_threshold=0.3)
309
 
310
+ if len(dogs) == 0:
311
+ return "No dogs detected in the image.", image, gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), None
312
+
313
+ if len(dogs) == 1:
314
+ return await process_single_dog(dogs[0][0]) # Pass the cropped image of the single detected dog
315
 
316
+ # Multi-dog scenario
317
  color_list = ['#FF0000', '#00FF00', '#0000FF', '#FFFF00', '#00FFFF', '#FF00FF', '#800080', '#FFA500']
318
  explanations = []
319
  buttons = []
 
321
  draw = ImageDraw.Draw(annotated_image)
322
  font = ImageFont.load_default()
323
 
324
+ for i, (cropped_image, confidence, box) in enumerate(dogs):
325
  top1_prob, topk_breeds, topk_probs_percent = await predict_single_dog(cropped_image)
326
  color = color_list[i % len(color_list)]
327
  draw.rectangle(box, outline=color, width=3)
 
364
 
365
  except Exception as e:
366
  error_msg = f"An error occurred: {str(e)}"
367
+ print(error_msg) # Add log output
368
  return error_msg, None, gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), None
369
 
370