File size: 2,595 Bytes
0fb9587
 
 
 
 
 
 
dabf2f5
 
0fb9587
 
 
 
dabf2f5
 
0fb9587
 
 
73aad24
0fb9587
 
dabf2f5
 
085cce1
0fb9587
 
7f5aa48
0fb9587
 
 
 
 
 
 
 
7f5aa48
31848b9
0fb9587
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f0c8280
 
 
0fb9587
 
1eee909
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import sys, os

import gradio as gr
import plotly.express as px
import numpy as np
import random
from ultralytics import YOLO
#from sahi.models.yolov8 import *
from src.sahi_onnx import *
from sahi.predict import get_sliced_prediction
from sahi.utils.cv import visualize_object_predictions
import PIL

#model_base = "https://huggingface.co/mayrajeo/marine-vessel-detection/resolve/main/"
model_base = 'onnx_models'

def inference(
    im:gr.inputs.Image=None,
    model_path:gr.inputs.Dropdown='YOLOv8n',
    conf_thr:gr.inputs.Slider=0.25
):
    #model = Yolov8DetectionModel(model_path=f'{model_base}/{model_path}/{model_path}.pt',
    model = Yolov8onnxDetectionModel(model_path=f'{model_base}/{model_path}/best.onnx',
                                     config_path=f'{model_base}/{model_path}/args.yaml',
                                 device='cpu',
                                 confidence_threshold=conf_thr,
                                 category_mapping={'0': 'Boat'},
                                 image_size=640)

    res = get_sliced_prediction(im, model, slice_width=320, 
                                slice_height=320, overlap_height_ratio=0.2,
                                overlap_width_ratio=0.2, verbose=0)
    img = PIL.Image.open(im)
    visual_result = visualize_object_predictions(image=np.array(img),
                                                 object_prediction_list=res.object_prediction_list,
                                                 text_size=0.4,
                                                 rect_th=1)
    fig = px.imshow(visual_result['image'])
    fig.update_layout(showlegend=False, hovermode=False)
    fig.update_xaxes(visible=False)
    fig.update_yaxes(visible=False)
    return fig

inputs = [
    gr.Image(type='filepath', label='Input'),
    gr.components.Dropdown([
        'YOLOv8n', 
        'YOLOv8s', 
        'YOLOv8m', 
        'YOLOv8l', 
        'YOLOv8x'
        ],
        value='YOLOv8n', label='Model'),
    gr.components.Slider(minimum=0.0, maximum=1.0, value=0.25, step=0.05, label='Confidence Threshold'),
]

outputs = [
    gr.Plot(label='Predictions')
]

example_images = [[f'examples/{f}'] for f in os.listdir('examples')] 



gr.Interface(
    fn=inference,
    inputs=inputs,
    outputs=outputs,
    allow_flagging='never',
    examples=example_images,
    cache_examples=False,
    examples_per_page=10,
    title='Marine vessel detection from Sentinel 2 images',
    description="""Detecting marine vessels from Sentinel 2 imagery. 
    Each example image covers 1500x1500 pixels."""
    ).launch()