Spaces:
Sleeping
Sleeping
import cv2 | |
import gradio as gr | |
from detection import PearDetectionModel | |
from classification import predict | |
# make streaming interface that reads from camera and displays the output with bounding boxes | |
config = {"model_path": "./weights/best.pt", "classes": ['burn_bbox', 'defected_pear', 'defected_pear_bbox', 'normal_pear', 'normal_pear_bbox']} | |
model = PearDetectionModel(config) | |
def classify(image): | |
""" | |
Gradio에서 PIL 이미지를 입력받아 추론 결과를 반환. | |
Args: | |
image (PIL.Image): 업로드된 이미지. | |
Returns: | |
str: 모델 예측 결과. | |
""" | |
# 임시 파일 저장 후 처리 | |
image_path = "temp_image.jpg" | |
image.save(image_path) | |
return predict(image_path) | |
def detect(img): | |
cls, xyxy, conf = model.inference(img) | |
for box, conf in zip(xyxy, conf): | |
cv2.rectangle( | |
img, | |
(int(box[0]), int(box[1])), | |
(int(box[2]), int(box[3])), | |
(0, 255, 0), | |
2, | |
) | |
cv2.putText( | |
img, | |
f"{conf:.2f}", | |
(int(box[0]), int(box[1])), | |
cv2.FONT_HERSHEY_SIMPLEX, | |
1, | |
(0, 255, 0), | |
2, | |
) | |
cv2.putText( | |
img, | |
"Class: Normal Pear" if cls == 0 else "Class: Abnormal Pear", | |
(0, 50), | |
cv2.FONT_HERSHEY_SIMPLEX, | |
1, | |
(0, 255, 0), | |
2, | |
) | |
return img | |
css = """.my-group {max-width: 500px !important; max-height: 500px !important;} | |
.my-column {display: flex !important; justify-content: center !important; align-items: center !important};""" | |
with gr.Blocks(css=css) as demo: | |
demo.title = "Pear Playground" | |
# add markdown | |
gr.Markdown("## This is a demo for Pear Playground by AISeed.") | |
with gr.Tab(label="Classification"): | |
gr.Interface( | |
fn=classify, | |
inputs=gr.Image(type="pil", label="Upload an image"), | |
outputs=gr.Label(num_top_classes=9), | |
examples=["examples/1.jpg", "examples/2.jpg"], | |
title="비정상 과수 분류기", | |
description="경량 모델 ResNet101e 을 활용하여 비정상배 분류" | |
) | |
with gr.Tab(label="Detection"): | |
with gr.Column(elem_classes=["my-column"]): | |
with gr.Group(elem_classes=["my-group"]): | |
input_img = gr.Image(sources=["webcam"], type="numpy", streaming=True) | |
input_img.stream( | |
detect, | |
[input_img], | |
[input_img], | |
time_limit=30, | |
stream_every=0.1, | |
) | |
if __name__ == "__main__": | |
demo.launch(allowed_paths=["./examples"]) | |