File size: 4,168 Bytes
b1f8229
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bf0463d
 
b1f8229
 
 
d7a6491
 
b1f8229
 
 
 
f4ea82b
b1f8229
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fba30d3
 
 
 
 
 
 
 
 
b1f8229
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import os
import cv2
import torch
import gradio as gr
import numpy as np
from ultralytics import YOLO
import supervision as sv
from PIL import Image
from huggingface_hub import snapshot_download

# Adjustable parameters for detection
CONFIDENCE_THRESHOLD = 0.1  # Confidence threshold for detections
NMS_THRESHOLD = 0         # IoU threshold for non-maximum suppression
SLICE_WIDTH = 1024          # Width of each slice
SLICE_HEIGHT = 1024         # Height of each slice
OVERLAP_WIDTH = 0         # Overlap width between slices
OVERLAP_HEIGHT = 0        # Overlap height between slices
ANNOTATION_COLOR = sv.Color.RED  # Red in BGR format for OpenCV
ANNOTATION_THICKNESS = 4     # Thickness of bounding box lines

device = 'cuda' if torch.cuda.is_available() else 'cpu'

# Download YOLO model weights from Hugging Face Hub
repo_id = 'edeler/ICC'  # Replace with your Hugging Face repository ID
model_dir = snapshot_download(repo_id, local_dir='./models/ICC')
model_path = os.path.join(model_dir, "best.pt")  # Adjust if filename differs
model = YOLO(model_path).to(device)

# Define the detection function for Gradio
def detect_objects(image: np.ndarray) -> Image.Image:
    # Ensure the image is in BGR format if provided by PIL (Gradio gives us an RGB image)
    if image.shape[-1] == 3:
        image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
    
    # Define callback function for slice-based inference
    def callback(image_slice: np.ndarray) -> sv.Detections:
        # Run inference on each slice
        result = model(image_slice)[0]
        # Convert detections to `sv.Detections` format for further processing
        detections = sv.Detections.from_ultralytics(result)
        # Filter detections based on confidence threshold
        return detections[detections.confidence >= CONFIDENCE_THRESHOLD]

    # Initialize InferenceSlicer with adjustable slice dimensions and overlap settings
    slicer = sv.InferenceSlicer(
        callback=callback,
        slice_wh=(SLICE_WIDTH, SLICE_HEIGHT),
        overlap_wh=(OVERLAP_WIDTH, OVERLAP_HEIGHT),
        overlap_ratio_wh=None
    )

    # Perform slicing-based inference on the entire image
    detections = slicer(image)

    # Apply Non-Maximum Suppression (NMS) to the detections to avoid duplicate boxes
    detections = detections.with_nms(threshold=NMS_THRESHOLD, class_agnostic=False)

    # Initialize an annotator for bounding boxes with specified color and thickness
    box_annotator = sv.OrientedBoxAnnotator(color=ANNOTATION_COLOR, thickness=ANNOTATION_THICKNESS)

    # Annotate the image with bounding boxes after NMS
    annotated_img = box_annotator.annotate(scene=image.copy(), detections=detections)

    # Convert annotated image to RGB for Gradio display (PIL expects RGB)
    annotated_img_rgb = cv2.cvtColor(annotated_img, cv2.COLOR_BGR2RGB)
    return Image.fromarray(annotated_img_rgb)

# Reset function for Gradio UI
def gradio_reset():
    return gr.update(value=None), gr.update(value=None)

# Set up Gradio interface
with gr.Blocks() as demo:
    gr.Markdown("<h1>Interstitial Cell of Cajal Detection and Quantification Tool</h1>")
    
    with gr.Row():
        with gr.Column():
            input_img = gr.Image(label="Upload an Image", type="numpy", interactive=True)
            clear = gr.Button(value="Clear")
            predict = gr.Button(value="Detect", variant="primary")
            
        with gr.Column():
            output_img = gr.Image(label="Detection Result", interactive=False)
    
# Add Examples section with images from the root directory
    with gr.Accordion("Select an Example Image"):
        example_root = os.path.dirname(__file__)  # Root directory
        example_images = [os.path.join(example_root, file) for file in os.listdir(example_root) if file.endswith(".jpg")]
        gr.Examples(
            examples=example_images,
            inputs=[input_img],
        )
    
    # Define button actions
    clear.click(gradio_reset, inputs=None, outputs=[input_img, output_img])
    predict.click(detect_objects, inputs=[input_img], outputs=[output_img])

# Launch Gradio app
demo.launch(server_name="0.0.0.0", server_port=7860, debug=True)