import gradio as gr import cv2 import numpy as np import os from ultralytics import YOLO from PIL import Image # Load the trained model model = YOLO('best.pt') # Define class names and colors class_names = ['IHC', 'OHC-1', 'OHC-2', 'OHC-3'] colors = [ (255, 255, 255), # IHC - White (255, 0, 0), # OHC-1 - Red (0, 255, 0), # OHC-2 - Green (0, 0, 255) # OHC-3 - Blue ] color_codes = {name: color for name, color in zip(class_names, colors)} # Function to draw ground truth boxes def draw_ground_truth(image, annotations): image_height, image_width = image.shape[:2] image_gt = image.copy() for cls_id, x_center, y_center, width, height in annotations: x = int((x_center - width / 2) * image_width) y = int((y_center - height / 2) * image_height) w = int(width * image_width) h = int(height * image_height) color = colors[cls_id % len(colors)] cv2.rectangle(image_gt, (x, y), (x + w, y + h), color, 2) return image_gt # Function to draw prediction boxes def draw_predictions(image): image_pred = image.copy() results = model(image) boxes = results[0].boxes.xyxy.cpu().numpy() classes = results[0].boxes.cls.cpu().numpy() names = results[0].names for i in range(len(boxes)): box = boxes[i] class_id = int(classes[i]) class_name = names[class_id] color = color_codes.get(class_name, (255, 255, 255)) cv2.rectangle( image_pred, (int(box[0]), int(box[1])), (int(box[2]), int(box[3])), color, 2 ) return image_pred # Prediction function for Step 1 def predict(input_image_path): # Read the image from the file path image = cv2.imread(input_image_path) # Error handling if image is not loaded if image is None: print("Error: Unable to read image from the provided path.") return None # Convert color space image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) image_name = os.path.basename(input_image_path) annotation_name = os.path.splitext(image_name)[0] + '.txt' annotation_path = f'./examples/Labels/{annotation_name}' if os.path.exists(annotation_path): # Load annotations annotations = [] with open(annotation_path, 'r') as f: for line in f: parts = line.strip().split() if len(parts) == 5: cls_id, x_center, y_center, width, height = map(float, parts) annotations.append((int(cls_id), x_center, y_center, width, height)) # Draw ground truth on the image image_gt = draw_ground_truth(image, annotations) else: print("Annotation file not found. Displaying original image as labeled image.") image_gt = image.copy() return Image.fromarray(image_gt) # Function to split the image into 4 equal parts def split_image(image): h, w = image.shape[:2] splits = [ image[0:h//2, 0:w//2], # Top-left image[0:h//2, w//2:w], # Top-right image[h//2:h, 0:w//2], # Bottom-left image[h//2:h, w//2:w], # Bottom-right ] return splits # Function to prepare split images def split_and_prepare(input_image_path): if input_image_path is None: return None # Load the input image image = cv2.imread(input_image_path) image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) # Split the image splits = split_image(image) splits_pil = [Image.fromarray(split) for split in splits] return splits_pil # Function when a split part is selected def select_image(splits, index): if splits is None: return None return splits[index] # Prediction function for selected part def predict_part(selected_img): if selected_img is None: return None image = np.array(selected_img) image_pred = draw_predictions(image) return Image.fromarray(image_pred) # Create the HTML legend legend_html = "