File size: 8,096 Bytes
5ece1f8
 
 
 
2db7876
29264a2
5ece1f8
 
29264a2
5ece1f8
 
29264a2
5ece1f8
29264a2
 
 
 
5ece1f8
 
 
a637284
 
5ece1f8
 
 
a637284
 
 
 
 
 
5ece1f8
 
 
 
 
 
97786d3
 
5ece1f8
 
 
 
 
97786d3
5ece1f8
29264a2
 
 
 
 
5ece1f8
 
 
29264a2
 
 
 
a637284
29264a2
 
 
fdfc9ab
a637284
29264a2
 
a637284
29264a2
 
 
6f04e08
29264a2
 
 
 
 
 
 
 
 
 
 
 
 
 
6f04e08
fdfc9ab
9840e95
a637284
 
 
 
 
 
 
 
 
 
 
 
 
29264a2
a637284
 
29264a2
 
 
b35ba2e
a637284
 
 
55979c0
a637284
29264a2
 
 
a637284
 
29264a2
 
 
 
 
a637284
29264a2
 
 
a637284
5ece1f8
d011adf
 
0dd8eb1
29264a2
d011adf
 
 
 
 
 
 
0dd8eb1
29264a2
 
 
 
 
 
 
 
 
 
 
 
 
 
5ece1f8
df0c2c4
b68b9d8
 
df0c2c4
0dd8eb1
d011adf
0dd8eb1
29264a2
 
 
df0c2c4
 
a068bb3
29264a2
 
 
 
d8f4585
29264a2
a068bb3
df0c2c4
 
 
 
fdfc9ab
29264a2
 
a637284
29264a2
a637284
29264a2
a637284
 
29264a2
a637284
 
29264a2
a637284
 
 
 
 
 
fdfc9ab
 
 
 
 
 
a637284
 
 
 
29264a2
 
 
 
 
 
 
 
 
 
a068bb3
29264a2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a637284
df0c2c4
5ece1f8
29264a2
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
import gradio as gr
import cv2
import numpy as np
import os
from ultralytics import YOLO
from PIL import Image

# Load the trained model
model = YOLO('best.pt')

# Define class names and colors
class_names = ['IHC', 'OHC-1', 'OHC-2', 'OHC-3']
colors = [
    (255, 255, 255),  # IHC - White
    (255, 0, 0),      # OHC-1 - Red
    (0, 255, 0),      # OHC-2 - Green
    (0, 0, 255)       # OHC-3 - Blue
]
color_codes = {name: color for name, color in zip(class_names, colors)}

# Function to draw ground truth boxes
def draw_ground_truth(image, annotations):
    image_height, image_width = image.shape[:2]
    image_gt = image.copy()
    for cls_id, x_center, y_center, width, height in annotations:
        x = int((x_center - width / 2) * image_width)
        y = int((y_center - height / 2) * image_height)
        w = int(width * image_width)
        h = int(height * image_height)
        color = colors[cls_id % len(colors)]
        cv2.rectangle(image_gt, (x, y), (x + w, y + h), color, 2)
    return image_gt

# Function to draw prediction boxes
def draw_predictions(image):
    image_pred = image.copy()
    results = model(image)
    boxes = results[0].boxes.xyxy.cpu().numpy()
    classes = results[0].boxes.cls.cpu().numpy()
    names = results[0].names
    for i in range(len(boxes)):
        box = boxes[i]
        class_id = int(classes[i])
        class_name = names[class_id]
        color = color_codes.get(class_name, (255, 255, 255))
        cv2.rectangle(
            image_pred,
            (int(box[0]), int(box[1])),
            (int(box[2]), int(box[3])),
            color,
            2
        )
    return image_pred

# Prediction function for Step 1
def predict(input_image_path):
    # Read the image from the file path
    image = cv2.imread(input_image_path)

    # Error handling if image is not loaded
    if image is None:
        print("Error: Unable to read image from the provided path.")
        return None

    # Convert color space
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

    image_name = os.path.basename(input_image_path)
    annotation_name = os.path.splitext(image_name)[0] + '.txt'
    annotation_path = f'./examples/Labels/{annotation_name}'

    if os.path.exists(annotation_path):
        # Load annotations
        annotations = []
        with open(annotation_path, 'r') as f:
            for line in f:
                parts = line.strip().split()
                if len(parts) == 5:
                    cls_id, x_center, y_center, width, height = map(float, parts)
                    annotations.append((int(cls_id), x_center, y_center, width, height))
        # Draw ground truth on the image
        image_gt = draw_ground_truth(image, annotations)
    else:
        print("Annotation file not found. Displaying original image as labeled image.")
        image_gt = image.copy()

    return Image.fromarray(image_gt)

# Function to split the image into 4 equal parts
def split_image(image):
    h, w = image.shape[:2]
    splits = [
        image[0:h//2, 0:w//2],     # Top-left
        image[0:h//2, w//2:w],     # Top-right
        image[h//2:h, 0:w//2],     # Bottom-left
        image[h//2:h, w//2:w],     # Bottom-right
    ]
    return splits

# Function to prepare split images
def split_and_prepare(input_image_path):
    if input_image_path is None:
        return None

    # Load the input image
    image = cv2.imread(input_image_path)
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

    # Split the image
    splits = split_image(image)
    splits_pil = [Image.fromarray(split) for split in splits]

    return splits_pil

# Function when a split part is selected
def select_image(splits, index):
    if splits is None:
        return None
    return splits[index]

# Prediction function for selected part
def predict_part(selected_img):
    if selected_img is None:
        return None

    image = np.array(selected_img)
    image_pred = draw_predictions(image)
    return Image.fromarray(image_pred)

# Create the HTML legend
legend_html = "<h3>Color Legend:</h3><div style='display: flex; align-items: center;'>"
for name, color in zip(class_names, colors):
    color_rgb = f'rgb({color[0]},{color[1]},{color[2]})'
    legend_html += (
        f"<div style='margin-right: 15px; display: flex; align-items: center;'>"
        f"<span style='color: {color_rgb}; font-size: 20px;'>&#9608;</span>"
        f"<span style='margin-left: 5px;'>{name}</span>"
        f"</div>"
    )
legend_html += "</div>"

# List of example images
example_paths = [
    './examples/Images/11_sample12_40x.png',
    './examples/Images/12_sample11_20x.png',
    './examples/Images/13_sample3_2_folder1_kaylee_20x.png',
    './examples/Images/14_sample3_folder1_kaylee_20x.png',
    './examples/Images/15_sample6_2_folder1_kaylee_20x.png',
    './examples/Images/17_sample8_folder1_kaylee_20x.png',
    './examples/Images/18_sample9_folder1_kaylee_20x.png',
    './examples/Images/20_sample11_folder1_kaylee_20x.png',
    './examples/Images/22_sample13_folder1_kaylee_20x.png',
    './examples/Images/23_sample14_folder1_kaylee_20x.png',
]

# Create Gradio interface
with gr.Blocks() as interface:
    gr.HTML("<h1 style='text-align: center;'>Detection of Cochlear Hair Cells Using YOLOv11</h1>")
    gr.HTML("<h2 style='text-align: center;'>Cole Krudwig</h2>")

    # Add the color legend
    gr.HTML(legend_html)

    # State variable to store the original splits
    splits_state = gr.State()

    with gr.Row():
        with gr.Column():
            input_image = gr.Image(type="filepath", label="Full Cochelar Image")
            gr.Examples(
                examples=example_paths,
                inputs=input_image,
                label="Examples"
            )
        with gr.Column():
            output_gt = gr.Image(type="pil", label="Manually Annotated Image Used to Train YOLO11 Model", interactive=False)

    input_image.change(
        fn=predict,
        inputs=input_image,
        outputs=output_gt,
    )

    split_button = gr.Button("Split Image")

    # Display split images
    with gr.Row():
        split_image1 = gr.Image(type="pil", label="Part 1", interactive=False)
        split_image2 = gr.Image(type="pil", label="Part 2", interactive=False)
    with gr.Row():
        split_image3 = gr.Image(type="pil", label="Part 3", interactive=False)
        split_image4 = gr.Image(type="pil", label="Part 4", interactive=False)

    # Function to set split images
    def set_split_images(splits):
        if splits is None or len(splits) != 4:
            return [None, None, None, None]
        return splits

    split_button.click(
        fn=split_and_prepare,
        inputs=input_image,
        outputs=splits_state,
    )

    splits_state.change(
        fn=set_split_images,
        inputs=splits_state,
        outputs=[split_image1, split_image2, split_image3, split_image4],
    )

    # Add buttons to select each part
    with gr.Row():
        select_part1 = gr.Button("Select Part 1")
        select_part2 = gr.Button("Select Part 2")
    with gr.Row():
        select_part3 = gr.Button("Select Part 3")
        select_part4 = gr.Button("Select Part 4")

    selected_part = gr.Image(type="pil", label="Select Cropped Cochlear Image for Hair Cell Detection")
    part_pred = gr.Image(type="pil", label="Prediction on Selected Part", interactive=False)

    # Handle select part buttons
    select_part1.click(
        fn=lambda splits: select_image(splits, 0),
        inputs=splits_state,
        outputs=selected_part,
    )
    select_part2.click(
        fn=lambda splits: select_image(splits, 1),
        inputs=splits_state,
        outputs=selected_part,
    )
    select_part3.click(
        fn=lambda splits: select_image(splits, 2),
        inputs=splits_state,
        outputs=selected_part,
    )
    select_part4.click(
        fn=lambda splits: select_image(splits, 3),
        inputs=splits_state,
        outputs=selected_part,
    )

    selected_part.change(
        fn=predict_part,
        inputs=selected_part,
        outputs=part_pred,
    )

    interface.launch()