Spaces:
Sleeping
Sleeping
import math | |
import torch | |
from timm import create_model | |
from torchvision import transforms | |
from PIL import Image | |
# ๊ธฐ๋ณธ ํฌ๋กญ ๋น์จ | |
DEFAULT_CROP_PCT = 0.875 | |
# EfficientNet-B0 ๋ชจ๋ธ ์ค์ ๋ฐ ๋ก๋ | |
weights_path = "./weights/resnest101e.in1k_weight_Pear_classification.pt" # ๋ก์ปฌ ๊ฐ์ค์น ํ์ผ ๊ฒฝ๋ก | |
model_name = "resnest101e" | |
model = create_model(model_name, pretrained=False,num_classes=9) # ์ฌ์ ํ์ต ๋ก๋ ์๋ต | |
#model.classifier = torch.nn.Linear(model.classifier.in_features, 2) # ์ด์ง ๋ถ๋ฅ๋ก ์์ | |
model.load_state_dict(torch.load(weights_path, map_location=torch.device('cpu'))) # ๋ก์ปฌ ๊ฐ์ค์น ๋ก๋ | |
model.eval() # ํ๊ฐ ๋ชจ๋ ์ค์ | |
# ํด๋์ค ์ด๋ฆ ์๋ ์ง์ | |
# class_labels = ["Abormal Pear", "Normal Pear"] # ๋ฐ์ดํฐ์ ์ ๋ง๊ฒ ์์ | |
class_labels = ["์ ์", "ํ์ฑ๋ณ","๊ณผํผ์ผ๋ฃฉ","๋ณต์ญ์ ์๋๋ฐฉ","๋ณต์ญ์ ์ผ์๋๋ฐฉ","๋ฐฐ ๊น์ง๋ฒ๋ ","์๋ง์ด ๋๋ฐฉ๋ฅ", "๊ธฐํ","๊ณผํผํ๋ณ"] | |
# ์ ์ฒ๋ฆฌ ํจ์ | |
def transforms_imagenet_eval( | |
img_path: str, | |
img_size: int = 224, | |
crop_pct: float = DEFAULT_CROP_PCT, | |
mean: tuple = (0.485, 0.456, 0.406), | |
std: tuple = (0.229, 0.224, 0.225), | |
normalize: bool = True, | |
): | |
""" | |
ImageNet ์คํ์ผ ์ด๋ฏธ์ง ์ ์ฒ๋ฆฌ ํจ์. | |
Args: | |
img_path (str): ์ด๋ฏธ์ง ๊ฒฝ๋ก. | |
img_size (int): ํฌ๋กญ ํฌ๊ธฐ. | |
crop_pct (float): ํฌ๋กญ ๋น์จ. | |
mean (tuple): ์ ๊ทํ ํ๊ท . | |
std (tuple): ์ ๊ทํ ํ์คํธ์ฐจ. | |
normalize (bool): ์ ๊ทํ ์ฌ๋ถ. | |
Returns: | |
torch.Tensor: ์ ์ฒ๋ฆฌ๋ ์ด๋ฏธ์ง ํ ์. | |
""" | |
img = Image.open(img_path).convert("RGB") # ์ด๋ฏธ์ง ๋ก๋ ๋ฐ RGB ๋ณํ | |
scale_size = math.floor(img_size / crop_pct) # ๋ฆฌ์ฌ์ด์ฆ ํฌ๊ธฐ ๊ณ์ฐ | |
# Transform ์ค์ | |
tfl = [ | |
transforms.Resize((scale_size, scale_size), interpolation=transforms.InterpolationMode.BILINEAR), | |
transforms.CenterCrop(img_size), | |
transforms.ToTensor(), | |
] | |
if normalize: | |
tfl += [transforms.Normalize(mean=torch.tensor(mean), std=torch.tensor(std))] | |
transform = transforms.Compose(tfl) | |
return transform(img) | |
# ์ถ๋ก ํจ์ | |
def predict(image_path: str): | |
""" | |
์ฃผ์ด์ง ์ด๋ฏธ์ง ํ์ผ ๊ฒฝ๋ก๋ฅผ ๋ฐ์ ๋ชจ๋ธ ์ถ๋ก ์ ์ํ. | |
Args: | |
image_path (str): ์ ๋ ฅ ์ด๋ฏธ์ง ๊ฒฝ๋ก. | |
Returns: | |
str: ๋ชจ๋ธ ์์ธก ๊ฒฐ๊ณผ. | |
""" | |
# ์ด๋ฏธ์ง ์ ์ฒ๋ฆฌ | |
input_tensor = transforms_imagenet_eval( | |
img_path=image_path, | |
img_size=224, | |
normalize=True | |
).unsqueeze(0) # ๋ฐฐ์น ์ฐจ์ ์ถ๊ฐ | |
# ๋ชจ๋ธ ์ถ๋ก | |
with torch.no_grad(): | |
prediction = model(input_tensor) | |
probs = torch.nn.functional.softmax(prediction[0], dim=-1) | |
confidences = {class_labels[i]: float(probs[i]) for i in range(9)} | |
# ์์ธก ๊ฒฐ๊ณผ ๋ฐํ | |
return confidences |