Spaces:
Sleeping
Sleeping
Update classification.py
Browse files- classification.py +85 -85
classification.py
CHANGED
@@ -1,86 +1,86 @@
|
|
1 |
-
import math
|
2 |
-
import torch
|
3 |
-
from timm import create_model
|
4 |
-
from torchvision import transforms
|
5 |
-
from PIL import Image
|
6 |
-
|
7 |
-
# ๊ธฐ๋ณธ ํฌ๋กญ ๋น์จ
|
8 |
-
DEFAULT_CROP_PCT = 0.875
|
9 |
-
|
10 |
-
# EfficientNet-B0 ๋ชจ๋ธ ์ค์ ๋ฐ ๋ก๋
|
11 |
-
weights_path = "./weights/resnest101e.in1k_weight_Pear_classification.pt" # ๋ก์ปฌ ๊ฐ์ค์น ํ์ผ ๊ฒฝ๋ก
|
12 |
-
model_name = "resnest101e"
|
13 |
-
model = create_model(model_name, pretrained=False,num_classes=9) # ์ฌ์ ํ์ต ๋ก๋ ์๋ต
|
14 |
-
#model.classifier = torch.nn.Linear(model.classifier.in_features, 2) # ์ด์ง ๋ถ๋ฅ๋ก ์์
|
15 |
-
model.load_state_dict(torch.load(weights_path)) # ๋ก์ปฌ ๊ฐ์ค์น ๋ก๋
|
16 |
-
model.eval() # ํ๊ฐ ๋ชจ๋ ์ค์
|
17 |
-
|
18 |
-
# ํด๋์ค ์ด๋ฆ ์๋ ์ง์
|
19 |
-
# class_labels = ["Abormal Pear", "Normal Pear"] # ๋ฐ์ดํฐ์
์ ๋ง๊ฒ ์์
|
20 |
-
class_labels = ["์ ์", "ํ์ฑ๋ณ","๊ณผํผ์ผ๋ฃฉ","๋ณต์ญ์ ์๋๋ฐฉ","๋ณต์ญ์ ์ผ์๋๋ฐฉ","๋ฐฐ ๊น์ง๋ฒ๋ ","์๋ง์ด ๋๋ฐฉ๋ฅ", "๊ธฐํ","๊ณผํผํ๋ณ"]
|
21 |
-
|
22 |
-
# ์ ์ฒ๋ฆฌ ํจ์
|
23 |
-
def transforms_imagenet_eval(
|
24 |
-
img_path: str,
|
25 |
-
img_size: int = 224,
|
26 |
-
crop_pct: float = DEFAULT_CROP_PCT,
|
27 |
-
mean: tuple = (0.485, 0.456, 0.406),
|
28 |
-
std: tuple = (0.229, 0.224, 0.225),
|
29 |
-
normalize: bool = True,
|
30 |
-
):
|
31 |
-
"""
|
32 |
-
ImageNet ์คํ์ผ ์ด๋ฏธ์ง ์ ์ฒ๋ฆฌ ํจ์.
|
33 |
-
|
34 |
-
Args:
|
35 |
-
img_path (str): ์ด๋ฏธ์ง ๊ฒฝ๋ก.
|
36 |
-
img_size (int): ํฌ๋กญ ํฌ๊ธฐ.
|
37 |
-
crop_pct (float): ํฌ๋กญ ๋น์จ.
|
38 |
-
mean (tuple): ์ ๊ทํ ํ๊ท .
|
39 |
-
std (tuple): ์ ๊ทํ ํ์คํธ์ฐจ.
|
40 |
-
normalize (bool): ์ ๊ทํ ์ฌ๋ถ.
|
41 |
-
|
42 |
-
Returns:
|
43 |
-
torch.Tensor: ์ ์ฒ๋ฆฌ๋ ์ด๋ฏธ์ง ํ
์.
|
44 |
-
"""
|
45 |
-
img = Image.open(img_path).convert("RGB") # ์ด๋ฏธ์ง ๋ก๋ ๋ฐ RGB ๋ณํ
|
46 |
-
scale_size = math.floor(img_size / crop_pct) # ๋ฆฌ์ฌ์ด์ฆ ํฌ๊ธฐ ๊ณ์ฐ
|
47 |
-
|
48 |
-
# Transform ์ค์
|
49 |
-
tfl = [
|
50 |
-
transforms.Resize((scale_size, scale_size), interpolation=transforms.InterpolationMode.BILINEAR),
|
51 |
-
transforms.CenterCrop(img_size),
|
52 |
-
transforms.ToTensor(),
|
53 |
-
]
|
54 |
-
if normalize:
|
55 |
-
tfl += [transforms.Normalize(mean=torch.tensor(mean), std=torch.tensor(std))]
|
56 |
-
|
57 |
-
transform = transforms.Compose(tfl)
|
58 |
-
return transform(img)
|
59 |
-
|
60 |
-
# ์ถ๋ก ํจ์
|
61 |
-
def predict(image_path: str):
|
62 |
-
"""
|
63 |
-
์ฃผ์ด์ง ์ด๋ฏธ์ง ํ์ผ ๊ฒฝ๋ก๋ฅผ ๋ฐ์ ๋ชจ๋ธ ์ถ๋ก ์ ์ํ.
|
64 |
-
|
65 |
-
Args:
|
66 |
-
image_path (str): ์
๋ ฅ ์ด๋ฏธ์ง ๊ฒฝ๋ก.
|
67 |
-
|
68 |
-
Returns:
|
69 |
-
str: ๋ชจ๋ธ ์์ธก ๊ฒฐ๊ณผ.
|
70 |
-
"""
|
71 |
-
# ์ด๋ฏธ์ง ์ ์ฒ๋ฆฌ
|
72 |
-
input_tensor = transforms_imagenet_eval(
|
73 |
-
img_path=image_path,
|
74 |
-
img_size=224,
|
75 |
-
normalize=True
|
76 |
-
).unsqueeze(0) # ๋ฐฐ์น ์ฐจ์ ์ถ๊ฐ
|
77 |
-
|
78 |
-
# ๋ชจ๋ธ ์ถ๋ก
|
79 |
-
with torch.no_grad():
|
80 |
-
prediction = model(input_tensor)
|
81 |
-
|
82 |
-
probs = torch.nn.functional.softmax(prediction[0], dim=-1)
|
83 |
-
confidences = {class_labels[i]: float(probs[i]) for i in range(9)}
|
84 |
-
|
85 |
-
# ์์ธก ๊ฒฐ๊ณผ ๋ฐํ
|
86 |
return confidences
|
|
|
1 |
+
import math
|
2 |
+
import torch
|
3 |
+
from timm import create_model
|
4 |
+
from torchvision import transforms
|
5 |
+
from PIL import Image
|
6 |
+
|
7 |
+
# ๊ธฐ๋ณธ ํฌ๋กญ ๋น์จ
|
8 |
+
DEFAULT_CROP_PCT = 0.875
|
9 |
+
|
10 |
+
# EfficientNet-B0 ๋ชจ๋ธ ์ค์ ๋ฐ ๋ก๋
|
11 |
+
weights_path = "./weights/resnest101e.in1k_weight_Pear_classification.pt" # ๋ก์ปฌ ๊ฐ์ค์น ํ์ผ ๊ฒฝ๋ก
|
12 |
+
model_name = "resnest101e"
|
13 |
+
model = create_model(model_name, pretrained=False,num_classes=9) # ์ฌ์ ํ์ต ๋ก๋ ์๋ต
|
14 |
+
#model.classifier = torch.nn.Linear(model.classifier.in_features, 2) # ์ด์ง ๋ถ๋ฅ๋ก ์์
|
15 |
+
model.load_state_dict(torch.load(weights_path, map_location=torch.device('cpu'))) # ๋ก์ปฌ ๊ฐ์ค์น ๋ก๋
|
16 |
+
model.eval() # ํ๊ฐ ๋ชจ๋ ์ค์
|
17 |
+
|
18 |
+
# ํด๋์ค ์ด๋ฆ ์๋ ์ง์
|
19 |
+
# class_labels = ["Abormal Pear", "Normal Pear"] # ๋ฐ์ดํฐ์
์ ๋ง๊ฒ ์์
|
20 |
+
class_labels = ["์ ์", "ํ์ฑ๋ณ","๊ณผํผ์ผ๋ฃฉ","๋ณต์ญ์ ์๋๋ฐฉ","๋ณต์ญ์ ์ผ์๋๋ฐฉ","๋ฐฐ ๊น์ง๋ฒ๋ ","์๋ง์ด ๋๋ฐฉ๋ฅ", "๊ธฐํ","๊ณผํผํ๋ณ"]
|
21 |
+
|
22 |
+
# ์ ์ฒ๋ฆฌ ํจ์
|
23 |
+
def transforms_imagenet_eval(
|
24 |
+
img_path: str,
|
25 |
+
img_size: int = 224,
|
26 |
+
crop_pct: float = DEFAULT_CROP_PCT,
|
27 |
+
mean: tuple = (0.485, 0.456, 0.406),
|
28 |
+
std: tuple = (0.229, 0.224, 0.225),
|
29 |
+
normalize: bool = True,
|
30 |
+
):
|
31 |
+
"""
|
32 |
+
ImageNet ์คํ์ผ ์ด๋ฏธ์ง ์ ์ฒ๋ฆฌ ํจ์.
|
33 |
+
|
34 |
+
Args:
|
35 |
+
img_path (str): ์ด๋ฏธ์ง ๊ฒฝ๋ก.
|
36 |
+
img_size (int): ํฌ๋กญ ํฌ๊ธฐ.
|
37 |
+
crop_pct (float): ํฌ๋กญ ๋น์จ.
|
38 |
+
mean (tuple): ์ ๊ทํ ํ๊ท .
|
39 |
+
std (tuple): ์ ๊ทํ ํ์คํธ์ฐจ.
|
40 |
+
normalize (bool): ์ ๊ทํ ์ฌ๋ถ.
|
41 |
+
|
42 |
+
Returns:
|
43 |
+
torch.Tensor: ์ ์ฒ๋ฆฌ๋ ์ด๋ฏธ์ง ํ
์.
|
44 |
+
"""
|
45 |
+
img = Image.open(img_path).convert("RGB") # ์ด๋ฏธ์ง ๋ก๋ ๋ฐ RGB ๋ณํ
|
46 |
+
scale_size = math.floor(img_size / crop_pct) # ๋ฆฌ์ฌ์ด์ฆ ํฌ๊ธฐ ๊ณ์ฐ
|
47 |
+
|
48 |
+
# Transform ์ค์
|
49 |
+
tfl = [
|
50 |
+
transforms.Resize((scale_size, scale_size), interpolation=transforms.InterpolationMode.BILINEAR),
|
51 |
+
transforms.CenterCrop(img_size),
|
52 |
+
transforms.ToTensor(),
|
53 |
+
]
|
54 |
+
if normalize:
|
55 |
+
tfl += [transforms.Normalize(mean=torch.tensor(mean), std=torch.tensor(std))]
|
56 |
+
|
57 |
+
transform = transforms.Compose(tfl)
|
58 |
+
return transform(img)
|
59 |
+
|
60 |
+
# ์ถ๋ก ํจ์
|
61 |
+
def predict(image_path: str):
|
62 |
+
"""
|
63 |
+
์ฃผ์ด์ง ์ด๋ฏธ์ง ํ์ผ ๊ฒฝ๋ก๋ฅผ ๋ฐ์ ๋ชจ๋ธ ์ถ๋ก ์ ์ํ.
|
64 |
+
|
65 |
+
Args:
|
66 |
+
image_path (str): ์
๋ ฅ ์ด๋ฏธ์ง ๊ฒฝ๋ก.
|
67 |
+
|
68 |
+
Returns:
|
69 |
+
str: ๋ชจ๋ธ ์์ธก ๊ฒฐ๊ณผ.
|
70 |
+
"""
|
71 |
+
# ์ด๋ฏธ์ง ์ ์ฒ๋ฆฌ
|
72 |
+
input_tensor = transforms_imagenet_eval(
|
73 |
+
img_path=image_path,
|
74 |
+
img_size=224,
|
75 |
+
normalize=True
|
76 |
+
).unsqueeze(0) # ๋ฐฐ์น ์ฐจ์ ์ถ๊ฐ
|
77 |
+
|
78 |
+
# ๋ชจ๋ธ ์ถ๋ก
|
79 |
+
with torch.no_grad():
|
80 |
+
prediction = model(input_tensor)
|
81 |
+
|
82 |
+
probs = torch.nn.functional.softmax(prediction[0], dim=-1)
|
83 |
+
confidences = {class_labels[i]: float(probs[i]) for i in range(9)}
|
84 |
+
|
85 |
+
# ์์ธก ๊ฒฐ๊ณผ ๋ฐํ
|
86 |
return confidences
|