Alexander Fengler commited on
Commit
5636b5c
1 Parent(s): f3a075d

adding multiprocessing version of image analyze

Browse files
Files changed (3) hide show
  1. .gitignore +3 -0
  2. app.py +116 -19
  3. inference.py +68 -9
.gitignore CHANGED
@@ -1,3 +1,6 @@
 
 
 
1
  # Byte-compiled / optimized / DLL files
2
  __pycache__/
3
  *.py[cod]
 
1
+ # Added
2
+ tmp/
3
+
4
  # Byte-compiled / optimized / DLL files
5
  __pycache__/
6
  *.py[cod]
app.py CHANGED
@@ -28,8 +28,11 @@ import numpy as np
28
  import gradio as gr
29
  import glob
30
  from inference import inference_frame
 
 
31
  import os
32
  import pathlib
 
33
 
34
  from time import time
35
 
@@ -37,7 +40,7 @@ def analyze_video(x, skip_frames = 5, frame_rate_out = 8):
37
  print(x)
38
 
39
  #Define path to saved images
40
- path = '/tmp/test/'
41
  os.makedirs(path, exist_ok=True)
42
 
43
  # Define name of current video as number of videos in path
@@ -56,34 +59,128 @@ def analyze_video(x, skip_frames = 5, frame_rate_out = 8):
56
  counter = 0
57
 
58
  while(cap.isOpened()):
 
59
  start = time()
60
- ret, frame = cap.read()
 
 
 
 
 
61
  print(f'read time: {time()-start}')
62
 
63
- if ret==True:
64
- if counter % skip_frames == 0:
65
- name = os.path.join(path,f'{counter:05d}.png')
66
- start = time()
67
- frame = inference_frame(frame)
68
- print(f'inference time: {time()-start}')
69
- # write the flipped frame
70
- start = time()
71
- cv2.imwrite(name, frame)
72
- print(f'write time: {time()-start}')
73
- else:
74
- pass
75
- print(counter)
76
- counter +=1
77
- else:
78
- break
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
  # Release everything if job is finished
81
  cap.release()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
 
83
  # Create video from predicted images
84
  print(path)
85
  os.system(f'''ffmpeg -framerate {frame_rate_out} -pattern_type glob -i '{path}/*.png' -c:v libx264 -pix_fmt yuv420p {outname} -y''')
86
  return outname
 
87
 
88
  def set_example_image(example: list) -> dict:
89
  return gr.Video.update(value=example[0])
@@ -120,7 +217,7 @@ with gr.Blocks(title='Shark Patrol',theme=gr.themes.Soft(),live=True,) as demo:
120
  samples=[[path.as_posix()]
121
  for path in paths if 'videos_side_by_side' not in str(path)])
122
 
123
- video_button.click(analyze_video, inputs=video_input, outputs=video_output)
124
 
125
  example_images.click(fn=set_example_image,
126
  inputs=example_images,
 
28
  import gradio as gr
29
  import glob
30
  from inference import inference_frame
31
+ from inference import inference_frame_par_ready
32
+ from inference import process_frame
33
  import os
34
  import pathlib
35
+ import multiprocessing as mp
36
 
37
  from time import time
38
 
 
40
  print(x)
41
 
42
  #Define path to saved images
43
+ path = 'tmp/test/'
44
  os.makedirs(path, exist_ok=True)
45
 
46
  # Define name of current video as number of videos in path
 
59
  counter = 0
60
 
61
  while(cap.isOpened()):
62
+ frames = []
63
  start = time()
64
+ for i in range(16):
65
+ start = time()
66
+ ret, frame = cap.read()
67
+ frames.append(frame)
68
+ if ret == False:
69
+ break
70
  print(f'read time: {time()-start}')
71
 
72
+ #if ret==True:
73
+
74
+ #if counter % skip_frames == 0:
75
+ name = os.path.join(path,f'{counter:05d}.png')
76
+ # Get timing for inference
77
+ start = time()
78
+ frames = inference_frame(frames)
79
+ print(f'inference time: {time()-start}')
80
+ # write the flipped frame
81
+
82
+ start = time()
83
+ for frame in frames:
84
+ name = os.path.join(path,f'{counter:05d}.png')
85
+ cv2.imwrite(name, frame)
86
+ counter +=1
87
+ print(f'write time: {time()-start}')
88
+ # else:
89
+
90
+ # print(counter)
91
+ # counter +=1
92
+ # else:
93
+ # break
94
+
95
+ # Release everything if job is finished
96
+ cap.release()
97
+
98
+ # Create video from predicted images
99
+ print(path)
100
+ os.system(f'''ffmpeg -framerate {frame_rate_out} -pattern_type glob -i '{path}/*.png' -c:v libx264 -pix_fmt yuv420p {outname} -y''')
101
+ return outname
102
+
103
+
104
+ def analyze_video_parallel(x, skip_frames = 5,
105
+ frame_rate_out = 8, batch_size = 16):
106
+ print(x)
107
+
108
+ #Define path to saved images
109
+ path = '/tmp/test/'
110
+ os.makedirs(path, exist_ok=True)
111
+
112
+ # Define name of current video as number of videos in path
113
+ n_videos_in_path = len(os.listdir(path))
114
+ path = f'{path}{n_videos_in_path}'
115
+ os.makedirs(path, exist_ok=True)
116
+
117
+ # Define name of output video
118
+ outname = f'{path}_processed.mp4'
119
+
120
+ if os.path.exists(outname):
121
+ print('video already processed')
122
+ return outname
123
 
124
+ cap = cv2.VideoCapture(x)
125
+ counter = 0
126
+ pred_results_all = []
127
+ frames_all = []
128
+ while(cap.isOpened()):
129
+ frames = []
130
+ #start = time()
131
+
132
+ while len(frames) < batch_size:
133
+ #start = time()
134
+ ret, frame = cap.read()
135
+ if ret == False:
136
+ break
137
+ elif counter % skip_frames == 0:
138
+ frames.append(frame)
139
+ counter += 1
140
+
141
+ #print(f'read time: {time()-start}')
142
+
143
+ frames_all.extend(frames)
144
+
145
+ # Get timing for inference
146
+ start = time()
147
+ print('len frames passed: ', len(frames))
148
+
149
+ if len(frames) > 0:
150
+ pred_results = inference_frame_par_ready(frames)
151
+ print(f'inference time: {time()-start}')
152
+ pred_results_all.extend(pred_results)
153
+
154
+ # break while loop when return of the image reader is False
155
+ if ret == False:
156
+ break
157
+
158
+ print('exited prediction loop')
159
  # Release everything if job is finished
160
  cap.release()
161
+
162
+ start = time()
163
+ pool = mp.Pool(mp.cpu_count()-2)
164
+ pool_out = pool.map(process_frame,
165
+ list(zip(pred_results_all,
166
+ frames_all,
167
+ [i for i in range(len(pred_results_all))])))
168
+ pool.close()
169
+ print(f'pool time: {time()-start}')
170
+
171
+ start = time()
172
+ counter = 0
173
+ for pool_out_tmp in pool_out:
174
+ name = os.path.join(path,f'{counter:05d}.png')
175
+ cv2.imwrite(name, pool_out_tmp)
176
+ counter +=1
177
+ print(f'write time: {time()-start}')
178
 
179
  # Create video from predicted images
180
  print(path)
181
  os.system(f'''ffmpeg -framerate {frame_rate_out} -pattern_type glob -i '{path}/*.png' -c:v libx264 -pix_fmt yuv420p {outname} -y''')
182
  return outname
183
+
184
 
185
  def set_example_image(example: list) -> dict:
186
  return gr.Video.update(value=example[0])
 
217
  samples=[[path.as_posix()]
218
  for path in paths if 'videos_side_by_side' not in str(path)])
219
 
220
+ video_button.click(analyze_video_parallel, inputs=video_input, outputs=video_output)
221
 
222
  example_images.click(fn=set_example_image,
223
  inputs=example_images,
inference.py CHANGED
@@ -13,6 +13,10 @@ from mmdet.registry import VISUALIZERS
13
 
14
  from huggingface_hub import hf_hub_download
15
  from huggingface_hub import snapshot_download
 
 
 
 
16
 
17
 
18
  classes= ['Beach',
@@ -76,7 +80,7 @@ classes= ['Beach',
76
  REPO_ID = "SharkSpace/maskformer_model"
77
  FILENAME = "mask2former"
78
 
79
- snapshot_download(repo_id=REPO_ID, token= os.environ.get('SHARK_MODEL'),local_dir='model/')
80
 
81
 
82
 
@@ -94,20 +98,75 @@ register_all_modules()
94
  # build the model from a config file and a checkpoint file
95
  model = init_detector(config_file, checkpoint_file, device='cuda:0') # or device='cuda:0'
96
  model.dataset_meta['classes'] = classes
 
97
  # init visualizer(run the block only once in jupyter notebook)
98
  visualizer = VISUALIZERS.build(model.cfg.visualizer)
 
99
  # the dataset_meta is loaded from the checkpoint and
100
  # then pass to the model in init_detector
101
  visualizer.dataset_meta = model.dataset_meta
 
102
  def inference_frame(image):
 
103
  result = inference_detector(model, image)
104
  # show the results
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
105
  visualizer.add_datasample(
106
- 'result',
107
- image,
108
- data_sample=result,
109
- draw_gt = None,
110
- show=False
111
- )
112
- frame = visualizer.get_image()
113
- return frame
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
  from huggingface_hub import hf_hub_download
15
  from huggingface_hub import snapshot_download
16
+ from time import time
17
+
18
+ import concurrent.futures
19
+ import threading
20
 
21
 
22
  classes= ['Beach',
 
80
  REPO_ID = "SharkSpace/maskformer_model"
81
  FILENAME = "mask2former"
82
 
83
+ # snapshot_download(repo_id=REPO_ID, token= os.environ.get('SHARK_MODEL'),local_dir='model/')
84
 
85
 
86
 
 
98
  # build the model from a config file and a checkpoint file
99
  model = init_detector(config_file, checkpoint_file, device='cuda:0') # or device='cuda:0'
100
  model.dataset_meta['classes'] = classes
101
+ print(model.cfg.visualizer)
102
  # init visualizer(run the block only once in jupyter notebook)
103
  visualizer = VISUALIZERS.build(model.cfg.visualizer)
104
+ print(dir(visualizer))
105
  # the dataset_meta is loaded from the checkpoint and
106
  # then pass to the model in init_detector
107
  visualizer.dataset_meta = model.dataset_meta
108
+
109
  def inference_frame(image):
110
+ #import ipdb; ipdb.set_trace()
111
  result = inference_detector(model, image)
112
  # show the results
113
+ #import ipdb; ipdb.set_trace()
114
+ frames = []
115
+ cnt=0
116
+
117
+ for res in result:
118
+ visualizer.add_datasample(
119
+ 'result',
120
+ image[cnt],
121
+ data_sample=res.numpy(),
122
+ draw_gt = None,
123
+ show=False
124
+ )
125
+ frame = visualizer.get_image()
126
+ frames.append(frame)
127
+ cnt+=1
128
+
129
+ #frames = process_frames(result, image, visualizer)
130
+ end = time()
131
+ print("Time taken for drawing: ", end-start)
132
+ return frames
133
+
134
+ def inference_frame_par_ready(image):
135
+ result = inference_detector(model, image)
136
+ return [result[i].numpy() for i in range(len(result))]
137
+
138
+ def process_frame(in_tuple = (None, None, None)):
139
  visualizer.add_datasample(
140
+ 'result',
141
+ in_tuple[1], #image,
142
+ data_sample=in_tuple[0], #result
143
+ draw_gt = None,
144
+ show=False
145
+ )
146
+
147
+ #frame = visualizer.get_image()
148
+ #print(in_tuple[2])
149
+ return visualizer.get_image()
150
+
151
+ #def process_frame(frame):
152
+
153
+ # def process_frames(result, image, visualizer):
154
+ # frames = []
155
+ # lock = threading.Lock()
156
+
157
+ # def process_data(cnt, res, img):
158
+ # visualizer.add_datasample('result', img, data_sample=res, draw_gt=None, show=False)
159
+ # frame = visualizer.get_image()
160
+ # with lock:
161
+ # frames.append(frame)
162
+
163
+ # threads = []
164
+ # for cnt, res in enumerate(result):
165
+ # t = threading.Thread(target=process_data, args=(cnt, res, image[cnt]))
166
+ # threads.append(t)
167
+ # t.start()
168
+
169
+ # for t in threads:
170
+ # t.join()
171
+
172
+ # return frames