mayrajeo's picture
Update app.py
f0c8280
raw
history blame
No virus
2.29 kB
import sys, os
import gradio as gr
import plotly.express as px
import numpy as np
import random
from ultralytics import YOLO
from sahi.models.yolov8 import *
from sahi.predict import get_sliced_prediction
from sahi.utils.cv import visualize_object_predictions
import PIL
model_base = "https://huggingface.co/mayrajeo/marine-vessel-detection/resolve/main/"
def inference(
im:gr.inputs.Image=None,
model_path:gr.inputs.Dropdown='YOLOv8n',
conf_thr:gr.inputs.Slider=0.25
):
model = Yolov8DetectionModel(model_path=f'{model_base}/{model_path}/best.pt',
device='cpu',
confidence_threshold=conf_thr,
image_size=640)
res = get_sliced_prediction(im, model, slice_width=320,
slice_height=320, overlap_height_ratio=0.2,
overlap_width_ratio=0.2, verbose=0)
img = PIL.Image.open(im)
visual_result = visualize_object_predictions(image=np.array(img),
object_prediction_list=res.object_prediction_list,
text_size=0.3,
rect_th=1)
fig = px.imshow(visual_result['image'])
fig.update_layout(showlegend=False, hovermode=False)
fig.update_xaxes(visible=False)
fig.update_yaxes(visible=False)
return fig
inputs = [
gr.Image(type='filepath', label='Input'),
gr.components.Dropdown([
'YOLOv8n',
'YOLOv8s',
'YOLOv8m',
'YOLOv8l',
'YOLOv8x'
],
value='YOLOv8n', label='Model'),
gr.components.Slider(minimum=0.0, maximum=1.0, value=0.25, step=0.05, label='Confidence Threshold'),
]
outputs = [
gr.Plot(label='Predictions')
]
example_images = [[f'examples/{f}'] for f in os.listdir('examples')]
gr.Interface(
fn=inference,
inputs=inputs,
outputs=outputs,
allow_flagging='never',
examples=example_images,
cache_examples=False,
examples_per_page=10,
title='Marine vessel detection from Sentinel 2 images',
description="""Detecting marine vessels from Sentinel 2 imagery.
Each example image covers 1500x1500 pixels."""
).launch()