File size: 4,168 Bytes
b1f8229 bf0463d b1f8229 d7a6491 b1f8229 f4ea82b b1f8229 fba30d3 b1f8229 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 |
import os
import cv2
import torch
import gradio as gr
import numpy as np
from ultralytics import YOLO
import supervision as sv
from PIL import Image
from huggingface_hub import snapshot_download
# Adjustable parameters for detection
CONFIDENCE_THRESHOLD = 0.1 # Confidence threshold for detections
NMS_THRESHOLD = 0 # IoU threshold for non-maximum suppression
SLICE_WIDTH = 1024 # Width of each slice
SLICE_HEIGHT = 1024 # Height of each slice
OVERLAP_WIDTH = 0 # Overlap width between slices
OVERLAP_HEIGHT = 0 # Overlap height between slices
ANNOTATION_COLOR = sv.Color.RED # Red in BGR format for OpenCV
ANNOTATION_THICKNESS = 4 # Thickness of bounding box lines
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# Download YOLO model weights from Hugging Face Hub
repo_id = 'edeler/ICC' # Replace with your Hugging Face repository ID
model_dir = snapshot_download(repo_id, local_dir='./models/ICC')
model_path = os.path.join(model_dir, "best.pt") # Adjust if filename differs
model = YOLO(model_path).to(device)
# Define the detection function for Gradio
def detect_objects(image: np.ndarray) -> Image.Image:
# Ensure the image is in BGR format if provided by PIL (Gradio gives us an RGB image)
if image.shape[-1] == 3:
image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
# Define callback function for slice-based inference
def callback(image_slice: np.ndarray) -> sv.Detections:
# Run inference on each slice
result = model(image_slice)[0]
# Convert detections to `sv.Detections` format for further processing
detections = sv.Detections.from_ultralytics(result)
# Filter detections based on confidence threshold
return detections[detections.confidence >= CONFIDENCE_THRESHOLD]
# Initialize InferenceSlicer with adjustable slice dimensions and overlap settings
slicer = sv.InferenceSlicer(
callback=callback,
slice_wh=(SLICE_WIDTH, SLICE_HEIGHT),
overlap_wh=(OVERLAP_WIDTH, OVERLAP_HEIGHT),
overlap_ratio_wh=None
)
# Perform slicing-based inference on the entire image
detections = slicer(image)
# Apply Non-Maximum Suppression (NMS) to the detections to avoid duplicate boxes
detections = detections.with_nms(threshold=NMS_THRESHOLD, class_agnostic=False)
# Initialize an annotator for bounding boxes with specified color and thickness
box_annotator = sv.OrientedBoxAnnotator(color=ANNOTATION_COLOR, thickness=ANNOTATION_THICKNESS)
# Annotate the image with bounding boxes after NMS
annotated_img = box_annotator.annotate(scene=image.copy(), detections=detections)
# Convert annotated image to RGB for Gradio display (PIL expects RGB)
annotated_img_rgb = cv2.cvtColor(annotated_img, cv2.COLOR_BGR2RGB)
return Image.fromarray(annotated_img_rgb)
# Reset function for Gradio UI
def gradio_reset():
return gr.update(value=None), gr.update(value=None)
# Set up Gradio interface
with gr.Blocks() as demo:
gr.Markdown("<h1>Interstitial Cell of Cajal Detection and Quantification Tool</h1>")
with gr.Row():
with gr.Column():
input_img = gr.Image(label="Upload an Image", type="numpy", interactive=True)
clear = gr.Button(value="Clear")
predict = gr.Button(value="Detect", variant="primary")
with gr.Column():
output_img = gr.Image(label="Detection Result", interactive=False)
# Add Examples section with images from the root directory
with gr.Accordion("Select an Example Image"):
example_root = os.path.dirname(__file__) # Root directory
example_images = [os.path.join(example_root, file) for file in os.listdir(example_root) if file.endswith(".jpg")]
gr.Examples(
examples=example_images,
inputs=[input_img],
)
# Define button actions
clear.click(gradio_reset, inputs=None, outputs=[input_img, output_img])
predict.click(detect_objects, inputs=[input_img], outputs=[output_img])
# Launch Gradio app
demo.launch(server_name="0.0.0.0", server_port=7860, debug=True)
|