edeler commited on
Commit
b1f8229
·
verified ·
1 Parent(s): 07e2db9

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +91 -0
app.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import torch
4
+ import gradio as gr
5
+ import numpy as np
6
+ from ultralytics import YOLO
7
+ import supervision as sv
8
+ from PIL import Image
9
+ from huggingface_hub import snapshot_download
10
+
11
+ # Set up environment and device
12
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
13
+
14
+ # Adjustable parameters for detection
15
+ CONFIDENCE_THRESHOLD = 0.1 # Confidence threshold for detections
16
+ NMS_THRESHOLD = 0 # IoU threshold for non-maximum suppression
17
+ SLICE_WIDTH = 1024 # Width of each slice
18
+ SLICE_HEIGHT = 1024 # Height of each slice
19
+ OVERLAP_WIDTH = 200 # Overlap width between slices
20
+ OVERLAP_HEIGHT = 200 # Overlap height between slices
21
+ ANNOTATION_COLOR = sv.Color.RED # Red in BGR format for OpenCV
22
+ ANNOTATION_THICKNESS = 4 # Thickness of bounding box lines
23
+
24
+ # Download YOLO model weights from Hugging Face Hub
25
+ repo_id = 'edeler/ICC' # Replace with your Hugging Face repository ID
26
+ model_dir = snapshot_download(repo_id, local_dir='./models/ICC')
27
+ model_path = os.path.join(model_dir, "best.pt") # Adjust if filename differs
28
+ model = YOLO(model_path).to(device)
29
+
30
+ # Define the detection function for Gradio
31
+ def detect_objects(image: np.ndarray) -> Image.Image:
32
+ # Ensure the image is in BGR format if provided by PIL (Gradio gives us an RGB image)
33
+ if image.shape[-1] == 3:
34
+ image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
35
+
36
+ # Define callback function for slice-based inference
37
+ def callback(image_slice: np.ndarray) -> sv.Detections:
38
+ # Run inference on each slice
39
+ result = model(image_slice)[0]
40
+ # Convert detections to `sv.Detections` format for further processing
41
+ detections = sv.Detections.from_ultralytics(result)
42
+ # Filter detections based on confidence threshold
43
+ return detections[detections.confidence >= CONFIDENCE_THRESHOLD]
44
+
45
+ # Initialize InferenceSlicer with adjustable slice dimensions and overlap settings
46
+ slicer = sv.InferenceSlicer(
47
+ callback=callback,
48
+ slice_wh=(SLICE_WIDTH, SLICE_HEIGHT),
49
+ overlap_wh=(OVERLAP_WIDTH, OVERLAP_HEIGHT),
50
+ overlap_ratio_wh=None
51
+ )
52
+
53
+ # Perform slicing-based inference on the entire image
54
+ detections = slicer(image)
55
+
56
+ # Apply Non-Maximum Suppression (NMS) to the detections to avoid duplicate boxes
57
+ detections = detections.with_nms(threshold=NMS_THRESHOLD, class_agnostic=False)
58
+
59
+ # Initialize an annotator for bounding boxes with specified color and thickness
60
+ box_annotator = sv.OrientedBoxAnnotator(color=ANNOTATION_COLOR, thickness=ANNOTATION_THICKNESS)
61
+
62
+ # Annotate the image with bounding boxes after NMS
63
+ annotated_img = box_annotator.annotate(scene=image.copy(), detections=detections)
64
+
65
+ # Convert annotated image to RGB for Gradio display (PIL expects RGB)
66
+ annotated_img_rgb = cv2.cvtColor(annotated_img, cv2.COLOR_BGR2RGB)
67
+ return Image.fromarray(annotated_img_rgb)
68
+
69
+ # Reset function for Gradio UI
70
+ def gradio_reset():
71
+ return gr.update(value=None), gr.update(value=None)
72
+
73
+ # Set up Gradio interface
74
+ with gr.Blocks() as demo:
75
+ gr.Markdown("<h1>Interstitial Cell of Cajal Detection and Quantification Tool</h1>")
76
+
77
+ with gr.Row():
78
+ with gr.Column():
79
+ input_img = gr.Image(label="Upload an Image", type="numpy", interactive=True)
80
+ clear = gr.Button(value="Clear")
81
+ predict = gr.Button(value="Detect", variant="primary")
82
+
83
+ with gr.Column():
84
+ output_img = gr.Image(label="Detection Result", interactive=False)
85
+
86
+ # Define button actions
87
+ clear.click(gradio_reset, inputs=None, outputs=[input_img, output_img])
88
+ predict.click(detect_objects, inputs=[input_img], outputs=[output_img])
89
+
90
+ # Launch Gradio app
91
+ demo.launch(server_name="0.0.0.0", server_port=7860, debug=True)