File size: 8,096 Bytes
5ece1f8 2db7876 29264a2 5ece1f8 29264a2 5ece1f8 29264a2 5ece1f8 29264a2 5ece1f8 a637284 5ece1f8 a637284 5ece1f8 97786d3 5ece1f8 97786d3 5ece1f8 29264a2 5ece1f8 29264a2 a637284 29264a2 fdfc9ab a637284 29264a2 a637284 29264a2 6f04e08 29264a2 6f04e08 fdfc9ab 9840e95 a637284 29264a2 a637284 29264a2 b35ba2e a637284 55979c0 a637284 29264a2 a637284 29264a2 a637284 29264a2 a637284 5ece1f8 d011adf 0dd8eb1 29264a2 d011adf 0dd8eb1 29264a2 5ece1f8 df0c2c4 b68b9d8 df0c2c4 0dd8eb1 d011adf 0dd8eb1 29264a2 df0c2c4 a068bb3 29264a2 d8f4585 29264a2 a068bb3 df0c2c4 fdfc9ab 29264a2 a637284 29264a2 a637284 29264a2 a637284 29264a2 a637284 29264a2 a637284 fdfc9ab a637284 29264a2 a068bb3 29264a2 a637284 df0c2c4 5ece1f8 29264a2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 |
import gradio as gr
import cv2
import numpy as np
import os
from ultralytics import YOLO
from PIL import Image
# Load the trained model
model = YOLO('best.pt')
# Define class names and colors
class_names = ['IHC', 'OHC-1', 'OHC-2', 'OHC-3']
colors = [
(255, 255, 255), # IHC - White
(255, 0, 0), # OHC-1 - Red
(0, 255, 0), # OHC-2 - Green
(0, 0, 255) # OHC-3 - Blue
]
color_codes = {name: color for name, color in zip(class_names, colors)}
# Function to draw ground truth boxes
def draw_ground_truth(image, annotations):
image_height, image_width = image.shape[:2]
image_gt = image.copy()
for cls_id, x_center, y_center, width, height in annotations:
x = int((x_center - width / 2) * image_width)
y = int((y_center - height / 2) * image_height)
w = int(width * image_width)
h = int(height * image_height)
color = colors[cls_id % len(colors)]
cv2.rectangle(image_gt, (x, y), (x + w, y + h), color, 2)
return image_gt
# Function to draw prediction boxes
def draw_predictions(image):
image_pred = image.copy()
results = model(image)
boxes = results[0].boxes.xyxy.cpu().numpy()
classes = results[0].boxes.cls.cpu().numpy()
names = results[0].names
for i in range(len(boxes)):
box = boxes[i]
class_id = int(classes[i])
class_name = names[class_id]
color = color_codes.get(class_name, (255, 255, 255))
cv2.rectangle(
image_pred,
(int(box[0]), int(box[1])),
(int(box[2]), int(box[3])),
color,
2
)
return image_pred
# Prediction function for Step 1
def predict(input_image_path):
# Read the image from the file path
image = cv2.imread(input_image_path)
# Error handling if image is not loaded
if image is None:
print("Error: Unable to read image from the provided path.")
return None
# Convert color space
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
image_name = os.path.basename(input_image_path)
annotation_name = os.path.splitext(image_name)[0] + '.txt'
annotation_path = f'./examples/Labels/{annotation_name}'
if os.path.exists(annotation_path):
# Load annotations
annotations = []
with open(annotation_path, 'r') as f:
for line in f:
parts = line.strip().split()
if len(parts) == 5:
cls_id, x_center, y_center, width, height = map(float, parts)
annotations.append((int(cls_id), x_center, y_center, width, height))
# Draw ground truth on the image
image_gt = draw_ground_truth(image, annotations)
else:
print("Annotation file not found. Displaying original image as labeled image.")
image_gt = image.copy()
return Image.fromarray(image_gt)
# Function to split the image into 4 equal parts
def split_image(image):
h, w = image.shape[:2]
splits = [
image[0:h//2, 0:w//2], # Top-left
image[0:h//2, w//2:w], # Top-right
image[h//2:h, 0:w//2], # Bottom-left
image[h//2:h, w//2:w], # Bottom-right
]
return splits
# Function to prepare split images
def split_and_prepare(input_image_path):
if input_image_path is None:
return None
# Load the input image
image = cv2.imread(input_image_path)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
# Split the image
splits = split_image(image)
splits_pil = [Image.fromarray(split) for split in splits]
return splits_pil
# Function when a split part is selected
def select_image(splits, index):
if splits is None:
return None
return splits[index]
# Prediction function for selected part
def predict_part(selected_img):
if selected_img is None:
return None
image = np.array(selected_img)
image_pred = draw_predictions(image)
return Image.fromarray(image_pred)
# Create the HTML legend
legend_html = "<h3>Color Legend:</h3><div style='display: flex; align-items: center;'>"
for name, color in zip(class_names, colors):
color_rgb = f'rgb({color[0]},{color[1]},{color[2]})'
legend_html += (
f"<div style='margin-right: 15px; display: flex; align-items: center;'>"
f"<span style='color: {color_rgb}; font-size: 20px;'>█</span>"
f"<span style='margin-left: 5px;'>{name}</span>"
f"</div>"
)
legend_html += "</div>"
# List of example images
example_paths = [
'./examples/Images/11_sample12_40x.png',
'./examples/Images/12_sample11_20x.png',
'./examples/Images/13_sample3_2_folder1_kaylee_20x.png',
'./examples/Images/14_sample3_folder1_kaylee_20x.png',
'./examples/Images/15_sample6_2_folder1_kaylee_20x.png',
'./examples/Images/17_sample8_folder1_kaylee_20x.png',
'./examples/Images/18_sample9_folder1_kaylee_20x.png',
'./examples/Images/20_sample11_folder1_kaylee_20x.png',
'./examples/Images/22_sample13_folder1_kaylee_20x.png',
'./examples/Images/23_sample14_folder1_kaylee_20x.png',
]
# Create Gradio interface
with gr.Blocks() as interface:
gr.HTML("<h1 style='text-align: center;'>Detection of Cochlear Hair Cells Using YOLOv11</h1>")
gr.HTML("<h2 style='text-align: center;'>Cole Krudwig</h2>")
# Add the color legend
gr.HTML(legend_html)
# State variable to store the original splits
splits_state = gr.State()
with gr.Row():
with gr.Column():
input_image = gr.Image(type="filepath", label="Full Cochelar Image")
gr.Examples(
examples=example_paths,
inputs=input_image,
label="Examples"
)
with gr.Column():
output_gt = gr.Image(type="pil", label="Manually Annotated Image Used to Train YOLO11 Model", interactive=False)
input_image.change(
fn=predict,
inputs=input_image,
outputs=output_gt,
)
split_button = gr.Button("Split Image")
# Display split images
with gr.Row():
split_image1 = gr.Image(type="pil", label="Part 1", interactive=False)
split_image2 = gr.Image(type="pil", label="Part 2", interactive=False)
with gr.Row():
split_image3 = gr.Image(type="pil", label="Part 3", interactive=False)
split_image4 = gr.Image(type="pil", label="Part 4", interactive=False)
# Function to set split images
def set_split_images(splits):
if splits is None or len(splits) != 4:
return [None, None, None, None]
return splits
split_button.click(
fn=split_and_prepare,
inputs=input_image,
outputs=splits_state,
)
splits_state.change(
fn=set_split_images,
inputs=splits_state,
outputs=[split_image1, split_image2, split_image3, split_image4],
)
# Add buttons to select each part
with gr.Row():
select_part1 = gr.Button("Select Part 1")
select_part2 = gr.Button("Select Part 2")
with gr.Row():
select_part3 = gr.Button("Select Part 3")
select_part4 = gr.Button("Select Part 4")
selected_part = gr.Image(type="pil", label="Select Cropped Cochlear Image for Hair Cell Detection")
part_pred = gr.Image(type="pil", label="Prediction on Selected Part", interactive=False)
# Handle select part buttons
select_part1.click(
fn=lambda splits: select_image(splits, 0),
inputs=splits_state,
outputs=selected_part,
)
select_part2.click(
fn=lambda splits: select_image(splits, 1),
inputs=splits_state,
outputs=selected_part,
)
select_part3.click(
fn=lambda splits: select_image(splits, 2),
inputs=splits_state,
outputs=selected_part,
)
select_part4.click(
fn=lambda splits: select_image(splits, 3),
inputs=splits_state,
outputs=selected_part,
)
selected_part.change(
fn=predict_part,
inputs=selected_part,
outputs=part_pred,
)
interface.launch() |