Spaces:
Runtime error
Runtime error
import os | |
os.system("mim install 'mmengine>=0.7.0'") | |
os.system("mim install mmcv") | |
os.system("mim install 'mmdet>=3.0.0'") | |
os.system("pip install -e .") | |
import numpy as np | |
import torch | |
from mmengine.config import Config | |
from mmengine.dataset import Compose | |
from mmengine.runner import Runner | |
from mmengine.runner.amp import autocast | |
from mmyolo.registry import RUNNERS | |
from torchvision.ops import nms | |
import supervision as sv | |
from PIL import Image | |
import cv2 | |
import spaces | |
import gradio as gr | |
TITLE = """ | |
# YOLO-World-Seg | |
This is a demo of zero-shot object detection and instance segmentation using only | |
[YOLO-World](https://github.com/AILab-CVC/YOLO-World) done via newest model YOLO-World-Seg. | |
Annototions Powered by [Supervision](https://github.com/roboflow/supervision). | |
""" | |
EXAMPLES = [ | |
["https://media.roboflow.com/efficient-sam/corgi.jpg", "dog",0.5,0.5,0.5,100], | |
["https://media.roboflow.com/efficient-sam/horses.jpg", "horses",0.5,0.5,0.5,100], | |
["https://media.roboflow.com/efficient-sam/bears.jpg", "bear",0.5,0.5,0.5,100], | |
] | |
box_annotator = sv.BoxAnnotator() | |
label_annotator = sv.LabelAnnotator(text_position=sv.Position.CENTER) | |
mask_annotator = sv.MaskAnnotator(color_lookup=sv.ColorLookup.INDEX) | |
def load_runner(): | |
cfg = Config.fromfile( | |
"./configs/segmentation/yolo_world_seg_l_dual_vlpan_2e-4_80e_8gpus_seghead_finetune_lvis.py" | |
) | |
cfg.work_dir = "." | |
cfg.load_from = "yolo_world_seg_l_dual_vlpan_2e-4_80e_8gpus_seghead_finetune_lvis-5a642d30.pth" | |
runner = Runner.from_cfg(cfg) | |
runner.call_hook("before_run") | |
runner.load_or_resume() | |
pipeline = cfg.test_dataloader.dataset.pipeline | |
runner.pipeline = Compose(pipeline) | |
runner.model.eval() | |
return runner | |
def run_image( | |
input_image, | |
class_names="person,car,bus,truck", | |
score_thr=0.05, | |
iou_thr=0.5, | |
nms_thr=0.5, | |
max_num_boxes=100, | |
): | |
runner = load_runner() | |
image_path='./work_dirs/input.png' | |
os.makedirs('./work_dirs', exist_ok=True) | |
input_image.save(image_path) | |
texts = [[t.strip()] for t in class_names.split(",")] + [[" "]] | |
data_info = runner.pipeline(dict(img_id=0, img_path=image_path, | |
texts=texts)) | |
data_batch = dict( | |
inputs=data_info["inputs"].unsqueeze(0), | |
data_samples=[data_info["data_samples"]], | |
) | |
with autocast(enabled=False), torch.no_grad(): | |
output = runner.model.test_step(data_batch)[0] | |
runner.model.class_names = texts | |
pred_instances = output.pred_instances | |
keep_idxs = nms(pred_instances.bboxes, pred_instances.scores, iou_threshold=iou_thr) | |
pred_instances = pred_instances[keep_idxs] | |
pred_instances = pred_instances[pred_instances.scores.float() > score_thr] | |
if len(pred_instances.scores) > max_num_boxes: | |
indices = pred_instances.scores.float().topk(max_num_boxes)[1] | |
pred_instances = pred_instances[indices] | |
output.pred_instances = pred_instances | |
result = pred_instances.cpu().numpy() | |
detections = sv.Detections( | |
xyxy=result['bboxes'], | |
class_id=result['labels'], | |
confidence=result['scores'], | |
mask = result['masks'] | |
) | |
detections = detections.with_nms(threshold=nms_thr) | |
labels = [ | |
f"{class_id} {confidence:.3f}" | |
for class_id, confidence | |
in zip(detections.class_id, detections.confidence) | |
] | |
svimage = np.array(input_image) | |
svimage = box_annotator.annotate(svimage, detections) | |
svimage = label_annotator.annotate(svimage, detections, labels) | |
svimage = mask_annotator.annotate(svimage,detections) | |
return svimage | |
confidence_threshold_component = gr.Slider( | |
minimum=0, | |
maximum=1.0, | |
value=0.3, | |
step=0.01, | |
label="Confidence Threshold", | |
info=( | |
"The confidence threshold for the YOLO-World model. Lower the threshold to " | |
"reduce false negatives, enhancing the model's sensitivity to detect " | |
"sought-after objects. Conversely, increase the threshold to minimize false " | |
"positives, preventing the model from identifying objects it shouldn't." | |
)) | |
iou_threshold_component = gr.Slider( | |
minimum=0, | |
maximum=1.0, | |
value=0.5, | |
step=0.01, | |
label="IoU Threshold", | |
info=( | |
"The Intersection over Union (IoU) threshold for non-maximum suppression. " | |
"Decrease the value to lessen the occurrence of overlapping bounding boxes, " | |
"making the detection process stricter. On the other hand, increase the value " | |
"to allow more overlapping bounding boxes, accommodating a broader range of " | |
"detections." | |
)) | |
nms_threshold_component = gr.Slider( | |
minimum=0, | |
maximum=1.0, | |
value=0.5, | |
step=0.01, | |
label="NMS Threshold", | |
info=( | |
"The Non-Maximum Suppression (NMS) Threshold is a parameter that determines the Intersection over Union (IoU) threshold for suppressing bounding boxes. " | |
"A lower value will reduce the likelihood of overlapping bounding boxes, resulting in a more stringent detection process. Conversely, a higher value " | |
"will permit more overlapping bounding boxes, thereby allowing for a wider variety of detections." | |
)) | |
with gr.Blocks() as demo: | |
gr.Markdown(TITLE) | |
with gr.Accordion("Configuration", open=False): | |
confidence_threshold_component.render() | |
iou_threshold_component.render() | |
nms_threshold_component.render() | |
with gr.Tab(label="Image"): | |
with gr.Row(): | |
input_image_component = gr.Image( | |
type='pil', | |
label='Input Image' | |
) | |
output_image_component = gr.Image( | |
type='numpy', | |
label='Output Image' | |
) | |
with gr.Row(): | |
image_categories_text_component = gr.Textbox( | |
label='Categories', | |
placeholder='comma separated list of categories', | |
scale=7 | |
) | |
image_submit_button_component = gr.Button( | |
value='Submit', | |
scale=1, | |
variant='primary' | |
) | |
gr.Examples( | |
fn=run_image, | |
examples=EXAMPLES, | |
inputs=[ | |
input_image_component, | |
image_categories_text_component, | |
confidence_threshold_component, | |
iou_threshold_component, | |
nms_threshold_component | |
], | |
outputs=output_image_component | |
) | |
image_submit_button_component.click( | |
fn=run_image, | |
inputs=[ | |
input_image_component, | |
image_categories_text_component, | |
confidence_threshold_component, | |
iou_threshold_component, | |
nms_threshold_component | |
], | |
outputs=output_image_component | |
) | |
demo.launch(debug=False, show_error=True) |