import math import torch from timm import create_model from torchvision import transforms from PIL import Image # 기본 크롭 비율 DEFAULT_CROP_PCT = 0.875 # EfficientNet-B0 모델 설정 및 로드 weights_path = "./weights/resnest101e.in1k_weight_Pear_classification.pt" # 로컬 가중치 파일 경로 model_name = "resnest101e" model = create_model(model_name, pretrained=False,num_classes=9) # 사전 학습 로드 생략 #model.classifier = torch.nn.Linear(model.classifier.in_features, 2) # 이진 분류로 수정 model.load_state_dict(torch.load(weights_path, map_location=torch.device('cpu'))) # 로컬 가중치 로드 model.eval() # 평가 모드 설정 # 클래스 이름 수동 지정 # class_labels = ["Abormal Pear", "Normal Pear"] # 데이터셋에 맞게 수정 class_labels = ["정상", "흑성병","과피얼룩","복숭아 순나방","복숭아 삼식나방","배 깍지벌레","잎말이 나방류", "기타","과피흑변"] # 전처리 함수 def transforms_imagenet_eval( img_path: str, img_size: int = 224, crop_pct: float = DEFAULT_CROP_PCT, mean: tuple = (0.485, 0.456, 0.406), std: tuple = (0.229, 0.224, 0.225), normalize: bool = True, ): """ ImageNet 스타일 이미지 전처리 함수. Args: img_path (str): 이미지 경로. img_size (int): 크롭 크기. crop_pct (float): 크롭 비율. mean (tuple): 정규화 평균. std (tuple): 정규화 표준편차. normalize (bool): 정규화 여부. Returns: torch.Tensor: 전처리된 이미지 텐서. """ img = Image.open(img_path).convert("RGB") # 이미지 로드 및 RGB 변환 scale_size = math.floor(img_size / crop_pct) # 리사이즈 크기 계산 # Transform 설정 tfl = [ transforms.Resize((scale_size, scale_size), interpolation=transforms.InterpolationMode.BILINEAR), transforms.CenterCrop(img_size), transforms.ToTensor(), ] if normalize: tfl += [transforms.Normalize(mean=torch.tensor(mean), std=torch.tensor(std))] transform = transforms.Compose(tfl) return transform(img) # 추론 함수 def predict(image_path: str): """ 주어진 이미지 파일 경로를 받아 모델 추론을 수행. Args: image_path (str): 입력 이미지 경로. Returns: str: 모델 예측 결과. """ # 이미지 전처리 input_tensor = transforms_imagenet_eval( img_path=image_path, img_size=224, normalize=True ).unsqueeze(0) # 배치 차원 추가 # 모델 추론 with torch.no_grad(): prediction = model(input_tensor) probs = torch.nn.functional.softmax(prediction[0], dim=-1) confidences = {class_labels[i]: float(probs[i]) for i in range(9)} # 예측 결과 반환 return confidences