Pear-playground / classification.py
Kaori1707's picture
Update classification.py
541eb9d verified
import math
import torch
from timm import create_model
from torchvision import transforms
from PIL import Image
# ๊ธฐ๋ณธ ํฌ๋กญ ๋น„์œจ
DEFAULT_CROP_PCT = 0.875
# EfficientNet-B0 ๋ชจ๋ธ ์„ค์ • ๋ฐ ๋กœ๋“œ
weights_path = "./weights/resnest101e.in1k_weight_Pear_classification.pt" # ๋กœ์ปฌ ๊ฐ€์ค‘์น˜ ํŒŒ์ผ ๊ฒฝ๋กœ
model_name = "resnest101e"
model = create_model(model_name, pretrained=False,num_classes=9) # ์‚ฌ์ „ ํ•™์Šต ๋กœ๋“œ ์ƒ๋žต
#model.classifier = torch.nn.Linear(model.classifier.in_features, 2) # ์ด์ง„ ๋ถ„๋ฅ˜๋กœ ์ˆ˜์ •
model.load_state_dict(torch.load(weights_path, map_location=torch.device('cpu'))) # ๋กœ์ปฌ ๊ฐ€์ค‘์น˜ ๋กœ๋“œ
model.eval() # ํ‰๊ฐ€ ๋ชจ๋“œ ์„ค์ •
# ํด๋ž˜์Šค ์ด๋ฆ„ ์ˆ˜๋™ ์ง€์ •
# class_labels = ["Abormal Pear", "Normal Pear"] # ๋ฐ์ดํ„ฐ์…‹์— ๋งž๊ฒŒ ์ˆ˜์ •
class_labels = ["์ •์ƒ", "ํ‘์„ฑ๋ณ‘","๊ณผํ”ผ์–ผ๋ฃฉ","๋ณต์ˆญ์•„ ์ˆœ๋‚˜๋ฐฉ","๋ณต์ˆญ์•„ ์‚ผ์‹๋‚˜๋ฐฉ","๋ฐฐ ๊น์ง€๋ฒŒ๋ ˆ","์žŽ๋ง์ด ๋‚˜๋ฐฉ๋ฅ˜", "๊ธฐํƒ€","๊ณผํ”ผํ‘๋ณ€"]
# ์ „์ฒ˜๋ฆฌ ํ•จ์ˆ˜
def transforms_imagenet_eval(
img_path: str,
img_size: int = 224,
crop_pct: float = DEFAULT_CROP_PCT,
mean: tuple = (0.485, 0.456, 0.406),
std: tuple = (0.229, 0.224, 0.225),
normalize: bool = True,
):
"""
ImageNet ์Šคํƒ€์ผ ์ด๋ฏธ์ง€ ์ „์ฒ˜๋ฆฌ ํ•จ์ˆ˜.
Args:
img_path (str): ์ด๋ฏธ์ง€ ๊ฒฝ๋กœ.
img_size (int): ํฌ๋กญ ํฌ๊ธฐ.
crop_pct (float): ํฌ๋กญ ๋น„์œจ.
mean (tuple): ์ •๊ทœํ™” ํ‰๊ท .
std (tuple): ์ •๊ทœํ™” ํ‘œ์ค€ํŽธ์ฐจ.
normalize (bool): ์ •๊ทœํ™” ์—ฌ๋ถ€.
Returns:
torch.Tensor: ์ „์ฒ˜๋ฆฌ๋œ ์ด๋ฏธ์ง€ ํ…์„œ.
"""
img = Image.open(img_path).convert("RGB") # ์ด๋ฏธ์ง€ ๋กœ๋“œ ๋ฐ RGB ๋ณ€ํ™˜
scale_size = math.floor(img_size / crop_pct) # ๋ฆฌ์‚ฌ์ด์ฆˆ ํฌ๊ธฐ ๊ณ„์‚ฐ
# Transform ์„ค์ •
tfl = [
transforms.Resize((scale_size, scale_size), interpolation=transforms.InterpolationMode.BILINEAR),
transforms.CenterCrop(img_size),
transforms.ToTensor(),
]
if normalize:
tfl += [transforms.Normalize(mean=torch.tensor(mean), std=torch.tensor(std))]
transform = transforms.Compose(tfl)
return transform(img)
# ์ถ”๋ก  ํ•จ์ˆ˜
def predict(image_path: str):
"""
์ฃผ์–ด์ง„ ์ด๋ฏธ์ง€ ํŒŒ์ผ ๊ฒฝ๋กœ๋ฅผ ๋ฐ›์•„ ๋ชจ๋ธ ์ถ”๋ก ์„ ์ˆ˜ํ–‰.
Args:
image_path (str): ์ž…๋ ฅ ์ด๋ฏธ์ง€ ๊ฒฝ๋กœ.
Returns:
str: ๋ชจ๋ธ ์˜ˆ์ธก ๊ฒฐ๊ณผ.
"""
# ์ด๋ฏธ์ง€ ์ „์ฒ˜๋ฆฌ
input_tensor = transforms_imagenet_eval(
img_path=image_path,
img_size=224,
normalize=True
).unsqueeze(0) # ๋ฐฐ์น˜ ์ฐจ์› ์ถ”๊ฐ€
# ๋ชจ๋ธ ์ถ”๋ก 
with torch.no_grad():
prediction = model(input_tensor)
probs = torch.nn.functional.softmax(prediction[0], dim=-1)
confidences = {class_labels[i]: float(probs[i]) for i in range(9)}
# ์˜ˆ์ธก ๊ฒฐ๊ณผ ๋ฐ˜ํ™˜
return confidences