Ivan Felipe Rodriguez commited on
Commit
021ea63
1 Parent(s): 5636b5c

testing new app for realtime pred

Browse files
Files changed (3) hide show
  1. app.py +36 -55
  2. app3.py +79 -0
  3. inference.py +15 -5
app.py CHANGED
@@ -6,7 +6,7 @@ from huggingface_hub import snapshot_download
6
 
7
 
8
  REPO_ID='SharkSpace/videos_examples'
9
- snapshot_download(repo_id=REPO_ID, token= os.environ.get('SHARK_MODEL'),repo_type='dataset',local_dir='videos_example')
10
 
11
 
12
  if os.getenv('SYSTEM') == 'spaces':
@@ -27,7 +27,7 @@ dotenv.load_dotenv()
27
  import numpy as np
28
  import gradio as gr
29
  import glob
30
- from inference import inference_frame
31
  from inference import inference_frame_par_ready
32
  from inference import process_frame
33
  import os
@@ -36,69 +36,42 @@ import multiprocessing as mp
36
 
37
  from time import time
38
 
39
- def analyze_video(x, skip_frames = 5, frame_rate_out = 8):
40
- print(x)
41
 
42
- #Define path to saved images
43
- path = 'tmp/test/'
 
44
  os.makedirs(path, exist_ok=True)
45
-
46
- # Define name of current video as number of videos in path
47
- n_videos_in_path = len(os.listdir(path))
48
- path = f'{path}{n_videos_in_path}'
49
  os.makedirs(path, exist_ok=True)
50
-
51
- # Define name of output video
52
  outname = f'{path}_processed.mp4'
53
-
54
  if os.path.exists(outname):
55
  print('video already processed')
56
  return outname
57
-
58
  cap = cv2.VideoCapture(x)
59
  counter = 0
60
-
61
  while(cap.isOpened()):
62
- frames = []
63
- start = time()
64
- for i in range(16):
65
- start = time()
66
- ret, frame = cap.read()
67
- frames.append(frame)
68
- if ret == False:
69
- break
70
- print(f'read time: {time()-start}')
71
-
72
- #if ret==True:
73
-
74
- #if counter % skip_frames == 0:
75
- name = os.path.join(path,f'{counter:05d}.png')
76
- # Get timing for inference
77
- start = time()
78
- frames = inference_frame(frames)
79
- print(f'inference time: {time()-start}')
80
- # write the flipped frame
81
-
82
- start = time()
83
- for frame in frames:
84
  name = os.path.join(path,f'{counter:05d}.png')
 
 
 
85
  cv2.imwrite(name, frame)
86
  counter +=1
87
- print(f'write time: {time()-start}')
88
- # else:
89
-
90
- # print(counter)
91
- # counter +=1
92
- # else:
93
- # break
94
-
95
  # Release everything if job is finished
96
- cap.release()
97
-
98
- # Create video from predicted images
99
  print(path)
100
- os.system(f'''ffmpeg -framerate {frame_rate_out} -pattern_type glob -i '{path}/*.png' -c:v libx264 -pix_fmt yuv420p {outname} -y''')
101
- return outname
 
 
 
102
 
103
 
104
  def analyze_video_parallel(x, skip_frames = 5,
@@ -174,12 +147,14 @@ def analyze_video_parallel(x, skip_frames = 5,
174
  name = os.path.join(path,f'{counter:05d}.png')
175
  cv2.imwrite(name, pool_out_tmp)
176
  counter +=1
 
 
177
  print(f'write time: {time()-start}')
178
 
179
  # Create video from predicted images
180
  print(path)
181
  os.system(f'''ffmpeg -framerate {frame_rate_out} -pattern_type glob -i '{path}/*.png' -c:v libx264 -pix_fmt yuv420p {outname} -y''')
182
- return outname
183
 
184
 
185
  def set_example_image(example: list) -> dict:
@@ -207,7 +182,10 @@ with gr.Blocks(title='Shark Patrol',theme=gr.themes.Soft(),live=True,) as demo:
207
  with gr.Row():
208
  video_input = gr.Video(source='upload',include_audio=False)
209
  #video_input.style(witdh='50%',height='50%')
 
 
210
  video_output = gr.Video()
 
211
  #video_output.style(witdh='50%',height='50%')
212
 
213
  video_button = gr.Button("Analyze your Video")
@@ -215,14 +193,17 @@ with gr.Blocks(title='Shark Patrol',theme=gr.themes.Soft(),live=True,) as demo:
215
  paths = sorted(pathlib.Path('videos_example/').rglob('*.mp4'))
216
  example_images = gr.Dataset(components=[video_input],
217
  samples=[[path.as_posix()]
218
- for path in paths if 'videos_side_by_side' not in str(path)])
219
 
220
- video_button.click(analyze_video_parallel, inputs=video_input, outputs=video_output)
221
 
222
  example_images.click(fn=set_example_image,
223
  inputs=example_images,
224
  outputs=video_input)
225
 
 
226
  demo.queue()
227
- #if os.getenv('SYSTEM') == 'spaces':
228
- demo.launch(width='40%',auth=(os.environ.get('SHARK_USERNAME'), os.environ.get('SHARK_PASSWORD')))
 
 
 
6
 
7
 
8
  REPO_ID='SharkSpace/videos_examples'
9
+ snapshot_download(repo_id=REPO_ID, token=os.environ.get('SHARK_MODEL'),repo_type='dataset',local_dir='videos_example')
10
 
11
 
12
  if os.getenv('SYSTEM') == 'spaces':
 
27
  import numpy as np
28
  import gradio as gr
29
  import glob
30
+ from inference import inference_frame,inference_frame_serial
31
  from inference import inference_frame_par_ready
32
  from inference import process_frame
33
  import os
 
36
 
37
  from time import time
38
 
 
 
39
 
40
+ def analize_video_serial(x):
41
+ print(x)
42
+ path = '/tmp/test/'
43
  os.makedirs(path, exist_ok=True)
44
+ videos = len(os.listdir(path))
45
+ path = f'{path}{videos}'
 
 
46
  os.makedirs(path, exist_ok=True)
 
 
47
  outname = f'{path}_processed.mp4'
 
48
  if os.path.exists(outname):
49
  print('video already processed')
50
  return outname
 
51
  cap = cv2.VideoCapture(x)
52
  counter = 0
53
+ import pdb;pdb.set_trace()
54
  while(cap.isOpened()):
55
+ ret, frame = cap.read()
56
+ yield None, frame
57
+ if ret==True:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
  name = os.path.join(path,f'{counter:05d}.png')
59
+ frame = inference_frame_serial(frame)
60
+ # write the flipped frame
61
+
62
  cv2.imwrite(name, frame)
63
  counter +=1
64
+
65
+ #yield None,frame
66
+ else:
67
+ break
 
 
 
 
68
  # Release everything if job is finished
 
 
 
69
  print(path)
70
+ os.system(f'''ffmpeg -framerate 20 -pattern_type glob -i '{path}/*.png' -c:v libx264 -pix_fmt yuv420p {outname} -y''')
71
+ return outname,frame
72
+
73
+
74
+
75
 
76
 
77
  def analyze_video_parallel(x, skip_frames = 5,
 
147
  name = os.path.join(path,f'{counter:05d}.png')
148
  cv2.imwrite(name, pool_out_tmp)
149
  counter +=1
150
+ yield None,pool_out_tmp
151
+
152
  print(f'write time: {time()-start}')
153
 
154
  # Create video from predicted images
155
  print(path)
156
  os.system(f'''ffmpeg -framerate {frame_rate_out} -pattern_type glob -i '{path}/*.png' -c:v libx264 -pix_fmt yuv420p {outname} -y''')
157
+ return outname, pool_out_tmp
158
 
159
 
160
  def set_example_image(example: list) -> dict:
 
182
  with gr.Row():
183
  video_input = gr.Video(source='upload',include_audio=False)
184
  #video_input.style(witdh='50%',height='50%')
185
+ image_temp = gr.Image()
186
+ with gr.Row():
187
  video_output = gr.Video()
188
+
189
  #video_output.style(witdh='50%',height='50%')
190
 
191
  video_button = gr.Button("Analyze your Video")
 
193
  paths = sorted(pathlib.Path('videos_example/').rglob('*.mp4'))
194
  example_images = gr.Dataset(components=[video_input],
195
  samples=[[path.as_posix()]
196
+ for path in paths if 'raw_videos' in str(path)])
197
 
198
+ video_button.click(analize_video_serial, inputs=video_input, outputs=[video_output,image_temp])
199
 
200
  example_images.click(fn=set_example_image,
201
  inputs=example_images,
202
  outputs=video_input)
203
 
204
+
205
  demo.queue()
206
+ if os.getenv('SYSTEM') == 'spaces':
207
+ demo.launch(width='40%',auth=(os.environ.get('SHARK_USERNAME'), os.environ.get('SHARK_PASSWORD')))
208
+ else:
209
+ demo.launch()
app3.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import subprocess
3
+ import os
4
+ if os.getenv('SYSTEM') == 'spaces':
5
+
6
+ subprocess.call('pip install -U openmim'.split())
7
+ subprocess.call('pip install python-dotenv'.split())
8
+ subprocess.call('pip install torch==1.12.1+cu113 torchvision==0.13.1+cu113 torchaudio==0.12.1 --extra-index-url https://download.pytorch.org/whl/cu113'.split())
9
+ subprocess.call('mim install mmcv>=2.0.0'.split())
10
+ subprocess.call('mim install mmengine'.split())
11
+ subprocess.call('mim install mmdet'.split())
12
+ subprocess.call('pip install opencv-python-headless==4.5.5.64'.split())
13
+ subprocess.call('pip install git+https://github.com/cocodataset/panopticapi.git'.split())
14
+
15
+ import gradio as gr
16
+
17
+ from huggingface_hub import snapshot_download
18
+ import cv2
19
+ import dotenv
20
+ dotenv.load_dotenv()
21
+ import numpy as np
22
+ import gradio as gr
23
+ import glob
24
+ from inference import inference_frame,inference_frame_serial
25
+ from inference import inference_frame_par_ready
26
+ from inference import process_frame
27
+ import os
28
+ import pathlib
29
+ import multiprocessing as mp
30
+ from time import time
31
+
32
+
33
+ REPO_ID='SharkSpace/videos_examples'
34
+ snapshot_download(repo_id=REPO_ID, token=os.environ.get('SHARK_MODEL'),repo_type='dataset',local_dir='videos_example')
35
+
36
+
37
+
38
+
39
+
40
+
41
+
42
+ def process_video(input_video):
43
+ cap = cv2.VideoCapture(input_video)
44
+
45
+ output_path = "output.mp4"
46
+
47
+ fps = int(cap.get(cv2.CAP_PROP_FPS))
48
+ width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
49
+ height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
50
+
51
+ video = cv2.VideoWriter(output_path, cv2.VideoWriter_fourcc(*"mp4v"), fps, (width, height))
52
+
53
+ iterating, frame = cap.read()
54
+ while iterating:
55
+ # flip frame vertically
56
+ display_frame = inference_frame_serial(frame)
57
+ video.write(frame)
58
+ yield display_frame, None
59
+ iterating, frame = cap.read()
60
+
61
+ video.release()
62
+ yield display_frame, output_path
63
+
64
+ with gr.Blocks() as demo:
65
+ with gr.Row():
66
+ input_video = gr.Video(label="input")
67
+ processed_frames = gr.Image(label="last frame")
68
+ output_video = gr.Video(label="output")
69
+
70
+ with gr.Row():
71
+ paths = sorted(pathlib.Path('videos_example/').rglob('*.mp4'))
72
+ samples=[[path.as_posix()] for path in paths if 'raw_videos' in str(path)]
73
+ examples = gr.Examples(samples, inputs=input_video)
74
+ process_video_btn = gr.Button("process video")
75
+
76
+ process_video_btn.click(process_video, input_video, [processed_frames, output_video])
77
+
78
+ demo.queue()
79
+ demo.launch()
inference.py CHANGED
@@ -80,7 +80,7 @@ classes= ['Beach',
80
  REPO_ID = "SharkSpace/maskformer_model"
81
  FILENAME = "mask2former"
82
 
83
- # snapshot_download(repo_id=REPO_ID, token= os.environ.get('SHARK_MODEL'),local_dir='model/')
84
 
85
 
86
 
@@ -105,12 +105,23 @@ print(dir(visualizer))
105
  # the dataset_meta is loaded from the checkpoint and
106
  # then pass to the model in init_detector
107
  visualizer.dataset_meta = model.dataset_meta
 
 
 
 
 
 
 
 
 
 
 
 
 
108
 
109
  def inference_frame(image):
110
- #import ipdb; ipdb.set_trace()
111
  result = inference_detector(model, image)
112
  # show the results
113
- #import ipdb; ipdb.set_trace()
114
  frames = []
115
  cnt=0
116
 
@@ -127,8 +138,7 @@ def inference_frame(image):
127
  cnt+=1
128
 
129
  #frames = process_frames(result, image, visualizer)
130
- end = time()
131
- print("Time taken for drawing: ", end-start)
132
  return frames
133
 
134
  def inference_frame_par_ready(image):
 
80
  REPO_ID = "SharkSpace/maskformer_model"
81
  FILENAME = "mask2former"
82
 
83
+ snapshot_download(repo_id=REPO_ID, token= os.environ.get('SHARK_MODEL'),local_dir='model/')
84
 
85
 
86
 
 
105
  # the dataset_meta is loaded from the checkpoint and
106
  # then pass to the model in init_detector
107
  visualizer.dataset_meta = model.dataset_meta
108
+ def inference_frame_serial(image):
109
+ result = inference_detector(model, image)
110
+ # show the results
111
+ visualizer.add_datasample(
112
+ 'result',
113
+ image,
114
+ data_sample=result,
115
+ draw_gt = None,
116
+ show=False
117
+ )
118
+ frame = visualizer.get_image()
119
+ return frame
120
+
121
 
122
  def inference_frame(image):
 
123
  result = inference_detector(model, image)
124
  # show the results
 
125
  frames = []
126
  cnt=0
127
 
 
138
  cnt+=1
139
 
140
  #frames = process_frames(result, image, visualizer)
141
+
 
142
  return frames
143
 
144
  def inference_frame_par_ready(image):