YOLO-World / app.py
SkalskiP's picture
EfficientSAM support added
39840e5
raw history blame
No virus
3.09 kB
from typing import List
import torch
import gradio as gr
import numpy as np
import supervision as sv
from inference.models import YOLOWorld
from utils.efficient_sam import load, inference_with_box
MARKDOWN = """
# YOLO-World 🔥 [with Efficient-SAM]
This is a demo of zero-shot instance segmentation using [YOLO-World](https://github.com/AILab-CVC/YOLO-World) and [Efficient-SAM](https://github.com/yformer/EfficientSAM).
Powered by Roboflow [Inference](https://github.com/roboflow/inference) and [Supervision](https://github.com/roboflow/supervision).
"""
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
EFFICIENT_SAM_MODEL = load(device=DEVICE)
YOLO_WORLD_MODEL = YOLOWorld(model_id="yolo_world/l")
BOUNDING_BOX_ANNOTATOR = sv.BoundingBoxAnnotator()
MASK_ANNOTATOR = sv.MaskAnnotator()
LABEL_ANNOTATOR = sv.LabelAnnotator()
def process_categories(categories: str) -> List[str]:
return [category.strip() for category in categories.split(',')]
def process_image(
input_image: np.ndarray,
categories: str,
confidence_threshold: float = 0.003,
iou_threshold: float = 0.5,
with_segmentation: bool = True,
with_confidence: bool = True
) -> np.ndarray:
categories = process_categories(categories)
YOLO_WORLD_MODEL.set_classes(categories)
results = YOLO_WORLD_MODEL.infer(input_image, confidence=confidence_threshold)
detections = sv.Detections.from_inference(results).with_nms(iou_threshold)
if with_segmentation:
masks = []
for [x_min, y_min, x_max, y_max] in detections.xyxy:
box = np.array([[x_min, y_min], [x_max, y_max]])
mask = inference_with_box(input_image, box, EFFICIENT_SAM_MODEL, DEVICE)
masks.append(mask)
detections.mask = np.array(masks)
labels = [
f"{categories[class_id]}: {confidence:.2f}" if with_confidence else f"{categories[class_id]}"
for class_id, confidence in
zip(detections.class_id, detections.confidence)
]
output_image = input_image.copy()
output_image = MASK_ANNOTATOR.annotate(output_image, detections)
output_image = BOUNDING_BOX_ANNOTATOR.annotate(output_image, detections)
output_image = LABEL_ANNOTATOR.annotate(output_image, detections, labels=labels)
return output_image
with gr.Blocks() as demo:
gr.Markdown(MARKDOWN)
with gr.Row():
input_image_component = gr.Image(
type='numpy',
label='Input Image'
)
output_image_component = gr.Image(
type='numpy',
label='Output Image'
)
with gr.Row():
categories_text_component = gr.Textbox(
label='Categories',
placeholder='comma separated list of categories',
scale=5
)
submit_button_component = gr.Button('Submit', scale=1)
submit_button_component.click(
fn=process_image,
inputs=[input_image_component, categories_text_component],
outputs=output_image_component
)
demo.launch(debug=False, show_error=True)