Spaces:
Runtime error
Runtime error
adding cockpit view, code needs clean up
Browse files- app.py +71 -10
- inference.py +4 -3
app.py
CHANGED
@@ -33,15 +33,63 @@ import pathlib
|
|
33 |
import multiprocessing as mp
|
34 |
from time import time
|
35 |
|
36 |
-
|
37 |
-
REPO_ID='SharkSpace/videos_examples'
|
38 |
-
snapshot_download(repo_id=REPO_ID, token=os.environ.get('SHARK_MODEL'),repo_type='dataset',local_dir='videos_example')
|
39 |
|
40 |
theme = gr.themes.Soft(
|
41 |
primary_hue="sky",
|
42 |
neutral_hue="slate",
|
43 |
)
|
44 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
45 |
def process_video(input_video, out_fps = 'auto', skip_frames = 7):
|
46 |
cap = cv2.VideoCapture(input_video)
|
47 |
|
@@ -70,10 +118,23 @@ def process_video(input_video, out_fps = 'auto', skip_frames = 7):
|
|
70 |
top_pred = process_results_for_plot(predictions = result.numpy(),
|
71 |
classes = classes,
|
72 |
class_sizes = class_sizes_lower)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
73 |
pred_dashbord = prediction_dashboard(top_pred = top_pred)
|
74 |
#print('sending frame')
|
75 |
print(cnt)
|
76 |
-
yield
|
77 |
cnt += 1
|
78 |
iterating, frame = cap.read()
|
79 |
|
@@ -81,15 +142,15 @@ def process_video(input_video, out_fps = 'auto', skip_frames = 7):
|
|
81 |
yield None, None, output_path, None
|
82 |
|
83 |
with gr.Blocks(theme=theme) as demo:
|
84 |
-
with gr.Row():
|
85 |
input_video = gr.Video(label="Input")
|
|
|
86 |
output_video = gr.Video(label="Output Video")
|
87 |
-
|
88 |
-
with gr.Row():
|
89 |
-
original_frames = gr.Image(label="Original Frame")
|
90 |
dashboard = gr.Image(label="Dashboard")
|
91 |
-
|
92 |
-
|
|
|
|
|
93 |
with gr.Row():
|
94 |
paths = sorted(pathlib.Path('videos_example/').rglob('*.mp4'))
|
95 |
samples=[[path.as_posix()] for path in paths if 'raw_videos' in str(path)]
|
|
|
33 |
import multiprocessing as mp
|
34 |
from time import time
|
35 |
|
36 |
+
if not os.path.exists('videos_example'):
|
37 |
+
REPO_ID='SharkSpace/videos_examples'
|
38 |
+
snapshot_download(repo_id=REPO_ID, token=os.environ.get('SHARK_MODEL'),repo_type='dataset',local_dir='videos_example')
|
39 |
|
40 |
theme = gr.themes.Soft(
|
41 |
primary_hue="sky",
|
42 |
neutral_hue="slate",
|
43 |
)
|
44 |
|
45 |
+
|
46 |
+
|
47 |
+
def add_border(frame, color = (255, 0, 0), thickness = 2):
|
48 |
+
# Add a red border to the image
|
49 |
+
relative = max(frame.shape[0],frame.shape[1])
|
50 |
+
top = int(relative*0.025)
|
51 |
+
bottom = int(relative*0.025)
|
52 |
+
left = int(relative*0.025)
|
53 |
+
right = int(relative*0.025)
|
54 |
+
# Add the border to the image
|
55 |
+
bordered_image = cv2.copyMakeBorder(frame, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color)
|
56 |
+
|
57 |
+
return bordered_image
|
58 |
+
|
59 |
+
def overlay_text_on_image(image, text_list, font=cv2.FONT_HERSHEY_SIMPLEX, font_size=0.5, font_thickness=1, margin=10, color=(255, 255, 255)):
|
60 |
+
relative = min(image.shape[0],image.shape[1])
|
61 |
+
y0, dy = margin, int(relative*0.1) # start y position and line gap
|
62 |
+
for i, line in enumerate(text_list):
|
63 |
+
y = y0 + i * dy
|
64 |
+
text_width, _ = cv2.getTextSize(line, font, font_size, font_thickness)[0]
|
65 |
+
cv2.putText(image, line, (image.shape[1] - text_width - margin, y), font, font_size, color, font_thickness, lineType=cv2.LINE_AA)
|
66 |
+
return image
|
67 |
+
|
68 |
+
def draw_cockpit(frame, top_pred,cnt):
|
69 |
+
# Bullet points:
|
70 |
+
high_danger_color = (255,0,0)
|
71 |
+
low_danger_color = yellowgreen = (154,205,50)
|
72 |
+
shark_sighted = 'Shark Detected: ' + str(top_pred['shark_sighted'])
|
73 |
+
human_sighted = 'Number of Humans: ' + str(top_pred['human_n'])
|
74 |
+
shark_size_estimate = 'Biggest shark size: ' + str(top_pred['biggest_shark_size'])
|
75 |
+
shark_weight_estimate = 'Biggest shark weight: ' + str(top_pred['biggest_shark_weight'])
|
76 |
+
danger_level = 'Danger Level: '
|
77 |
+
danger_level += 'High' if top_pred['dangerous_dist'] else 'Low'
|
78 |
+
danger_color = 'orangered' if top_pred['dangerous_dist'] else 'yellowgreen'
|
79 |
+
# Create a list of strings to plot
|
80 |
+
strings = [shark_sighted, human_sighted, shark_size_estimate, shark_weight_estimate, danger_level]
|
81 |
+
relative = max(frame.shape[0],frame.shape[1])
|
82 |
+
if top_pred['shark_sighted'] and top_pred['dangerous_dist'] and cnt%2 == 0:
|
83 |
+
relative = max(frame.shape[0],frame.shape[1])
|
84 |
+
frame = add_border(frame, color=high_danger_color, thickness=int(relative*0.025))
|
85 |
+
elif top_pred['shark_sighted'] and not top_pred['dangerous_dist'] and cnt%2 == 0:
|
86 |
+
relative = max(frame.shape[0],frame.shape[1])
|
87 |
+
frame = add_border(frame, color=low_danger_color, thickness=int(relative*0.025))
|
88 |
+
overlay_text_on_image(frame, strings, font=cv2.FONT_HERSHEY_SIMPLEX, font_size=1.5, font_thickness=3, margin=int(relative*0.05), color=(255, 255, 255))
|
89 |
+
return frame
|
90 |
+
|
91 |
+
|
92 |
+
|
93 |
def process_video(input_video, out_fps = 'auto', skip_frames = 7):
|
94 |
cap = cv2.VideoCapture(input_video)
|
95 |
|
|
|
118 |
top_pred = process_results_for_plot(predictions = result.numpy(),
|
119 |
classes = classes,
|
120 |
class_sizes = class_sizes_lower)
|
121 |
+
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
122 |
+
prediction_frame = cv2.cvtColor(display_frame, cv2.COLOR_BGR2RGB)
|
123 |
+
|
124 |
+
|
125 |
+
frame = cv2.resize(frame, (int(width*4), int(height*4)))
|
126 |
+
|
127 |
+
if cnt*skip_frames %2==0 and top_pred['shark_sighted']:
|
128 |
+
prediction_frame = cv2.resize(prediction_frame, (int(width*4), int(height*4)))
|
129 |
+
frame =prediction_frame
|
130 |
+
|
131 |
+
if top_pred['shark_sighted']:
|
132 |
+
frame = draw_cockpit(frame, top_pred,cnt*skip_frames)
|
133 |
+
|
134 |
pred_dashbord = prediction_dashboard(top_pred = top_pred)
|
135 |
#print('sending frame')
|
136 |
print(cnt)
|
137 |
+
yield prediction_frame,frame , None, pred_dashbord
|
138 |
cnt += 1
|
139 |
iterating, frame = cap.read()
|
140 |
|
|
|
142 |
yield None, None, output_path, None
|
143 |
|
144 |
with gr.Blocks(theme=theme) as demo:
|
145 |
+
with gr.Row().style(equal_height=True,height='50%'):
|
146 |
input_video = gr.Video(label="Input")
|
147 |
+
processed_frames = gr.Image(label="Shark Engine")
|
148 |
output_video = gr.Video(label="Output Video")
|
|
|
|
|
|
|
149 |
dashboard = gr.Image(label="Dashboard")
|
150 |
+
|
151 |
+
with gr.Row(height='200%',width='200%'):
|
152 |
+
original_frames = gr.Image(label="Original Frame", width='200%', height='200%').style( width=1000)
|
153 |
+
|
154 |
with gr.Row():
|
155 |
paths = sorted(pathlib.Path('videos_example/').rglob('*.mp4'))
|
156 |
samples=[[path.as_posix()] for path in paths if 'raw_videos' in str(path)]
|
inference.py
CHANGED
@@ -132,10 +132,11 @@ classes_is_shark_id = [i for i, x in enumerate(classes_is_shark) if x == 1]
|
|
132 |
classes_is_human_id = [i for i, x in enumerate(classes_is_human) if x == 1]
|
133 |
classes_is_unknown_id = [i for i, x in enumerate(classes_is_unknown) if x == 1]
|
134 |
|
135 |
-
REPO_ID = "SharkSpace/maskformer_model"
|
136 |
-
FILENAME = "mask2former"
|
137 |
|
138 |
-
|
|
|
|
|
|
|
139 |
|
140 |
# Choose to use a config and initialize the detector
|
141 |
config_file ='model/mask2former_swin-t-p4-w7-224_8xb2-lsj-50e_coco-panoptic/mask2former_swin-t-p4-w7-224_8xb2-lsj-50e_coco-panoptic.py'
|
|
|
132 |
classes_is_human_id = [i for i, x in enumerate(classes_is_human) if x == 1]
|
133 |
classes_is_unknown_id = [i for i, x in enumerate(classes_is_unknown) if x == 1]
|
134 |
|
|
|
|
|
135 |
|
136 |
+
if not os.path.exists('model'):
|
137 |
+
REPO_ID = "SharkSpace/maskformer_model"
|
138 |
+
FILENAME = "mask2former"
|
139 |
+
snapshot_download(repo_id=REPO_ID, token= os.environ.get('SHARK_MODEL'),local_dir='model/')
|
140 |
|
141 |
# Choose to use a config and initialize the detector
|
142 |
config_file ='model/mask2former_swin-t-p4-w7-224_8xb2-lsj-50e_coco-panoptic/mask2former_swin-t-p4-w7-224_8xb2-lsj-50e_coco-panoptic.py'
|