princeml commited on
Commit
36fe33c
·
1 Parent(s): a4dd585

Create utils.py

Browse files
Files changed (1) hide show
  1. utils.py +471 -0
utils.py ADDED
@@ -0,0 +1,471 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import cv2
3
+ import pandas as pd
4
+ import operator
5
+ import matplotlib.pyplot as plt
6
+ import os
7
+ from sklearn.model_selection import train_test_split
8
+ from tensorflow.keras.utils import Sequence
9
+ from config import yolo_config
10
+
11
+
12
+ def load_weights(model, weights_file_path):
13
+ conv_layer_size = 110
14
+ conv_output_idxs = [93, 101, 109]
15
+ with open(weights_file_path, 'rb') as file:
16
+ major, minor, revision, seen, _ = np.fromfile(file, dtype=np.int32, count=5)
17
+
18
+ bn_idx = 0
19
+ for conv_idx in range(conv_layer_size):
20
+ conv_layer_name = f'conv2d_{conv_idx}' if conv_idx > 0 else 'conv2d'
21
+ bn_layer_name = f'batch_normalization_{bn_idx}' if bn_idx > 0 else 'batch_normalization'
22
+
23
+ conv_layer = model.get_layer(conv_layer_name)
24
+ filters = conv_layer.filters
25
+ kernel_size = conv_layer.kernel_size[0]
26
+ input_dims = conv_layer.input_shape[-1]
27
+
28
+ if conv_idx not in conv_output_idxs:
29
+ # darknet bn layer weights: [beta, gamma, mean, variance]
30
+ bn_weights = np.fromfile(file, dtype=np.float32, count=4 * filters)
31
+ # tf bn layer weights: [gamma, beta, mean, variance]
32
+ bn_weights = bn_weights.reshape((4, filters))[[1, 0, 2, 3]]
33
+ bn_layer = model.get_layer(bn_layer_name)
34
+ bn_idx += 1
35
+ else:
36
+ conv_bias = np.fromfile(file, dtype=np.float32, count=filters)
37
+
38
+ # darknet shape: (out_dim, input_dims, height, width)
39
+ # tf shape: (height, width, input_dims, out_dim)
40
+ conv_shape = (filters, input_dims, kernel_size, kernel_size)
41
+ conv_weights = np.fromfile(file, dtype=np.float32, count=np.product(conv_shape))
42
+ conv_weights = conv_weights.reshape(conv_shape).transpose([2, 3, 1, 0])
43
+
44
+ if conv_idx not in conv_output_idxs:
45
+ conv_layer.set_weights([conv_weights])
46
+ bn_layer.set_weights(bn_weights)
47
+ else:
48
+ conv_layer.set_weights([conv_weights, conv_bias])
49
+
50
+ if len(file.read()) == 0:
51
+ print('all weights read')
52
+ else:
53
+ print(f'failed to read all weights, # of unread weights: {len(file.read())}')
54
+
55
+
56
+ def get_detection_data(img, model_outputs, class_names):
57
+ """
58
+ :param img: target raw image
59
+ :param model_outputs: outputs from inference_model
60
+ :param class_names: list of object class names
61
+ :return:
62
+ """
63
+
64
+ num_bboxes = model_outputs[-1][0]
65
+ boxes, scores, classes = [output[0][:num_bboxes] for output in model_outputs[:-1]]
66
+
67
+ h, w = img.shape[:2]
68
+ df = pd.DataFrame(boxes, columns=['x1', 'y1', 'x2', 'y2'])
69
+ df[['x1', 'x2']] = (df[['x1', 'x2']] * w).astype('int64')
70
+ df[['y1', 'y2']] = (df[['y1', 'y2']] * h).astype('int64')
71
+ df['class_name'] = np.array(class_names)[classes.astype('int64')]
72
+ df['score'] = scores
73
+ df['w'] = df['x2'] - df['x1']
74
+ df['h'] = df['y2'] - df['y1']
75
+
76
+ print(f'# of bboxes: {num_bboxes}')
77
+ return df
78
+
79
+ def read_annotation_lines(annotation_path, test_size=None, random_seed=5566):
80
+ with open(annotation_path) as f:
81
+ lines = f.readlines()
82
+ if test_size:
83
+ return train_test_split(lines, test_size=test_size, random_state=random_seed)
84
+ else:
85
+ return lines
86
+
87
+ def draw_bbox(img, detections, cmap, random_color=True, figsize=(10, 10), show_img=True, show_text=True):
88
+ """
89
+ Draw bounding boxes on the img.
90
+ :param img: BGR img.
91
+ :param detections: pandas DataFrame containing detections
92
+ :param random_color: assign random color for each objects
93
+ :param cmap: object colormap
94
+ :param plot_img: if plot img with bboxes
95
+ :return: None
96
+ """
97
+ img = np.array(img)
98
+ scale = max(img.shape[0:2]) / 416
99
+ line_width = int(2 * scale)
100
+
101
+ for _, row in detections.iterrows():
102
+ x1, y1, x2, y2, cls, score, w, h = row.values
103
+ color = list(np.random.random(size=3) * 255) if random_color else cmap[cls]
104
+ cv2.rectangle(img, (x1, y1), (x2, y2), color, line_width)
105
+ if show_text:
106
+ text = f'{cls} {score:.2f}'
107
+ font = cv2.FONT_HERSHEY_DUPLEX
108
+ font_scale = max(0.3 * scale, 0.3)
109
+ thickness = max(int(1 * scale), 1)
110
+ (text_width, text_height) = cv2.getTextSize(text, font, fontScale=font_scale, thickness=thickness)[0]
111
+ cv2.rectangle(img, (x1 - line_width//2, y1 - text_height), (x1 + text_width, y1), color, cv2.FILLED)
112
+ cv2.putText(img, text, (x1, y1), font, font_scale, (255, 255, 255), thickness, cv2.LINE_AA)
113
+ if show_img:
114
+ plt.figure(figsize=figsize)
115
+ plt.imshow(img)
116
+ plt.show()
117
+ return img
118
+
119
+
120
+ class DataGenerator(Sequence):
121
+ """
122
+ Generates data for Keras
123
+ ref: https://stanford.edu/~shervine/blog/keras-how-to-generate-data-on-the-fly
124
+ """
125
+ def __init__(self,
126
+ annotation_lines,
127
+ class_name_path,
128
+ folder_path,
129
+ max_boxes=100,
130
+ shuffle=True):
131
+ self.annotation_lines = annotation_lines
132
+ self.class_name_path = class_name_path
133
+ self.num_classes = len([line.strip() for line in open(class_name_path).readlines()])
134
+ self.num_gpu = yolo_config['num_gpu']
135
+ self.batch_size = yolo_config['batch_size'] * self.num_gpu
136
+ self.target_img_size = yolo_config['img_size']
137
+ self.anchors = np.array(yolo_config['anchors']).reshape((9, 2))
138
+ self.shuffle = shuffle
139
+ self.indexes = np.arange(len(self.annotation_lines))
140
+ self.folder_path = folder_path
141
+ self.max_boxes = max_boxes
142
+ self.on_epoch_end()
143
+
144
+ def __len__(self):
145
+ 'number of batches per epoch'
146
+ return int(np.ceil(len(self.annotation_lines) / self.batch_size))
147
+
148
+ def __getitem__(self, index):
149
+ 'Generate one batch of data'
150
+
151
+ # Generate indexes of the batch
152
+ idxs = self.indexes[index * self.batch_size:(index + 1) * self.batch_size]
153
+
154
+ # Find list of IDs
155
+ lines = [self.annotation_lines[i] for i in idxs]
156
+
157
+ # Generate data
158
+ X, y_tensor, y_bbox = self.__data_generation(lines)
159
+
160
+ return [X, *y_tensor, y_bbox], np.zeros(len(lines))
161
+
162
+ def on_epoch_end(self):
163
+ 'Updates indexes after each epoch'
164
+ if self.shuffle:
165
+ np.random.shuffle(self.indexes)
166
+
167
+ def __data_generation(self, annotation_lines):
168
+ """
169
+ Generates data containing batch_size samples
170
+ :param annotation_lines:
171
+ :return:
172
+ """
173
+
174
+ X = np.empty((len(annotation_lines), *self.target_img_size), dtype=np.float32)
175
+ y_bbox = np.empty((len(annotation_lines), self.max_boxes, 5), dtype=np.float32) # x1y1x2y2
176
+
177
+ for i, line in enumerate(annotation_lines):
178
+ img_data, box_data = self.get_data(line)
179
+ X[i] = img_data
180
+ y_bbox[i] = box_data
181
+
182
+ y_tensor, y_true_boxes_xywh = preprocess_true_boxes(y_bbox, self.target_img_size[:2], self.anchors, self.num_classes)
183
+
184
+ return X, y_tensor, y_true_boxes_xywh
185
+
186
+ def get_data(self, annotation_line):
187
+ line = annotation_line.split()
188
+ img_path = line[0]
189
+ img = cv2.imread(os.path.join(self.folder_path, img_path))[:, :, ::-1]
190
+ ih, iw = img.shape[:2]
191
+ h, w, c = self.target_img_size
192
+ boxes = np.array([np.array(list(map(float, box.split(',')))) for box in line[1:]], dtype=np.float32) # x1y1x2y2
193
+ scale_w, scale_h = w / iw, h / ih
194
+ img = cv2.resize(img, (w, h))
195
+ image_data = np.array(img) / 255.
196
+
197
+ # correct boxes coordinates
198
+ box_data = np.zeros((self.max_boxes, 5))
199
+ if len(boxes) > 0:
200
+ np.random.shuffle(boxes)
201
+ boxes = boxes[:self.max_boxes]
202
+ boxes[:, [0, 2]] = boxes[:, [0, 2]] * scale_w # + dx
203
+ boxes[:, [1, 3]] = boxes[:, [1, 3]] * scale_h # + dy
204
+ box_data[:len(boxes)] = boxes
205
+
206
+ return image_data, box_data
207
+
208
+
209
+ def preprocess_true_boxes(true_boxes, input_shape, anchors, num_classes):
210
+ '''Preprocess true boxes to training input format
211
+ Parameters
212
+ ----------
213
+ true_boxes: array, shape=(bs, max boxes per img, 5)
214
+ Absolute x_min, y_min, x_max, y_max, class_id relative to input_shape.
215
+ input_shape: array-like, hw, multiples of 32
216
+ anchors: array, shape=(N, 2), (9, wh)
217
+ num_classes: int
218
+ Returns
219
+ -------
220
+ y_true: list of array, shape like yolo_outputs, xywh are reletive value
221
+ '''
222
+
223
+ num_stages = 3 # default setting for yolo, tiny yolo will be 2
224
+ anchor_mask = [[0, 1, 2], [3, 4, 5], [6, 7, 8]]
225
+ bbox_per_grid = 3
226
+ true_boxes = np.array(true_boxes, dtype='float32')
227
+ true_boxes_abs = np.array(true_boxes, dtype='float32')
228
+ input_shape = np.array(input_shape, dtype='int32')
229
+ true_boxes_xy = (true_boxes_abs[..., 0:2] + true_boxes_abs[..., 2:4]) // 2 # (100, 2)
230
+ true_boxes_wh = true_boxes_abs[..., 2:4] - true_boxes_abs[..., 0:2] # (100, 2)
231
+
232
+ # Normalize x,y,w, h, relative to img size -> (0~1)
233
+ true_boxes[..., 0:2] = true_boxes_xy/input_shape[::-1] # xy
234
+ true_boxes[..., 2:4] = true_boxes_wh/input_shape[::-1] # wh
235
+
236
+ bs = true_boxes.shape[0]
237
+ grid_sizes = [input_shape//{0:8, 1:16, 2:32}[stage] for stage in range(num_stages)]
238
+ y_true = [np.zeros((bs,
239
+ grid_sizes[s][0],
240
+ grid_sizes[s][1],
241
+ bbox_per_grid,
242
+ 5+num_classes), dtype='float32')
243
+ for s in range(num_stages)]
244
+ # [(?, 52, 52, 3, 5+num_classes) (?, 26, 26, 3, 5+num_classes) (?, 13, 13, 3, 5+num_classes) ]
245
+ y_true_boxes_xywh = np.concatenate((true_boxes_xy, true_boxes_wh), axis=-1)
246
+ # Expand dim to apply broadcasting.
247
+ anchors = np.expand_dims(anchors, 0) # (1, 9 , 2)
248
+ anchor_maxes = anchors / 2. # (1, 9 , 2)
249
+ anchor_mins = -anchor_maxes # (1, 9 , 2)
250
+ valid_mask = true_boxes_wh[..., 0] > 0 # (1, 100)
251
+
252
+ for batch_idx in range(bs):
253
+ # Discard zero rows.
254
+ wh = true_boxes_wh[batch_idx, valid_mask[batch_idx]] # (# of bbox, 2)
255
+ num_boxes = len(wh)
256
+ if num_boxes == 0: continue
257
+ wh = np.expand_dims(wh, -2) # (# of bbox, 1, 2)
258
+ box_maxes = wh / 2. # (# of bbox, 1, 2)
259
+ box_mins = -box_maxes # (# of bbox, 1, 2)
260
+
261
+ # Compute IoU between each anchors and true boxes for responsibility assignment
262
+ intersect_mins = np.maximum(box_mins, anchor_mins) # (# of bbox, 9, 2)
263
+ intersect_maxes = np.minimum(box_maxes, anchor_maxes)
264
+ intersect_wh = np.maximum(intersect_maxes - intersect_mins, 0.)
265
+ intersect_area = np.prod(intersect_wh, axis=-1) # (9,)
266
+ box_area = wh[..., 0] * wh[..., 1] # (# of bbox, 1)
267
+ anchor_area = anchors[..., 0] * anchors[..., 1] # (1, 9)
268
+ iou = intersect_area / (box_area + anchor_area - intersect_area) # (# of bbox, 9)
269
+
270
+ # Find best anchor for each true box
271
+ best_anchors = np.argmax(iou, axis=-1) # (# of bbox,)
272
+ for box_idx in range(num_boxes):
273
+ best_anchor = best_anchors[box_idx]
274
+ for stage in range(num_stages):
275
+ if best_anchor in anchor_mask[stage]:
276
+ x_offset = true_boxes[batch_idx, box_idx, 0]*grid_sizes[stage][1]
277
+ y_offset = true_boxes[batch_idx, box_idx, 1]*grid_sizes[stage][0]
278
+ # Grid Index
279
+ grid_col = np.floor(x_offset).astype('int32')
280
+ grid_row = np.floor(y_offset).astype('int32')
281
+ anchor_idx = anchor_mask[stage].index(best_anchor)
282
+ class_idx = true_boxes[batch_idx, box_idx, 4].astype('int32')
283
+ # y_true[stage][batch_idx, grid_row, grid_col, anchor_idx, 0] = x_offset - grid_col # x
284
+ # y_true[stage][batch_idx, grid_row, grid_col, anchor_idx, 1] = y_offset - grid_row # y
285
+ # y_true[stage][batch_idx, grid_row, grid_col, anchor_idx, :4] = true_boxes_abs[batch_idx, box_idx, :4] # abs xywh
286
+ y_true[stage][batch_idx, grid_row, grid_col, anchor_idx, :2] = true_boxes_xy[batch_idx, box_idx, :] # abs xy
287
+ y_true[stage][batch_idx, grid_row, grid_col, anchor_idx, 2:4] = true_boxes_wh[batch_idx, box_idx, :] # abs wh
288
+ y_true[stage][batch_idx, grid_row, grid_col, anchor_idx, 4] = 1 # confidence
289
+
290
+ y_true[stage][batch_idx, grid_row, grid_col, anchor_idx, 5+class_idx] = 1 # one-hot encoding
291
+ # smooth
292
+ # onehot = np.zeros(num_classes, dtype=np.float)
293
+ # onehot[class_idx] = 1.0
294
+ # uniform_distribution = np.full(num_classes, 1.0 / num_classes)
295
+ # delta = 0.01
296
+ # smooth_onehot = onehot * (1 - delta) + delta * uniform_distribution
297
+ # y_true[stage][batch_idx, grid_row, grid_col, anchor_idx, 5:] = smooth_onehot
298
+
299
+ return y_true, y_true_boxes_xywh
300
+
301
+ """
302
+ Calculate the AP given the recall and precision array
303
+ 1st) We compute a version of the measured precision/recall curve with
304
+ precision monotonically decreasing
305
+ 2nd) We compute the AP as the area under this curve by numerical integration.
306
+ """
307
+ def voc_ap(rec, prec):
308
+ """
309
+ --- Official matlab code VOC2012---
310
+ mrec=[0 ; rec ; 1];
311
+ mpre=[0 ; prec ; 0];
312
+ for i=numel(mpre)-1:-1:1
313
+ mpre(i)=max(mpre(i),mpre(i+1));
314
+ end
315
+ i=find(mrec(2:end)~=mrec(1:end-1))+1;
316
+ ap=sum((mrec(i)-mrec(i-1)).*mpre(i));
317
+ """
318
+ rec.insert(0, 0.0) # insert 0.0 at begining of list
319
+ rec.append(1.0) # insert 1.0 at end of list
320
+ mrec = rec[:]
321
+ prec.insert(0, 0.0) # insert 0.0 at begining of list
322
+ prec.append(0.0) # insert 0.0 at end of list
323
+ mpre = prec[:]
324
+ """
325
+ This part makes the precision monotonically decreasing
326
+ (goes from the end to the beginning)
327
+ matlab: for i=numel(mpre)-1:-1:1
328
+ mpre(i)=max(mpre(i),mpre(i+1));
329
+ """
330
+ # matlab indexes start in 1 but python in 0, so I have to do:
331
+ # range(start=(len(mpre) - 2), end=0, step=-1)
332
+ # also the python function range excludes the end, resulting in:
333
+ # range(start=(len(mpre) - 2), end=-1, step=-1)
334
+ for i in range(len(mpre)-2, -1, -1):
335
+ mpre[i] = max(mpre[i], mpre[i+1])
336
+ """
337
+ This part creates a list of indexes where the recall changes
338
+ matlab: i=find(mrec(2:end)~=mrec(1:end-1))+1;
339
+ """
340
+ i_list = []
341
+ for i in range(1, len(mrec)):
342
+ if mrec[i] != mrec[i-1]:
343
+ i_list.append(i) # if it was matlab would be i + 1
344
+ """
345
+ The Average Precision (AP) is the area under the curve
346
+ (numerical integration)
347
+ matlab: ap=sum((mrec(i)-mrec(i-1)).*mpre(i));
348
+ """
349
+ ap = 0.0
350
+ for i in i_list:
351
+ ap += ((mrec[i]-mrec[i-1])*mpre[i])
352
+ return ap, mrec, mpre
353
+
354
+ """
355
+ Draw plot using Matplotlib
356
+ """
357
+ def draw_plot_func(dictionary, n_classes, window_title, plot_title, x_label, output_path, to_show, plot_color, true_p_bar):
358
+ # sort the dictionary by decreasing value, into a list of tuples
359
+ sorted_dic_by_value = sorted(dictionary.items(), key=operator.itemgetter(1))
360
+ print(sorted_dic_by_value)
361
+ # unpacking the list of tuples into two lists
362
+ sorted_keys, sorted_values = zip(*sorted_dic_by_value)
363
+ #
364
+ if true_p_bar != "":
365
+ """
366
+ Special case to draw in:
367
+ - green -> TP: True Positives (object detected and matches ground-truth)
368
+ - red -> FP: False Positives (object detected but does not match ground-truth)
369
+ - pink -> FN: False Negatives (object not detected but present in the ground-truth)
370
+ """
371
+ fp_sorted = []
372
+ tp_sorted = []
373
+ for key in sorted_keys:
374
+ fp_sorted.append(dictionary[key] - true_p_bar[key])
375
+ tp_sorted.append(true_p_bar[key])
376
+ plt.barh(range(n_classes), fp_sorted, align='center', color='crimson', label='False Positive')
377
+ plt.barh(range(n_classes), tp_sorted, align='center', color='forestgreen', label='True Positive', left=fp_sorted)
378
+ # add legend
379
+ plt.legend(loc='lower right')
380
+ """
381
+ Write number on side of bar
382
+ """
383
+ fig = plt.gcf() # gcf - get current figure
384
+ axes = plt.gca()
385
+ r = fig.canvas.get_renderer()
386
+ for i, val in enumerate(sorted_values):
387
+ fp_val = fp_sorted[i]
388
+ tp_val = tp_sorted[i]
389
+ fp_str_val = " " + str(fp_val)
390
+ tp_str_val = fp_str_val + " " + str(tp_val)
391
+ # trick to paint multicolor with offset:
392
+ # first paint everything and then repaint the first number
393
+ t = plt.text(val, i, tp_str_val, color='forestgreen', va='center', fontweight='bold')
394
+ plt.text(val, i, fp_str_val, color='crimson', va='center', fontweight='bold')
395
+ if i == (len(sorted_values)-1): # largest bar
396
+ adjust_axes(r, t, fig, axes)
397
+ else:
398
+ plt.barh(range(n_classes), sorted_values, color=plot_color)
399
+ """
400
+ Write number on side of bar
401
+ """
402
+ fig = plt.gcf() # gcf - get current figure
403
+ axes = plt.gca()
404
+ r = fig.canvas.get_renderer()
405
+ for i, val in enumerate(sorted_values):
406
+ str_val = " " + str(val) # add a space before
407
+ if val < 1.0:
408
+ str_val = " {0:.2f}".format(val)
409
+ t = plt.text(val, i, str_val, color=plot_color, va='center', fontweight='bold')
410
+ # re-set axes to show number inside the figure
411
+ if i == (len(sorted_values)-1): # largest bar
412
+ adjust_axes(r, t, fig, axes)
413
+ # set window title
414
+ fig.canvas.set_window_title(window_title)
415
+ # write classes in y axis
416
+ tick_font_size = 12
417
+ plt.yticks(range(n_classes), sorted_keys, fontsize=tick_font_size)
418
+ """
419
+ Re-scale height accordingly
420
+ """
421
+ init_height = fig.get_figheight()
422
+ # comput the matrix height in points and inches
423
+ dpi = fig.dpi
424
+ height_pt = n_classes * (tick_font_size * 1.4) # 1.4 (some spacing)
425
+ height_in = height_pt / dpi
426
+ # compute the required figure height
427
+ top_margin = 0.15 # in percentage of the figure height
428
+ bottom_margin = 0.05 # in percentage of the figure height
429
+ figure_height = height_in / (1 - top_margin - bottom_margin)
430
+ # set new height
431
+ if figure_height > init_height:
432
+ fig.set_figheight(figure_height)
433
+
434
+ # set plot title
435
+ plt.title(plot_title, fontsize=14)
436
+ # set axis titles
437
+ # plt.xlabel('classes')
438
+ plt.xlabel(x_label, fontsize='large')
439
+ # adjust size of window
440
+ fig.tight_layout()
441
+ # save the plot
442
+ fig.savefig(output_path)
443
+ # show image
444
+ # if to_show:
445
+ plt.show()
446
+ # close the plot
447
+ # plt.close()
448
+
449
+ """
450
+ Plot - adjust axes
451
+ """
452
+ def adjust_axes(r, t, fig, axes):
453
+ # get text width for re-scaling
454
+ bb = t.get_window_extent(renderer=r)
455
+ text_width_inches = bb.width / fig.dpi
456
+ # get axis width in inches
457
+ current_fig_width = fig.get_figwidth()
458
+ new_fig_width = current_fig_width + text_width_inches
459
+ propotion = new_fig_width / current_fig_width
460
+ # get axis limit
461
+ x_lim = axes.get_xlim()
462
+ axes.set_xlim([x_lim[0], x_lim[1]*propotion])
463
+
464
+
465
+ def read_txt_to_list(path):
466
+ # open txt file lines to a list
467
+ with open(path) as f:
468
+ content = f.readlines()
469
+ # remove whitespace characters like `\n` at the end of each line
470
+ content = [x.strip() for x in content]
471
+ return content