Alexander Fengler commited on
Commit
6454b14
1 Parent(s): 021ea63

layout improvements and faster outputs

Browse files
Files changed (4) hide show
  1. app.py +47 -169
  2. app3.py +0 -79
  3. app_legacy.py +209 -0
  4. inference.py +3 -13
app.py CHANGED
@@ -1,14 +1,6 @@
1
- import gradio as gr
2
- import os
3
- import subprocess
4
-
5
- from huggingface_hub import snapshot_download
6
-
7
-
8
- REPO_ID='SharkSpace/videos_examples'
9
- snapshot_download(repo_id=REPO_ID, token=os.environ.get('SHARK_MODEL'),repo_type='dataset',local_dir='videos_example')
10
-
11
 
 
 
12
  if os.getenv('SYSTEM') == 'spaces':
13
 
14
  subprocess.call('pip install -U openmim'.split())
@@ -19,8 +11,10 @@ if os.getenv('SYSTEM') == 'spaces':
19
  subprocess.call('mim install mmdet'.split())
20
  subprocess.call('pip install opencv-python-headless==4.5.5.64'.split())
21
  subprocess.call('pip install git+https://github.com/cocodataset/panopticapi.git'.split())
22
-
23
 
 
 
 
24
  import cv2
25
  import dotenv
26
  dotenv.load_dotenv()
@@ -33,177 +27,61 @@ from inference import process_frame
33
  import os
34
  import pathlib
35
  import multiprocessing as mp
36
-
37
  from time import time
38
 
39
 
40
- def analize_video_serial(x):
41
- print(x)
42
- path = '/tmp/test/'
43
- os.makedirs(path, exist_ok=True)
44
- videos = len(os.listdir(path))
45
- path = f'{path}{videos}'
46
- os.makedirs(path, exist_ok=True)
47
- outname = f'{path}_processed.mp4'
48
- if os.path.exists(outname):
49
- print('video already processed')
50
- return outname
51
- cap = cv2.VideoCapture(x)
52
- counter = 0
53
- import pdb;pdb.set_trace()
54
- while(cap.isOpened()):
55
- ret, frame = cap.read()
56
- yield None, frame
57
- if ret==True:
58
- name = os.path.join(path,f'{counter:05d}.png')
59
- frame = inference_frame_serial(frame)
60
- # write the flipped frame
61
-
62
- cv2.imwrite(name, frame)
63
- counter +=1
64
-
65
- #yield None,frame
66
- else:
67
- break
68
- # Release everything if job is finished
69
- print(path)
70
- os.system(f'''ffmpeg -framerate 20 -pattern_type glob -i '{path}/*.png' -c:v libx264 -pix_fmt yuv420p {outname} -y''')
71
- return outname,frame
72
-
73
 
 
 
74
 
 
 
 
 
 
 
 
75
 
 
 
76
 
77
- def analyze_video_parallel(x, skip_frames = 5,
78
- frame_rate_out = 8, batch_size = 16):
79
- print(x)
80
 
81
- #Define path to saved images
82
- path = '/tmp/test/'
83
- os.makedirs(path, exist_ok=True)
84
-
85
- # Define name of current video as number of videos in path
86
- n_videos_in_path = len(os.listdir(path))
87
- path = f'{path}{n_videos_in_path}'
88
- os.makedirs(path, exist_ok=True)
89
 
90
- # Define name of output video
91
- outname = f'{path}_processed.mp4'
 
 
 
 
 
 
 
 
92
 
93
- if os.path.exists(outname):
94
- print('video already processed')
95
- return outname
96
-
97
- cap = cv2.VideoCapture(x)
98
- counter = 0
99
- pred_results_all = []
100
- frames_all = []
101
- while(cap.isOpened()):
102
- frames = []
103
- #start = time()
104
-
105
- while len(frames) < batch_size:
106
- #start = time()
107
- ret, frame = cap.read()
108
- if ret == False:
109
- break
110
- elif counter % skip_frames == 0:
111
- frames.append(frame)
112
- counter += 1
113
-
114
- #print(f'read time: {time()-start}')
115
-
116
- frames_all.extend(frames)
117
-
118
- # Get timing for inference
119
- start = time()
120
- print('len frames passed: ', len(frames))
121
-
122
- if len(frames) > 0:
123
- pred_results = inference_frame_par_ready(frames)
124
- print(f'inference time: {time()-start}')
125
- pred_results_all.extend(pred_results)
126
 
127
- # break while loop when return of the image reader is False
128
- if ret == False:
129
- break
130
-
131
- print('exited prediction loop')
132
- # Release everything if job is finished
133
- cap.release()
134
-
135
- start = time()
136
- pool = mp.Pool(mp.cpu_count()-2)
137
- pool_out = pool.map(process_frame,
138
- list(zip(pred_results_all,
139
- frames_all,
140
- [i for i in range(len(pred_results_all))])))
141
- pool.close()
142
- print(f'pool time: {time()-start}')
143
 
144
- start = time()
145
- counter = 0
146
- for pool_out_tmp in pool_out:
147
- name = os.path.join(path,f'{counter:05d}.png')
148
- cv2.imwrite(name, pool_out_tmp)
149
- counter +=1
150
- yield None,pool_out_tmp
151
-
152
- print(f'write time: {time()-start}')
153
-
154
- # Create video from predicted images
155
- print(path)
156
- os.system(f'''ffmpeg -framerate {frame_rate_out} -pattern_type glob -i '{path}/*.png' -c:v libx264 -pix_fmt yuv420p {outname} -y''')
157
- return outname, pool_out_tmp
158
-
159
 
160
- def set_example_image(example: list) -> dict:
161
- return gr.Video.update(value=example[0])
162
-
163
- def show_video(example: list) -> dict:
164
- return gr.Video.update(value=example[0])
165
-
166
- with gr.Blocks(title='Shark Patrol',theme=gr.themes.Soft(),live=True,) as demo:
167
- gr.Markdown("Alpha Demo of the Sharkpatrol Oceanlife Detector.")
168
- with gr.Tab("Preloaded Examples"):
169
-
170
- with gr.Row():
171
- video_example = gr.Video(source='upload',include_audio=False,stream=True)
172
- with gr.Row():
173
- paths = sorted(pathlib.Path('videos_example/').rglob('*rgb.mp4'))
174
- example_preds = gr.Dataset(components=[video_example],
175
- samples=[[path.as_posix()]
176
- for path in paths])
177
- example_preds.click(fn=show_video,
178
- inputs=example_preds,
179
- outputs=video_example)
180
 
181
- with gr.Tab("Test your own Video"):
182
- with gr.Row():
183
- video_input = gr.Video(source='upload',include_audio=False)
184
- #video_input.style(witdh='50%',height='50%')
185
- image_temp = gr.Image()
186
- with gr.Row():
187
- video_output = gr.Video()
188
-
189
- #video_output.style(witdh='50%',height='50%')
190
-
191
- video_button = gr.Button("Analyze your Video")
192
- with gr.Row():
193
- paths = sorted(pathlib.Path('videos_example/').rglob('*.mp4'))
194
- example_images = gr.Dataset(components=[video_input],
195
- samples=[[path.as_posix()]
196
- for path in paths if 'raw_videos' in str(path)])
197
 
198
- video_button.click(analize_video_serial, inputs=video_input, outputs=[video_output,image_temp])
199
-
200
- example_images.click(fn=set_example_image,
201
- inputs=example_images,
202
- outputs=video_input)
203
-
204
-
205
  demo.queue()
206
- if os.getenv('SYSTEM') == 'spaces':
207
- demo.launch(width='40%',auth=(os.environ.get('SHARK_USERNAME'), os.environ.get('SHARK_PASSWORD')))
208
- else:
209
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
1
 
2
+ import subprocess
3
+ import os
4
  if os.getenv('SYSTEM') == 'spaces':
5
 
6
  subprocess.call('pip install -U openmim'.split())
 
11
  subprocess.call('mim install mmdet'.split())
12
  subprocess.call('pip install opencv-python-headless==4.5.5.64'.split())
13
  subprocess.call('pip install git+https://github.com/cocodataset/panopticapi.git'.split())
 
14
 
15
+ import gradio as gr
16
+
17
+ from huggingface_hub import snapshot_download
18
  import cv2
19
  import dotenv
20
  dotenv.load_dotenv()
 
27
  import os
28
  import pathlib
29
  import multiprocessing as mp
 
30
  from time import time
31
 
32
 
33
+ REPO_ID='SharkSpace/videos_examples'
34
+ snapshot_download(repo_id=REPO_ID, token=os.environ.get('SHARK_MODEL'),repo_type='dataset',local_dir='videos_example')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
 
36
+ def process_video(input_video, out_fps = 'auto', skip_frames = 7):
37
+ cap = cv2.VideoCapture(input_video)
38
 
39
+ output_path = "output.mp4"
40
+ if out_fps != 'auto' and type(out_fps) == int:
41
+ fps = int(out_fps)
42
+ else:
43
+ fps = int(cap.get(cv2.CAP_PROP_FPS))
44
+ if out_fps == 'auto':
45
+ fps = int(fps / skip_frames)
46
 
47
+ width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
48
+ height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
49
 
50
+ video = cv2.VideoWriter(output_path, cv2.VideoWriter_fourcc(*"mp4v"), fps, (width, height))
 
 
51
 
52
+ iterating, frame = cap.read()
53
+ cnt = 0
 
 
 
 
 
 
54
 
55
+ while iterating:
56
+ if (cnt % skip_frames) == 0:
57
+ # flip frame vertically
58
+ display_frame = inference_frame_serial(frame)
59
+ video.write(cv2.cvtColor(display_frame, cv2.COLOR_BGR2RGB))
60
+ print('sending frame')
61
+ print(cnt)
62
+ yield cv2.cvtColor(display_frame, cv2.COLOR_BGR2RGB), cv2.cvtColor(frame, cv2.COLOR_BGR2RGB), None
63
+ cnt += 1
64
+ iterating, frame = cap.read()
65
 
66
+ video.release()
67
+ yield None, None, output_path
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
 
69
+ with gr.Blocks() as demo:
70
+ with gr.Row():
71
+ input_video = gr.Video(label="Input")
72
+ output_video = gr.Video(label="Output Video")
 
 
 
 
 
 
 
 
 
 
 
 
73
 
74
+ with gr.Row():
75
+ processed_frames = gr.Image(label="Live Frame")
76
+ original_frames = gr.Image(label="Original Frame")
 
 
 
 
 
 
 
 
 
 
 
 
77
 
78
+ with gr.Row():
79
+ paths = sorted(pathlib.Path('videos_example/').rglob('*.mp4'))
80
+ samples=[[path.as_posix()] for path in paths if 'raw_videos' in str(path)]
81
+ examples = gr.Examples(samples, inputs=input_video)
82
+ process_video_btn = gr.Button("Process Video")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
 
84
+ process_video_btn.click(process_video, input_video, [processed_frames, original_frames, output_video])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
 
 
 
 
 
 
 
 
86
  demo.queue()
87
+ demo.launch()
 
 
 
app3.py DELETED
@@ -1,79 +0,0 @@
1
-
2
- import subprocess
3
- import os
4
- if os.getenv('SYSTEM') == 'spaces':
5
-
6
- subprocess.call('pip install -U openmim'.split())
7
- subprocess.call('pip install python-dotenv'.split())
8
- subprocess.call('pip install torch==1.12.1+cu113 torchvision==0.13.1+cu113 torchaudio==0.12.1 --extra-index-url https://download.pytorch.org/whl/cu113'.split())
9
- subprocess.call('mim install mmcv>=2.0.0'.split())
10
- subprocess.call('mim install mmengine'.split())
11
- subprocess.call('mim install mmdet'.split())
12
- subprocess.call('pip install opencv-python-headless==4.5.5.64'.split())
13
- subprocess.call('pip install git+https://github.com/cocodataset/panopticapi.git'.split())
14
-
15
- import gradio as gr
16
-
17
- from huggingface_hub import snapshot_download
18
- import cv2
19
- import dotenv
20
- dotenv.load_dotenv()
21
- import numpy as np
22
- import gradio as gr
23
- import glob
24
- from inference import inference_frame,inference_frame_serial
25
- from inference import inference_frame_par_ready
26
- from inference import process_frame
27
- import os
28
- import pathlib
29
- import multiprocessing as mp
30
- from time import time
31
-
32
-
33
- REPO_ID='SharkSpace/videos_examples'
34
- snapshot_download(repo_id=REPO_ID, token=os.environ.get('SHARK_MODEL'),repo_type='dataset',local_dir='videos_example')
35
-
36
-
37
-
38
-
39
-
40
-
41
-
42
- def process_video(input_video):
43
- cap = cv2.VideoCapture(input_video)
44
-
45
- output_path = "output.mp4"
46
-
47
- fps = int(cap.get(cv2.CAP_PROP_FPS))
48
- width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
49
- height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
50
-
51
- video = cv2.VideoWriter(output_path, cv2.VideoWriter_fourcc(*"mp4v"), fps, (width, height))
52
-
53
- iterating, frame = cap.read()
54
- while iterating:
55
- # flip frame vertically
56
- display_frame = inference_frame_serial(frame)
57
- video.write(frame)
58
- yield display_frame, None
59
- iterating, frame = cap.read()
60
-
61
- video.release()
62
- yield display_frame, output_path
63
-
64
- with gr.Blocks() as demo:
65
- with gr.Row():
66
- input_video = gr.Video(label="input")
67
- processed_frames = gr.Image(label="last frame")
68
- output_video = gr.Video(label="output")
69
-
70
- with gr.Row():
71
- paths = sorted(pathlib.Path('videos_example/').rglob('*.mp4'))
72
- samples=[[path.as_posix()] for path in paths if 'raw_videos' in str(path)]
73
- examples = gr.Examples(samples, inputs=input_video)
74
- process_video_btn = gr.Button("process video")
75
-
76
- process_video_btn.click(process_video, input_video, [processed_frames, output_video])
77
-
78
- demo.queue()
79
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
app_legacy.py ADDED
@@ -0,0 +1,209 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import os
3
+ import subprocess
4
+
5
+ from huggingface_hub import snapshot_download
6
+
7
+
8
+ REPO_ID='SharkSpace/videos_examples'
9
+ snapshot_download(repo_id=REPO_ID, token=os.environ.get('SHARK_MODEL'),repo_type='dataset',local_dir='videos_example')
10
+
11
+
12
+ if os.getenv('SYSTEM') == 'spaces':
13
+
14
+ subprocess.call('pip install -U openmim'.split())
15
+ subprocess.call('pip install python-dotenv'.split())
16
+ subprocess.call('pip install torch==1.12.1+cu113 torchvision==0.13.1+cu113 torchaudio==0.12.1 --extra-index-url https://download.pytorch.org/whl/cu113'.split())
17
+ subprocess.call('mim install mmcv>=2.0.0'.split())
18
+ subprocess.call('mim install mmengine'.split())
19
+ subprocess.call('mim install mmdet'.split())
20
+ subprocess.call('pip install opencv-python-headless==4.5.5.64'.split())
21
+ subprocess.call('pip install git+https://github.com/cocodataset/panopticapi.git'.split())
22
+
23
+
24
+ import cv2
25
+ import dotenv
26
+ dotenv.load_dotenv()
27
+ import numpy as np
28
+ import gradio as gr
29
+ import glob
30
+ from inference import inference_frame,inference_frame_serial
31
+ from inference import inference_frame_par_ready
32
+ from inference import process_frame
33
+ import os
34
+ import pathlib
35
+ import multiprocessing as mp
36
+
37
+ from time import time
38
+
39
+
40
+ def analize_video_serial(x):
41
+ print(x)
42
+ path = '/tmp/test/'
43
+ os.makedirs(path, exist_ok=True)
44
+ videos = len(os.listdir(path))
45
+ path = f'{path}{videos}'
46
+ os.makedirs(path, exist_ok=True)
47
+ outname = f'{path}_processed.mp4'
48
+ if os.path.exists(outname):
49
+ print('video already processed')
50
+ return outname
51
+ cap = cv2.VideoCapture(x)
52
+ counter = 0
53
+ import pdb;pdb.set_trace()
54
+ while(cap.isOpened()):
55
+ ret, frame = cap.read()
56
+ yield None, frame
57
+ if ret==True:
58
+ name = os.path.join(path,f'{counter:05d}.png')
59
+ frame = inference_frame_serial(frame)
60
+ # write the flipped frame
61
+
62
+ cv2.imwrite(name, frame)
63
+ counter +=1
64
+
65
+ #yield None,frame
66
+ else:
67
+ break
68
+ # Release everything if job is finished
69
+ print(path)
70
+ os.system(f'''ffmpeg -framerate 20 -pattern_type glob -i '{path}/*.png' -c:v libx264 -pix_fmt yuv420p {outname} -y''')
71
+ return outname,frame
72
+
73
+
74
+
75
+
76
+
77
+ def analyze_video_parallel(x, skip_frames = 5,
78
+ frame_rate_out = 8, batch_size = 16):
79
+ print(x)
80
+
81
+ #Define path to saved images
82
+ path = '/tmp/test/'
83
+ os.makedirs(path, exist_ok=True)
84
+
85
+ # Define name of current video as number of videos in path
86
+ n_videos_in_path = len(os.listdir(path))
87
+ path = f'{path}{n_videos_in_path}'
88
+ os.makedirs(path, exist_ok=True)
89
+
90
+ # Define name of output video
91
+ outname = f'{path}_processed.mp4'
92
+
93
+ if os.path.exists(outname):
94
+ print('video already processed')
95
+ return outname
96
+
97
+ cap = cv2.VideoCapture(x)
98
+ counter = 0
99
+ pred_results_all = []
100
+ frames_all = []
101
+ while(cap.isOpened()):
102
+ frames = []
103
+ #start = time()
104
+
105
+ while len(frames) < batch_size:
106
+ #start = time()
107
+ ret, frame = cap.read()
108
+ if ret == False:
109
+ break
110
+ elif counter % skip_frames == 0:
111
+ frames.append(frame)
112
+ counter += 1
113
+
114
+ #print(f'read time: {time()-start}')
115
+
116
+ frames_all.extend(frames)
117
+
118
+ # Get timing for inference
119
+ start = time()
120
+ print('len frames passed: ', len(frames))
121
+
122
+ if len(frames) > 0:
123
+ pred_results = inference_frame_par_ready(frames)
124
+ print(f'inference time: {time()-start}')
125
+ pred_results_all.extend(pred_results)
126
+
127
+ # break while loop when return of the image reader is False
128
+ if ret == False:
129
+ break
130
+
131
+ print('exited prediction loop')
132
+ # Release everything if job is finished
133
+ cap.release()
134
+
135
+ start = time()
136
+ pool = mp.Pool(mp.cpu_count()-2)
137
+ pool_out = pool.map(process_frame,
138
+ list(zip(pred_results_all,
139
+ frames_all,
140
+ [i for i in range(len(pred_results_all))])))
141
+ pool.close()
142
+ print(f'pool time: {time()-start}')
143
+
144
+ start = time()
145
+ counter = 0
146
+ for pool_out_tmp in pool_out:
147
+ name = os.path.join(path,f'{counter:05d}.png')
148
+ cv2.imwrite(name, pool_out_tmp)
149
+ counter +=1
150
+ yield None,pool_out_tmp
151
+
152
+ print(f'write time: {time()-start}')
153
+
154
+ # Create video from predicted images
155
+ print(path)
156
+ os.system(f'''ffmpeg -framerate {frame_rate_out} -pattern_type glob -i '{path}/*.png' -c:v libx264 -pix_fmt yuv420p {outname} -y''')
157
+ return outname, pool_out_tmp
158
+
159
+
160
+ def set_example_image(example: list) -> dict:
161
+ return gr.Video.update(value=example[0])
162
+
163
+ def show_video(example: list) -> dict:
164
+ return gr.Video.update(value=example[0])
165
+
166
+ with gr.Blocks(title='Shark Patrol',theme=gr.themes.Soft(),live=True,) as demo:
167
+ gr.Markdown("Alpha Demo of the Sharkpatrol Oceanlife Detector.")
168
+ with gr.Tab("Preloaded Examples"):
169
+
170
+ with gr.Row():
171
+ video_example = gr.Video(source='upload',include_audio=False,stream=True)
172
+ with gr.Row():
173
+ paths = sorted(pathlib.Path('videos_example/').rglob('*rgb.mp4'))
174
+ example_preds = gr.Dataset(components=[video_example],
175
+ samples=[[path.as_posix()]
176
+ for path in paths])
177
+ example_preds.click(fn=show_video,
178
+ inputs=example_preds,
179
+ outputs=video_example)
180
+
181
+ with gr.Tab("Test your own Video"):
182
+ with gr.Row():
183
+ video_input = gr.Video(source='upload',include_audio=False)
184
+ #video_input.style(witdh='50%',height='50%')
185
+ image_temp = gr.Image()
186
+ with gr.Row():
187
+ video_output = gr.Video()
188
+
189
+ #video_output.style(witdh='50%',height='50%')
190
+
191
+ video_button = gr.Button("Analyze your Video")
192
+ with gr.Row():
193
+ paths = sorted(pathlib.Path('videos_example/').rglob('*.mp4'))
194
+ example_images = gr.Dataset(components=[video_input],
195
+ samples=[[path.as_posix()]
196
+ for path in paths if 'raw_videos' in str(path)])
197
+
198
+ video_button.click(analize_video_serial, inputs=video_input, outputs=[video_output,image_temp])
199
+
200
+ example_images.click(fn=set_example_image,
201
+ inputs=example_images,
202
+ outputs=video_input)
203
+
204
+
205
+ demo.queue()
206
+ if os.getenv('SYSTEM') == 'spaces':
207
+ demo.launch(width='40%',auth=(os.environ.get('SHARK_USERNAME'), os.environ.get('SHARK_PASSWORD')))
208
+ else:
209
+ demo.launch()
inference.py CHANGED
@@ -15,10 +15,6 @@ from huggingface_hub import hf_hub_download
15
  from huggingface_hub import snapshot_download
16
  from time import time
17
 
18
- import concurrent.futures
19
- import threading
20
-
21
-
22
  classes= ['Beach',
23
  'Sea',
24
  'Wave',
@@ -73,23 +69,16 @@ classes= ['Beach',
73
  'Bull shark']*3
74
 
75
 
76
-
77
-
78
-
79
-
80
  REPO_ID = "SharkSpace/maskformer_model"
81
  FILENAME = "mask2former"
82
 
83
  snapshot_download(repo_id=REPO_ID, token= os.environ.get('SHARK_MODEL'),local_dir='model/')
84
 
85
-
86
-
87
-
88
  # Choose to use a config and initialize the detector
89
  config_file ='model/mask2former_swin-t-p4-w7-224_8xb2-lsj-50e_coco-panoptic/mask2former_swin-t-p4-w7-224_8xb2-lsj-50e_coco-panoptic.py'
90
  #'/content/mmdetection/configs/panoptic_fpn/panoptic-fpn_r50_fpn_ms-3x_coco.py'
91
  # Setup a checkpoint file to load
92
- checkpoint_file ='model/mask2former_swin-t-p4-w7-224_8xb2-lsj-50e_coco-panoptic/checkpoint.pth'
93
  # '/content/drive/MyDrive/Algorithms/weights/shark_panoptic_weights_16_4_23/panoptic-fpn_r50_fpn_ms-3x_coco/epoch_36.pth'
94
 
95
  # register all modules in mmdet into the registries
@@ -106,7 +95,9 @@ print(dir(visualizer))
106
  # then pass to the model in init_detector
107
  visualizer.dataset_meta = model.dataset_meta
108
  def inference_frame_serial(image):
 
109
  result = inference_detector(model, image)
 
110
  # show the results
111
  visualizer.add_datasample(
112
  'result',
@@ -118,7 +109,6 @@ def inference_frame_serial(image):
118
  frame = visualizer.get_image()
119
  return frame
120
 
121
-
122
  def inference_frame(image):
123
  result = inference_detector(model, image)
124
  # show the results
 
15
  from huggingface_hub import snapshot_download
16
  from time import time
17
 
 
 
 
 
18
  classes= ['Beach',
19
  'Sea',
20
  'Wave',
 
69
  'Bull shark']*3
70
 
71
 
 
 
 
 
72
  REPO_ID = "SharkSpace/maskformer_model"
73
  FILENAME = "mask2former"
74
 
75
  snapshot_download(repo_id=REPO_ID, token= os.environ.get('SHARK_MODEL'),local_dir='model/')
76
 
 
 
 
77
  # Choose to use a config and initialize the detector
78
  config_file ='model/mask2former_swin-t-p4-w7-224_8xb2-lsj-50e_coco-panoptic/mask2former_swin-t-p4-w7-224_8xb2-lsj-50e_coco-panoptic.py'
79
  #'/content/mmdetection/configs/panoptic_fpn/panoptic-fpn_r50_fpn_ms-3x_coco.py'
80
  # Setup a checkpoint file to load
81
+ checkpoint_file ='model/mask2former_swin-t-p4-w7-224_8xb2-lsj-50e_coco-panoptic/checkpoint_v2.pth'
82
  # '/content/drive/MyDrive/Algorithms/weights/shark_panoptic_weights_16_4_23/panoptic-fpn_r50_fpn_ms-3x_coco/epoch_36.pth'
83
 
84
  # register all modules in mmdet into the registries
 
95
  # then pass to the model in init_detector
96
  visualizer.dataset_meta = model.dataset_meta
97
  def inference_frame_serial(image):
98
+ start = time()
99
  result = inference_detector(model, image)
100
+ print(f'inference time: {time()-start}')
101
  # show the results
102
  visualizer.add_datasample(
103
  'result',
 
109
  frame = visualizer.get_image()
110
  return frame
111
 
 
112
  def inference_frame(image):
113
  result = inference_detector(model, image)
114
  # show the results