File size: 3,986 Bytes
74d7b02 d4461b5 74d7b02 041db94 74d7b02 041db94 74d7b02 8f78001 b48c806 74d7b02 041db94 74d7b02 041db94 74d7b02 041db94 74d7b02 be786c9 041db94 74d7b02 be786c9 041db94 74d7b02 041db94 d4461b5 041db94 be786c9 041db94 be786c9 041db94 74d7b02 041db94 74d7b02 041db94 74d7b02 041db94 74d7b02 041db94 d4461b5 041db94 74d7b02 cd645f2 74d7b02 be786c9 041db94 74d7b02 041db94 74d7b02 041db94 74d7b02 be786c9 041db94 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 |
import gradio as gr
import sahi
import torch
from ultralyticsplus import YOLO, render_model_output
# Download sample images
sahi.utils.file.download_from_url(
"https://raw.githubusercontent.com/kadirnar/dethub/main/data/images/highway.jpg",
"highway.jpg",
)
sahi.utils.file.download_from_url(
"https://raw.githubusercontent.com/obss/sahi/main/tests/data/small-vehicles1.jpeg",
"small-vehicles1.jpeg",
)
sahi.utils.file.download_from_url(
"https://raw.githubusercontent.com/ultralytics/yolov5/master/data/images/zidane.jpg",
"zidane.jpg",
)
# List of YOLOv8 segmentation models
model_names = [
"yolov8n-seg.pt",
"yolov8s-seg.pt",
"yolov8m-seg.pt",
"yolov8l-seg.pt",
"yolov8x-seg.pt",
]
current_model_name = "yolov8m-seg.pt"
model = YOLO(current_model_name)
def yolov8_inference(
image: gr.Image = None,
model_name: gr.Dropdown = None,
image_size: gr.Slider = 640,
conf_threshold: gr.Slider = 0.25,
iou_threshold: gr.Slider = 0.45,
):
"""
YOLOv8 inference function to return masks and label names for each detected object
Args:
image: Input image
model_name: Name of the model
image_size: Image size
conf_threshold: Confidence threshold
iou_threshold: IOU threshold
Returns:
Object masks, coordinates, and label names
"""
global model
global current_model_name
# Check if a new model is selected
if model_name != current_model_name:
model = YOLO(model_name)
current_model_name = model_name
# Set the confidence and IOU thresholds
model.overrides["conf"] = conf_threshold
model.overrides["iou"] = iou_threshold
# Perform model prediction
results = model.predict(image, imgsz=image_size, return_outputs=True)
# Initialize an empty list to store the output
output = []
# Iterate over the results
for result in results:
# Check if segmentation masks are available
if 'masks' in result and result['masks'] is not None:
masks = result['masks']['data']
for i, (mask, box) in enumerate(zip(masks, result['boxes'])):
label = model.names[int(result['boxes']['cls'][i])]
mask_coords = mask.tolist() # Convert mask coordinates to list format
output.append({"label": label, "mask_coords": mask_coords})
else:
# If masks are not available, just extract bounding box information
for i, box in enumerate(result['boxes']):
label = model.names[int(result['boxes']['cls'][i])]
bbox = box['xyxy'].tolist() # Bounding box coordinates
output.append({"label": label, "bbox_coords": bbox})
return output
# Define Gradio interface inputs and outputs
inputs = [
gr.Image(type="filepath", label="Input Image"),
gr.Dropdown(
model_names,
value=current_model_name,
label="Model type",
),
gr.Slider(minimum=320, maximum=1280, value=640, step=32, label="Image Size"),
gr.Slider(
minimum=0.0, maximum=1.0, value=0.25, step=0.05, label="Confidence Threshold"
),
gr.Slider(minimum=0.0, maximum=1.0, value=0.45, step=0.05, label="IOU Threshold"),
]
# Output is a dictionary containing label names and coordinates of masks or boxes
outputs = gr.JSON(label="Output Masks and Labels")
title = "Ultralytics YOLOv8 Segmentation Demo"
# Example images for the interface
examples = [
["zidane.jpg", "yolov8m-seg.pt", 640, 0.6, 0.45],
["highway.jpg", "yolov8m-seg.pt", 640, 0.25, 0.45],
["small-vehicles1.jpeg", "yolov8m-seg.pt", 640, 0.25, 0.45],
]
# Build the Gradio demo app
demo_app = gr.Interface(
fn=yolov8_inference,
inputs=inputs,
outputs=outputs,
title=title,
examples=examples,
cache_examples=False, # Set to False to avoid caching issues
theme="default",
)
# Launch the app
demo_app.queue().launch(debug=True) |