Spaces:
Sleeping
Sleeping
Update smart_breed_matcher.py
Browse files- smart_breed_matcher.py +28 -2
smart_breed_matcher.py
CHANGED
@@ -18,19 +18,40 @@ def gpu_init_wrapper(func):
|
|
18 |
return func(*args, **kwargs)
|
19 |
return wrapper
|
20 |
|
21 |
-
|
22 |
class SmartBreedMatcher:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
23 |
def __init__(self, dog_data: List[Tuple]):
|
24 |
self.dog_data = dog_data
|
25 |
-
self.model =
|
26 |
self._embedding_cache = {}
|
27 |
self._clear_cache()
|
28 |
|
|
|
|
|
|
|
|
|
|
|
29 |
def _clear_cache(self):
|
30 |
self._embedding_cache = {}
|
31 |
|
32 |
|
|
|
33 |
def _get_cached_embedding(self, text: str) -> torch.Tensor:
|
|
|
|
|
|
|
|
|
34 |
if text not in self._embedding_cache:
|
35 |
self._embedding_cache[text] = self.model.encode(text)
|
36 |
return self._embedding_cache[text]
|
@@ -75,6 +96,8 @@ class SmartBreedMatcher:
|
|
75 |
List[Tuple[str, float]]: 相似品種列表,包含品種名稱和相似度分數
|
76 |
"""
|
77 |
try:
|
|
|
|
|
78 |
target_breed = next((breed for breed in self.dog_data if breed[1] == breed_name), None)
|
79 |
if not target_breed:
|
80 |
return []
|
@@ -868,8 +891,11 @@ class SmartBreedMatcher:
|
|
868 |
}
|
869 |
|
870 |
@gpu_init_wrapper
|
|
|
871 |
def match_user_preference(self, description: str, top_n: int = 10) -> List[Dict]:
|
872 |
try:
|
|
|
|
|
873 |
# 獲取場景權重
|
874 |
weights = self._detect_scenario(description)
|
875 |
matches = []
|
|
|
18 |
return func(*args, **kwargs)
|
19 |
return wrapper
|
20 |
|
|
|
21 |
class SmartBreedMatcher:
|
22 |
+
def _safe_prediction(self, func):
|
23 |
+
@wraps(func)
|
24 |
+
def wrapper(*args, **kwargs):
|
25 |
+
try:
|
26 |
+
return func(*args, **kwargs)
|
27 |
+
except RuntimeError as e:
|
28 |
+
if "CUDA" in str(e):
|
29 |
+
print("GPU 操作失敗,嘗試使用 CPU")
|
30 |
+
return func(*args, **kwargs)
|
31 |
+
raise
|
32 |
+
return wrapper
|
33 |
+
|
34 |
def __init__(self, dog_data: List[Tuple]):
|
35 |
self.dog_data = dog_data
|
36 |
+
self.model = None
|
37 |
self._embedding_cache = {}
|
38 |
self._clear_cache()
|
39 |
|
40 |
+
def _initialize_model(self):
|
41 |
+
"""延遲初始化模型,只在需要時才創建"""
|
42 |
+
if self.model is None:
|
43 |
+
self.model = SentenceTransformer('all-mpnet-base-v2')
|
44 |
+
|
45 |
def _clear_cache(self):
|
46 |
self._embedding_cache = {}
|
47 |
|
48 |
|
49 |
+
@spaces.GPU
|
50 |
def _get_cached_embedding(self, text: str) -> torch.Tensor:
|
51 |
+
"""使用 GPU 裝飾器確保在正確的時機初始化 CUDA"""
|
52 |
+
if self.model is None:
|
53 |
+
self._initialize_model()
|
54 |
+
|
55 |
if text not in self._embedding_cache:
|
56 |
self._embedding_cache[text] = self.model.encode(text)
|
57 |
return self._embedding_cache[text]
|
|
|
96 |
List[Tuple[str, float]]: 相似品種列表,包含品種名稱和相似度分數
|
97 |
"""
|
98 |
try:
|
99 |
+
if self.model is None:
|
100 |
+
self._initialize_model()
|
101 |
target_breed = next((breed for breed in self.dog_data if breed[1] == breed_name), None)
|
102 |
if not target_breed:
|
103 |
return []
|
|
|
891 |
}
|
892 |
|
893 |
@gpu_init_wrapper
|
894 |
+
@_safe_prediction
|
895 |
def match_user_preference(self, description: str, top_n: int = 10) -> List[Dict]:
|
896 |
try:
|
897 |
+
if self.model is None:
|
898 |
+
self._initialize_model()
|
899 |
# 獲取場景權重
|
900 |
weights = self._detect_scenario(description)
|
901 |
matches = []
|