Spaces:
Runtime error
Runtime error
Alexander Fengler
commited on
Commit
•
588ce8d
1
Parent(s):
c28e79f
add dashboard
Browse files- app.py +22 -8
- inference.py +65 -2
- metrics.py +211 -0
- visualization_tests.ipynb +0 -0
app.py
CHANGED
@@ -24,6 +24,10 @@ import glob
|
|
24 |
from inference import inference_frame,inference_frame_serial
|
25 |
from inference import inference_frame_par_ready
|
26 |
from inference import process_frame
|
|
|
|
|
|
|
|
|
27 |
import os
|
28 |
import pathlib
|
29 |
import multiprocessing as mp
|
@@ -33,6 +37,11 @@ from time import time
|
|
33 |
REPO_ID='SharkSpace/videos_examples'
|
34 |
snapshot_download(repo_id=REPO_ID, token=os.environ.get('SHARK_MODEL'),repo_type='dataset',local_dir='videos_example')
|
35 |
|
|
|
|
|
|
|
|
|
|
|
36 |
def process_video(input_video, out_fps = 'auto', skip_frames = 7):
|
37 |
cap = cv2.VideoCapture(input_video)
|
38 |
|
@@ -55,26 +64,31 @@ def process_video(input_video, out_fps = 'auto', skip_frames = 7):
|
|
55 |
while iterating:
|
56 |
if (cnt % skip_frames) == 0:
|
57 |
# flip frame vertically
|
58 |
-
display_frame = inference_frame_serial(frame)
|
59 |
video.write(cv2.cvtColor(display_frame, cv2.COLOR_BGR2RGB))
|
60 |
-
print(
|
|
|
|
|
|
|
|
|
|
|
61 |
print(cnt)
|
62 |
-
yield cv2.cvtColor(display_frame, cv2.COLOR_BGR2RGB), cv2.cvtColor(frame, cv2.COLOR_BGR2RGB), None
|
63 |
cnt += 1
|
64 |
iterating, frame = cap.read()
|
65 |
|
66 |
video.release()
|
67 |
-
yield None, None, output_path
|
68 |
|
69 |
-
with gr.Blocks() as demo:
|
70 |
with gr.Row():
|
71 |
input_video = gr.Video(label="Input")
|
72 |
output_video = gr.Video(label="Output Video")
|
73 |
|
74 |
with gr.Row():
|
75 |
-
processed_frames = gr.Image(label="Live Frame")
|
76 |
-
# graphs = gr.Image(label="Graphs")
|
77 |
original_frames = gr.Image(label="Original Frame")
|
|
|
|
|
78 |
|
79 |
with gr.Row():
|
80 |
paths = sorted(pathlib.Path('videos_example/').rglob('*.mp4'))
|
@@ -82,7 +96,7 @@ with gr.Blocks() as demo:
|
|
82 |
examples = gr.Examples(samples, inputs=input_video)
|
83 |
process_video_btn = gr.Button("Process Video")
|
84 |
|
85 |
-
process_video_btn.click(process_video, input_video, [processed_frames, original_frames, output_video])
|
86 |
|
87 |
demo.queue()
|
88 |
if os.getenv('SYSTEM') == 'spaces':
|
|
|
24 |
from inference import inference_frame,inference_frame_serial
|
25 |
from inference import inference_frame_par_ready
|
26 |
from inference import process_frame
|
27 |
+
from inference import classes
|
28 |
+
from inference import class_sizes_lower
|
29 |
+
from metrics import process_results_for_plot
|
30 |
+
from metrics import prediction_dashboard
|
31 |
import os
|
32 |
import pathlib
|
33 |
import multiprocessing as mp
|
|
|
37 |
REPO_ID='SharkSpace/videos_examples'
|
38 |
snapshot_download(repo_id=REPO_ID, token=os.environ.get('SHARK_MODEL'),repo_type='dataset',local_dir='videos_example')
|
39 |
|
40 |
+
theme = gr.themes.Soft(
|
41 |
+
primary_hue="sky",
|
42 |
+
neutral_hue="slate",
|
43 |
+
)
|
44 |
+
|
45 |
def process_video(input_video, out_fps = 'auto', skip_frames = 7):
|
46 |
cap = cv2.VideoCapture(input_video)
|
47 |
|
|
|
64 |
while iterating:
|
65 |
if (cnt % skip_frames) == 0:
|
66 |
# flip frame vertically
|
67 |
+
display_frame, result = inference_frame_serial(frame)
|
68 |
video.write(cv2.cvtColor(display_frame, cv2.COLOR_BGR2RGB))
|
69 |
+
#print(result)
|
70 |
+
top_pred = process_results_for_plot(predictions = result.numpy(),
|
71 |
+
classes = classes,
|
72 |
+
class_sizes = class_sizes_lower)
|
73 |
+
pred_dashbord = prediction_dashboard(top_pred = top_pred)
|
74 |
+
#print('sending frame')
|
75 |
print(cnt)
|
76 |
+
yield cv2.cvtColor(display_frame, cv2.COLOR_BGR2RGB), cv2.cvtColor(frame, cv2.COLOR_BGR2RGB), None, pred_dashbord
|
77 |
cnt += 1
|
78 |
iterating, frame = cap.read()
|
79 |
|
80 |
video.release()
|
81 |
+
yield None, None, output_path, None
|
82 |
|
83 |
+
with gr.Blocks(theme=theme) as demo:
|
84 |
with gr.Row():
|
85 |
input_video = gr.Video(label="Input")
|
86 |
output_video = gr.Video(label="Output Video")
|
87 |
|
88 |
with gr.Row():
|
|
|
|
|
89 |
original_frames = gr.Image(label="Original Frame")
|
90 |
+
dashboard = gr.Image(label="Dashboard")
|
91 |
+
processed_frames = gr.Image(label="Shark Engine")
|
92 |
|
93 |
with gr.Row():
|
94 |
paths = sorted(pathlib.Path('videos_example/').rglob('*.mp4'))
|
|
|
96 |
examples = gr.Examples(samples, inputs=input_video)
|
97 |
process_video_btn = gr.Button("Process Video")
|
98 |
|
99 |
+
process_video_btn.click(process_video, input_video, [processed_frames, original_frames, output_video, dashboard])
|
100 |
|
101 |
demo.queue()
|
102 |
if os.getenv('SYSTEM') == 'spaces':
|
inference.py
CHANGED
@@ -15,7 +15,7 @@ from huggingface_hub import hf_hub_download
|
|
15 |
from huggingface_hub import snapshot_download
|
16 |
from time import time
|
17 |
|
18 |
-
classes= ['Beach',
|
19 |
'Sea',
|
20 |
'Wave',
|
21 |
'Rock',
|
@@ -68,6 +68,69 @@ classes= ['Beach',
|
|
68 |
'Tiger shark',
|
69 |
'Bull shark']*3
|
70 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
71 |
|
72 |
REPO_ID = "SharkSpace/maskformer_model"
|
73 |
FILENAME = "mask2former"
|
@@ -101,7 +164,7 @@ palette = visualizer.dataset_meta.get('palette', None)
|
|
101 |
print(len(classes))
|
102 |
print(len(palette))
|
103 |
def inference_frame_serial(image, visualize = True):
|
104 |
-
start = time()
|
105 |
result = inference_detector(model, image)
|
106 |
#print(f'inference time: {time()-start}')
|
107 |
# show the results
|
|
|
15 |
from huggingface_hub import snapshot_download
|
16 |
from time import time
|
17 |
|
18 |
+
classes = ['Beach',
|
19 |
'Sea',
|
20 |
'Wave',
|
21 |
'Rock',
|
|
|
68 |
'Tiger shark',
|
69 |
'Bull shark']*3
|
70 |
|
71 |
+
class_sizes = {'Beach': None,
|
72 |
+
'Sea': None,
|
73 |
+
'Wave': None,
|
74 |
+
'Rock': None,
|
75 |
+
'Breaking wave': None,
|
76 |
+
'Reflection of the sea': None,
|
77 |
+
'Foam': None,
|
78 |
+
'Algae': None,
|
79 |
+
'Vegetation': None,
|
80 |
+
'Watermark': None,
|
81 |
+
'Bird': {'feet':[1, 3], 'meter': [0.3, 0.9], 'kg': [0.5, 1.5], 'pounds': [1, 3]},
|
82 |
+
'Ship': {'feet':[10, 100], 'meter': [3, 30], 'kg': [1000, 100000], 'pounds': [2200, 220000]},
|
83 |
+
'Boat': {'feet':[10, 45], 'meter': [3, 15], 'kg': [750, 80000], 'pounds': [1500, 160000]},
|
84 |
+
'Car': {'feet':[10, 20], 'meter': [3, 6], 'kg': [1000, 2000], 'pounds': [2200, 4400]},
|
85 |
+
'Kayak': {'feet':[10, 20], 'meter': [3, 6], 'kg': [50, 300], 'pounds': [100, 600]},
|
86 |
+
"Shark's line": None,
|
87 |
+
'Dock': None,
|
88 |
+
'Dog': {'feet':[1, 3], 'meter': [0.3, 0.9], 'kg': [10, 50], 'pounds': [20, 100]},
|
89 |
+
'Unidentifiable shade': None,
|
90 |
+
'Bird shadow': None,
|
91 |
+
'Boat shadow': None,
|
92 |
+
'Kayal shade': None,
|
93 |
+
'Surfer shadow': None,
|
94 |
+
'Shark shadow': None,
|
95 |
+
'Surfboard shadow': None,
|
96 |
+
'Crocodile': {'feet':[10, 20], 'meter': [3, 6], 'kg': [410, 1000], 'pounds': [900, 2200]},
|
97 |
+
'Sea cow': {'feet':[9,12], 'meter': [3, 4], 'kg': [400, 590], 'pounds': [900, 1300]},
|
98 |
+
'Stingray': {'feet':[2, 7.5], 'meter': [0.6, 2.5], 'kg': [100, 300], 'pounds': [220, 770]},
|
99 |
+
'Person': {'feet':[5, 7], 'meter': [1.5, 2.1], 'kg': [50, 150], 'pounds': [110, 300]},
|
100 |
+
'Ocean': None,
|
101 |
+
'Surfer': {'feet':[5, 7], 'meter': [1.5, 2.1], 'kg': [50, 150], 'pounds': [110, 300]},
|
102 |
+
'Surfer': {'feet':[5, 7], 'meter': [1.5, 2.1], 'kg': [50, 150], 'pounds': [110, 300]},
|
103 |
+
'Fish': {'feet':[1, 3], 'meter': [0.3, 0.9], 'kg': [20, 150], 'pounds': [40, 300]},
|
104 |
+
'Killer whale': {'feet':[10, 20], 'meter': [3, 6], 'kg': [3600, 5400], 'pounds': [8000, 12000]},
|
105 |
+
'Whale': {'feet':[15, 30], 'meter': [4.5, 10], 'kg': [2500, 80000], 'pounds': [55000, 176000]},
|
106 |
+
'Dolphin': {'feet':[6.6, 13.1], 'meter': [2, 4], 'kg': [150, 650], 'pounds': [330, 1430]},
|
107 |
+
'Miscellaneous': None,
|
108 |
+
'Unidentifiable shark': {'feet': [2, 15], 'meter': [0.6, 4.5], 'kg': [50, 1000], 'pounds': [110, 2200]},
|
109 |
+
'Carpet shark': {'feet': [4, 10], 'meter': [1.25, 3], 'kg': [50, 1000], 'pounds': [110, 2200]}, # Prob incorrect
|
110 |
+
'Dusty shark': {'feet': [9, 14], 'meter': [3, 4.25], 'kg': [160, 180], 'pounds': [350, 400]},
|
111 |
+
'Blue shark': {'feet': [7.9, 12.5], 'meter': [2.4, 3], 'kg': [60, 120], 'pounds': [130, 260]},
|
112 |
+
'Great white shark': {'feet': [13.1, 20], 'meter': [4, 6], 'kg': [680, 1800], 'pounds': [1500, 4000]},
|
113 |
+
'Copper shark': {'feet': [7.2, 10.8], 'meter': [2.2, 3.3], 'kg': [130, 300], 'pounds': [290, 660]},
|
114 |
+
'Nurse shark': {'feet': [7.9, 9.8], 'meter': [2.4, 3], 'kg': [90, 115], 'pounds': [200, 250]},
|
115 |
+
'Silky shark': {'feet': [6.6, 8.2], 'meter': [2, 2.5], 'kg': [300, 380], 'pounds': [660, 840]},
|
116 |
+
'Leopard shark': {'feet': [3.9, 4.9], 'meter': [1.2, 1.5], 'kg': [11, 20], 'pounds': [22, 44]},
|
117 |
+
'Shortfin mako shark': {'feet': [10.5, 12], 'meter': [3.2, 3.6], 'kg': [60, 135], 'pounds': [130, 300]},
|
118 |
+
'Hammerhead shark': {'feet': [4.9, 20], 'meter': [1.5, 6.1], 'kg': [230, 450], 'pounds': [500, 1000]},
|
119 |
+
'Oceanic whitetip shark': {'feet': [5.9, 9.8], 'meter': [1.8, 3], 'kg': [36, 170], 'pounds': [80, 375]},
|
120 |
+
'Blacktip shark': {'feet': [4.9, 6.6], 'meter': [1.5, 2], 'kg': [40, 100], 'pounds': [90, 220]},
|
121 |
+
'Tiger shark': {'feet': [9.8, 18], 'meter': [3, 5.5], 'kg': [385, 635], 'pounds': [850, 1400]},
|
122 |
+
'Bull shark': {'feet': [7.9, 11.2], 'meter': [2.4, 3.4], 'kg': [200, 315], 'pounds': [440, 690]},
|
123 |
+
}
|
124 |
+
|
125 |
+
class_sizes_lower = {k.lower(): v for k, v in class_sizes.items()}
|
126 |
+
|
127 |
+
classes_is_shark = [1 if 'shark' in x.lower() else 0 for x in classes]
|
128 |
+
classes_is_human = [1 if 'person' or 'surfer' in x.lower() else 0 for x in classes]
|
129 |
+
classes_is_unknown = [1 if 'unidentifiable' in x.lower() else 0 for x in classes]
|
130 |
+
|
131 |
+
classes_is_shark_id = [i for i, x in enumerate(classes_is_shark) if x == 1]
|
132 |
+
classes_is_human_id = [i for i, x in enumerate(classes_is_human) if x == 1]
|
133 |
+
classes_is_unknown_id = [i for i, x in enumerate(classes_is_unknown) if x == 1]
|
134 |
|
135 |
REPO_ID = "SharkSpace/maskformer_model"
|
136 |
FILENAME = "mask2former"
|
|
|
164 |
print(len(classes))
|
165 |
print(len(palette))
|
166 |
def inference_frame_serial(image, visualize = True):
|
167 |
+
#start = time()
|
168 |
result = inference_detector(model, image)
|
169 |
#print(f'inference time: {time()-start}')
|
170 |
# show the results
|
metrics.py
ADDED
@@ -0,0 +1,211 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import matplotlib.pyplot as plt
|
3 |
+
|
4 |
+
|
5 |
+
def get_top_predictions(prediction = None, threshold = 0.7):
|
6 |
+
if prediction is None:
|
7 |
+
return None, None
|
8 |
+
else:
|
9 |
+
sorted_scores_ids = prediction.pred_instances.scores.argsort()[::-1]
|
10 |
+
sorted_scores = prediction.pred_instances.scores[sorted_scores_ids]
|
11 |
+
sorted_predictions = prediction.pred_instances.labels[sorted_scores_ids]
|
12 |
+
return {'pred_above_thresh': sorted_predictions[sorted_scores > threshold],
|
13 |
+
'pred_above_thresh_id': sorted_scores_ids[sorted_scores > threshold],
|
14 |
+
'pred_above_thresh_scores': sorted_scores[sorted_scores > threshold],
|
15 |
+
'pred_above_thresh_bboxes': prediction.pred_instances['bboxes'][sorted_scores_ids][sorted_scores > threshold]}
|
16 |
+
|
17 |
+
def add_class_labels(top_pred = {}, class_labels = None):
|
18 |
+
if class_labels == None:
|
19 |
+
print('No class labels provided, returning original dictionary')
|
20 |
+
return top_pred
|
21 |
+
else:
|
22 |
+
top_pred['pred_above_thresh_labels'] = [class_labels[x].lower() for x in top_pred['pred_above_thresh']]
|
23 |
+
top_pred['any_detection'] = len(top_pred['pred_above_thresh_labels']) > 0
|
24 |
+
if top_pred['any_detection']:
|
25 |
+
# Get shark / human / unknown vectors
|
26 |
+
top_pred['is_shark'] = np.array([1 if 'shark' in x else 0 for x in top_pred['pred_above_thresh_labels']])
|
27 |
+
top_pred['is_human'] = np.array([1 if 'person' in x else 1 if 'surfer' in x else 0 for x in top_pred['pred_above_thresh_labels']])
|
28 |
+
top_pred['is_unknown'] = np.array([1 if 'unidentifiable' in x else 0 for x in top_pred['pred_above_thresh_labels']])
|
29 |
+
# Get shark / human / unknown numbers of detections
|
30 |
+
top_pred['shark_n'] = np.sum(top_pred['is_shark'])
|
31 |
+
top_pred['human_n'] = np.sum(top_pred['is_human'])
|
32 |
+
top_pred['unknown_n'] = np.sum(top_pred['is_unknown'])
|
33 |
+
else:
|
34 |
+
# Get shark / human / unknown vectors
|
35 |
+
top_pred['is_shark'] = None
|
36 |
+
top_pred['is_human'] = None
|
37 |
+
top_pred['is_unknown'] = None
|
38 |
+
# Get shark / human / unknown numbers of detections
|
39 |
+
top_pred['shark_n'] = 0
|
40 |
+
top_pred['human_n'] = 0
|
41 |
+
top_pred['unknown_n'] = 0
|
42 |
+
return top_pred
|
43 |
+
|
44 |
+
def add_class_sizes(top_pred = {}, class_sizes = None):
|
45 |
+
size_list = []
|
46 |
+
shark_size_list = []
|
47 |
+
if top_pred['any_detection']:
|
48 |
+
for tmp_pred in top_pred['pred_above_thresh_labels']:
|
49 |
+
tmp_class_sizes = class_sizes[tmp_pred.lower()]
|
50 |
+
if tmp_class_sizes == None:
|
51 |
+
size_list.append(None)
|
52 |
+
else:
|
53 |
+
size_list.append(tmp_class_sizes['feet'])
|
54 |
+
|
55 |
+
if 'shark' in tmp_pred.lower():
|
56 |
+
shark_size_list.append(np.mean(tmp_class_sizes['feet']))
|
57 |
+
|
58 |
+
top_pred['pred_above_thresh_sizes'] = size_list
|
59 |
+
|
60 |
+
if top_pred['shark_n'] > 0:
|
61 |
+
top_pred['biggest_shark_size'] = np.max(shark_size_list)
|
62 |
+
else:
|
63 |
+
top_pred['biggest_shark_size'] = None
|
64 |
+
else:
|
65 |
+
top_pred['pred_above_thresh_sizes'] = None
|
66 |
+
top_pred['biggest_shark_size'] = None
|
67 |
+
return top_pred
|
68 |
+
|
69 |
+
def add_class_weights(top_pred = {}, class_weights = None):
|
70 |
+
weight_list = []
|
71 |
+
shark_weight_list = []
|
72 |
+
if top_pred['any_detection']:
|
73 |
+
for tmp_pred in top_pred['pred_above_thresh_labels']:
|
74 |
+
tmp_class_weights = class_weights[tmp_pred.lower()]
|
75 |
+
if tmp_class_weights == None:
|
76 |
+
weight_list.append(None)
|
77 |
+
else:
|
78 |
+
weight_list.append(tmp_class_weights['pounds'])
|
79 |
+
|
80 |
+
if 'shark' in tmp_pred.lower():
|
81 |
+
shark_weight_list.append(np.mean(tmp_class_weights['pounds']))
|
82 |
+
|
83 |
+
top_pred['pred_above_thresh_weights'] = weight_list
|
84 |
+
|
85 |
+
if top_pred['shark_n'] > 0:
|
86 |
+
top_pred['biggest_shark_weight'] = np.max(shark_weight_list)
|
87 |
+
else:
|
88 |
+
top_pred['biggest_shark_weight'] = None
|
89 |
+
else:
|
90 |
+
top_pred['pred_above_thresh_weights'] = None
|
91 |
+
top_pred['biggest_shark_weight'] = None
|
92 |
+
return top_pred
|
93 |
+
|
94 |
+
# Sizes
|
95 |
+
def get_min_distance_shark_person(top_pred, class_sizes = None, dangerous_distance = 100):
|
96 |
+
min_dist = 99999
|
97 |
+
dist_calculated = False
|
98 |
+
# Calculate distance for every pairing of human and shark
|
99 |
+
# and accumulate the min distance
|
100 |
+
for i, tmp_shark in enumerate(top_pred['is_shark']):
|
101 |
+
for j, tmp_person in enumerate(top_pred['is_human']):
|
102 |
+
if tmp_shark == 1 and tmp_person == 1:
|
103 |
+
dist_calculated = True
|
104 |
+
#print(top_pred['pred_above_thresh_bboxes'][i])
|
105 |
+
#print(top_pred['pred_above_thresh_bboxes'][j])
|
106 |
+
tmp_dist_feed = _calculate_dist_estimate(top_pred['pred_above_thresh_bboxes'][i],
|
107 |
+
top_pred['pred_above_thresh_bboxes'][j],
|
108 |
+
[top_pred['pred_above_thresh_labels'][i], top_pred['pred_above_thresh_labels'][j]],
|
109 |
+
class_sizes,
|
110 |
+
measurement = 'feet')
|
111 |
+
#print(tmp_dist_feed)
|
112 |
+
min_dist = min(min_dist, tmp_dist_feed)
|
113 |
+
else:
|
114 |
+
pass
|
115 |
+
return {'min_dist': str(round(min_dist,1)) + ' feet' if dist_calculated else '',
|
116 |
+
'any_dist_calculated': dist_calculated,
|
117 |
+
'dangerous_dist': min_dist < dangerous_distance}
|
118 |
+
|
119 |
+
def _calculate_dist_estimate(bbox1, bbox2, labels, class_sizes = None, measurement = 'feet'):
|
120 |
+
class_feet_size_mean = np.array([class_sizes[labels[0]][measurement][0],
|
121 |
+
class_sizes[labels[1]][measurement][0]]).mean()
|
122 |
+
box_pixel_size_mean = np.array([np.linalg.norm(bbox1[[0, 1]] - bbox1[[2, 3]]),
|
123 |
+
np.linalg.norm(bbox2[[0, 1]] - bbox2[[2, 3]])]).mean()
|
124 |
+
|
125 |
+
# Calculate the max size of the two boxes
|
126 |
+
box_center_1 = np.array([(bbox1[2] - bbox1[0])/2 + bbox1[0],
|
127 |
+
(bbox1[3] - bbox1[1])/2 + bbox1[1]])
|
128 |
+
box_center_2 = np.array([(bbox2[2] - bbox2[0])/2 + bbox2[0],
|
129 |
+
(bbox2[3] - bbox2[1])/2 + bbox2[1]])
|
130 |
+
|
131 |
+
# Return ratio distance
|
132 |
+
return np.linalg.norm(box_center_1 - box_center_2) / box_pixel_size_mean * class_feet_size_mean
|
133 |
+
|
134 |
+
# bboxes info!
|
135 |
+
# 1 x1 (left, lower pixel number)
|
136 |
+
# 2 y1 (top , lower pixel number)
|
137 |
+
# 3 x2 (right, higher pixel number)
|
138 |
+
# 4 y2 (bottom, higher pixel number)
|
139 |
+
|
140 |
+
def process_results_for_plot(predictions = None, threshold = 0.5, classes = None,
|
141 |
+
class_sizes = None, dangerous_distance = 100):
|
142 |
+
|
143 |
+
top_pred = get_top_predictions(predictions, threshold = threshold)
|
144 |
+
top_pred = add_class_labels(top_pred, class_labels = classes)
|
145 |
+
top_pred = add_class_sizes(top_pred, class_sizes = class_sizes)
|
146 |
+
top_pred = add_class_weights(top_pred, class_weights = class_sizes)
|
147 |
+
|
148 |
+
if len(top_pred['pred_above_thresh']) > 0:
|
149 |
+
min_dist = get_min_distance_shark_person(top_pred, class_sizes = class_sizes)
|
150 |
+
else:
|
151 |
+
min_dist = {'any_dist_calculated': False,
|
152 |
+
'min_dist': '',
|
153 |
+
'dangerous_dist': False}
|
154 |
+
|
155 |
+
return {'min_dist_str': min_dist['min_dist'],
|
156 |
+
'shark_sighted': top_pred['shark_n'] > 0,
|
157 |
+
'human_sighted': top_pred['human_n'] > 0,
|
158 |
+
'shark_n': top_pred['shark_n'],
|
159 |
+
'human_n': top_pred['human_n'],
|
160 |
+
'human_and_shark': (top_pred['shark_n'] > 0) and (top_pred['human_n'] > 0),
|
161 |
+
'dangerous_dist': min_dist['dangerous_dist'],
|
162 |
+
'dist_calculated': min_dist['any_dist_calculated'],
|
163 |
+
'biggest_shark_size': '' if top_pred['biggest_shark_size'] == None else str(round(top_pred['biggest_shark_size'],1)) + ' feet',
|
164 |
+
'biggest_shark_weight': '' if top_pred['biggest_shark_weight'] == None else str(round(top_pred['biggest_shark_weight'],1)) + ' pounds',
|
165 |
+
}
|
166 |
+
|
167 |
+
def prediction_dashboard(top_pred = None):
|
168 |
+
# Bullet points:
|
169 |
+
shark_sighted = 'Shark Detected: ' + str(top_pred['shark_sighted'])
|
170 |
+
human_sighted = 'Number of Humans: ' + str(top_pred['human_n'])
|
171 |
+
|
172 |
+
shark_size_estimate = 'Biggest shark size: ' + str(top_pred['biggest_shark_size'])
|
173 |
+
shark_weight_estimate = 'Biggest shark weight: ' + str(top_pred['biggest_shark_weight'])
|
174 |
+
|
175 |
+
danger_level = 'Danger Level: '
|
176 |
+
danger_level += 'High' if top_pred['dangerous_dist'] else 'Low'
|
177 |
+
|
178 |
+
danger_color = 'orangered' if top_pred['dangerous_dist'] else 'yellowgreen'
|
179 |
+
|
180 |
+
# Create a list of strings to plot
|
181 |
+
strings = [shark_sighted, human_sighted, shark_size_estimate, shark_weight_estimate, danger_level]
|
182 |
+
|
183 |
+
# Create a figure and axis
|
184 |
+
fig, ax = plt.subplots()
|
185 |
+
fig.set_facecolor((35/255,40/255,54/255))
|
186 |
+
|
187 |
+
# Hide axes
|
188 |
+
ax.axis('off')
|
189 |
+
|
190 |
+
# Position for starting to place text, starting from top
|
191 |
+
y_pos = 0.7
|
192 |
+
|
193 |
+
# Iterate through list and place each item as text on the plot
|
194 |
+
for s in strings:
|
195 |
+
if 'danger' in s.lower():
|
196 |
+
ax.text(0.05, y_pos, s, transform=ax.transAxes, fontsize=16, color=danger_color)
|
197 |
+
else:
|
198 |
+
ax.text(0.05, y_pos, s, transform=ax.transAxes, fontsize=16, color=(0, 204/255, 153/255))
|
199 |
+
y_pos -= 0.1 # move down for next item
|
200 |
+
|
201 |
+
# plt.tight_layout()
|
202 |
+
# If we haven't already shown or saved the plot, then we need to
|
203 |
+
# draw the figure first...
|
204 |
+
fig.canvas.draw();
|
205 |
+
|
206 |
+
# Now we can save it to a numpy array.
|
207 |
+
data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
|
208 |
+
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
|
209 |
+
plt.close()
|
210 |
+
#plt.savefig('tmp.png', format='png')
|
211 |
+
return data #plt.show()
|
visualization_tests.ipynb
CHANGED
The diff for this file is too large to render.
See raw diff
|
|