Spaces:
Running
Running
#!/usr/bin/env python | |
from __future__ import annotations | |
import pathlib | |
import gradio as gr | |
import numpy as np | |
import PIL.Image | |
import torch | |
from sahi.prediction import ObjectPrediction | |
from sahi.utils.cv import visualize_object_predictions | |
from transformers import AutoImageProcessor, DetaForObjectDetection | |
from ultralytics import YOLO | |
DESCRIPTION = '# Compare DETA and YOLOv8' | |
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') | |
MODEL_ID = 'jozhang97/deta-swin-large' | |
image_processor = AutoImageProcessor.from_pretrained(MODEL_ID) | |
model_deta = DetaForObjectDetection.from_pretrained(MODEL_ID) | |
model_deta.to(device) | |
model_yolo = YOLO('yolov8x.pt') | |
def run_deta(image_path: str, threshold: float) -> np.ndarray: | |
image = PIL.Image.open(image_path) | |
inputs = image_processor(images=image, return_tensors='pt').to(device) | |
outputs = model_deta(**inputs) | |
target_sizes = torch.tensor([image.size[::-1]]) | |
results = image_processor.post_process_object_detection( | |
outputs, threshold=threshold, target_sizes=target_sizes)[0] | |
boxes = results['boxes'].cpu().numpy() | |
scores = results['scores'].cpu().numpy() | |
cat_ids = results['labels'].cpu().numpy().tolist() | |
preds = [] | |
for box, score, cat_id in zip(boxes, scores, cat_ids): | |
box = np.round(box).astype(int) | |
cat_label = model_deta.config.id2label[cat_id] | |
pred = ObjectPrediction(bbox=box, | |
category_id=cat_id, | |
category_name=cat_label, | |
score=score) | |
preds.append(pred) | |
res = visualize_object_predictions(np.asarray(image), preds)['image'] | |
return res | |
def run_yolov8(image_path: str, threshold: float) -> np.ndarray: | |
image = PIL.Image.open(image_path) | |
results = model_yolo(image, imgsz=640, conf=threshold) | |
boxes = results[0].boxes.cpu().numpy().data | |
preds = [] | |
for box in boxes: | |
score = box[4] | |
cat_id = int(box[5]) | |
box = np.round(box[:4]).astype(int) | |
cat_label = model_yolo.model.names[cat_id] | |
pred = ObjectPrediction(bbox=box, | |
category_id=cat_id, | |
category_name=cat_label, | |
score=score) | |
preds.append(pred) | |
res = visualize_object_predictions(np.asarray(image), preds)['image'] | |
return res | |
def run(image_path: str, threshold: float) -> tuple[np.ndarray, np.ndarray]: | |
return run_deta(image_path, threshold), run_yolov8(image_path, threshold) | |
with gr.Blocks(css='style.css') as demo: | |
gr.Markdown(DESCRIPTION) | |
with gr.Row(): | |
with gr.Column(): | |
image = gr.Image(label='Input image', type='filepath') | |
threshold = gr.Slider(label='Score threshold', | |
minimum=0, | |
maximum=1, | |
value=0.5, | |
step=0.01) | |
run_button = gr.Button('Run') | |
with gr.Column(): | |
result_deta = gr.Image(label='Result (DETA)', type='numpy') | |
result_yolo = gr.Image(label='Result (YOLOv8)', type='numpy') | |
with gr.Row(): | |
paths = sorted(pathlib.Path('images').glob('*.jpg')) | |
gr.Examples(examples=[[path.as_posix(), 0.5] for path in paths], | |
inputs=[ | |
image, | |
threshold, | |
], | |
outputs=[ | |
result_deta, | |
result_yolo, | |
], | |
fn=run, | |
cache_examples=True) | |
run_button.click(fn=run, | |
inputs=[ | |
image, | |
threshold, | |
], | |
outputs=[ | |
result_deta, | |
result_yolo, | |
]) | |
demo.queue().launch() | |