DawnC commited on
Commit
be7eec2
1 Parent(s): c601c51

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +28 -10
app.py CHANGED
@@ -591,13 +591,14 @@ class BaseModel(nn.Module):
591
 
592
  class ModelManager:
593
  """
594
- 模型管理器:負責AI模型的初始化和管理
595
- 使用單例模式確保只有一個實例在管理所有模型
596
  """
597
  _instance = None
598
  _initialized = False
599
  _yolo_model = None
600
  _breed_model = None
 
601
 
602
  def __new__(cls):
603
  if cls._instance is None:
@@ -607,8 +608,20 @@ class ModelManager:
607
  def __init__(self):
608
  # 避免重複初始化
609
  if not ModelManager._initialized:
 
 
610
  ModelManager._initialized = True
611
 
 
 
 
 
 
 
 
 
 
 
612
  @property
613
  def yolo_model(self):
614
  """
@@ -623,18 +636,23 @@ class ModelManager:
623
  def breed_model(self):
624
  """
625
  延遲初始化品種分類模型
626
- 只有在第一次使用時才會創建實例
627
  """
628
  if self._breed_model is None:
629
- self._breed_model = BaseModel(num_classes=len(dog_breeds),
630
- device=device).to(device)
631
- checkpoint = torch.load('124_best_model_dog.pth',
632
- map_location=device)
633
- self._breed_model.load_state_dict(checkpoint['base_model'],
634
- strict=False)
 
 
 
 
635
  self._breed_model.eval()
636
  return self._breed_model
637
 
 
638
  model_manager = ModelManager()
639
 
640
 
@@ -663,7 +681,7 @@ def predict_single_dog(image):
663
  tuple: (top1_prob, topk_breeds, relative_probs)
664
  """
665
 
666
- image_tensor = preprocess_image(image).to(device)
667
 
668
  with torch.no_grad():
669
  # Get model outputs (只使用logits,不需要features)
 
591
 
592
  class ModelManager:
593
  """
594
+ 模型管理器:負責AI模型的初始化、設備管理和資源控制
595
+ 使用單例模式確保整個應用程序中只有一個實例
596
  """
597
  _instance = None
598
  _initialized = False
599
  _yolo_model = None
600
  _breed_model = None
601
+ _device = None
602
 
603
  def __new__(cls):
604
  if cls._instance is None:
 
608
  def __init__(self):
609
  # 避免重複初始化
610
  if not ModelManager._initialized:
611
+ # 初始化設備,這會在第一次創建實例時執行
612
+ self._device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
613
  ModelManager._initialized = True
614
 
615
+ @property
616
+ def device(self):
617
+ """
618
+ 提供對設備的訪問
619
+ 確保在需要時設備已經被初始化
620
+ """
621
+ if self._device is None:
622
+ self._device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
623
+ return self._device
624
+
625
  @property
626
  def yolo_model(self):
627
  """
 
636
  def breed_model(self):
637
  """
638
  延遲初始化品種分類模型
639
+ 只有在第一次使用時才會創建實例並移動到正確的設備上
640
  """
641
  if self._breed_model is None:
642
+ self._breed_model = BaseModel(
643
+ num_classes=len(dog_breeds),
644
+ device=self.device # 使用我們的device屬性
645
+ ).to(self.device)
646
+
647
+ checkpoint = torch.load(
648
+ '124_best_model_dog.pth',
649
+ map_location=self.device # 確保checkpoint加載到正確的設備
650
+ )
651
+ self._breed_model.load_state_dict(checkpoint['base_model'], strict=False)
652
  self._breed_model.eval()
653
  return self._breed_model
654
 
655
+
656
  model_manager = ModelManager()
657
 
658
 
 
681
  tuple: (top1_prob, topk_breeds, relative_probs)
682
  """
683
 
684
+ image_tensor = preprocess_image(image).to(model_manager.device)
685
 
686
  with torch.no_grad():
687
  # Get model outputs (只使用logits,不需要features)