Kaori1707 commited on
Commit
69ef5c2
1 Parent(s): 91a1441

Upload 10 files

Browse files

update app and weights

.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ examples/2.JPG filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
@@ -1,14 +1,14 @@
1
- ---
2
- title: Pear Playground
3
- emoji: 🐨
4
- colorFrom: gray
5
- colorTo: red
6
- sdk: gradio
7
- sdk_version: 5.8.0
8
- app_file: app.py
9
- pinned: false
10
- license: apache-2.0
11
- short_description: Defects Detection on Pear
12
- ---
13
-
14
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
+ ---
2
+ title: Pear Playground
3
+ emoji: 🐨
4
+ colorFrom: gray
5
+ colorTo: red
6
+ sdk: gradio
7
+ sdk_version: 5.8.0
8
+ app_file: app.py
9
+ pinned: false
10
+ license: apache-2.0
11
+ short_description: Defects Detection on Pear
12
+ ---
13
+
14
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py CHANGED
@@ -1,88 +1,88 @@
1
- import cv2
2
- import gradio as gr
3
- from detection import PearDetectionModel
4
- from classification import predict
5
-
6
- # make streaming interface that reads from camera and displays the output with bounding boxes
7
- config = {"model_path": "./weights/best.pt", "classes": ['burn_bbox', 'defected_pear', 'defected_pear_bbox', 'normal_pear', 'normal_pear_bbox']}
8
- model = PearDetectionModel(config)
9
-
10
-
11
- def classify(image):
12
- """
13
- Gradio에서 PIL 이미지를 입력받아 추론 결과를 반환.
14
-
15
- Args:
16
- image (PIL.Image): 업로드된 이미지.
17
-
18
- Returns:
19
- str: 모델 예측 결과.
20
- """
21
- # 임시 파일 저장 후 처리
22
- image_path = "temp_image.jpg"
23
- image.save(image_path)
24
- return predict(image_path)
25
-
26
- def detect(img):
27
- cls, xyxy, conf = model.inference(img)
28
- for box, conf in zip(xyxy, conf):
29
- cv2.rectangle(
30
- img,
31
- (int(box[0]), int(box[1])),
32
- (int(box[2]), int(box[3])),
33
- (0, 255, 0),
34
- 2,
35
- )
36
- cv2.putText(
37
- img,
38
- f"{conf:.2f}",
39
- (int(box[0]), int(box[1])),
40
- cv2.FONT_HERSHEY_SIMPLEX,
41
- 1,
42
- (0, 255, 0),
43
- 2,
44
- )
45
- cv2.putText(
46
- img,
47
- "Class: Normal Pear" if cls == 0 else "Class: Abnormal Pear",
48
- (0, 50),
49
- cv2.FONT_HERSHEY_SIMPLEX,
50
- 1,
51
- (0, 255, 0),
52
- 2,
53
- )
54
- return img
55
-
56
-
57
- css = """.my-group {max-width: 500px !important; max-height: 500px !important;}
58
- .my-column {display: flex !important; justify-content: center !important; align-items: center !important};"""
59
-
60
- with gr.Blocks(css=css) as demo:
61
- demo.title = "Pear Playground"
62
- # add markdown
63
- gr.Markdown("## This is a demo for Pear Playground by AISeed.")
64
- with gr.Tab(label="Classification"):
65
- gr.Interface(
66
- fn=classify,
67
- inputs=gr.Image(type="pil", label="Upload an image"),
68
- outputs=gr.Label(num_top_classes=9),
69
- examples=["examples/1.jpg", "examples/2.jpg"],
70
- title="비정상 과수 분류기",
71
- description="경량 모델 ResNet101e 을 활용하여 비정상배 분류"
72
- )
73
- with gr.Tab(label="Detection"):
74
- with gr.Column(elem_classes=["my-column"]):
75
- with gr.Group(elem_classes=["my-group"]):
76
-
77
- input_img = gr.Image(sources=["webcam"], type="numpy", streaming=True)
78
- input_img.stream(
79
- detect,
80
- [input_img],
81
- [input_img],
82
- time_limit=30,
83
- stream_every=0.1,
84
- )
85
-
86
- if __name__ == "__main__":
87
-
88
- demo.launch()
 
1
+ import cv2
2
+ import gradio as gr
3
+ from detection import PearDetectionModel
4
+ from classification import predict
5
+
6
+ # make streaming interface that reads from camera and displays the output with bounding boxes
7
+ config = {"model_path": "./weights/best.pt", "classes": ['burn_bbox', 'defected_pear', 'defected_pear_bbox', 'normal_pear', 'normal_pear_bbox']}
8
+ model = PearDetectionModel(config)
9
+
10
+
11
+ def classify(image):
12
+ """
13
+ Gradio에서 PIL 이미지를 입력받아 추론 결과를 반환.
14
+
15
+ Args:
16
+ image (PIL.Image): 업로드된 이미지.
17
+
18
+ Returns:
19
+ str: 모델 예측 결과.
20
+ """
21
+ # 임시 파일 저장 후 처리
22
+ image_path = "temp_image.jpg"
23
+ image.save(image_path)
24
+ return predict(image_path)
25
+
26
+ def detect(img):
27
+ cls, xyxy, conf = model.inference(img)
28
+ for box, conf in zip(xyxy, conf):
29
+ cv2.rectangle(
30
+ img,
31
+ (int(box[0]), int(box[1])),
32
+ (int(box[2]), int(box[3])),
33
+ (0, 255, 0),
34
+ 2,
35
+ )
36
+ cv2.putText(
37
+ img,
38
+ f"{conf:.2f}",
39
+ (int(box[0]), int(box[1])),
40
+ cv2.FONT_HERSHEY_SIMPLEX,
41
+ 1,
42
+ (0, 255, 0),
43
+ 2,
44
+ )
45
+ cv2.putText(
46
+ img,
47
+ "Class: Normal Pear" if cls == 0 else "Class: Abnormal Pear",
48
+ (0, 50),
49
+ cv2.FONT_HERSHEY_SIMPLEX,
50
+ 1,
51
+ (0, 255, 0),
52
+ 2,
53
+ )
54
+ return img
55
+
56
+
57
+ css = """.my-group {max-width: 500px !important; max-height: 500px !important;}
58
+ .my-column {display: flex !important; justify-content: center !important; align-items: center !important};"""
59
+
60
+ with gr.Blocks(css=css) as demo:
61
+ demo.title = "Pear Playground"
62
+ # add markdown
63
+ gr.Markdown("## This is a demo for Pear Playground by AISeed.")
64
+ with gr.Tab(label="Classification"):
65
+ gr.Interface(
66
+ fn=classify,
67
+ inputs=gr.Image(type="pil", label="Upload an image"),
68
+ outputs=gr.Label(num_top_classes=9),
69
+ examples=["examples/1.jpg", "examples/2.jpg"],
70
+ title="비정상 과수 분류기",
71
+ description="경량 모델 ResNet101e 을 활용하여 비정상배 분류"
72
+ )
73
+ with gr.Tab(label="Detection"):
74
+ with gr.Column(elem_classes=["my-column"]):
75
+ with gr.Group(elem_classes=["my-group"]):
76
+
77
+ input_img = gr.Image(sources=["webcam"], type="numpy", streaming=True)
78
+ input_img.stream(
79
+ detect,
80
+ [input_img],
81
+ [input_img],
82
+ time_limit=30,
83
+ stream_every=0.1,
84
+ )
85
+
86
+ if __name__ == "__main__":
87
+
88
+ demo.launch()
classification.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ from timm import create_model
4
+ from torchvision import transforms
5
+ from PIL import Image
6
+
7
+ # 기본 크롭 비율
8
+ DEFAULT_CROP_PCT = 0.875
9
+
10
+ # EfficientNet-B0 모델 설정 및 로드
11
+ weights_path = "./weights/resnest101e.in1k_weight_Pear_classification.pt" # 로컬 가중치 파일 경로
12
+ model_name = "resnest101e"
13
+ model = create_model(model_name, pretrained=False,num_classes=9) # 사전 학습 로드 생략
14
+ #model.classifier = torch.nn.Linear(model.classifier.in_features, 2) # 이진 분류로 수정
15
+ model.load_state_dict(torch.load(weights_path)) # 로컬 가중치 로드
16
+ model.eval() # 평가 모드 설정
17
+
18
+ # 클래스 이름 수동 지정
19
+ # class_labels = ["Abormal Pear", "Normal Pear"] # 데이터셋에 맞게 수정
20
+ class_labels = ["정상", "흑성병","과피얼룩","복숭아 순나방","복숭아 삼식나방","배 깍지벌레","잎말이 나방류", "기타","과피흑변"]
21
+
22
+ # 전처리 함수
23
+ def transforms_imagenet_eval(
24
+ img_path: str,
25
+ img_size: int = 224,
26
+ crop_pct: float = DEFAULT_CROP_PCT,
27
+ mean: tuple = (0.485, 0.456, 0.406),
28
+ std: tuple = (0.229, 0.224, 0.225),
29
+ normalize: bool = True,
30
+ ):
31
+ """
32
+ ImageNet 스타일 이미지 전처리 함수.
33
+
34
+ Args:
35
+ img_path (str): 이미지 경로.
36
+ img_size (int): 크롭 크기.
37
+ crop_pct (float): 크롭 비율.
38
+ mean (tuple): 정규화 평균.
39
+ std (tuple): 정규화 표준편차.
40
+ normalize (bool): 정규화 여부.
41
+
42
+ Returns:
43
+ torch.Tensor: 전처리된 이미지 텐서.
44
+ """
45
+ img = Image.open(img_path).convert("RGB") # 이미지 로드 및 RGB 변환
46
+ scale_size = math.floor(img_size / crop_pct) # 리사이즈 크기 계산
47
+
48
+ # Transform 설정
49
+ tfl = [
50
+ transforms.Resize((scale_size, scale_size), interpolation=transforms.InterpolationMode.BILINEAR),
51
+ transforms.CenterCrop(img_size),
52
+ transforms.ToTensor(),
53
+ ]
54
+ if normalize:
55
+ tfl += [transforms.Normalize(mean=torch.tensor(mean), std=torch.tensor(std))]
56
+
57
+ transform = transforms.Compose(tfl)
58
+ return transform(img)
59
+
60
+ # 추론 함수
61
+ def predict(image_path: str):
62
+ """
63
+ 주어진 이미지 파일 경로를 받아 모델 추론을 수행.
64
+
65
+ Args:
66
+ image_path (str): 입력 이미지 경로.
67
+
68
+ Returns:
69
+ str: 모델 예측 결과.
70
+ """
71
+ # 이미지 전처리
72
+ input_tensor = transforms_imagenet_eval(
73
+ img_path=image_path,
74
+ img_size=224,
75
+ normalize=True
76
+ ).unsqueeze(0) # 배치 차원 추가
77
+
78
+ # 모델 추론
79
+ with torch.no_grad():
80
+ prediction = model(input_tensor)
81
+
82
+ probs = torch.nn.functional.softmax(prediction[0], dim=-1)
83
+ confidences = {class_labels[i]: float(probs[i]) for i in range(9)}
84
+
85
+ # 예측 결과 반환
86
+ return confidences
detection.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from ultralytics import YOLO
3
+
4
+ class PearDetectionModel:
5
+ def __init__(self, config) -> None:
6
+ self.device = (
7
+ torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
8
+ )
9
+ self.model = YOLO(config["model_path"], task="detect")
10
+
11
+ self.names = config["classes"]
12
+
13
+ def detect(self, img):
14
+ results = self.model.predict(img)
15
+ return results[0].boxes.cpu().numpy()
16
+
17
+ def inference(self, img):
18
+ pred = self.detect(img)
19
+
20
+ # remove the box with confidence lower than 0.9 if no "burn_bbox" is detected, else 0.8
21
+ pred = (
22
+ pred[pred.conf > 0.8]
23
+ if all([pred != "burn_bbox" for pred in self.names])
24
+ else pred[pred.conf > 0.5]
25
+ )
26
+ labels = [self.names[int(cat)] for cat in pred.cls]
27
+
28
+ # if any classes rather than "normal_pear_box" is detected, return 0 else return 1
29
+ if any([label == "burn_bbox" for label in labels]):
30
+ return 1, pred.xyxy, pred.conf
31
+ else:
32
+ return 0, pred.xyxy, pred.conf
33
+
34
+ def _preporcess(self, img):
35
+ pass
examples/1.JPG ADDED
examples/2.JPG ADDED

Git LFS Details

  • SHA256: c78d2392c6a44b0a386c64d766789e2a4e426d45d76508da6fdeb9d5a01a5c6c
  • Pointer size: 132 Bytes
  • Size of remote file: 8.13 MB
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ ultralytics>=8.2.34
2
+ opencv-python>=4.1.2
3
+ timm>=1.0.1
weights/EfficientNetb0_weight_Pear_classification.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0f7b383cf8f5d904d85f052264f220020a3bbaf5234f11647c514490f9ad176e
3
+ size 16347397
weights/best.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4554fe5814e40ca157b622e61793fda17be7559ba5733d339e205414c38c1aa1
3
+ size 22522979
weights/resnest101e.in1k_weight_Pear_classification.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b43ed2b7473a0fc89756da00c6b4c58eca0689bba43c0b5e9679d8c12775b00a
3
+ size 185853431