ICC / app.py
edeler's picture
Update app.py
e0162a2 verified
import os
import cv2
import torch
import gradio as gr
import numpy as np
from ultralytics import YOLO
import supervision as sv
from PIL import Image
from huggingface_hub import snapshot_download
import spaces # Import spaces for ZeroGPU compatibility
# Adjustable parameters for detection
CONFIDENCE_THRESHOLD = 0.1 # Confidence threshold for detections
NMS_THRESHOLD = 0 # IoU threshold for non-maximum suppression
SLICE_WIDTH = 1024 # Width of each slice
SLICE_HEIGHT = 1024 # Height of each slice
OVERLAP_WIDTH = 0 # Overlap width between slices
OVERLAP_HEIGHT = 0 # Overlap height between slices
ANNOTATION_COLOR = sv.Color.RED # Red in BGR format for OpenCV
ANNOTATION_THICKNESS = 4 # Thickness of bounding box lines
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# Download YOLO model weights from Hugging Face Hub
repo_id = 'edeler/ICC' # Replace with your Hugging Face repository ID
model_dir = snapshot_download(repo_id, local_dir='./models/ICC')
model_path = os.path.join(model_dir, "best.pt") # Adjust if filename differs
model = YOLO(model_path).to(device)
# Define the detection function for Gradio
@spaces.GPU # Decorator to allocate GPU for ZeroGPU-enabled Spaces
def detect_objects(image: np.ndarray) -> (Image.Image, str):
# Ensure the image is in BGR format if provided by PIL (Gradio gives us an RGB image)
if image.shape[-1] == 3:
image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
# Define callback function for slice-based inference
def callback(image_slice: np.ndarray) -> sv.Detections:
# Run inference on each slice
result = model(image_slice)[0]
# Convert detections to `sv.Detections` format for further processing
detections = sv.Detections.from_ultralytics(result)
# Filter detections based on confidence threshold
return detections[detections.confidence >= CONFIDENCE_THRESHOLD]
# Initialize InferenceSlicer with adjustable slice dimensions and overlap settings
slicer = sv.InferenceSlicer(
callback=callback,
slice_wh=(SLICE_WIDTH, SLICE_HEIGHT),
overlap_wh=(OVERLAP_WIDTH, OVERLAP_HEIGHT),
overlap_ratio_wh=None
)
# Perform slicing-based inference on the entire image
detections = slicer(image)
# Apply Non-Maximum Suppression (NMS) to the detections to avoid duplicate boxes
detections = detections.with_nms(threshold=NMS_THRESHOLD, class_agnostic=False)
# Count total detections after NMS
total_detections = len(detections)
# Initialize an annotator for bounding boxes with specified color and thickness
box_annotator = sv.OrientedBoxAnnotator(color=ANNOTATION_COLOR, thickness=ANNOTATION_THICKNESS)
# Annotate the image with bounding boxes after NMS
annotated_img = box_annotator.annotate(scene=image.copy(), detections=detections)
# Convert annotated image to RGB for Gradio display (PIL expects RGB)
annotated_img_rgb = cv2.cvtColor(annotated_img, cv2.COLOR_BGR2RGB)
# Return the annotated image and the total count of detections
return Image.fromarray(annotated_img_rgb), f"Total Detections: {total_detections}"
# Reset function for Gradio UI
def gradio_reset():
return gr.update(value=None), gr.update(value=None), gr.update(value="")
# Set up Gradio interface
with gr.Blocks() as demo:
gr.Markdown("<h1>Interstitial Cell of Cajal Detection and Quantification Tool</h1>")
# Define input image before using it in the Examples section
input_img = gr.Image(label="Upload an Image", type="numpy", interactive=True)
# Move the Examples section here, right after the title
with gr.Accordion("Select an Example Image"):
example_root = os.path.dirname(__file__) # Root directory
example_images = [os.path.join(example_root, file) for file in os.listdir(example_root) if file.endswith(".jpg")]
gr.Examples(
examples=example_images,
inputs=[input_img],
)
with gr.Row():
with gr.Column():
clear = gr.Button(value="Clear")
predict = gr.Button(value="Detect", variant="primary")
with gr.Column():
output_img = gr.Image(label="Detection Result", interactive=False)
detection_count = gr.Textbox(label="Detection Summary", interactive=False)
# Define button actions
clear.click(gradio_reset, inputs=None, outputs=[input_img, output_img, detection_count])
predict.click(detect_objects, inputs=[input_img], outputs=[output_img, detection_count])
# Launch Gradio app
demo.launch(server_name="0.0.0.0", server_port=7860, debug=True)