ICC / app.py
edeler's picture
Update app.py
d7a6491 verified
raw
history blame
4.17 kB
import os
import cv2
import torch
import gradio as gr
import numpy as np
from ultralytics import YOLO
import supervision as sv
from PIL import Image
from huggingface_hub import snapshot_download
# Adjustable parameters for detection
CONFIDENCE_THRESHOLD = 0.1 # Confidence threshold for detections
NMS_THRESHOLD = 0 # IoU threshold for non-maximum suppression
SLICE_WIDTH = 1024 # Width of each slice
SLICE_HEIGHT = 1024 # Height of each slice
OVERLAP_WIDTH = 0 # Overlap width between slices
OVERLAP_HEIGHT = 0 # Overlap height between slices
ANNOTATION_COLOR = sv.Color.RED # Red in BGR format for OpenCV
ANNOTATION_THICKNESS = 4 # Thickness of bounding box lines
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# Download YOLO model weights from Hugging Face Hub
repo_id = 'edeler/ICC' # Replace with your Hugging Face repository ID
model_dir = snapshot_download(repo_id, local_dir='./models/ICC')
model_path = os.path.join(model_dir, "best.pt") # Adjust if filename differs
model = YOLO(model_path).to(device)
# Define the detection function for Gradio
def detect_objects(image: np.ndarray) -> Image.Image:
# Ensure the image is in BGR format if provided by PIL (Gradio gives us an RGB image)
if image.shape[-1] == 3:
image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
# Define callback function for slice-based inference
def callback(image_slice: np.ndarray) -> sv.Detections:
# Run inference on each slice
result = model(image_slice)[0]
# Convert detections to `sv.Detections` format for further processing
detections = sv.Detections.from_ultralytics(result)
# Filter detections based on confidence threshold
return detections[detections.confidence >= CONFIDENCE_THRESHOLD]
# Initialize InferenceSlicer with adjustable slice dimensions and overlap settings
slicer = sv.InferenceSlicer(
callback=callback,
slice_wh=(SLICE_WIDTH, SLICE_HEIGHT),
overlap_wh=(OVERLAP_WIDTH, OVERLAP_HEIGHT),
overlap_ratio_wh=None
)
# Perform slicing-based inference on the entire image
detections = slicer(image)
# Apply Non-Maximum Suppression (NMS) to the detections to avoid duplicate boxes
detections = detections.with_nms(threshold=NMS_THRESHOLD, class_agnostic=False)
# Initialize an annotator for bounding boxes with specified color and thickness
box_annotator = sv.OrientedBoxAnnotator(color=ANNOTATION_COLOR, thickness=ANNOTATION_THICKNESS)
# Annotate the image with bounding boxes after NMS
annotated_img = box_annotator.annotate(scene=image.copy(), detections=detections)
# Convert annotated image to RGB for Gradio display (PIL expects RGB)
annotated_img_rgb = cv2.cvtColor(annotated_img, cv2.COLOR_BGR2RGB)
return Image.fromarray(annotated_img_rgb)
# Reset function for Gradio UI
def gradio_reset():
return gr.update(value=None), gr.update(value=None)
# Set up Gradio interface
with gr.Blocks() as demo:
gr.Markdown("<h1>Interstitial Cell of Cajal Detection and Quantification Tool</h1>")
with gr.Row():
with gr.Column():
input_img = gr.Image(label="Upload an Image", type="numpy", interactive=True)
clear = gr.Button(value="Clear")
predict = gr.Button(value="Detect", variant="primary")
with gr.Column():
output_img = gr.Image(label="Detection Result", interactive=False)
# Add Examples section with images from the root directory
with gr.Accordion("Select an Example Image"):
example_root = os.path.dirname(__file__) # Root directory
example_images = [os.path.join(example_root, file) for file in os.listdir(example_root) if file.endswith(".jpg")]
gr.Examples(
examples=example_images,
inputs=[input_img],
)
# Define button actions
clear.click(gradio_reset, inputs=None, outputs=[input_img, output_img])
predict.click(detect_objects, inputs=[input_img], outputs=[output_img])
# Launch Gradio app
demo.launch(server_name="0.0.0.0", server_port=7860, debug=True)