import gradio as gr import cv2 import numpy as np import os from ultralytics import YOLO from PIL import Image # Load the trained model model = YOLO('best.pt') # Define class names and colors class_names = ['IHC', 'OHC-1', 'OHC-2', 'OHC-3'] colors = [ (255, 255, 255), # IHC - White (255, 0, 0), # OHC-1 - Red (0, 255, 0), # OHC-2 - Green (0, 0, 255) # OHC-3 - Blue ] color_codes = {name: color for name, color in zip(class_names, colors)} # Function to draw ground truth boxes def draw_ground_truth(image, annotations): image_height, image_width = image.shape[:2] image_gt = image.copy() for cls_id, x_center, y_center, width, height in annotations: x = int((x_center - width / 2) * image_width) y = int((y_center - height / 2) * image_height) w = int(width * image_width) h = int(height * image_height) color = colors[cls_id % len(colors)] cv2.rectangle(image_gt, (x, y), (x + w, y + h), color, 2) return image_gt # Function to draw prediction boxes def draw_predictions(image): image_pred = image.copy() results = model(image) boxes = results[0].boxes.xyxy.cpu().numpy() classes = results[0].boxes.cls.cpu().numpy() names = results[0].names for i in range(len(boxes)): box = boxes[i] class_id = int(classes[i]) class_name = names[class_id] color = color_codes.get(class_name, (255, 255, 255)) cv2.rectangle( image_pred, (int(box[0]), int(box[1])), (int(box[2]), int(box[3])), color, 2 ) return image_pred # Prediction function for Step 1 def predict(input_image_path): # Read the image from the file path image = cv2.imread(input_image_path) # Error handling if image is not loaded if image is None: print("Error: Unable to read image from the provided path.") return None # Convert color space image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) image_name = os.path.basename(input_image_path) annotation_name = os.path.splitext(image_name)[0] + '.txt' annotation_path = f'./examples/Labels/{annotation_name}' if os.path.exists(annotation_path): # Load annotations annotations = [] with open(annotation_path, 'r') as f: for line in f: parts = line.strip().split() if len(parts) == 5: cls_id, x_center, y_center, width, height = map(float, parts) annotations.append((int(cls_id), x_center, y_center, width, height)) # Draw ground truth on the image image_gt = draw_ground_truth(image, annotations) else: print("Annotation file not found. Displaying original image as labeled image.") image_gt = image.copy() return Image.fromarray(image_gt) # Function to split the image into 4 equal parts def split_image(image): h, w = image.shape[:2] splits = [ image[0:h//2, 0:w//2], # Top-left image[0:h//2, w//2:w], # Top-right image[h//2:h, 0:w//2], # Bottom-left image[h//2:h, w//2:w], # Bottom-right ] return splits # Function to prepare split images def split_and_prepare(input_image_path): if input_image_path is None: return None # Load the input image image = cv2.imread(input_image_path) image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) # Split the image splits = split_image(image) splits_pil = [Image.fromarray(split) for split in splits] return splits_pil # Function when a split part is selected def select_image(splits, index): if splits is None: return None return splits[index] # Prediction function for selected part def predict_part(selected_img): if selected_img is None: return None image = np.array(selected_img) image_pred = draw_predictions(image) return Image.fromarray(image_pred) # Create the HTML legend legend_html = "

Color Legend:

" for name, color in zip(class_names, colors): color_rgb = f'rgb({color[0]},{color[1]},{color[2]})' legend_html += ( f"
" f"" f"{name}" f"
" ) legend_html += "
" # List of example images example_paths = [ './examples/Images/11_sample12_40x.png', './examples/Images/12_sample11_20x.png', './examples/Images/13_sample3_2_folder1_kaylee_20x.png', './examples/Images/14_sample3_folder1_kaylee_20x.png', './examples/Images/15_sample6_2_folder1_kaylee_20x.png', './examples/Images/17_sample8_folder1_kaylee_20x.png', './examples/Images/18_sample9_folder1_kaylee_20x.png', './examples/Images/20_sample11_folder1_kaylee_20x.png', './examples/Images/22_sample13_folder1_kaylee_20x.png', './examples/Images/23_sample14_folder1_kaylee_20x.png', ] # Create Gradio interface with gr.Blocks() as interface: gr.HTML("

Detection of Cochlear Hair Cells Using YOLOv11

") gr.HTML("

Cole Krudwig

") # Add the color legend gr.HTML(legend_html) # State variable to store the original splits splits_state = gr.State() with gr.Row(): with gr.Column(): input_image = gr.Image(type="filepath", label="Full Cochelar Image") gr.Examples( examples=example_paths, inputs=input_image, label="Examples" ) with gr.Column(): output_gt = gr.Image(type="pil", label="Manually Annotated Image Used to Train YOLO11 Model", interactive=False) input_image.change( fn=predict, inputs=input_image, outputs=output_gt, ) split_button = gr.Button("Split Image") # Display split images with gr.Row(): split_image1 = gr.Image(type="pil", label="Part 1", interactive=False) split_image2 = gr.Image(type="pil", label="Part 2", interactive=False) with gr.Row(): split_image3 = gr.Image(type="pil", label="Part 3", interactive=False) split_image4 = gr.Image(type="pil", label="Part 4", interactive=False) # Function to set split images def set_split_images(splits): if splits is None or len(splits) != 4: return [None, None, None, None] return splits split_button.click( fn=split_and_prepare, inputs=input_image, outputs=splits_state, ) splits_state.change( fn=set_split_images, inputs=splits_state, outputs=[split_image1, split_image2, split_image3, split_image4], ) # Add buttons to select each part with gr.Row(): select_part1 = gr.Button("Select Part 1") select_part2 = gr.Button("Select Part 2") with gr.Row(): select_part3 = gr.Button("Select Part 3") select_part4 = gr.Button("Select Part 4") selected_part = gr.Image(type="pil", label="Select Cropped Cochlear Image for Hair Cell Detection") part_pred = gr.Image(type="pil", label="Prediction on Selected Part", interactive=False) # Handle select part buttons select_part1.click( fn=lambda splits: select_image(splits, 0), inputs=splits_state, outputs=selected_part, ) select_part2.click( fn=lambda splits: select_image(splits, 1), inputs=splits_state, outputs=selected_part, ) select_part3.click( fn=lambda splits: select_image(splits, 2), inputs=splits_state, outputs=selected_part, ) select_part4.click( fn=lambda splits: select_image(splits, 3), inputs=splits_state, outputs=selected_part, ) selected_part.change( fn=predict_part, inputs=selected_part, outputs=part_pred, ) interface.launch()