|
import os |
|
import cv2 |
|
import torch |
|
import gradio as gr |
|
import numpy as np |
|
from ultralytics import YOLO |
|
import supervision as sv |
|
from PIL import Image |
|
from huggingface_hub import snapshot_download |
|
|
|
|
|
CONFIDENCE_THRESHOLD = 0.1 |
|
NMS_THRESHOLD = 0 |
|
SLICE_WIDTH = 1024 |
|
SLICE_HEIGHT = 1024 |
|
OVERLAP_WIDTH = 0 |
|
OVERLAP_HEIGHT = 0 |
|
ANNOTATION_COLOR = sv.Color.RED |
|
ANNOTATION_THICKNESS = 4 |
|
|
|
device = 'cuda' if torch.cuda.is_available() else 'cpu' |
|
|
|
|
|
repo_id = 'edeler/ICC' |
|
model_dir = snapshot_download(repo_id, local_dir='./models/ICC') |
|
model_path = os.path.join(model_dir, "best.pt") |
|
model = YOLO(model_path).to(device) |
|
|
|
|
|
def detect_objects(image: np.ndarray) -> Image.Image: |
|
|
|
if image.shape[-1] == 3: |
|
image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) |
|
|
|
|
|
def callback(image_slice: np.ndarray) -> sv.Detections: |
|
|
|
result = model(image_slice)[0] |
|
|
|
detections = sv.Detections.from_ultralytics(result) |
|
|
|
return detections[detections.confidence >= CONFIDENCE_THRESHOLD] |
|
|
|
|
|
slicer = sv.InferenceSlicer( |
|
callback=callback, |
|
slice_wh=(SLICE_WIDTH, SLICE_HEIGHT), |
|
overlap_wh=(OVERLAP_WIDTH, OVERLAP_HEIGHT), |
|
overlap_ratio_wh=None |
|
) |
|
|
|
|
|
detections = slicer(image) |
|
|
|
|
|
detections = detections.with_nms(threshold=NMS_THRESHOLD, class_agnostic=False) |
|
|
|
|
|
box_annotator = sv.OrientedBoxAnnotator(color=ANNOTATION_COLOR, thickness=ANNOTATION_THICKNESS) |
|
|
|
|
|
annotated_img = box_annotator.annotate(scene=image.copy(), detections=detections) |
|
|
|
|
|
annotated_img_rgb = cv2.cvtColor(annotated_img, cv2.COLOR_BGR2RGB) |
|
return Image.fromarray(annotated_img_rgb) |
|
|
|
|
|
def gradio_reset(): |
|
return gr.update(value=None), gr.update(value=None) |
|
|
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown("<h1>Interstitial Cell of Cajal Detection and Quantification Tool</h1>") |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
input_img = gr.Image(label="Upload an Image", type="numpy", interactive=True) |
|
clear = gr.Button(value="Clear") |
|
predict = gr.Button(value="Detect", variant="primary") |
|
|
|
with gr.Column(): |
|
output_img = gr.Image(label="Detection Result", interactive=False) |
|
|
|
|
|
with gr.Accordion("Select an Example Image"): |
|
example_root = os.path.dirname(__file__) |
|
example_images = [os.path.join(example_root, file) for file in os.listdir(example_root) if file.endswith(".jpg")] |
|
gr.Examples( |
|
examples=example_images, |
|
inputs=[input_img], |
|
) |
|
|
|
|
|
clear.click(gradio_reset, inputs=None, outputs=[input_img, output_img]) |
|
predict.click(detect_objects, inputs=[input_img], outputs=[output_img]) |
|
|
|
|
|
demo.launch(server_name="0.0.0.0", server_port=7860, debug=True) |
|
|