|
|
|
|
|
import pathlib |
|
|
|
import gradio as gr |
|
import numpy as np |
|
import PIL.Image |
|
import torch |
|
from sahi.prediction import ObjectPrediction |
|
from sahi.utils.cv import visualize_object_predictions |
|
from transformers import AutoImageProcessor, DetaForObjectDetection |
|
|
|
DESCRIPTION = '# DETA (Detection Transformers with Assignment)' |
|
|
|
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') |
|
|
|
MODEL_ID = 'jozhang97/deta-swin-large' |
|
image_processor = AutoImageProcessor.from_pretrained(MODEL_ID) |
|
model = DetaForObjectDetection.from_pretrained(MODEL_ID) |
|
model.to(device) |
|
|
|
|
|
@torch.inference_mode() |
|
def run(image_path: str, threshold: float) -> np.ndarray: |
|
image = PIL.Image.open(image_path) |
|
inputs = image_processor(images=image, return_tensors='pt').to(device) |
|
outputs = model(**inputs) |
|
target_sizes = torch.tensor([image.size[::-1]]) |
|
results = image_processor.post_process_object_detection( |
|
outputs, threshold=threshold, target_sizes=target_sizes)[0] |
|
|
|
boxes = results['boxes'].cpu().numpy() |
|
scores = results['scores'].cpu().numpy() |
|
cat_ids = results['labels'].cpu().numpy().tolist() |
|
|
|
preds = [] |
|
for box, score, cat_id in zip(boxes, scores, cat_ids): |
|
box = np.round(box).astype(int) |
|
cat_label = model.config.id2label[cat_id] |
|
pred = ObjectPrediction(bbox=box, |
|
category_id=cat_id, |
|
category_name=cat_label, |
|
score=score) |
|
preds.append(pred) |
|
|
|
res = visualize_object_predictions(np.asarray(image), preds)['image'] |
|
return res |
|
|
|
|
|
with gr.Blocks(css='style.css') as demo: |
|
gr.Markdown(DESCRIPTION) |
|
with gr.Row(): |
|
with gr.Column(): |
|
image = gr.Image(label='Input image', type='filepath') |
|
threshold = gr.Slider(label='Score threshold', |
|
minimum=0, |
|
maximum=1, |
|
value=0.1, |
|
step=0.01) |
|
run_button = gr.Button('Run') |
|
result = gr.Image(label='Result', type='numpy') |
|
|
|
with gr.Row(): |
|
paths = sorted(pathlib.Path('images').glob('*.jpg')) |
|
gr.Examples(examples=[[path.as_posix(), 0.1] for path in paths], |
|
inputs=[image, threshold], |
|
outputs=result, |
|
fn=run, |
|
cache_examples=True) |
|
|
|
run_button.click(fn=run, inputs=[image, threshold], outputs=result) |
|
|
|
demo.queue().launch() |
|
|