YOLO-ARENA / app.py
SkalskiP's picture
allow to set different confidence thresholds per model
362e68b
from typing import Tuple
import gradio as gr
import numpy as np
import supervision as sv
from inference import get_model
MARKDOWN = """
<h1 style='text-align: center'>YOLO-ARENA 🏟️</h1>
Welcome to YOLO-Arena! This demo showcases the performance of various YOLO models
pre-trained on the COCO dataset.
- **YOLOv8**
<div style="display: flex; align-items: center;">
<a href="https://github.com/ultralytics/ultralytics" style="margin-right: 10px;">
<img src="https://badges.aleen42.com/src/github.svg">
</a>
<a href="https://colab.research.google.com/github/roboflow-ai/notebooks/blob/main/notebooks/train-yolov8-object-detection-on-custom-dataset.ipynb" style="margin-right: 10px;">
<img src="https://colab.research.google.com/assets/colab-badge.svg">
</a>
</div>
- **YOLOv9**
<div style="display: flex; align-items: center;">
<a href="https://github.com/WongKinYiu/yolov9" style="margin-right: 10px;">
<img src="https://badges.aleen42.com/src/github.svg">
</a>
<a href="https://arxiv.org/abs/2402.13616" style="margin-right: 10px;">
<img src="https://img.shields.io/badge/arXiv-2402.13616-b31b1b.svg">
</a>
<a href="https://colab.research.google.com/github/roboflow-ai/notebooks/blob/main/notebooks/train-yolov9-object-detection-on-custom-dataset.ipynb" style="margin-right: 10px;">
<img src="https://colab.research.google.com/assets/colab-badge.svg">
</a>
</div>
- **YOLOv10**
<div style="display: flex; align-items: center;">
<a href="https://github.com/THU-MIG/yolov10" style="margin-right: 10px;">
<img src="https://badges.aleen42.com/src/github.svg">
</a>
<a href="https://arxiv.org/abs/2405.14458" style="margin-right: 10px;">
<img src="https://img.shields.io/badge/arXiv-2405.14458-b31b1b.svg">
</a>
<a href="https://colab.research.google.com/github/roboflow-ai/notebooks/blob/main/notebooks/train-yolov10-object-detection-on-custom-dataset.ipynb" style="margin-right: 10px;">
<img src="https://colab.research.google.com/assets/colab-badge.svg">
</a>
</div>
Powered by Roboflow [Inference](https://github.com/roboflow/inference) and
[Supervision](https://github.com/roboflow/supervision). 🔥
"""
IMAGE_EXAMPLES = [
['https://media.roboflow.com/supervision/image-examples/people-walking.png', 0.3, 0.3, 0.1],
['https://media.roboflow.com/supervision/image-examples/vehicles.png', 0.3, 0.3, 0.1],
['https://media.roboflow.com/supervision/image-examples/basketball-1.png', 0.3, 0.3, 0.1],
]
YOLO_V8_MODEL = get_model(model_id="coco/8")
YOLO_V9_MODEL = get_model(model_id="coco/17")
YOLO_V10_MODEL = get_model(model_id="coco/22")
LABEL_ANNOTATORS = sv.LabelAnnotator(text_color=sv.Color.black())
BOUNDING_BOX_ANNOTATORS = sv.BoundingBoxAnnotator()
def detect_and_annotate(
model,
input_image: np.ndarray,
confidence_threshold: float,
iou_threshold: float,
class_id_mapping: dict = None
) -> np.ndarray:
result = model.infer(
input_image,
confidence=confidence_threshold,
iou_threshold=iou_threshold
)[0]
detections = sv.Detections.from_inference(result)
if class_id_mapping:
detections.class_id = np.array([
class_id_mapping[class_id]
for class_id
in detections.class_id
])
labels = [
f"{class_name} ({confidence:.2f})"
for class_name, confidence
in zip(detections['class_name'], detections.confidence)
]
annotated_image = input_image.copy()
annotated_image = BOUNDING_BOX_ANNOTATORS.annotate(
scene=annotated_image, detections=detections)
annotated_image = LABEL_ANNOTATORS.annotate(
scene=annotated_image, detections=detections, labels=labels)
return annotated_image
def process_image(
input_image: np.ndarray,
yolo_v8_confidence_threshold: float,
yolo_v9_confidence_threshold: float,
yolo_v10_confidence_threshold: float,
iou_threshold: float
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
yolo_v8_annotated_image = detect_and_annotate(
YOLO_V8_MODEL, input_image, yolo_v8_confidence_threshold, iou_threshold)
yolo_v9_annotated_image = detect_and_annotate(
YOLO_V9_MODEL, input_image, yolo_v9_confidence_threshold, iou_threshold)
yolo_10_annotated_image = detect_and_annotate(
YOLO_V10_MODEL, input_image, yolo_v10_confidence_threshold, iou_threshold)
return (
yolo_v8_annotated_image,
yolo_v9_annotated_image,
yolo_10_annotated_image
)
yolo_v8_confidence_threshold_component = gr.Slider(
minimum=0,
maximum=1.0,
value=0.3,
step=0.01,
label="YOLOv8 Confidence Threshold",
info=(
"The confidence threshold for the YOLO model. Lower the threshold to "
"reduce false negatives, enhancing the model's sensitivity to detect "
"sought-after objects. Conversely, increase the threshold to minimize false "
"positives, preventing the model from identifying objects it shouldn't."
))
yolo_v9_confidence_threshold_component = gr.Slider(
minimum=0,
maximum=1.0,
value=0.3,
step=0.01,
label="YOLOv9 Confidence Threshold",
info=(
"The confidence threshold for the YOLO model. Lower the threshold to "
"reduce false negatives, enhancing the model's sensitivity to detect "
"sought-after objects. Conversely, increase the threshold to minimize false "
"positives, preventing the model from identifying objects it shouldn't."
))
yolo_v10_confidence_threshold_component = gr.Slider(
minimum=0,
maximum=1.0,
value=0.3,
step=0.01,
label="YOLOv10 Confidence Threshold",
info=(
"The confidence threshold for the YOLO model. Lower the threshold to "
"reduce false negatives, enhancing the model's sensitivity to detect "
"sought-after objects. Conversely, increase the threshold to minimize false "
"positives, preventing the model from identifying objects it shouldn't."
))
iou_threshold_component = gr.Slider(
minimum=0,
maximum=1.0,
value=0.5,
step=0.01,
label="IoU Threshold",
info=(
"The Intersection over Union (IoU) threshold for non-maximum suppression. "
"Decrease the value to lessen the occurrence of overlapping bounding boxes, "
"making the detection process stricter. On the other hand, increase the value "
"to allow more overlapping bounding boxes, accommodating a broader range of "
"detections."
))
with gr.Blocks() as demo:
gr.Markdown(MARKDOWN)
with gr.Accordion("Configuration", open=False):
with gr.Row():
yolo_v8_confidence_threshold_component.render()
yolo_v9_confidence_threshold_component.render()
yolo_v10_confidence_threshold_component.render()
iou_threshold_component.render()
with gr.Row():
input_image_component = gr.Image(
type='pil',
label='Input'
)
yolo_v8_output_image_component = gr.Image(
type='pil',
label='YOLOv8'
)
with gr.Row():
yolo_v9_output_image_component = gr.Image(
type='pil',
label='YOLOv9'
)
yolo_v10_output_image_component = gr.Image(
type='pil',
label='YOLOv10'
)
submit_button_component = gr.Button(
value='Submit',
scale=1,
variant='primary'
)
gr.Examples(
fn=process_image,
examples=IMAGE_EXAMPLES,
inputs=[
input_image_component,
yolo_v8_confidence_threshold_component,
yolo_v9_confidence_threshold_component,
yolo_v10_confidence_threshold_component,
iou_threshold_component
],
outputs=[
yolo_v8_output_image_component,
yolo_v9_output_image_component,
yolo_v10_output_image_component
]
)
submit_button_component.click(
fn=process_image,
inputs=[
input_image_component,
yolo_v8_confidence_threshold_component,
yolo_v9_confidence_threshold_component,
yolo_v10_confidence_threshold_component,
iou_threshold_component
],
outputs=[
yolo_v8_output_image_component,
yolo_v9_output_image_component,
yolo_v10_output_image_component
]
)
demo.launch(debug=False, show_error=True, max_threads=1)