pesi
/

Luigi commited on
Commit
3939a91
1 Parent(s): 8072759

Use RTMO_GPU_Batch to demo

Browse files
Files changed (2) hide show
  1. demo.sh +1 -1
  2. rtmo_demo_batch.py +75 -0
demo.sh CHANGED
@@ -1,2 +1,2 @@
1
  #!/bin/sh
2
- python3 rtmo_demo.py ./video rtmo-t.fp16.onnx
 
1
  #!/bin/sh
2
+ python3 rtmo_demo_batch.py ./video rtmo-t.fp16.onnx
rtmo_demo_batch.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+
3
+ import time
4
+ import cv2
5
+ from pathlib import Path
6
+ import argparse
7
+ from rtmo_gpu import RTMO_GPU_Batch, draw_skeleton # Ensure to import RTMO_GPU_Batch
8
+
9
+ def process_video(video_path, body_estimator):
10
+ cap = cv2.VideoCapture(video_path)
11
+
12
+ batch_frames = []
13
+ frame_idxs = []
14
+ batch_size = 4 # Define a suitable batch size based on your GPU memory
15
+
16
+ while cap.isOpened():
17
+ success, frame = cap.read()
18
+
19
+ if not success:
20
+ break
21
+
22
+ batch_frames.append(frame)
23
+ frame_idxs.append(cap.get(cv2.CAP_PROP_POS_FRAMES))
24
+
25
+ # Process the batch when it's full
26
+ if len(batch_frames) == batch_size:
27
+ s = time.time()
28
+ batch_keypoints, batch_scores = body_estimator(batch_frames)
29
+ det_time = time.time() - s
30
+ print(f'Batch det: {round(batch_size / det_time, 1)} FPS')
31
+
32
+ for i, keypoints in enumerate(batch_keypoints):
33
+ scores = batch_scores[i]
34
+ frame = batch_frames[i]
35
+ img_show = frame.copy()
36
+ img_show = draw_skeleton(img_show, keypoints, scores, kpt_thr=0.3, line_width=2)
37
+ img_show = cv2.resize(img_show, (788, 525))
38
+ cv2.imshow(f'{video_path}', img_show)
39
+ cv2.waitKey(10)
40
+
41
+ # Clear the batch
42
+ batch_frames = []
43
+
44
+ # Process remaining frames if any
45
+ if batch_frames:
46
+ batch_keypoints, batch_scores = body_estimator(batch_frames)
47
+ for i, keypoints in enumerate(batch_keypoints):
48
+ scores = batch_scores[i]
49
+ frame = batch_frames[i]
50
+ img_show = frame.copy()
51
+ img_show = draw_skeleton(img_show, keypoints, scores, kpt_thr=0.3, line_width=2)
52
+ img_show = cv2.resize(img_show, (720, 480))
53
+ cv2.imshow(f'{video_path}', img_show)
54
+ #cv2.waitKey(10)
55
+
56
+ cap.release()
57
+ cv2.destroyAllWindows()
58
+
59
+ if __name__ == "__main__":
60
+ # Set up argument parsing
61
+ parser = argparse.ArgumentParser(description='Process the path to a video file folder.')
62
+ parser.add_argument('path', type=str, help='Path to the folder containing video files (required)')
63
+ parser.add_argument('model_path', type=str, help='Path to a RTMO ONNX model file (required)')
64
+
65
+ # Parse the command-line arguments
66
+ args = parser.parse_args()
67
+
68
+ onnx_model = args.model_path # Example: 'rtmo-s_8xb32-600e_body7-640x640.onnx'
69
+ model_input_size = (416, 416) if 'rtmo-t' in onnx_model.lower() else (640, 640)
70
+
71
+ # Instantiate the RTMO_GPU_Batch instead of RTMO_GPU
72
+ body_estimator = RTMO_GPU_Batch(onnx_model=onnx_model, model_input_size=model_input_size)
73
+
74
+ for mp4_path in Path(args.path).glob('*'):
75
+ process_video(str(mp4_path), body_estimator)