Alexander Fengler commited on
Commit
588ce8d
1 Parent(s): c28e79f

add dashboard

Browse files
Files changed (4) hide show
  1. app.py +22 -8
  2. inference.py +65 -2
  3. metrics.py +211 -0
  4. visualization_tests.ipynb +0 -0
app.py CHANGED
@@ -24,6 +24,10 @@ import glob
24
  from inference import inference_frame,inference_frame_serial
25
  from inference import inference_frame_par_ready
26
  from inference import process_frame
 
 
 
 
27
  import os
28
  import pathlib
29
  import multiprocessing as mp
@@ -33,6 +37,11 @@ from time import time
33
  REPO_ID='SharkSpace/videos_examples'
34
  snapshot_download(repo_id=REPO_ID, token=os.environ.get('SHARK_MODEL'),repo_type='dataset',local_dir='videos_example')
35
 
 
 
 
 
 
36
  def process_video(input_video, out_fps = 'auto', skip_frames = 7):
37
  cap = cv2.VideoCapture(input_video)
38
 
@@ -55,26 +64,31 @@ def process_video(input_video, out_fps = 'auto', skip_frames = 7):
55
  while iterating:
56
  if (cnt % skip_frames) == 0:
57
  # flip frame vertically
58
- display_frame = inference_frame_serial(frame)
59
  video.write(cv2.cvtColor(display_frame, cv2.COLOR_BGR2RGB))
60
- print('sending frame')
 
 
 
 
 
61
  print(cnt)
62
- yield cv2.cvtColor(display_frame, cv2.COLOR_BGR2RGB), cv2.cvtColor(frame, cv2.COLOR_BGR2RGB), None
63
  cnt += 1
64
  iterating, frame = cap.read()
65
 
66
  video.release()
67
- yield None, None, output_path
68
 
69
- with gr.Blocks() as demo:
70
  with gr.Row():
71
  input_video = gr.Video(label="Input")
72
  output_video = gr.Video(label="Output Video")
73
 
74
  with gr.Row():
75
- processed_frames = gr.Image(label="Live Frame")
76
- # graphs = gr.Image(label="Graphs")
77
  original_frames = gr.Image(label="Original Frame")
 
 
78
 
79
  with gr.Row():
80
  paths = sorted(pathlib.Path('videos_example/').rglob('*.mp4'))
@@ -82,7 +96,7 @@ with gr.Blocks() as demo:
82
  examples = gr.Examples(samples, inputs=input_video)
83
  process_video_btn = gr.Button("Process Video")
84
 
85
- process_video_btn.click(process_video, input_video, [processed_frames, original_frames, output_video])
86
 
87
  demo.queue()
88
  if os.getenv('SYSTEM') == 'spaces':
 
24
  from inference import inference_frame,inference_frame_serial
25
  from inference import inference_frame_par_ready
26
  from inference import process_frame
27
+ from inference import classes
28
+ from inference import class_sizes_lower
29
+ from metrics import process_results_for_plot
30
+ from metrics import prediction_dashboard
31
  import os
32
  import pathlib
33
  import multiprocessing as mp
 
37
  REPO_ID='SharkSpace/videos_examples'
38
  snapshot_download(repo_id=REPO_ID, token=os.environ.get('SHARK_MODEL'),repo_type='dataset',local_dir='videos_example')
39
 
40
+ theme = gr.themes.Soft(
41
+ primary_hue="sky",
42
+ neutral_hue="slate",
43
+ )
44
+
45
  def process_video(input_video, out_fps = 'auto', skip_frames = 7):
46
  cap = cv2.VideoCapture(input_video)
47
 
 
64
  while iterating:
65
  if (cnt % skip_frames) == 0:
66
  # flip frame vertically
67
+ display_frame, result = inference_frame_serial(frame)
68
  video.write(cv2.cvtColor(display_frame, cv2.COLOR_BGR2RGB))
69
+ #print(result)
70
+ top_pred = process_results_for_plot(predictions = result.numpy(),
71
+ classes = classes,
72
+ class_sizes = class_sizes_lower)
73
+ pred_dashbord = prediction_dashboard(top_pred = top_pred)
74
+ #print('sending frame')
75
  print(cnt)
76
+ yield cv2.cvtColor(display_frame, cv2.COLOR_BGR2RGB), cv2.cvtColor(frame, cv2.COLOR_BGR2RGB), None, pred_dashbord
77
  cnt += 1
78
  iterating, frame = cap.read()
79
 
80
  video.release()
81
+ yield None, None, output_path, None
82
 
83
+ with gr.Blocks(theme=theme) as demo:
84
  with gr.Row():
85
  input_video = gr.Video(label="Input")
86
  output_video = gr.Video(label="Output Video")
87
 
88
  with gr.Row():
 
 
89
  original_frames = gr.Image(label="Original Frame")
90
+ dashboard = gr.Image(label="Dashboard")
91
+ processed_frames = gr.Image(label="Shark Engine")
92
 
93
  with gr.Row():
94
  paths = sorted(pathlib.Path('videos_example/').rglob('*.mp4'))
 
96
  examples = gr.Examples(samples, inputs=input_video)
97
  process_video_btn = gr.Button("Process Video")
98
 
99
+ process_video_btn.click(process_video, input_video, [processed_frames, original_frames, output_video, dashboard])
100
 
101
  demo.queue()
102
  if os.getenv('SYSTEM') == 'spaces':
inference.py CHANGED
@@ -15,7 +15,7 @@ from huggingface_hub import hf_hub_download
15
  from huggingface_hub import snapshot_download
16
  from time import time
17
 
18
- classes= ['Beach',
19
  'Sea',
20
  'Wave',
21
  'Rock',
@@ -68,6 +68,69 @@ classes= ['Beach',
68
  'Tiger shark',
69
  'Bull shark']*3
70
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
 
72
  REPO_ID = "SharkSpace/maskformer_model"
73
  FILENAME = "mask2former"
@@ -101,7 +164,7 @@ palette = visualizer.dataset_meta.get('palette', None)
101
  print(len(classes))
102
  print(len(palette))
103
  def inference_frame_serial(image, visualize = True):
104
- start = time()
105
  result = inference_detector(model, image)
106
  #print(f'inference time: {time()-start}')
107
  # show the results
 
15
  from huggingface_hub import snapshot_download
16
  from time import time
17
 
18
+ classes = ['Beach',
19
  'Sea',
20
  'Wave',
21
  'Rock',
 
68
  'Tiger shark',
69
  'Bull shark']*3
70
 
71
+ class_sizes = {'Beach': None,
72
+ 'Sea': None,
73
+ 'Wave': None,
74
+ 'Rock': None,
75
+ 'Breaking wave': None,
76
+ 'Reflection of the sea': None,
77
+ 'Foam': None,
78
+ 'Algae': None,
79
+ 'Vegetation': None,
80
+ 'Watermark': None,
81
+ 'Bird': {'feet':[1, 3], 'meter': [0.3, 0.9], 'kg': [0.5, 1.5], 'pounds': [1, 3]},
82
+ 'Ship': {'feet':[10, 100], 'meter': [3, 30], 'kg': [1000, 100000], 'pounds': [2200, 220000]},
83
+ 'Boat': {'feet':[10, 45], 'meter': [3, 15], 'kg': [750, 80000], 'pounds': [1500, 160000]},
84
+ 'Car': {'feet':[10, 20], 'meter': [3, 6], 'kg': [1000, 2000], 'pounds': [2200, 4400]},
85
+ 'Kayak': {'feet':[10, 20], 'meter': [3, 6], 'kg': [50, 300], 'pounds': [100, 600]},
86
+ "Shark's line": None,
87
+ 'Dock': None,
88
+ 'Dog': {'feet':[1, 3], 'meter': [0.3, 0.9], 'kg': [10, 50], 'pounds': [20, 100]},
89
+ 'Unidentifiable shade': None,
90
+ 'Bird shadow': None,
91
+ 'Boat shadow': None,
92
+ 'Kayal shade': None,
93
+ 'Surfer shadow': None,
94
+ 'Shark shadow': None,
95
+ 'Surfboard shadow': None,
96
+ 'Crocodile': {'feet':[10, 20], 'meter': [3, 6], 'kg': [410, 1000], 'pounds': [900, 2200]},
97
+ 'Sea cow': {'feet':[9,12], 'meter': [3, 4], 'kg': [400, 590], 'pounds': [900, 1300]},
98
+ 'Stingray': {'feet':[2, 7.5], 'meter': [0.6, 2.5], 'kg': [100, 300], 'pounds': [220, 770]},
99
+ 'Person': {'feet':[5, 7], 'meter': [1.5, 2.1], 'kg': [50, 150], 'pounds': [110, 300]},
100
+ 'Ocean': None,
101
+ 'Surfer': {'feet':[5, 7], 'meter': [1.5, 2.1], 'kg': [50, 150], 'pounds': [110, 300]},
102
+ 'Surfer': {'feet':[5, 7], 'meter': [1.5, 2.1], 'kg': [50, 150], 'pounds': [110, 300]},
103
+ 'Fish': {'feet':[1, 3], 'meter': [0.3, 0.9], 'kg': [20, 150], 'pounds': [40, 300]},
104
+ 'Killer whale': {'feet':[10, 20], 'meter': [3, 6], 'kg': [3600, 5400], 'pounds': [8000, 12000]},
105
+ 'Whale': {'feet':[15, 30], 'meter': [4.5, 10], 'kg': [2500, 80000], 'pounds': [55000, 176000]},
106
+ 'Dolphin': {'feet':[6.6, 13.1], 'meter': [2, 4], 'kg': [150, 650], 'pounds': [330, 1430]},
107
+ 'Miscellaneous': None,
108
+ 'Unidentifiable shark': {'feet': [2, 15], 'meter': [0.6, 4.5], 'kg': [50, 1000], 'pounds': [110, 2200]},
109
+ 'Carpet shark': {'feet': [4, 10], 'meter': [1.25, 3], 'kg': [50, 1000], 'pounds': [110, 2200]}, # Prob incorrect
110
+ 'Dusty shark': {'feet': [9, 14], 'meter': [3, 4.25], 'kg': [160, 180], 'pounds': [350, 400]},
111
+ 'Blue shark': {'feet': [7.9, 12.5], 'meter': [2.4, 3], 'kg': [60, 120], 'pounds': [130, 260]},
112
+ 'Great white shark': {'feet': [13.1, 20], 'meter': [4, 6], 'kg': [680, 1800], 'pounds': [1500, 4000]},
113
+ 'Copper shark': {'feet': [7.2, 10.8], 'meter': [2.2, 3.3], 'kg': [130, 300], 'pounds': [290, 660]},
114
+ 'Nurse shark': {'feet': [7.9, 9.8], 'meter': [2.4, 3], 'kg': [90, 115], 'pounds': [200, 250]},
115
+ 'Silky shark': {'feet': [6.6, 8.2], 'meter': [2, 2.5], 'kg': [300, 380], 'pounds': [660, 840]},
116
+ 'Leopard shark': {'feet': [3.9, 4.9], 'meter': [1.2, 1.5], 'kg': [11, 20], 'pounds': [22, 44]},
117
+ 'Shortfin mako shark': {'feet': [10.5, 12], 'meter': [3.2, 3.6], 'kg': [60, 135], 'pounds': [130, 300]},
118
+ 'Hammerhead shark': {'feet': [4.9, 20], 'meter': [1.5, 6.1], 'kg': [230, 450], 'pounds': [500, 1000]},
119
+ 'Oceanic whitetip shark': {'feet': [5.9, 9.8], 'meter': [1.8, 3], 'kg': [36, 170], 'pounds': [80, 375]},
120
+ 'Blacktip shark': {'feet': [4.9, 6.6], 'meter': [1.5, 2], 'kg': [40, 100], 'pounds': [90, 220]},
121
+ 'Tiger shark': {'feet': [9.8, 18], 'meter': [3, 5.5], 'kg': [385, 635], 'pounds': [850, 1400]},
122
+ 'Bull shark': {'feet': [7.9, 11.2], 'meter': [2.4, 3.4], 'kg': [200, 315], 'pounds': [440, 690]},
123
+ }
124
+
125
+ class_sizes_lower = {k.lower(): v for k, v in class_sizes.items()}
126
+
127
+ classes_is_shark = [1 if 'shark' in x.lower() else 0 for x in classes]
128
+ classes_is_human = [1 if 'person' or 'surfer' in x.lower() else 0 for x in classes]
129
+ classes_is_unknown = [1 if 'unidentifiable' in x.lower() else 0 for x in classes]
130
+
131
+ classes_is_shark_id = [i for i, x in enumerate(classes_is_shark) if x == 1]
132
+ classes_is_human_id = [i for i, x in enumerate(classes_is_human) if x == 1]
133
+ classes_is_unknown_id = [i for i, x in enumerate(classes_is_unknown) if x == 1]
134
 
135
  REPO_ID = "SharkSpace/maskformer_model"
136
  FILENAME = "mask2former"
 
164
  print(len(classes))
165
  print(len(palette))
166
  def inference_frame_serial(image, visualize = True):
167
+ #start = time()
168
  result = inference_detector(model, image)
169
  #print(f'inference time: {time()-start}')
170
  # show the results
metrics.py ADDED
@@ -0,0 +1,211 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import matplotlib.pyplot as plt
3
+
4
+
5
+ def get_top_predictions(prediction = None, threshold = 0.7):
6
+ if prediction is None:
7
+ return None, None
8
+ else:
9
+ sorted_scores_ids = prediction.pred_instances.scores.argsort()[::-1]
10
+ sorted_scores = prediction.pred_instances.scores[sorted_scores_ids]
11
+ sorted_predictions = prediction.pred_instances.labels[sorted_scores_ids]
12
+ return {'pred_above_thresh': sorted_predictions[sorted_scores > threshold],
13
+ 'pred_above_thresh_id': sorted_scores_ids[sorted_scores > threshold],
14
+ 'pred_above_thresh_scores': sorted_scores[sorted_scores > threshold],
15
+ 'pred_above_thresh_bboxes': prediction.pred_instances['bboxes'][sorted_scores_ids][sorted_scores > threshold]}
16
+
17
+ def add_class_labels(top_pred = {}, class_labels = None):
18
+ if class_labels == None:
19
+ print('No class labels provided, returning original dictionary')
20
+ return top_pred
21
+ else:
22
+ top_pred['pred_above_thresh_labels'] = [class_labels[x].lower() for x in top_pred['pred_above_thresh']]
23
+ top_pred['any_detection'] = len(top_pred['pred_above_thresh_labels']) > 0
24
+ if top_pred['any_detection']:
25
+ # Get shark / human / unknown vectors
26
+ top_pred['is_shark'] = np.array([1 if 'shark' in x else 0 for x in top_pred['pred_above_thresh_labels']])
27
+ top_pred['is_human'] = np.array([1 if 'person' in x else 1 if 'surfer' in x else 0 for x in top_pred['pred_above_thresh_labels']])
28
+ top_pred['is_unknown'] = np.array([1 if 'unidentifiable' in x else 0 for x in top_pred['pred_above_thresh_labels']])
29
+ # Get shark / human / unknown numbers of detections
30
+ top_pred['shark_n'] = np.sum(top_pred['is_shark'])
31
+ top_pred['human_n'] = np.sum(top_pred['is_human'])
32
+ top_pred['unknown_n'] = np.sum(top_pred['is_unknown'])
33
+ else:
34
+ # Get shark / human / unknown vectors
35
+ top_pred['is_shark'] = None
36
+ top_pred['is_human'] = None
37
+ top_pred['is_unknown'] = None
38
+ # Get shark / human / unknown numbers of detections
39
+ top_pred['shark_n'] = 0
40
+ top_pred['human_n'] = 0
41
+ top_pred['unknown_n'] = 0
42
+ return top_pred
43
+
44
+ def add_class_sizes(top_pred = {}, class_sizes = None):
45
+ size_list = []
46
+ shark_size_list = []
47
+ if top_pred['any_detection']:
48
+ for tmp_pred in top_pred['pred_above_thresh_labels']:
49
+ tmp_class_sizes = class_sizes[tmp_pred.lower()]
50
+ if tmp_class_sizes == None:
51
+ size_list.append(None)
52
+ else:
53
+ size_list.append(tmp_class_sizes['feet'])
54
+
55
+ if 'shark' in tmp_pred.lower():
56
+ shark_size_list.append(np.mean(tmp_class_sizes['feet']))
57
+
58
+ top_pred['pred_above_thresh_sizes'] = size_list
59
+
60
+ if top_pred['shark_n'] > 0:
61
+ top_pred['biggest_shark_size'] = np.max(shark_size_list)
62
+ else:
63
+ top_pred['biggest_shark_size'] = None
64
+ else:
65
+ top_pred['pred_above_thresh_sizes'] = None
66
+ top_pred['biggest_shark_size'] = None
67
+ return top_pred
68
+
69
+ def add_class_weights(top_pred = {}, class_weights = None):
70
+ weight_list = []
71
+ shark_weight_list = []
72
+ if top_pred['any_detection']:
73
+ for tmp_pred in top_pred['pred_above_thresh_labels']:
74
+ tmp_class_weights = class_weights[tmp_pred.lower()]
75
+ if tmp_class_weights == None:
76
+ weight_list.append(None)
77
+ else:
78
+ weight_list.append(tmp_class_weights['pounds'])
79
+
80
+ if 'shark' in tmp_pred.lower():
81
+ shark_weight_list.append(np.mean(tmp_class_weights['pounds']))
82
+
83
+ top_pred['pred_above_thresh_weights'] = weight_list
84
+
85
+ if top_pred['shark_n'] > 0:
86
+ top_pred['biggest_shark_weight'] = np.max(shark_weight_list)
87
+ else:
88
+ top_pred['biggest_shark_weight'] = None
89
+ else:
90
+ top_pred['pred_above_thresh_weights'] = None
91
+ top_pred['biggest_shark_weight'] = None
92
+ return top_pred
93
+
94
+ # Sizes
95
+ def get_min_distance_shark_person(top_pred, class_sizes = None, dangerous_distance = 100):
96
+ min_dist = 99999
97
+ dist_calculated = False
98
+ # Calculate distance for every pairing of human and shark
99
+ # and accumulate the min distance
100
+ for i, tmp_shark in enumerate(top_pred['is_shark']):
101
+ for j, tmp_person in enumerate(top_pred['is_human']):
102
+ if tmp_shark == 1 and tmp_person == 1:
103
+ dist_calculated = True
104
+ #print(top_pred['pred_above_thresh_bboxes'][i])
105
+ #print(top_pred['pred_above_thresh_bboxes'][j])
106
+ tmp_dist_feed = _calculate_dist_estimate(top_pred['pred_above_thresh_bboxes'][i],
107
+ top_pred['pred_above_thresh_bboxes'][j],
108
+ [top_pred['pred_above_thresh_labels'][i], top_pred['pred_above_thresh_labels'][j]],
109
+ class_sizes,
110
+ measurement = 'feet')
111
+ #print(tmp_dist_feed)
112
+ min_dist = min(min_dist, tmp_dist_feed)
113
+ else:
114
+ pass
115
+ return {'min_dist': str(round(min_dist,1)) + ' feet' if dist_calculated else '',
116
+ 'any_dist_calculated': dist_calculated,
117
+ 'dangerous_dist': min_dist < dangerous_distance}
118
+
119
+ def _calculate_dist_estimate(bbox1, bbox2, labels, class_sizes = None, measurement = 'feet'):
120
+ class_feet_size_mean = np.array([class_sizes[labels[0]][measurement][0],
121
+ class_sizes[labels[1]][measurement][0]]).mean()
122
+ box_pixel_size_mean = np.array([np.linalg.norm(bbox1[[0, 1]] - bbox1[[2, 3]]),
123
+ np.linalg.norm(bbox2[[0, 1]] - bbox2[[2, 3]])]).mean()
124
+
125
+ # Calculate the max size of the two boxes
126
+ box_center_1 = np.array([(bbox1[2] - bbox1[0])/2 + bbox1[0],
127
+ (bbox1[3] - bbox1[1])/2 + bbox1[1]])
128
+ box_center_2 = np.array([(bbox2[2] - bbox2[0])/2 + bbox2[0],
129
+ (bbox2[3] - bbox2[1])/2 + bbox2[1]])
130
+
131
+ # Return ratio distance
132
+ return np.linalg.norm(box_center_1 - box_center_2) / box_pixel_size_mean * class_feet_size_mean
133
+
134
+ # bboxes info!
135
+ # 1 x1 (left, lower pixel number)
136
+ # 2 y1 (top , lower pixel number)
137
+ # 3 x2 (right, higher pixel number)
138
+ # 4 y2 (bottom, higher pixel number)
139
+
140
+ def process_results_for_plot(predictions = None, threshold = 0.5, classes = None,
141
+ class_sizes = None, dangerous_distance = 100):
142
+
143
+ top_pred = get_top_predictions(predictions, threshold = threshold)
144
+ top_pred = add_class_labels(top_pred, class_labels = classes)
145
+ top_pred = add_class_sizes(top_pred, class_sizes = class_sizes)
146
+ top_pred = add_class_weights(top_pred, class_weights = class_sizes)
147
+
148
+ if len(top_pred['pred_above_thresh']) > 0:
149
+ min_dist = get_min_distance_shark_person(top_pred, class_sizes = class_sizes)
150
+ else:
151
+ min_dist = {'any_dist_calculated': False,
152
+ 'min_dist': '',
153
+ 'dangerous_dist': False}
154
+
155
+ return {'min_dist_str': min_dist['min_dist'],
156
+ 'shark_sighted': top_pred['shark_n'] > 0,
157
+ 'human_sighted': top_pred['human_n'] > 0,
158
+ 'shark_n': top_pred['shark_n'],
159
+ 'human_n': top_pred['human_n'],
160
+ 'human_and_shark': (top_pred['shark_n'] > 0) and (top_pred['human_n'] > 0),
161
+ 'dangerous_dist': min_dist['dangerous_dist'],
162
+ 'dist_calculated': min_dist['any_dist_calculated'],
163
+ 'biggest_shark_size': '' if top_pred['biggest_shark_size'] == None else str(round(top_pred['biggest_shark_size'],1)) + ' feet',
164
+ 'biggest_shark_weight': '' if top_pred['biggest_shark_weight'] == None else str(round(top_pred['biggest_shark_weight'],1)) + ' pounds',
165
+ }
166
+
167
+ def prediction_dashboard(top_pred = None):
168
+ # Bullet points:
169
+ shark_sighted = 'Shark Detected: ' + str(top_pred['shark_sighted'])
170
+ human_sighted = 'Number of Humans: ' + str(top_pred['human_n'])
171
+
172
+ shark_size_estimate = 'Biggest shark size: ' + str(top_pred['biggest_shark_size'])
173
+ shark_weight_estimate = 'Biggest shark weight: ' + str(top_pred['biggest_shark_weight'])
174
+
175
+ danger_level = 'Danger Level: '
176
+ danger_level += 'High' if top_pred['dangerous_dist'] else 'Low'
177
+
178
+ danger_color = 'orangered' if top_pred['dangerous_dist'] else 'yellowgreen'
179
+
180
+ # Create a list of strings to plot
181
+ strings = [shark_sighted, human_sighted, shark_size_estimate, shark_weight_estimate, danger_level]
182
+
183
+ # Create a figure and axis
184
+ fig, ax = plt.subplots()
185
+ fig.set_facecolor((35/255,40/255,54/255))
186
+
187
+ # Hide axes
188
+ ax.axis('off')
189
+
190
+ # Position for starting to place text, starting from top
191
+ y_pos = 0.7
192
+
193
+ # Iterate through list and place each item as text on the plot
194
+ for s in strings:
195
+ if 'danger' in s.lower():
196
+ ax.text(0.05, y_pos, s, transform=ax.transAxes, fontsize=16, color=danger_color)
197
+ else:
198
+ ax.text(0.05, y_pos, s, transform=ax.transAxes, fontsize=16, color=(0, 204/255, 153/255))
199
+ y_pos -= 0.1 # move down for next item
200
+
201
+ # plt.tight_layout()
202
+ # If we haven't already shown or saved the plot, then we need to
203
+ # draw the figure first...
204
+ fig.canvas.draw();
205
+
206
+ # Now we can save it to a numpy array.
207
+ data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
208
+ data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
209
+ plt.close()
210
+ #plt.savefig('tmp.png', format='png')
211
+ return data #plt.show()
visualization_tests.ipynb CHANGED
The diff for this file is too large to render. See raw diff