DawnC commited on
Commit
57c59a8
1 Parent(s): f5a076e

Update smart_breed_matcher.py

Browse files
Files changed (1) hide show
  1. smart_breed_matcher.py +14 -12
smart_breed_matcher.py CHANGED
@@ -18,18 +18,20 @@ def gpu_init_wrapper(func):
18
  return func(*args, **kwargs)
19
  return wrapper
20
 
21
- class SmartBreedMatcher:
22
- def _safe_prediction(self, func):
23
- @wraps(func)
24
- def wrapper(*args, **kwargs):
25
- try:
 
 
 
 
26
  return func(*args, **kwargs)
27
- except RuntimeError as e:
28
- if "CUDA" in str(e):
29
- print("GPU 操作失敗,嘗試使用 CPU")
30
- return func(*args, **kwargs)
31
- raise
32
- return wrapper
33
 
34
  def __init__(self, dog_data: List[Tuple]):
35
  self.dog_data = dog_data
@@ -891,7 +893,7 @@ class SmartBreedMatcher:
891
  }
892
 
893
  @gpu_init_wrapper
894
- @_safe_prediction
895
  def match_user_preference(self, description: str, top_n: int = 10) -> List[Dict]:
896
  try:
897
  if self.model is None:
 
18
  return func(*args, **kwargs)
19
  return wrapper
20
 
21
+ def safe_prediction(func):
22
+ """錯誤處理裝飾器,提供 GPU 到 CPU 的降級機制"""
23
+ @wraps(func)
24
+ def wrapper(*args, **kwargs):
25
+ try:
26
+ return func(*args, **kwargs)
27
+ except RuntimeError as e:
28
+ if "CUDA" in str(e):
29
+ print("GPU 操作失敗,嘗試使用 CPU")
30
  return func(*args, **kwargs)
31
+ raise
32
+ return wrapper
33
+
34
+ class SmartBreedMatcher:
 
 
35
 
36
  def __init__(self, dog_data: List[Tuple]):
37
  self.dog_data = dog_data
 
893
  }
894
 
895
  @gpu_init_wrapper
896
+ @safe_prediction
897
  def match_user_preference(self, description: str, top_n: int = 10) -> List[Dict]:
898
  try:
899
  if self.model is None: