|
import gradio as gr |
|
import sahi.utils |
|
from sahi import AutoDetectionModel |
|
import sahi.predict |
|
import sahi.slicing |
|
from PIL import Image |
|
import numpy |
|
from huggingface_hub import hf_hub_download |
|
import torch |
|
|
|
|
|
IMAGE_SIZE = 640 |
|
|
|
model_path=hf_hub_download("kadirnar/deprem_model_v1", filename="last.pt",revision="main") |
|
|
|
|
|
current_device='cuda' if torch.cuda.is_available() else 'cpu' |
|
|
|
|
|
model = AutoDetectionModel.from_pretrained( |
|
model_type="yolov5", model_path=model_path, device=current_device, confidence_threshold=0.5, image_size=IMAGE_SIZE |
|
) |
|
|
|
|
|
def sahi_yolo_inference( |
|
model_type, |
|
image, |
|
slice_height=512, |
|
slice_width=512, |
|
overlap_height_ratio=0.2, |
|
overlap_width_ratio=0.2, |
|
postprocess_type="GREEDYNMM", |
|
postprocess_match_metric="IOS", |
|
postprocess_match_threshold=0.5, |
|
postprocess_class_agnostic=False, |
|
): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if "SAHI" in model_type: |
|
prediction_result_2 = sahi.predict.get_sliced_prediction( |
|
image=image, |
|
detection_model=model, |
|
slice_height=int(slice_height), |
|
slice_width=int(slice_width), |
|
overlap_height_ratio=overlap_height_ratio, |
|
overlap_width_ratio=overlap_width_ratio, |
|
postprocess_type=postprocess_type, |
|
postprocess_match_metric=postprocess_match_metric, |
|
postprocess_match_threshold=postprocess_match_threshold, |
|
postprocess_class_agnostic=postprocess_class_agnostic, |
|
) |
|
visual_result_2 = sahi.utils.cv.visualize_object_predictions( |
|
image=numpy.array(image), |
|
object_prediction_list=prediction_result_2.object_prediction_list, |
|
) |
|
output = Image.fromarray(visual_result_2["image"]) |
|
|
|
else: |
|
|
|
prediction_result_1 = sahi.predict.get_prediction( |
|
image=image, detection_model=model |
|
) |
|
print(image) |
|
visual_result_1 = sahi.utils.cv.visualize_object_predictions( |
|
image=numpy.array(image), |
|
object_prediction_list=prediction_result_1.object_prediction_list, |
|
) |
|
output = Image.fromarray(visual_result_1["image"]) |
|
|
|
|
|
|
|
|
|
return output |
|
|
|
|
|
inputs = [ |
|
gr.Dropdown(choices=["YOLOv5","YOLOv5 + SAHI"],label="Choose Model Type"), |
|
gr.inputs.Image(type="pil", label="Original Image"), |
|
gr.inputs.Number(default=512, label="slice_height"), |
|
gr.inputs.Number(default=512, label="slice_width"), |
|
gr.inputs.Number(default=0.2, label="overlap_height_ratio"), |
|
gr.inputs.Number(default=0.2, label="overlap_width_ratio"), |
|
gr.inputs.Dropdown( |
|
["NMS", "GREEDYNMM"], |
|
type="value", |
|
default="GREEDYNMM", |
|
label="postprocess_type", |
|
), |
|
gr.inputs.Dropdown( |
|
["IOU", "IOS"], type="value", default="IOS", label="postprocess_type" |
|
), |
|
gr.inputs.Number(default=0.5, label="postprocess_match_threshold"), |
|
gr.inputs.Checkbox(default=True, label="postprocess_class_agnostic"), |
|
] |
|
|
|
outputs = [ |
|
gr.outputs.Image(type="pil", label="Output") |
|
] |
|
|
|
title = "Small Object Detection with SAHI + YOLOv5" |
|
description = "SAHI + YOLOv5 demo for small object detection. Upload an image or click an example image to use." |
|
article = "<p style='text-align: center'>SAHI is a lightweight vision library for performing large scale object detection/ instance segmentation.. <a href='https://github.com/obss/sahi'>SAHI Github</a> | <a href='https://medium.com/codable/sahi-a-vision-library-for-performing-sliced-inference-on-large-images-small-objects-c8b086af3b80'>SAHI Blog</a> | <a href='https://github.com/fcakyon/yolov5-pip'>YOLOv5 Github</a> </p>" |
|
examples = [ |
|
["apple_tree.jpg", 256, 256, 0.2, 0.2, "GREEDYNMM", "IOS", 0.5, True], |
|
["highway.jpg", 256, 256, 0.2, 0.2, "GREEDYNMM", "IOS", 0.5, True], |
|
["highway2.jpg", 512, 512, 0.2, 0.2, "GREEDYNMM", "IOS", 0.5, True], |
|
["highway3.jpg", 512, 512, 0.2, 0.2, "GREEDYNMM", "IOS", 0.5, True], |
|
] |
|
|
|
gr.Interface( |
|
sahi_yolo_inference, |
|
inputs, |
|
outputs, |
|
title=title, |
|
description=description, |
|
article=article, |
|
examples=examples, |
|
theme="huggingface", |
|
).launch(debug=True, enable_queue=True) |
|
|