|
import gradio as gr |
|
import cv2 |
|
import numpy as np |
|
import os |
|
from ultralytics import YOLO |
|
from PIL import Image |
|
|
|
|
|
model = YOLO('best.pt') |
|
|
|
|
|
class_names = ['IHC', 'OHC-1', 'OHC-2', 'OHC-3'] |
|
colors = [ |
|
(255, 255, 255), |
|
(255, 0, 0), |
|
(0, 255, 0), |
|
(0, 0, 255) |
|
] |
|
color_codes = {name: color for name, color in zip(class_names, colors)} |
|
|
|
|
|
def draw_ground_truth(image, annotations): |
|
image_height, image_width = image.shape[:2] |
|
image_gt = image.copy() |
|
for cls_id, x_center, y_center, width, height in annotations: |
|
x = int((x_center - width / 2) * image_width) |
|
y = int((y_center - height / 2) * image_height) |
|
w = int(width * image_width) |
|
h = int(height * image_height) |
|
color = colors[cls_id % len(colors)] |
|
cv2.rectangle(image_gt, (x, y), (x + w, y + h), color, 2) |
|
return image_gt |
|
|
|
|
|
def draw_predictions(image): |
|
image_pred = image.copy() |
|
results = model(image) |
|
boxes = results[0].boxes.xyxy.cpu().numpy() |
|
classes = results[0].boxes.cls.cpu().numpy() |
|
names = results[0].names |
|
for i in range(len(boxes)): |
|
box = boxes[i] |
|
class_id = int(classes[i]) |
|
class_name = names[class_id] |
|
color = color_codes.get(class_name, (255, 255, 255)) |
|
cv2.rectangle( |
|
image_pred, |
|
(int(box[0]), int(box[1])), |
|
(int(box[2]), int(box[3])), |
|
color, |
|
2 |
|
) |
|
return image_pred |
|
|
|
|
|
def predict(input_image_path): |
|
|
|
image = cv2.imread(input_image_path) |
|
|
|
|
|
if image is None: |
|
print("Error: Unable to read image from the provided path.") |
|
return None |
|
|
|
|
|
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) |
|
|
|
image_name = os.path.basename(input_image_path) |
|
annotation_name = os.path.splitext(image_name)[0] + '.txt' |
|
annotation_path = f'./examples/Labels/{annotation_name}' |
|
|
|
if os.path.exists(annotation_path): |
|
|
|
annotations = [] |
|
with open(annotation_path, 'r') as f: |
|
for line in f: |
|
parts = line.strip().split() |
|
if len(parts) == 5: |
|
cls_id, x_center, y_center, width, height = map(float, parts) |
|
annotations.append((int(cls_id), x_center, y_center, width, height)) |
|
|
|
image_gt = draw_ground_truth(image, annotations) |
|
else: |
|
print("Annotation file not found. Displaying original image as labeled image.") |
|
image_gt = image.copy() |
|
|
|
return Image.fromarray(image_gt) |
|
|
|
|
|
def split_image(image): |
|
h, w = image.shape[:2] |
|
splits = [ |
|
image[0:h//2, 0:w//2], |
|
image[0:h//2, w//2:w], |
|
image[h//2:h, 0:w//2], |
|
image[h//2:h, w//2:w], |
|
] |
|
return splits |
|
|
|
|
|
def split_and_prepare(input_image_path): |
|
if input_image_path is None: |
|
return None |
|
|
|
|
|
image = cv2.imread(input_image_path) |
|
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) |
|
|
|
|
|
splits = split_image(image) |
|
splits_pil = [Image.fromarray(split) for split in splits] |
|
|
|
return splits_pil |
|
|
|
|
|
def select_image(splits, index): |
|
if splits is None: |
|
return None |
|
return splits[index] |
|
|
|
|
|
def predict_part(selected_img): |
|
if selected_img is None: |
|
return None |
|
|
|
image = np.array(selected_img) |
|
image_pred = draw_predictions(image) |
|
return Image.fromarray(image_pred) |
|
|
|
|
|
legend_html = "<h3>Color Legend:</h3><div style='display: flex; align-items: center;'>" |
|
for name, color in zip(class_names, colors): |
|
color_rgb = f'rgb({color[0]},{color[1]},{color[2]})' |
|
legend_html += ( |
|
f"<div style='margin-right: 15px; display: flex; align-items: center;'>" |
|
f"<span style='color: {color_rgb}; font-size: 20px;'>█</span>" |
|
f"<span style='margin-left: 5px;'>{name}</span>" |
|
f"</div>" |
|
) |
|
legend_html += "</div>" |
|
|
|
|
|
example_paths = [ |
|
'./examples/Images/11_sample12_40x.png', |
|
'./examples/Images/12_sample11_20x.png', |
|
'./examples/Images/13_sample3_2_folder1_kaylee_20x.png', |
|
'./examples/Images/14_sample3_folder1_kaylee_20x.png', |
|
'./examples/Images/15_sample6_2_folder1_kaylee_20x.png', |
|
'./examples/Images/17_sample8_folder1_kaylee_20x.png', |
|
'./examples/Images/18_sample9_folder1_kaylee_20x.png', |
|
'./examples/Images/20_sample11_folder1_kaylee_20x.png', |
|
'./examples/Images/22_sample13_folder1_kaylee_20x.png', |
|
'./examples/Images/23_sample14_folder1_kaylee_20x.png', |
|
] |
|
|
|
|
|
with gr.Blocks() as interface: |
|
gr.HTML("<h1 style='text-align: center;'>Detection of Cochlear Hair Cells Using YOLOv11</h1>") |
|
gr.HTML("<h2 style='text-align: center;'>Cole Krudwig</h2>") |
|
|
|
|
|
gr.HTML(legend_html) |
|
|
|
|
|
splits_state = gr.State() |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
input_image = gr.Image(type="filepath", label="Full Cochelar Image") |
|
gr.Examples( |
|
examples=example_paths, |
|
inputs=input_image, |
|
label="Examples" |
|
) |
|
with gr.Column(): |
|
output_gt = gr.Image(type="pil", label="Manually Annotated Image Used to Train YOLO11 Model", interactive=False) |
|
|
|
input_image.change( |
|
fn=predict, |
|
inputs=input_image, |
|
outputs=output_gt, |
|
) |
|
|
|
split_button = gr.Button("Split Image") |
|
|
|
|
|
with gr.Row(): |
|
split_image1 = gr.Image(type="pil", label="Part 1", interactive=False) |
|
split_image2 = gr.Image(type="pil", label="Part 2", interactive=False) |
|
with gr.Row(): |
|
split_image3 = gr.Image(type="pil", label="Part 3", interactive=False) |
|
split_image4 = gr.Image(type="pil", label="Part 4", interactive=False) |
|
|
|
|
|
def set_split_images(splits): |
|
if splits is None or len(splits) != 4: |
|
return [None, None, None, None] |
|
return splits |
|
|
|
split_button.click( |
|
fn=split_and_prepare, |
|
inputs=input_image, |
|
outputs=splits_state, |
|
) |
|
|
|
splits_state.change( |
|
fn=set_split_images, |
|
inputs=splits_state, |
|
outputs=[split_image1, split_image2, split_image3, split_image4], |
|
) |
|
|
|
|
|
with gr.Row(): |
|
select_part1 = gr.Button("Select Part 1") |
|
select_part2 = gr.Button("Select Part 2") |
|
with gr.Row(): |
|
select_part3 = gr.Button("Select Part 3") |
|
select_part4 = gr.Button("Select Part 4") |
|
|
|
selected_part = gr.Image(type="pil", label="Select Cropped Cochlear Image for Hair Cell Detection") |
|
part_pred = gr.Image(type="pil", label="Prediction on Selected Part", interactive=False) |
|
|
|
|
|
select_part1.click( |
|
fn=lambda splits: select_image(splits, 0), |
|
inputs=splits_state, |
|
outputs=selected_part, |
|
) |
|
select_part2.click( |
|
fn=lambda splits: select_image(splits, 1), |
|
inputs=splits_state, |
|
outputs=selected_part, |
|
) |
|
select_part3.click( |
|
fn=lambda splits: select_image(splits, 2), |
|
inputs=splits_state, |
|
outputs=selected_part, |
|
) |
|
select_part4.click( |
|
fn=lambda splits: select_image(splits, 3), |
|
inputs=splits_state, |
|
outputs=selected_part, |
|
) |
|
|
|
selected_part.change( |
|
fn=predict_part, |
|
inputs=selected_part, |
|
outputs=part_pred, |
|
) |
|
|
|
interface.launch() |