mayrajeo's picture
add app
0fb9587
raw
history blame
2.28 kB
import sys, os
import gradio as gr
import plotly.express as px
import numpy as np
import random
from ultralytics import YOLO
from sahi.models.yolov8 import *
from sahi.predict import get_sliced_prediction
from sahi.utils.cv import visualize_object_predictions
import PIL
model_base = "https://huggingface.co/mayrajeo/marine-vessel-detection/resolve/main/"
def inference(
im:gr.inputs.Image=None,
model_path:gr.inputs.Dropdown=None,
conf_thr:gr.inputs.Slider=0.25
):
model = Yolov8DetectionModel(model_path=f'{model_base}/{model_path}/best.pt',
device='cpu',
confidence_threshold=conf_thr,
image_size=640)
res = get_sliced_prediction(im, model, slice_width=320,
slice_height=320, overlap_height_ratio=0.2,
overlap_width_ratio=0.2, verbose=0)
img = PIL.Image.open(im)
visual_result = visualize_object_predictions(image=np.array(img),
object_prediction_list=res.object_prediction_list,
text_size=0.3,
rect_th=1)
fig = px.imshow(visual_result['image'])
fig.update_layout(showlegend=False, hovermode=False)
fig.update_xaxes(visible=False)
fig.update_yaxes(visible=False)
return fig
inputs = [
gr.Image(type='filepath', label='Input'),
gr.components.Dropdown([
'YOLOv8n',
'YOLOv8s',
'YOLOv8m',
'YOLOv8l',
'YOLOv8x'
],
value='YOLOv8n', label='Model'),
gr.components.Slider(minimum=0.0, maximum=1.0, value=0.25, step=0.05, label='Confidence Threshold'),
]
outputs = [
gr.Plot(label='Predictions')
]
example_images = [[f'examples/{f}'] for f in os.listdir('examples')]
gr.Interface(
fn=inference,
inputs=inputs,
outputs=outputs,
allow_flagging='never',
examples=example_images,
examples_per_page=32,
title='Boat detection from Sentinel 2 images',
description="""Detecting marine vessels from Sentinel 2 imagery.
Each example image covers 1500x1500 pixels."""
).launch(share=True, server_name='0.0.0.0')