Kaori1707 commited on
Commit
91a1441
·
verified ·
1 Parent(s): e573c1d

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +88 -0
app.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import gradio as gr
3
+ from detection import PearDetectionModel
4
+ from classification import predict
5
+
6
+ # make streaming interface that reads from camera and displays the output with bounding boxes
7
+ config = {"model_path": "./weights/best.pt", "classes": ['burn_bbox', 'defected_pear', 'defected_pear_bbox', 'normal_pear', 'normal_pear_bbox']}
8
+ model = PearDetectionModel(config)
9
+
10
+
11
+ def classify(image):
12
+ """
13
+ Gradio에서 PIL 이미지를 입력받아 추론 결과를 반환.
14
+
15
+ Args:
16
+ image (PIL.Image): 업로드된 이미지.
17
+
18
+ Returns:
19
+ str: 모델 예측 결과.
20
+ """
21
+ # 임시 파일 저장 후 처리
22
+ image_path = "temp_image.jpg"
23
+ image.save(image_path)
24
+ return predict(image_path)
25
+
26
+ def detect(img):
27
+ cls, xyxy, conf = model.inference(img)
28
+ for box, conf in zip(xyxy, conf):
29
+ cv2.rectangle(
30
+ img,
31
+ (int(box[0]), int(box[1])),
32
+ (int(box[2]), int(box[3])),
33
+ (0, 255, 0),
34
+ 2,
35
+ )
36
+ cv2.putText(
37
+ img,
38
+ f"{conf:.2f}",
39
+ (int(box[0]), int(box[1])),
40
+ cv2.FONT_HERSHEY_SIMPLEX,
41
+ 1,
42
+ (0, 255, 0),
43
+ 2,
44
+ )
45
+ cv2.putText(
46
+ img,
47
+ "Class: Normal Pear" if cls == 0 else "Class: Abnormal Pear",
48
+ (0, 50),
49
+ cv2.FONT_HERSHEY_SIMPLEX,
50
+ 1,
51
+ (0, 255, 0),
52
+ 2,
53
+ )
54
+ return img
55
+
56
+
57
+ css = """.my-group {max-width: 500px !important; max-height: 500px !important;}
58
+ .my-column {display: flex !important; justify-content: center !important; align-items: center !important};"""
59
+
60
+ with gr.Blocks(css=css) as demo:
61
+ demo.title = "Pear Playground"
62
+ # add markdown
63
+ gr.Markdown("## This is a demo for Pear Playground by AISeed.")
64
+ with gr.Tab(label="Classification"):
65
+ gr.Interface(
66
+ fn=classify,
67
+ inputs=gr.Image(type="pil", label="Upload an image"),
68
+ outputs=gr.Label(num_top_classes=9),
69
+ examples=["examples/1.jpg", "examples/2.jpg"],
70
+ title="비정상 과수 분류기",
71
+ description="경량 모델 ResNet101e 을 활용하여 비정상배 분류"
72
+ )
73
+ with gr.Tab(label="Detection"):
74
+ with gr.Column(elem_classes=["my-column"]):
75
+ with gr.Group(elem_classes=["my-group"]):
76
+
77
+ input_img = gr.Image(sources=["webcam"], type="numpy", streaming=True)
78
+ input_img.stream(
79
+ detect,
80
+ [input_img],
81
+ [input_img],
82
+ time_limit=30,
83
+ stream_every=0.1,
84
+ )
85
+
86
+ if __name__ == "__main__":
87
+
88
+ demo.launch()