File size: 2,867 Bytes
541eb9d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69ef5c2
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import math
import torch
from timm import create_model
from torchvision import transforms
from PIL import Image

# ๊ธฐ๋ณธ ํฌ๋กญ ๋น„์œจ
DEFAULT_CROP_PCT = 0.875

# EfficientNet-B0 ๋ชจ๋ธ ์„ค์ • ๋ฐ ๋กœ๋“œ
weights_path =  "./weights/resnest101e.in1k_weight_Pear_classification.pt"  # ๋กœ์ปฌ ๊ฐ€์ค‘์น˜ ํŒŒ์ผ ๊ฒฝ๋กœ
model_name = "resnest101e"
model = create_model(model_name, pretrained=False,num_classes=9)  # ์‚ฌ์ „ ํ•™์Šต ๋กœ๋“œ ์ƒ๋žต
#model.classifier = torch.nn.Linear(model.classifier.in_features, 2)  # ์ด์ง„ ๋ถ„๋ฅ˜๋กœ ์ˆ˜์ •
model.load_state_dict(torch.load(weights_path, map_location=torch.device('cpu')))  # ๋กœ์ปฌ ๊ฐ€์ค‘์น˜ ๋กœ๋“œ
model.eval()  # ํ‰๊ฐ€ ๋ชจ๋“œ ์„ค์ •

# ํด๋ž˜์Šค ์ด๋ฆ„ ์ˆ˜๋™ ์ง€์ •
# class_labels = ["Abormal Pear", "Normal Pear"]  # ๋ฐ์ดํ„ฐ์…‹์— ๋งž๊ฒŒ ์ˆ˜์ •
class_labels = ["์ •์ƒ", "ํ‘์„ฑ๋ณ‘","๊ณผํ”ผ์–ผ๋ฃฉ","๋ณต์ˆญ์•„ ์ˆœ๋‚˜๋ฐฉ","๋ณต์ˆญ์•„ ์‚ผ์‹๋‚˜๋ฐฉ","๋ฐฐ ๊น์ง€๋ฒŒ๋ ˆ","์žŽ๋ง์ด ๋‚˜๋ฐฉ๋ฅ˜", "๊ธฐํƒ€","๊ณผํ”ผํ‘๋ณ€"]

# ์ „์ฒ˜๋ฆฌ ํ•จ์ˆ˜
def transforms_imagenet_eval(
    img_path: str,
    img_size: int = 224,
    crop_pct: float = DEFAULT_CROP_PCT,
    mean: tuple = (0.485, 0.456, 0.406),
    std: tuple = (0.229, 0.224, 0.225),
    normalize: bool = True,
):
    """
    ImageNet ์Šคํƒ€์ผ ์ด๋ฏธ์ง€ ์ „์ฒ˜๋ฆฌ ํ•จ์ˆ˜.

    Args:
        img_path (str): ์ด๋ฏธ์ง€ ๊ฒฝ๋กœ.
        img_size (int): ํฌ๋กญ ํฌ๊ธฐ.
        crop_pct (float): ํฌ๋กญ ๋น„์œจ.
        mean (tuple): ์ •๊ทœํ™” ํ‰๊ท .
        std (tuple): ์ •๊ทœํ™” ํ‘œ์ค€ํŽธ์ฐจ.
        normalize (bool): ์ •๊ทœํ™” ์—ฌ๋ถ€.

    Returns:
        torch.Tensor: ์ „์ฒ˜๋ฆฌ๋œ ์ด๋ฏธ์ง€ ํ…์„œ.
    """
    img = Image.open(img_path).convert("RGB")  # ์ด๋ฏธ์ง€ ๋กœ๋“œ ๋ฐ RGB ๋ณ€ํ™˜
    scale_size = math.floor(img_size / crop_pct)  # ๋ฆฌ์‚ฌ์ด์ฆˆ ํฌ๊ธฐ ๊ณ„์‚ฐ

    # Transform ์„ค์ •
    tfl = [
        transforms.Resize((scale_size, scale_size), interpolation=transforms.InterpolationMode.BILINEAR),
        transforms.CenterCrop(img_size),
        transforms.ToTensor(),
    ]
    if normalize:
        tfl += [transforms.Normalize(mean=torch.tensor(mean), std=torch.tensor(std))]

    transform = transforms.Compose(tfl)
    return transform(img)

# ์ถ”๋ก  ํ•จ์ˆ˜
def predict(image_path: str):
    """
    ์ฃผ์–ด์ง„ ์ด๋ฏธ์ง€ ํŒŒ์ผ ๊ฒฝ๋กœ๋ฅผ ๋ฐ›์•„ ๋ชจ๋ธ ์ถ”๋ก ์„ ์ˆ˜ํ–‰.

    Args:
        image_path (str): ์ž…๋ ฅ ์ด๋ฏธ์ง€ ๊ฒฝ๋กœ.

    Returns:
        str: ๋ชจ๋ธ ์˜ˆ์ธก ๊ฒฐ๊ณผ.
    """
    # ์ด๋ฏธ์ง€ ์ „์ฒ˜๋ฆฌ
    input_tensor = transforms_imagenet_eval(
        img_path=image_path,
        img_size=224,
        normalize=True
    ).unsqueeze(0)  # ๋ฐฐ์น˜ ์ฐจ์› ์ถ”๊ฐ€

    # ๋ชจ๋ธ ์ถ”๋ก 
    with torch.no_grad():
        prediction = model(input_tensor)
    
    probs = torch.nn.functional.softmax(prediction[0], dim=-1)
    confidences = {class_labels[i]: float(probs[i]) for i in range(9)}

    # ์˜ˆ์ธก ๊ฒฐ๊ณผ ๋ฐ˜ํ™˜
    return confidences