Spaces:
Runtime error
Runtime error
Create utils.py
Browse files
utils.py
ADDED
@@ -0,0 +1,471 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import cv2
|
3 |
+
import pandas as pd
|
4 |
+
import operator
|
5 |
+
import matplotlib.pyplot as plt
|
6 |
+
import os
|
7 |
+
from sklearn.model_selection import train_test_split
|
8 |
+
from tensorflow.keras.utils import Sequence
|
9 |
+
from config import yolo_config
|
10 |
+
|
11 |
+
|
12 |
+
def load_weights(model, weights_file_path):
|
13 |
+
conv_layer_size = 110
|
14 |
+
conv_output_idxs = [93, 101, 109]
|
15 |
+
with open(weights_file_path, 'rb') as file:
|
16 |
+
major, minor, revision, seen, _ = np.fromfile(file, dtype=np.int32, count=5)
|
17 |
+
|
18 |
+
bn_idx = 0
|
19 |
+
for conv_idx in range(conv_layer_size):
|
20 |
+
conv_layer_name = f'conv2d_{conv_idx}' if conv_idx > 0 else 'conv2d'
|
21 |
+
bn_layer_name = f'batch_normalization_{bn_idx}' if bn_idx > 0 else 'batch_normalization'
|
22 |
+
|
23 |
+
conv_layer = model.get_layer(conv_layer_name)
|
24 |
+
filters = conv_layer.filters
|
25 |
+
kernel_size = conv_layer.kernel_size[0]
|
26 |
+
input_dims = conv_layer.input_shape[-1]
|
27 |
+
|
28 |
+
if conv_idx not in conv_output_idxs:
|
29 |
+
# darknet bn layer weights: [beta, gamma, mean, variance]
|
30 |
+
bn_weights = np.fromfile(file, dtype=np.float32, count=4 * filters)
|
31 |
+
# tf bn layer weights: [gamma, beta, mean, variance]
|
32 |
+
bn_weights = bn_weights.reshape((4, filters))[[1, 0, 2, 3]]
|
33 |
+
bn_layer = model.get_layer(bn_layer_name)
|
34 |
+
bn_idx += 1
|
35 |
+
else:
|
36 |
+
conv_bias = np.fromfile(file, dtype=np.float32, count=filters)
|
37 |
+
|
38 |
+
# darknet shape: (out_dim, input_dims, height, width)
|
39 |
+
# tf shape: (height, width, input_dims, out_dim)
|
40 |
+
conv_shape = (filters, input_dims, kernel_size, kernel_size)
|
41 |
+
conv_weights = np.fromfile(file, dtype=np.float32, count=np.product(conv_shape))
|
42 |
+
conv_weights = conv_weights.reshape(conv_shape).transpose([2, 3, 1, 0])
|
43 |
+
|
44 |
+
if conv_idx not in conv_output_idxs:
|
45 |
+
conv_layer.set_weights([conv_weights])
|
46 |
+
bn_layer.set_weights(bn_weights)
|
47 |
+
else:
|
48 |
+
conv_layer.set_weights([conv_weights, conv_bias])
|
49 |
+
|
50 |
+
if len(file.read()) == 0:
|
51 |
+
print('all weights read')
|
52 |
+
else:
|
53 |
+
print(f'failed to read all weights, # of unread weights: {len(file.read())}')
|
54 |
+
|
55 |
+
|
56 |
+
def get_detection_data(img, model_outputs, class_names):
|
57 |
+
"""
|
58 |
+
:param img: target raw image
|
59 |
+
:param model_outputs: outputs from inference_model
|
60 |
+
:param class_names: list of object class names
|
61 |
+
:return:
|
62 |
+
"""
|
63 |
+
|
64 |
+
num_bboxes = model_outputs[-1][0]
|
65 |
+
boxes, scores, classes = [output[0][:num_bboxes] for output in model_outputs[:-1]]
|
66 |
+
|
67 |
+
h, w = img.shape[:2]
|
68 |
+
df = pd.DataFrame(boxes, columns=['x1', 'y1', 'x2', 'y2'])
|
69 |
+
df[['x1', 'x2']] = (df[['x1', 'x2']] * w).astype('int64')
|
70 |
+
df[['y1', 'y2']] = (df[['y1', 'y2']] * h).astype('int64')
|
71 |
+
df['class_name'] = np.array(class_names)[classes.astype('int64')]
|
72 |
+
df['score'] = scores
|
73 |
+
df['w'] = df['x2'] - df['x1']
|
74 |
+
df['h'] = df['y2'] - df['y1']
|
75 |
+
|
76 |
+
print(f'# of bboxes: {num_bboxes}')
|
77 |
+
return df
|
78 |
+
|
79 |
+
def read_annotation_lines(annotation_path, test_size=None, random_seed=5566):
|
80 |
+
with open(annotation_path) as f:
|
81 |
+
lines = f.readlines()
|
82 |
+
if test_size:
|
83 |
+
return train_test_split(lines, test_size=test_size, random_state=random_seed)
|
84 |
+
else:
|
85 |
+
return lines
|
86 |
+
|
87 |
+
def draw_bbox(img, detections, cmap, random_color=True, figsize=(10, 10), show_img=True, show_text=True):
|
88 |
+
"""
|
89 |
+
Draw bounding boxes on the img.
|
90 |
+
:param img: BGR img.
|
91 |
+
:param detections: pandas DataFrame containing detections
|
92 |
+
:param random_color: assign random color for each objects
|
93 |
+
:param cmap: object colormap
|
94 |
+
:param plot_img: if plot img with bboxes
|
95 |
+
:return: None
|
96 |
+
"""
|
97 |
+
img = np.array(img)
|
98 |
+
scale = max(img.shape[0:2]) / 416
|
99 |
+
line_width = int(2 * scale)
|
100 |
+
|
101 |
+
for _, row in detections.iterrows():
|
102 |
+
x1, y1, x2, y2, cls, score, w, h = row.values
|
103 |
+
color = list(np.random.random(size=3) * 255) if random_color else cmap[cls]
|
104 |
+
cv2.rectangle(img, (x1, y1), (x2, y2), color, line_width)
|
105 |
+
if show_text:
|
106 |
+
text = f'{cls} {score:.2f}'
|
107 |
+
font = cv2.FONT_HERSHEY_DUPLEX
|
108 |
+
font_scale = max(0.3 * scale, 0.3)
|
109 |
+
thickness = max(int(1 * scale), 1)
|
110 |
+
(text_width, text_height) = cv2.getTextSize(text, font, fontScale=font_scale, thickness=thickness)[0]
|
111 |
+
cv2.rectangle(img, (x1 - line_width//2, y1 - text_height), (x1 + text_width, y1), color, cv2.FILLED)
|
112 |
+
cv2.putText(img, text, (x1, y1), font, font_scale, (255, 255, 255), thickness, cv2.LINE_AA)
|
113 |
+
if show_img:
|
114 |
+
plt.figure(figsize=figsize)
|
115 |
+
plt.imshow(img)
|
116 |
+
plt.show()
|
117 |
+
return img
|
118 |
+
|
119 |
+
|
120 |
+
class DataGenerator(Sequence):
|
121 |
+
"""
|
122 |
+
Generates data for Keras
|
123 |
+
ref: https://stanford.edu/~shervine/blog/keras-how-to-generate-data-on-the-fly
|
124 |
+
"""
|
125 |
+
def __init__(self,
|
126 |
+
annotation_lines,
|
127 |
+
class_name_path,
|
128 |
+
folder_path,
|
129 |
+
max_boxes=100,
|
130 |
+
shuffle=True):
|
131 |
+
self.annotation_lines = annotation_lines
|
132 |
+
self.class_name_path = class_name_path
|
133 |
+
self.num_classes = len([line.strip() for line in open(class_name_path).readlines()])
|
134 |
+
self.num_gpu = yolo_config['num_gpu']
|
135 |
+
self.batch_size = yolo_config['batch_size'] * self.num_gpu
|
136 |
+
self.target_img_size = yolo_config['img_size']
|
137 |
+
self.anchors = np.array(yolo_config['anchors']).reshape((9, 2))
|
138 |
+
self.shuffle = shuffle
|
139 |
+
self.indexes = np.arange(len(self.annotation_lines))
|
140 |
+
self.folder_path = folder_path
|
141 |
+
self.max_boxes = max_boxes
|
142 |
+
self.on_epoch_end()
|
143 |
+
|
144 |
+
def __len__(self):
|
145 |
+
'number of batches per epoch'
|
146 |
+
return int(np.ceil(len(self.annotation_lines) / self.batch_size))
|
147 |
+
|
148 |
+
def __getitem__(self, index):
|
149 |
+
'Generate one batch of data'
|
150 |
+
|
151 |
+
# Generate indexes of the batch
|
152 |
+
idxs = self.indexes[index * self.batch_size:(index + 1) * self.batch_size]
|
153 |
+
|
154 |
+
# Find list of IDs
|
155 |
+
lines = [self.annotation_lines[i] for i in idxs]
|
156 |
+
|
157 |
+
# Generate data
|
158 |
+
X, y_tensor, y_bbox = self.__data_generation(lines)
|
159 |
+
|
160 |
+
return [X, *y_tensor, y_bbox], np.zeros(len(lines))
|
161 |
+
|
162 |
+
def on_epoch_end(self):
|
163 |
+
'Updates indexes after each epoch'
|
164 |
+
if self.shuffle:
|
165 |
+
np.random.shuffle(self.indexes)
|
166 |
+
|
167 |
+
def __data_generation(self, annotation_lines):
|
168 |
+
"""
|
169 |
+
Generates data containing batch_size samples
|
170 |
+
:param annotation_lines:
|
171 |
+
:return:
|
172 |
+
"""
|
173 |
+
|
174 |
+
X = np.empty((len(annotation_lines), *self.target_img_size), dtype=np.float32)
|
175 |
+
y_bbox = np.empty((len(annotation_lines), self.max_boxes, 5), dtype=np.float32) # x1y1x2y2
|
176 |
+
|
177 |
+
for i, line in enumerate(annotation_lines):
|
178 |
+
img_data, box_data = self.get_data(line)
|
179 |
+
X[i] = img_data
|
180 |
+
y_bbox[i] = box_data
|
181 |
+
|
182 |
+
y_tensor, y_true_boxes_xywh = preprocess_true_boxes(y_bbox, self.target_img_size[:2], self.anchors, self.num_classes)
|
183 |
+
|
184 |
+
return X, y_tensor, y_true_boxes_xywh
|
185 |
+
|
186 |
+
def get_data(self, annotation_line):
|
187 |
+
line = annotation_line.split()
|
188 |
+
img_path = line[0]
|
189 |
+
img = cv2.imread(os.path.join(self.folder_path, img_path))[:, :, ::-1]
|
190 |
+
ih, iw = img.shape[:2]
|
191 |
+
h, w, c = self.target_img_size
|
192 |
+
boxes = np.array([np.array(list(map(float, box.split(',')))) for box in line[1:]], dtype=np.float32) # x1y1x2y2
|
193 |
+
scale_w, scale_h = w / iw, h / ih
|
194 |
+
img = cv2.resize(img, (w, h))
|
195 |
+
image_data = np.array(img) / 255.
|
196 |
+
|
197 |
+
# correct boxes coordinates
|
198 |
+
box_data = np.zeros((self.max_boxes, 5))
|
199 |
+
if len(boxes) > 0:
|
200 |
+
np.random.shuffle(boxes)
|
201 |
+
boxes = boxes[:self.max_boxes]
|
202 |
+
boxes[:, [0, 2]] = boxes[:, [0, 2]] * scale_w # + dx
|
203 |
+
boxes[:, [1, 3]] = boxes[:, [1, 3]] * scale_h # + dy
|
204 |
+
box_data[:len(boxes)] = boxes
|
205 |
+
|
206 |
+
return image_data, box_data
|
207 |
+
|
208 |
+
|
209 |
+
def preprocess_true_boxes(true_boxes, input_shape, anchors, num_classes):
|
210 |
+
'''Preprocess true boxes to training input format
|
211 |
+
Parameters
|
212 |
+
----------
|
213 |
+
true_boxes: array, shape=(bs, max boxes per img, 5)
|
214 |
+
Absolute x_min, y_min, x_max, y_max, class_id relative to input_shape.
|
215 |
+
input_shape: array-like, hw, multiples of 32
|
216 |
+
anchors: array, shape=(N, 2), (9, wh)
|
217 |
+
num_classes: int
|
218 |
+
Returns
|
219 |
+
-------
|
220 |
+
y_true: list of array, shape like yolo_outputs, xywh are reletive value
|
221 |
+
'''
|
222 |
+
|
223 |
+
num_stages = 3 # default setting for yolo, tiny yolo will be 2
|
224 |
+
anchor_mask = [[0, 1, 2], [3, 4, 5], [6, 7, 8]]
|
225 |
+
bbox_per_grid = 3
|
226 |
+
true_boxes = np.array(true_boxes, dtype='float32')
|
227 |
+
true_boxes_abs = np.array(true_boxes, dtype='float32')
|
228 |
+
input_shape = np.array(input_shape, dtype='int32')
|
229 |
+
true_boxes_xy = (true_boxes_abs[..., 0:2] + true_boxes_abs[..., 2:4]) // 2 # (100, 2)
|
230 |
+
true_boxes_wh = true_boxes_abs[..., 2:4] - true_boxes_abs[..., 0:2] # (100, 2)
|
231 |
+
|
232 |
+
# Normalize x,y,w, h, relative to img size -> (0~1)
|
233 |
+
true_boxes[..., 0:2] = true_boxes_xy/input_shape[::-1] # xy
|
234 |
+
true_boxes[..., 2:4] = true_boxes_wh/input_shape[::-1] # wh
|
235 |
+
|
236 |
+
bs = true_boxes.shape[0]
|
237 |
+
grid_sizes = [input_shape//{0:8, 1:16, 2:32}[stage] for stage in range(num_stages)]
|
238 |
+
y_true = [np.zeros((bs,
|
239 |
+
grid_sizes[s][0],
|
240 |
+
grid_sizes[s][1],
|
241 |
+
bbox_per_grid,
|
242 |
+
5+num_classes), dtype='float32')
|
243 |
+
for s in range(num_stages)]
|
244 |
+
# [(?, 52, 52, 3, 5+num_classes) (?, 26, 26, 3, 5+num_classes) (?, 13, 13, 3, 5+num_classes) ]
|
245 |
+
y_true_boxes_xywh = np.concatenate((true_boxes_xy, true_boxes_wh), axis=-1)
|
246 |
+
# Expand dim to apply broadcasting.
|
247 |
+
anchors = np.expand_dims(anchors, 0) # (1, 9 , 2)
|
248 |
+
anchor_maxes = anchors / 2. # (1, 9 , 2)
|
249 |
+
anchor_mins = -anchor_maxes # (1, 9 , 2)
|
250 |
+
valid_mask = true_boxes_wh[..., 0] > 0 # (1, 100)
|
251 |
+
|
252 |
+
for batch_idx in range(bs):
|
253 |
+
# Discard zero rows.
|
254 |
+
wh = true_boxes_wh[batch_idx, valid_mask[batch_idx]] # (# of bbox, 2)
|
255 |
+
num_boxes = len(wh)
|
256 |
+
if num_boxes == 0: continue
|
257 |
+
wh = np.expand_dims(wh, -2) # (# of bbox, 1, 2)
|
258 |
+
box_maxes = wh / 2. # (# of bbox, 1, 2)
|
259 |
+
box_mins = -box_maxes # (# of bbox, 1, 2)
|
260 |
+
|
261 |
+
# Compute IoU between each anchors and true boxes for responsibility assignment
|
262 |
+
intersect_mins = np.maximum(box_mins, anchor_mins) # (# of bbox, 9, 2)
|
263 |
+
intersect_maxes = np.minimum(box_maxes, anchor_maxes)
|
264 |
+
intersect_wh = np.maximum(intersect_maxes - intersect_mins, 0.)
|
265 |
+
intersect_area = np.prod(intersect_wh, axis=-1) # (9,)
|
266 |
+
box_area = wh[..., 0] * wh[..., 1] # (# of bbox, 1)
|
267 |
+
anchor_area = anchors[..., 0] * anchors[..., 1] # (1, 9)
|
268 |
+
iou = intersect_area / (box_area + anchor_area - intersect_area) # (# of bbox, 9)
|
269 |
+
|
270 |
+
# Find best anchor for each true box
|
271 |
+
best_anchors = np.argmax(iou, axis=-1) # (# of bbox,)
|
272 |
+
for box_idx in range(num_boxes):
|
273 |
+
best_anchor = best_anchors[box_idx]
|
274 |
+
for stage in range(num_stages):
|
275 |
+
if best_anchor in anchor_mask[stage]:
|
276 |
+
x_offset = true_boxes[batch_idx, box_idx, 0]*grid_sizes[stage][1]
|
277 |
+
y_offset = true_boxes[batch_idx, box_idx, 1]*grid_sizes[stage][0]
|
278 |
+
# Grid Index
|
279 |
+
grid_col = np.floor(x_offset).astype('int32')
|
280 |
+
grid_row = np.floor(y_offset).astype('int32')
|
281 |
+
anchor_idx = anchor_mask[stage].index(best_anchor)
|
282 |
+
class_idx = true_boxes[batch_idx, box_idx, 4].astype('int32')
|
283 |
+
# y_true[stage][batch_idx, grid_row, grid_col, anchor_idx, 0] = x_offset - grid_col # x
|
284 |
+
# y_true[stage][batch_idx, grid_row, grid_col, anchor_idx, 1] = y_offset - grid_row # y
|
285 |
+
# y_true[stage][batch_idx, grid_row, grid_col, anchor_idx, :4] = true_boxes_abs[batch_idx, box_idx, :4] # abs xywh
|
286 |
+
y_true[stage][batch_idx, grid_row, grid_col, anchor_idx, :2] = true_boxes_xy[batch_idx, box_idx, :] # abs xy
|
287 |
+
y_true[stage][batch_idx, grid_row, grid_col, anchor_idx, 2:4] = true_boxes_wh[batch_idx, box_idx, :] # abs wh
|
288 |
+
y_true[stage][batch_idx, grid_row, grid_col, anchor_idx, 4] = 1 # confidence
|
289 |
+
|
290 |
+
y_true[stage][batch_idx, grid_row, grid_col, anchor_idx, 5+class_idx] = 1 # one-hot encoding
|
291 |
+
# smooth
|
292 |
+
# onehot = np.zeros(num_classes, dtype=np.float)
|
293 |
+
# onehot[class_idx] = 1.0
|
294 |
+
# uniform_distribution = np.full(num_classes, 1.0 / num_classes)
|
295 |
+
# delta = 0.01
|
296 |
+
# smooth_onehot = onehot * (1 - delta) + delta * uniform_distribution
|
297 |
+
# y_true[stage][batch_idx, grid_row, grid_col, anchor_idx, 5:] = smooth_onehot
|
298 |
+
|
299 |
+
return y_true, y_true_boxes_xywh
|
300 |
+
|
301 |
+
"""
|
302 |
+
Calculate the AP given the recall and precision array
|
303 |
+
1st) We compute a version of the measured precision/recall curve with
|
304 |
+
precision monotonically decreasing
|
305 |
+
2nd) We compute the AP as the area under this curve by numerical integration.
|
306 |
+
"""
|
307 |
+
def voc_ap(rec, prec):
|
308 |
+
"""
|
309 |
+
--- Official matlab code VOC2012---
|
310 |
+
mrec=[0 ; rec ; 1];
|
311 |
+
mpre=[0 ; prec ; 0];
|
312 |
+
for i=numel(mpre)-1:-1:1
|
313 |
+
mpre(i)=max(mpre(i),mpre(i+1));
|
314 |
+
end
|
315 |
+
i=find(mrec(2:end)~=mrec(1:end-1))+1;
|
316 |
+
ap=sum((mrec(i)-mrec(i-1)).*mpre(i));
|
317 |
+
"""
|
318 |
+
rec.insert(0, 0.0) # insert 0.0 at begining of list
|
319 |
+
rec.append(1.0) # insert 1.0 at end of list
|
320 |
+
mrec = rec[:]
|
321 |
+
prec.insert(0, 0.0) # insert 0.0 at begining of list
|
322 |
+
prec.append(0.0) # insert 0.0 at end of list
|
323 |
+
mpre = prec[:]
|
324 |
+
"""
|
325 |
+
This part makes the precision monotonically decreasing
|
326 |
+
(goes from the end to the beginning)
|
327 |
+
matlab: for i=numel(mpre)-1:-1:1
|
328 |
+
mpre(i)=max(mpre(i),mpre(i+1));
|
329 |
+
"""
|
330 |
+
# matlab indexes start in 1 but python in 0, so I have to do:
|
331 |
+
# range(start=(len(mpre) - 2), end=0, step=-1)
|
332 |
+
# also the python function range excludes the end, resulting in:
|
333 |
+
# range(start=(len(mpre) - 2), end=-1, step=-1)
|
334 |
+
for i in range(len(mpre)-2, -1, -1):
|
335 |
+
mpre[i] = max(mpre[i], mpre[i+1])
|
336 |
+
"""
|
337 |
+
This part creates a list of indexes where the recall changes
|
338 |
+
matlab: i=find(mrec(2:end)~=mrec(1:end-1))+1;
|
339 |
+
"""
|
340 |
+
i_list = []
|
341 |
+
for i in range(1, len(mrec)):
|
342 |
+
if mrec[i] != mrec[i-1]:
|
343 |
+
i_list.append(i) # if it was matlab would be i + 1
|
344 |
+
"""
|
345 |
+
The Average Precision (AP) is the area under the curve
|
346 |
+
(numerical integration)
|
347 |
+
matlab: ap=sum((mrec(i)-mrec(i-1)).*mpre(i));
|
348 |
+
"""
|
349 |
+
ap = 0.0
|
350 |
+
for i in i_list:
|
351 |
+
ap += ((mrec[i]-mrec[i-1])*mpre[i])
|
352 |
+
return ap, mrec, mpre
|
353 |
+
|
354 |
+
"""
|
355 |
+
Draw plot using Matplotlib
|
356 |
+
"""
|
357 |
+
def draw_plot_func(dictionary, n_classes, window_title, plot_title, x_label, output_path, to_show, plot_color, true_p_bar):
|
358 |
+
# sort the dictionary by decreasing value, into a list of tuples
|
359 |
+
sorted_dic_by_value = sorted(dictionary.items(), key=operator.itemgetter(1))
|
360 |
+
print(sorted_dic_by_value)
|
361 |
+
# unpacking the list of tuples into two lists
|
362 |
+
sorted_keys, sorted_values = zip(*sorted_dic_by_value)
|
363 |
+
#
|
364 |
+
if true_p_bar != "":
|
365 |
+
"""
|
366 |
+
Special case to draw in:
|
367 |
+
- green -> TP: True Positives (object detected and matches ground-truth)
|
368 |
+
- red -> FP: False Positives (object detected but does not match ground-truth)
|
369 |
+
- pink -> FN: False Negatives (object not detected but present in the ground-truth)
|
370 |
+
"""
|
371 |
+
fp_sorted = []
|
372 |
+
tp_sorted = []
|
373 |
+
for key in sorted_keys:
|
374 |
+
fp_sorted.append(dictionary[key] - true_p_bar[key])
|
375 |
+
tp_sorted.append(true_p_bar[key])
|
376 |
+
plt.barh(range(n_classes), fp_sorted, align='center', color='crimson', label='False Positive')
|
377 |
+
plt.barh(range(n_classes), tp_sorted, align='center', color='forestgreen', label='True Positive', left=fp_sorted)
|
378 |
+
# add legend
|
379 |
+
plt.legend(loc='lower right')
|
380 |
+
"""
|
381 |
+
Write number on side of bar
|
382 |
+
"""
|
383 |
+
fig = plt.gcf() # gcf - get current figure
|
384 |
+
axes = plt.gca()
|
385 |
+
r = fig.canvas.get_renderer()
|
386 |
+
for i, val in enumerate(sorted_values):
|
387 |
+
fp_val = fp_sorted[i]
|
388 |
+
tp_val = tp_sorted[i]
|
389 |
+
fp_str_val = " " + str(fp_val)
|
390 |
+
tp_str_val = fp_str_val + " " + str(tp_val)
|
391 |
+
# trick to paint multicolor with offset:
|
392 |
+
# first paint everything and then repaint the first number
|
393 |
+
t = plt.text(val, i, tp_str_val, color='forestgreen', va='center', fontweight='bold')
|
394 |
+
plt.text(val, i, fp_str_val, color='crimson', va='center', fontweight='bold')
|
395 |
+
if i == (len(sorted_values)-1): # largest bar
|
396 |
+
adjust_axes(r, t, fig, axes)
|
397 |
+
else:
|
398 |
+
plt.barh(range(n_classes), sorted_values, color=plot_color)
|
399 |
+
"""
|
400 |
+
Write number on side of bar
|
401 |
+
"""
|
402 |
+
fig = plt.gcf() # gcf - get current figure
|
403 |
+
axes = plt.gca()
|
404 |
+
r = fig.canvas.get_renderer()
|
405 |
+
for i, val in enumerate(sorted_values):
|
406 |
+
str_val = " " + str(val) # add a space before
|
407 |
+
if val < 1.0:
|
408 |
+
str_val = " {0:.2f}".format(val)
|
409 |
+
t = plt.text(val, i, str_val, color=plot_color, va='center', fontweight='bold')
|
410 |
+
# re-set axes to show number inside the figure
|
411 |
+
if i == (len(sorted_values)-1): # largest bar
|
412 |
+
adjust_axes(r, t, fig, axes)
|
413 |
+
# set window title
|
414 |
+
fig.canvas.set_window_title(window_title)
|
415 |
+
# write classes in y axis
|
416 |
+
tick_font_size = 12
|
417 |
+
plt.yticks(range(n_classes), sorted_keys, fontsize=tick_font_size)
|
418 |
+
"""
|
419 |
+
Re-scale height accordingly
|
420 |
+
"""
|
421 |
+
init_height = fig.get_figheight()
|
422 |
+
# comput the matrix height in points and inches
|
423 |
+
dpi = fig.dpi
|
424 |
+
height_pt = n_classes * (tick_font_size * 1.4) # 1.4 (some spacing)
|
425 |
+
height_in = height_pt / dpi
|
426 |
+
# compute the required figure height
|
427 |
+
top_margin = 0.15 # in percentage of the figure height
|
428 |
+
bottom_margin = 0.05 # in percentage of the figure height
|
429 |
+
figure_height = height_in / (1 - top_margin - bottom_margin)
|
430 |
+
# set new height
|
431 |
+
if figure_height > init_height:
|
432 |
+
fig.set_figheight(figure_height)
|
433 |
+
|
434 |
+
# set plot title
|
435 |
+
plt.title(plot_title, fontsize=14)
|
436 |
+
# set axis titles
|
437 |
+
# plt.xlabel('classes')
|
438 |
+
plt.xlabel(x_label, fontsize='large')
|
439 |
+
# adjust size of window
|
440 |
+
fig.tight_layout()
|
441 |
+
# save the plot
|
442 |
+
fig.savefig(output_path)
|
443 |
+
# show image
|
444 |
+
# if to_show:
|
445 |
+
plt.show()
|
446 |
+
# close the plot
|
447 |
+
# plt.close()
|
448 |
+
|
449 |
+
"""
|
450 |
+
Plot - adjust axes
|
451 |
+
"""
|
452 |
+
def adjust_axes(r, t, fig, axes):
|
453 |
+
# get text width for re-scaling
|
454 |
+
bb = t.get_window_extent(renderer=r)
|
455 |
+
text_width_inches = bb.width / fig.dpi
|
456 |
+
# get axis width in inches
|
457 |
+
current_fig_width = fig.get_figwidth()
|
458 |
+
new_fig_width = current_fig_width + text_width_inches
|
459 |
+
propotion = new_fig_width / current_fig_width
|
460 |
+
# get axis limit
|
461 |
+
x_lim = axes.get_xlim()
|
462 |
+
axes.set_xlim([x_lim[0], x_lim[1]*propotion])
|
463 |
+
|
464 |
+
|
465 |
+
def read_txt_to_list(path):
|
466 |
+
# open txt file lines to a list
|
467 |
+
with open(path) as f:
|
468 |
+
content = f.readlines()
|
469 |
+
# remove whitespace characters like `\n` at the end of each line
|
470 |
+
content = [x.strip() for x in content]
|
471 |
+
return content
|