|
import sys, os |
|
|
|
import gradio as gr |
|
import plotly.express as px |
|
import numpy as np |
|
import random |
|
from ultralytics import YOLO |
|
from sahi.models.yolov8 import * |
|
from sahi.predict import get_sliced_prediction |
|
from sahi.utils.cv import visualize_object_predictions |
|
import PIL |
|
|
|
model_base = "https://huggingface.co/mayrajeo/marine-vessel-detection/resolve/main/" |
|
|
|
def inference( |
|
im:gr.inputs.Image=None, |
|
model_path:gr.inputs.Dropdown=None, |
|
conf_thr:gr.inputs.Slider=0.25 |
|
): |
|
model = Yolov8DetectionModel(model_path=f'{model_base}/{model_path}/best.pt', |
|
device='cpu', |
|
confidence_threshold=conf_thr, |
|
image_size=640) |
|
|
|
res = get_sliced_prediction(im, model, slice_width=320, |
|
slice_height=320, overlap_height_ratio=0.2, |
|
overlap_width_ratio=0.2, verbose=0) |
|
img = PIL.Image.open(im) |
|
visual_result = visualize_object_predictions(image=np.array(img), |
|
object_prediction_list=res.object_prediction_list, |
|
text_size=0.3, |
|
rect_th=1) |
|
fig = px.imshow(visual_result['image']) |
|
fig.update_layout(showlegend=False, hovermode=False) |
|
fig.update_xaxes(visible=False) |
|
fig.update_yaxes(visible=False) |
|
return fig |
|
|
|
inputs = [ |
|
gr.Image(type='filepath', label='Input'), |
|
gr.components.Dropdown([ |
|
'YOLOv8n', |
|
'YOLOv8s', |
|
'YOLOv8m', |
|
'YOLOv8l', |
|
'YOLOv8x' |
|
], |
|
value='YOLOv8n', label='Model'), |
|
gr.components.Slider(minimum=0.0, maximum=1.0, value=0.25, step=0.05, label='Confidence Threshold'), |
|
] |
|
|
|
outputs = [ |
|
gr.Plot(label='Predictions') |
|
] |
|
|
|
example_images = [[f'examples/{f}'] for f in os.listdir('examples')] |
|
|
|
|
|
|
|
gr.Interface( |
|
fn=inference, |
|
inputs=inputs, |
|
outputs=outputs, |
|
allow_flagging='never', |
|
examples=example_images, |
|
examples_per_page=32, |
|
title='Boat detection from Sentinel 2 images', |
|
description="""Detecting marine vessels from Sentinel 2 imagery. |
|
Each example image covers 1500x1500 pixels.""" |
|
).launch(share=True, server_name='0.0.0.0') |