piperod91 commited on
Commit
b3cb6e3
1 Parent(s): 109b444

adding cockpit view, code needs clean up

Browse files
Files changed (2) hide show
  1. app.py +71 -10
  2. inference.py +4 -3
app.py CHANGED
@@ -33,15 +33,63 @@ import pathlib
33
  import multiprocessing as mp
34
  from time import time
35
 
36
-
37
- REPO_ID='SharkSpace/videos_examples'
38
- snapshot_download(repo_id=REPO_ID, token=os.environ.get('SHARK_MODEL'),repo_type='dataset',local_dir='videos_example')
39
 
40
  theme = gr.themes.Soft(
41
  primary_hue="sky",
42
  neutral_hue="slate",
43
  )
44
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
  def process_video(input_video, out_fps = 'auto', skip_frames = 7):
46
  cap = cv2.VideoCapture(input_video)
47
 
@@ -70,10 +118,23 @@ def process_video(input_video, out_fps = 'auto', skip_frames = 7):
70
  top_pred = process_results_for_plot(predictions = result.numpy(),
71
  classes = classes,
72
  class_sizes = class_sizes_lower)
 
 
 
 
 
 
 
 
 
 
 
 
 
73
  pred_dashbord = prediction_dashboard(top_pred = top_pred)
74
  #print('sending frame')
75
  print(cnt)
76
- yield cv2.cvtColor(display_frame, cv2.COLOR_BGR2RGB), cv2.cvtColor(frame, cv2.COLOR_BGR2RGB), None, pred_dashbord
77
  cnt += 1
78
  iterating, frame = cap.read()
79
 
@@ -81,15 +142,15 @@ def process_video(input_video, out_fps = 'auto', skip_frames = 7):
81
  yield None, None, output_path, None
82
 
83
  with gr.Blocks(theme=theme) as demo:
84
- with gr.Row():
85
  input_video = gr.Video(label="Input")
 
86
  output_video = gr.Video(label="Output Video")
87
-
88
- with gr.Row():
89
- original_frames = gr.Image(label="Original Frame")
90
  dashboard = gr.Image(label="Dashboard")
91
- processed_frames = gr.Image(label="Shark Engine")
92
-
 
 
93
  with gr.Row():
94
  paths = sorted(pathlib.Path('videos_example/').rglob('*.mp4'))
95
  samples=[[path.as_posix()] for path in paths if 'raw_videos' in str(path)]
 
33
  import multiprocessing as mp
34
  from time import time
35
 
36
+ if not os.path.exists('videos_example'):
37
+ REPO_ID='SharkSpace/videos_examples'
38
+ snapshot_download(repo_id=REPO_ID, token=os.environ.get('SHARK_MODEL'),repo_type='dataset',local_dir='videos_example')
39
 
40
  theme = gr.themes.Soft(
41
  primary_hue="sky",
42
  neutral_hue="slate",
43
  )
44
 
45
+
46
+
47
+ def add_border(frame, color = (255, 0, 0), thickness = 2):
48
+ # Add a red border to the image
49
+ relative = max(frame.shape[0],frame.shape[1])
50
+ top = int(relative*0.025)
51
+ bottom = int(relative*0.025)
52
+ left = int(relative*0.025)
53
+ right = int(relative*0.025)
54
+ # Add the border to the image
55
+ bordered_image = cv2.copyMakeBorder(frame, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color)
56
+
57
+ return bordered_image
58
+
59
+ def overlay_text_on_image(image, text_list, font=cv2.FONT_HERSHEY_SIMPLEX, font_size=0.5, font_thickness=1, margin=10, color=(255, 255, 255)):
60
+ relative = min(image.shape[0],image.shape[1])
61
+ y0, dy = margin, int(relative*0.1) # start y position and line gap
62
+ for i, line in enumerate(text_list):
63
+ y = y0 + i * dy
64
+ text_width, _ = cv2.getTextSize(line, font, font_size, font_thickness)[0]
65
+ cv2.putText(image, line, (image.shape[1] - text_width - margin, y), font, font_size, color, font_thickness, lineType=cv2.LINE_AA)
66
+ return image
67
+
68
+ def draw_cockpit(frame, top_pred,cnt):
69
+ # Bullet points:
70
+ high_danger_color = (255,0,0)
71
+ low_danger_color = yellowgreen = (154,205,50)
72
+ shark_sighted = 'Shark Detected: ' + str(top_pred['shark_sighted'])
73
+ human_sighted = 'Number of Humans: ' + str(top_pred['human_n'])
74
+ shark_size_estimate = 'Biggest shark size: ' + str(top_pred['biggest_shark_size'])
75
+ shark_weight_estimate = 'Biggest shark weight: ' + str(top_pred['biggest_shark_weight'])
76
+ danger_level = 'Danger Level: '
77
+ danger_level += 'High' if top_pred['dangerous_dist'] else 'Low'
78
+ danger_color = 'orangered' if top_pred['dangerous_dist'] else 'yellowgreen'
79
+ # Create a list of strings to plot
80
+ strings = [shark_sighted, human_sighted, shark_size_estimate, shark_weight_estimate, danger_level]
81
+ relative = max(frame.shape[0],frame.shape[1])
82
+ if top_pred['shark_sighted'] and top_pred['dangerous_dist'] and cnt%2 == 0:
83
+ relative = max(frame.shape[0],frame.shape[1])
84
+ frame = add_border(frame, color=high_danger_color, thickness=int(relative*0.025))
85
+ elif top_pred['shark_sighted'] and not top_pred['dangerous_dist'] and cnt%2 == 0:
86
+ relative = max(frame.shape[0],frame.shape[1])
87
+ frame = add_border(frame, color=low_danger_color, thickness=int(relative*0.025))
88
+ overlay_text_on_image(frame, strings, font=cv2.FONT_HERSHEY_SIMPLEX, font_size=1.5, font_thickness=3, margin=int(relative*0.05), color=(255, 255, 255))
89
+ return frame
90
+
91
+
92
+
93
  def process_video(input_video, out_fps = 'auto', skip_frames = 7):
94
  cap = cv2.VideoCapture(input_video)
95
 
 
118
  top_pred = process_results_for_plot(predictions = result.numpy(),
119
  classes = classes,
120
  class_sizes = class_sizes_lower)
121
+ frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
122
+ prediction_frame = cv2.cvtColor(display_frame, cv2.COLOR_BGR2RGB)
123
+
124
+
125
+ frame = cv2.resize(frame, (int(width*4), int(height*4)))
126
+
127
+ if cnt*skip_frames %2==0 and top_pred['shark_sighted']:
128
+ prediction_frame = cv2.resize(prediction_frame, (int(width*4), int(height*4)))
129
+ frame =prediction_frame
130
+
131
+ if top_pred['shark_sighted']:
132
+ frame = draw_cockpit(frame, top_pred,cnt*skip_frames)
133
+
134
  pred_dashbord = prediction_dashboard(top_pred = top_pred)
135
  #print('sending frame')
136
  print(cnt)
137
+ yield prediction_frame,frame , None, pred_dashbord
138
  cnt += 1
139
  iterating, frame = cap.read()
140
 
 
142
  yield None, None, output_path, None
143
 
144
  with gr.Blocks(theme=theme) as demo:
145
+ with gr.Row().style(equal_height=True,height='50%'):
146
  input_video = gr.Video(label="Input")
147
+ processed_frames = gr.Image(label="Shark Engine")
148
  output_video = gr.Video(label="Output Video")
 
 
 
149
  dashboard = gr.Image(label="Dashboard")
150
+
151
+ with gr.Row(height='200%',width='200%'):
152
+ original_frames = gr.Image(label="Original Frame", width='200%', height='200%').style( width=1000)
153
+
154
  with gr.Row():
155
  paths = sorted(pathlib.Path('videos_example/').rglob('*.mp4'))
156
  samples=[[path.as_posix()] for path in paths if 'raw_videos' in str(path)]
inference.py CHANGED
@@ -132,10 +132,11 @@ classes_is_shark_id = [i for i, x in enumerate(classes_is_shark) if x == 1]
132
  classes_is_human_id = [i for i, x in enumerate(classes_is_human) if x == 1]
133
  classes_is_unknown_id = [i for i, x in enumerate(classes_is_unknown) if x == 1]
134
 
135
- REPO_ID = "SharkSpace/maskformer_model"
136
- FILENAME = "mask2former"
137
 
138
- snapshot_download(repo_id=REPO_ID, token= os.environ.get('SHARK_MODEL'),local_dir='model/')
 
 
 
139
 
140
  # Choose to use a config and initialize the detector
141
  config_file ='model/mask2former_swin-t-p4-w7-224_8xb2-lsj-50e_coco-panoptic/mask2former_swin-t-p4-w7-224_8xb2-lsj-50e_coco-panoptic.py'
 
132
  classes_is_human_id = [i for i, x in enumerate(classes_is_human) if x == 1]
133
  classes_is_unknown_id = [i for i, x in enumerate(classes_is_unknown) if x == 1]
134
 
 
 
135
 
136
+ if not os.path.exists('model'):
137
+ REPO_ID = "SharkSpace/maskformer_model"
138
+ FILENAME = "mask2former"
139
+ snapshot_download(repo_id=REPO_ID, token= os.environ.get('SHARK_MODEL'),local_dir='model/')
140
 
141
  # Choose to use a config and initialize the detector
142
  config_file ='model/mask2former_swin-t-p4-w7-224_8xb2-lsj-50e_coco-panoptic/mask2former_swin-t-p4-w7-224_8xb2-lsj-50e_coco-panoptic.py'