Spaces:
Sleeping
Sleeping
File size: 2,867 Bytes
541eb9d 69ef5c2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 |
import math
import torch
from timm import create_model
from torchvision import transforms
from PIL import Image
# ๊ธฐ๋ณธ ํฌ๋กญ ๋น์จ
DEFAULT_CROP_PCT = 0.875
# EfficientNet-B0 ๋ชจ๋ธ ์ค์ ๋ฐ ๋ก๋
weights_path = "./weights/resnest101e.in1k_weight_Pear_classification.pt" # ๋ก์ปฌ ๊ฐ์ค์น ํ์ผ ๊ฒฝ๋ก
model_name = "resnest101e"
model = create_model(model_name, pretrained=False,num_classes=9) # ์ฌ์ ํ์ต ๋ก๋ ์๋ต
#model.classifier = torch.nn.Linear(model.classifier.in_features, 2) # ์ด์ง ๋ถ๋ฅ๋ก ์์
model.load_state_dict(torch.load(weights_path, map_location=torch.device('cpu'))) # ๋ก์ปฌ ๊ฐ์ค์น ๋ก๋
model.eval() # ํ๊ฐ ๋ชจ๋ ์ค์
# ํด๋์ค ์ด๋ฆ ์๋ ์ง์
# class_labels = ["Abormal Pear", "Normal Pear"] # ๋ฐ์ดํฐ์
์ ๋ง๊ฒ ์์
class_labels = ["์ ์", "ํ์ฑ๋ณ","๊ณผํผ์ผ๋ฃฉ","๋ณต์ญ์ ์๋๋ฐฉ","๋ณต์ญ์ ์ผ์๋๋ฐฉ","๋ฐฐ ๊น์ง๋ฒ๋ ","์๋ง์ด ๋๋ฐฉ๋ฅ", "๊ธฐํ","๊ณผํผํ๋ณ"]
# ์ ์ฒ๋ฆฌ ํจ์
def transforms_imagenet_eval(
img_path: str,
img_size: int = 224,
crop_pct: float = DEFAULT_CROP_PCT,
mean: tuple = (0.485, 0.456, 0.406),
std: tuple = (0.229, 0.224, 0.225),
normalize: bool = True,
):
"""
ImageNet ์คํ์ผ ์ด๋ฏธ์ง ์ ์ฒ๋ฆฌ ํจ์.
Args:
img_path (str): ์ด๋ฏธ์ง ๊ฒฝ๋ก.
img_size (int): ํฌ๋กญ ํฌ๊ธฐ.
crop_pct (float): ํฌ๋กญ ๋น์จ.
mean (tuple): ์ ๊ทํ ํ๊ท .
std (tuple): ์ ๊ทํ ํ์คํธ์ฐจ.
normalize (bool): ์ ๊ทํ ์ฌ๋ถ.
Returns:
torch.Tensor: ์ ์ฒ๋ฆฌ๋ ์ด๋ฏธ์ง ํ
์.
"""
img = Image.open(img_path).convert("RGB") # ์ด๋ฏธ์ง ๋ก๋ ๋ฐ RGB ๋ณํ
scale_size = math.floor(img_size / crop_pct) # ๋ฆฌ์ฌ์ด์ฆ ํฌ๊ธฐ ๊ณ์ฐ
# Transform ์ค์
tfl = [
transforms.Resize((scale_size, scale_size), interpolation=transforms.InterpolationMode.BILINEAR),
transforms.CenterCrop(img_size),
transforms.ToTensor(),
]
if normalize:
tfl += [transforms.Normalize(mean=torch.tensor(mean), std=torch.tensor(std))]
transform = transforms.Compose(tfl)
return transform(img)
# ์ถ๋ก ํจ์
def predict(image_path: str):
"""
์ฃผ์ด์ง ์ด๋ฏธ์ง ํ์ผ ๊ฒฝ๋ก๋ฅผ ๋ฐ์ ๋ชจ๋ธ ์ถ๋ก ์ ์ํ.
Args:
image_path (str): ์
๋ ฅ ์ด๋ฏธ์ง ๊ฒฝ๋ก.
Returns:
str: ๋ชจ๋ธ ์์ธก ๊ฒฐ๊ณผ.
"""
# ์ด๋ฏธ์ง ์ ์ฒ๋ฆฌ
input_tensor = transforms_imagenet_eval(
img_path=image_path,
img_size=224,
normalize=True
).unsqueeze(0) # ๋ฐฐ์น ์ฐจ์ ์ถ๊ฐ
# ๋ชจ๋ธ ์ถ๋ก
with torch.no_grad():
prediction = model(input_tensor)
probs = torch.nn.functional.softmax(prediction[0], dim=-1)
confidences = {class_labels[i]: float(probs[i]) for i in range(9)}
# ์์ธก ๊ฒฐ๊ณผ ๋ฐํ
return confidences |