AI-RESEARCHER-2024's picture
Update app.py
b68b9d8 verified
import gradio as gr
import cv2
import numpy as np
import os
from ultralytics import YOLO
from PIL import Image
# Load the trained model
model = YOLO('best.pt')
# Define class names and colors
class_names = ['IHC', 'OHC-1', 'OHC-2', 'OHC-3']
colors = [
(255, 255, 255), # IHC - White
(255, 0, 0), # OHC-1 - Red
(0, 255, 0), # OHC-2 - Green
(0, 0, 255) # OHC-3 - Blue
]
color_codes = {name: color for name, color in zip(class_names, colors)}
# Function to draw ground truth boxes
def draw_ground_truth(image, annotations):
image_height, image_width = image.shape[:2]
image_gt = image.copy()
for cls_id, x_center, y_center, width, height in annotations:
x = int((x_center - width / 2) * image_width)
y = int((y_center - height / 2) * image_height)
w = int(width * image_width)
h = int(height * image_height)
color = colors[cls_id % len(colors)]
cv2.rectangle(image_gt, (x, y), (x + w, y + h), color, 2)
return image_gt
# Function to draw prediction boxes
def draw_predictions(image):
image_pred = image.copy()
results = model(image)
boxes = results[0].boxes.xyxy.cpu().numpy()
classes = results[0].boxes.cls.cpu().numpy()
names = results[0].names
for i in range(len(boxes)):
box = boxes[i]
class_id = int(classes[i])
class_name = names[class_id]
color = color_codes.get(class_name, (255, 255, 255))
cv2.rectangle(
image_pred,
(int(box[0]), int(box[1])),
(int(box[2]), int(box[3])),
color,
2
)
return image_pred
# Prediction function for Step 1
def predict(input_image_path):
# Read the image from the file path
image = cv2.imread(input_image_path)
# Error handling if image is not loaded
if image is None:
print("Error: Unable to read image from the provided path.")
return None
# Convert color space
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
image_name = os.path.basename(input_image_path)
annotation_name = os.path.splitext(image_name)[0] + '.txt'
annotation_path = f'./examples/Labels/{annotation_name}'
if os.path.exists(annotation_path):
# Load annotations
annotations = []
with open(annotation_path, 'r') as f:
for line in f:
parts = line.strip().split()
if len(parts) == 5:
cls_id, x_center, y_center, width, height = map(float, parts)
annotations.append((int(cls_id), x_center, y_center, width, height))
# Draw ground truth on the image
image_gt = draw_ground_truth(image, annotations)
else:
print("Annotation file not found. Displaying original image as labeled image.")
image_gt = image.copy()
return Image.fromarray(image_gt)
# Function to split the image into 4 equal parts
def split_image(image):
h, w = image.shape[:2]
splits = [
image[0:h//2, 0:w//2], # Top-left
image[0:h//2, w//2:w], # Top-right
image[h//2:h, 0:w//2], # Bottom-left
image[h//2:h, w//2:w], # Bottom-right
]
return splits
# Function to prepare split images
def split_and_prepare(input_image_path):
if input_image_path is None:
return None
# Load the input image
image = cv2.imread(input_image_path)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
# Split the image
splits = split_image(image)
splits_pil = [Image.fromarray(split) for split in splits]
return splits_pil
# Function when a split part is selected
def select_image(splits, index):
if splits is None:
return None
return splits[index]
# Prediction function for selected part
def predict_part(selected_img):
if selected_img is None:
return None
image = np.array(selected_img)
image_pred = draw_predictions(image)
return Image.fromarray(image_pred)
# Create the HTML legend
legend_html = "<h3>Color Legend:</h3><div style='display: flex; align-items: center;'>"
for name, color in zip(class_names, colors):
color_rgb = f'rgb({color[0]},{color[1]},{color[2]})'
legend_html += (
f"<div style='margin-right: 15px; display: flex; align-items: center;'>"
f"<span style='color: {color_rgb}; font-size: 20px;'>&#9608;</span>"
f"<span style='margin-left: 5px;'>{name}</span>"
f"</div>"
)
legend_html += "</div>"
# List of example images
example_paths = [
'./examples/Images/11_sample12_40x.png',
'./examples/Images/12_sample11_20x.png',
'./examples/Images/13_sample3_2_folder1_kaylee_20x.png',
'./examples/Images/14_sample3_folder1_kaylee_20x.png',
'./examples/Images/15_sample6_2_folder1_kaylee_20x.png',
'./examples/Images/17_sample8_folder1_kaylee_20x.png',
'./examples/Images/18_sample9_folder1_kaylee_20x.png',
'./examples/Images/20_sample11_folder1_kaylee_20x.png',
'./examples/Images/22_sample13_folder1_kaylee_20x.png',
'./examples/Images/23_sample14_folder1_kaylee_20x.png',
]
# Create Gradio interface
with gr.Blocks() as interface:
gr.HTML("<h1 style='text-align: center;'>Detection of Cochlear Hair Cells Using YOLOv11</h1>")
gr.HTML("<h2 style='text-align: center;'>Cole Krudwig</h2>")
# Add the color legend
gr.HTML(legend_html)
# State variable to store the original splits
splits_state = gr.State()
with gr.Row():
with gr.Column():
input_image = gr.Image(type="filepath", label="Full Cochelar Image")
gr.Examples(
examples=example_paths,
inputs=input_image,
label="Examples"
)
with gr.Column():
output_gt = gr.Image(type="pil", label="Manually Annotated Image Used to Train YOLO11 Model", interactive=False)
input_image.change(
fn=predict,
inputs=input_image,
outputs=output_gt,
)
split_button = gr.Button("Split Image")
# Display split images
with gr.Row():
split_image1 = gr.Image(type="pil", label="Part 1", interactive=False)
split_image2 = gr.Image(type="pil", label="Part 2", interactive=False)
with gr.Row():
split_image3 = gr.Image(type="pil", label="Part 3", interactive=False)
split_image4 = gr.Image(type="pil", label="Part 4", interactive=False)
# Function to set split images
def set_split_images(splits):
if splits is None or len(splits) != 4:
return [None, None, None, None]
return splits
split_button.click(
fn=split_and_prepare,
inputs=input_image,
outputs=splits_state,
)
splits_state.change(
fn=set_split_images,
inputs=splits_state,
outputs=[split_image1, split_image2, split_image3, split_image4],
)
# Add buttons to select each part
with gr.Row():
select_part1 = gr.Button("Select Part 1")
select_part2 = gr.Button("Select Part 2")
with gr.Row():
select_part3 = gr.Button("Select Part 3")
select_part4 = gr.Button("Select Part 4")
selected_part = gr.Image(type="pil", label="Select Cropped Cochlear Image for Hair Cell Detection")
part_pred = gr.Image(type="pil", label="Prediction on Selected Part", interactive=False)
# Handle select part buttons
select_part1.click(
fn=lambda splits: select_image(splits, 0),
inputs=splits_state,
outputs=selected_part,
)
select_part2.click(
fn=lambda splits: select_image(splits, 1),
inputs=splits_state,
outputs=selected_part,
)
select_part3.click(
fn=lambda splits: select_image(splits, 2),
inputs=splits_state,
outputs=selected_part,
)
select_part4.click(
fn=lambda splits: select_image(splits, 3),
inputs=splits_state,
outputs=selected_part,
)
selected_part.change(
fn=predict_part,
inputs=selected_part,
outputs=part_pred,
)
interface.launch()