mayrajeo's picture
Update app.py
dabf2f5
raw
history blame
2.53 kB
import sys, os
import gradio as gr
import plotly.express as px
import numpy as np
import random
from ultralytics import YOLO
#from sahi.models.yolov8 import *
from src.sahi_onnx import *
from sahi.predict import get_sliced_prediction
from sahi.utils.cv import visualize_object_predictions
import PIL
#model_base = "https://huggingface.co/mayrajeo/marine-vessel-detection/resolve/main/"
model_base = 'onnx_models'
def inference(
im:gr.inputs.Image=None,
model_path:gr.inputs.Dropdown='YOLOv8n',
conf_thr:gr.inputs.Slider=0.25
):
#model = Yolov8DetectionModel(model_path=f'{model_base}/{model_path}/{model_path}.pt',
model = Yolov8onnxDetectionModel(model_path=f'{model_base}/{model_path}/best.onnx',
config_path=f'{model_base}/{model_path}/args.yaml''
device='cpu',
confidence_threshold=conf_thr,
image_size=640)
res = get_sliced_prediction(im, model, slice_width=320,
slice_height=320, overlap_height_ratio=0.2,
overlap_width_ratio=0.2, verbose=0)
img = PIL.Image.open(im)
visual_result = visualize_object_predictions(image=np.array(img),
object_prediction_list=res.object_prediction_list,
text_size=0.3,
rect_th=1)
fig = px.imshow(visual_result['image'])
fig.update_layout(showlegend=False, hovermode=False)
fig.update_xaxes(visible=False)
fig.update_yaxes(visible=False)
return fig
inputs = [
gr.Image(type='filepath', label='Input'),
gr.components.Dropdown([
'YOLOv8n',
'YOLOv8s',
'YOLOv8m',
'YOLOv8l',
'YOLOv8x'
],
value='YOLOv8n', label='Model'),
gr.components.Slider(minimum=0.0, maximum=1.0, value=0.25, step=0.05, label='Confidence Threshold'),
]
outputs = [
gr.Plot(label='Predictions')
]
example_images = [[f'examples/{f}'] for f in os.listdir('examples')]
gr.Interface(
fn=inference,
inputs=inputs,
outputs=outputs,
allow_flagging='never',
examples=example_images,
cache_examples=False,
examples_per_page=10,
title='Marine vessel detection from Sentinel 2 images',
description="""Detecting marine vessels from Sentinel 2 imagery.
Each example image covers 1500x1500 pixels."""
).launch()