Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -95,7 +95,6 @@ class BaseModel(nn.Module):
|
|
95 |
self.feature_dim = self.backbone.classifier[1].in_features
|
96 |
self.backbone.classifier = nn.Identity()
|
97 |
|
98 |
-
# 動態計算 num_heads
|
99 |
self.num_heads = max(1, min(8, self.feature_dim // 64))
|
100 |
self.attention = MultiHeadAttention(self.feature_dim, num_heads=self.num_heads)
|
101 |
|
@@ -122,16 +121,16 @@ model = BaseModel(num_classes=num_classes, device=device)
|
|
122 |
checkpoint = torch.load('best_model_81_dog.pth', map_location=torch.device('cpu'))
|
123 |
model.load_state_dict(checkpoint['model_state_dict'])
|
124 |
|
125 |
-
#
|
126 |
model.eval()
|
127 |
|
128 |
# Image preprocessing function
|
129 |
def preprocess_image(image):
|
130 |
-
#
|
131 |
if isinstance(image, np.ndarray):
|
132 |
image = Image.fromarray(image)
|
133 |
|
134 |
-
#
|
135 |
transform = transforms.Compose([
|
136 |
transforms.Resize((224, 224)),
|
137 |
transforms.ToTensor(),
|
@@ -140,38 +139,12 @@ def preprocess_image(image):
|
|
140 |
|
141 |
return transform(image).unsqueeze(0)
|
142 |
|
143 |
-
# def predict(image):
|
144 |
-
# try:
|
145 |
-
# image_tensor = preprocess_image(image)
|
146 |
-
# with torch.no_grad():
|
147 |
-
# logits, _ = model(image_tensor)
|
148 |
-
# _, predicted = torch.max(logits, 1)
|
149 |
-
|
150 |
-
# breed = dog_breeds[predicted.item()] # Map label to breed name
|
151 |
-
|
152 |
-
# # Retrieve breed description
|
153 |
-
# description = get_dog_description(breed)
|
154 |
-
|
155 |
-
# # Formatting the description for better display
|
156 |
-
# if isinstance(description, dict):
|
157 |
-
# description_str = f"**Breed**: {description['Breed']}\n\n"
|
158 |
-
# description_str += f"**Size**: {description['Size']}\n\n"
|
159 |
-
# description_str += f"**Lifespan**: {description['Lifespan']}\n\n"
|
160 |
-
# description_str += f"**Temperament**: {description['Temperament']}\n\n"
|
161 |
-
# description_str += f"**Care Level**: {description['Care Level']}\n\n"
|
162 |
-
# description_str += f"**Good with Children**: {description['Good with Children']}\n\n"
|
163 |
-
# description_str += f"**Exercise Needs**: {description['Exercise Needs']}\n\n"
|
164 |
-
# description_str += f"**Grooming Needs**: {description['Grooming Needs']}\n\n"
|
165 |
-
# description_str += f"**Description**: {description['Description']}\n\n"
|
166 |
-
# else:
|
167 |
-
# description_str = description
|
168 |
-
|
169 |
-
# return description_str
|
170 |
-
# except Exception as e:
|
171 |
-
# return f"An error occurred: {e}"
|
172 |
|
173 |
def get_akc_link(breed):
|
174 |
-
|
|
|
|
|
|
|
175 |
return f"https://www.akc.org/dog-breeds/{formatted_breed}/"
|
176 |
|
177 |
def predict(image):
|
@@ -184,30 +157,23 @@ def predict(image):
|
|
184 |
else:
|
185 |
logits = output
|
186 |
_, predicted = torch.max(logits, 1)
|
187 |
-
breed = dog_breeds[predicted.item()]
|
188 |
|
189 |
-
# Retrieve breed description
|
190 |
description = get_dog_description(breed)
|
191 |
-
|
192 |
-
# Generate AKC link
|
193 |
akc_link = get_akc_link(breed)
|
194 |
|
195 |
-
# Formatting the description for better display
|
196 |
if isinstance(description, dict):
|
197 |
-
description_str = f"**
|
198 |
-
description_str += f"**Size**: {description['Size']}\n\n"
|
199 |
-
description_str += f"**Lifespan**: {description['Lifespan']}\n\n"
|
200 |
-
description_str += f"**Temperament**: {description['Temperament']}\n\n"
|
201 |
-
description_str += f"**Care Level**: {description['Care Level']}\n\n"
|
202 |
-
description_str += f"**Good with Children**: {description['Good with Children']}\n\n"
|
203 |
-
description_str += f"**Exercise Needs**: {description['Exercise Needs']}\n\n"
|
204 |
-
description_str += f"**Grooming Needs**: {description['Grooming Needs']}\n\n"
|
205 |
-
description_str += f"**Description**: {description['Description']}\n\n"
|
206 |
else:
|
207 |
description_str = description
|
208 |
|
209 |
-
# Add AKC link
|
210 |
-
description_str += f"\n\n
|
|
|
|
|
|
|
|
|
|
|
211 |
|
212 |
return description_str
|
213 |
except Exception as e:
|
@@ -226,7 +192,6 @@ iface = gr.Interface(
|
|
226 |
'French_Bulldog.jpeg',
|
227 |
'Samoyed.jpg'],
|
228 |
css = """
|
229 |
-
/* 新增樣式 */
|
230 |
.container {
|
231 |
max-width: 900px;
|
232 |
margin: 0 auto;
|
|
|
95 |
self.feature_dim = self.backbone.classifier[1].in_features
|
96 |
self.backbone.classifier = nn.Identity()
|
97 |
|
|
|
98 |
self.num_heads = max(1, min(8, self.feature_dim // 64))
|
99 |
self.attention = MultiHeadAttention(self.feature_dim, num_heads=self.num_heads)
|
100 |
|
|
|
121 |
checkpoint = torch.load('best_model_81_dog.pth', map_location=torch.device('cpu'))
|
122 |
model.load_state_dict(checkpoint['model_state_dict'])
|
123 |
|
124 |
+
# evaluation mode
|
125 |
model.eval()
|
126 |
|
127 |
# Image preprocessing function
|
128 |
def preprocess_image(image):
|
129 |
+
# If the image is numpy.ndarray turn into PIL.Image
|
130 |
if isinstance(image, np.ndarray):
|
131 |
image = Image.fromarray(image)
|
132 |
|
133 |
+
# Use torchvision.transforms to process images
|
134 |
transform = transforms.Compose([
|
135 |
transforms.Resize((224, 224)),
|
136 |
transforms.ToTensor(),
|
|
|
139 |
|
140 |
return transform(image).unsqueeze(0)
|
141 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
142 |
|
143 |
def get_akc_link(breed):
|
144 |
+
# Remove any non-English characters and convert to lowercase
|
145 |
+
formatted_breed = ''.join(c for c in breed if ord(c) < 128).lower()
|
146 |
+
# Replace spaces with hyphens and remove any remaining special characters
|
147 |
+
formatted_breed = '-'.join(word for word in formatted_breed.split() if word.isalnum())
|
148 |
return f"https://www.akc.org/dog-breeds/{formatted_breed}/"
|
149 |
|
150 |
def predict(image):
|
|
|
157 |
else:
|
158 |
logits = output
|
159 |
_, predicted = torch.max(logits, 1)
|
160 |
+
breed = dog_breeds[predicted.item()]
|
161 |
|
|
|
162 |
description = get_dog_description(breed)
|
|
|
|
|
163 |
akc_link = get_akc_link(breed)
|
164 |
|
|
|
165 |
if isinstance(description, dict):
|
166 |
+
description_str = "\n\n".join([f"**{key}**: {value}" for key, value in description.items()])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
167 |
else:
|
168 |
description_str = description
|
169 |
|
170 |
+
# Add AKC link as an option
|
171 |
+
description_str += f"\n\n**Want to learn more?** [View detailed information about {breed} on the AKC website]({akc_link})"
|
172 |
+
|
173 |
+
# Add disclaimer
|
174 |
+
disclaimer = ("\n\n*Disclaimer: The external link provided leads to the American Kennel Club (AKC) website. "
|
175 |
+
"We are not responsible for the content on external sites. Please refer to the AKC's terms of use and privacy policy.*")
|
176 |
+
description_str += disclaimer
|
177 |
|
178 |
return description_str
|
179 |
except Exception as e:
|
|
|
192 |
'French_Bulldog.jpeg',
|
193 |
'Samoyed.jpg'],
|
194 |
css = """
|
|
|
195 |
.container {
|
196 |
max-width: 900px;
|
197 |
margin: 0 auto;
|