arad1367 commited on
Commit
1449a51
1 Parent(s): 2819ddc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +197 -198
app.py CHANGED
@@ -1,198 +1,197 @@
1
- import time
2
- import uuid
3
-
4
- import cv2
5
- import gradio as gr
6
- import numpy as np
7
- import spaces
8
- import supervision as sv
9
- import torch # Ensuring torch import remains
10
-
11
- from transformers import AutoModelForZeroShotObjectDetection, AutoProcessor
12
-
13
- # Detect if CUDA is available and set the device accordingly
14
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
15
-
16
- # Load the processor and model from Hugging Face
17
- processor = AutoProcessor.from_pretrained("omlab/omdet-turbo-swin-tiny-hf")
18
- model = AutoModelForZeroShotObjectDetection.from_pretrained("omlab/omdet-turbo-swin-tiny-hf").to(device)
19
-
20
- # Custom CSS to enhance text area visibility
21
- css = """
22
- .feedback textarea {font-size: 24px !important}
23
- """
24
-
25
- # Initialize global variables
26
- global classes
27
- global detections
28
- global labels
29
- global threshold
30
-
31
- # Set default values
32
- classes = "person, university, class, Liectenstein"
33
- detections = None
34
- labels = None
35
- threshold = 0.2
36
-
37
- # Instantiate annotators for bounding boxes, masks, and labels
38
- BOUNDING_BOX_ANNOTATOR = sv.BoundingBoxAnnotator()
39
- MASK_ANNOTATOR = sv.MaskAnnotator()
40
- LABEL_ANNOTATOR = sv.LabelAnnotator()
41
-
42
- # Frame subsampling factor for video processing efficiency
43
- SUBSAMPLE = 2
44
-
45
-
46
- def annotate_image(input_image, detections, labels) -> np.ndarray:
47
- """Applies mask, bounding box, and label annotations to a given image."""
48
- output_image = MASK_ANNOTATOR.annotate(input_image, detections)
49
- output_image = BOUNDING_BOX_ANNOTATOR.annotate(output_image, detections)
50
- output_image = LABEL_ANNOTATOR.annotate(output_image, detections, labels=labels)
51
- return output_image
52
-
53
-
54
- @spaces.GPU
55
- def process_video(input_video, confidence_threshold, classes_new, progress=gr.Progress(track_tqdm=True)):
56
- """Processes the input video frame by frame, performs object detection, and saves the output video."""
57
- global detections, labels, classes, threshold
58
- classes = classes_new
59
- threshold = confidence_threshold
60
-
61
- # Generate a unique file name for the output video
62
- result_file_name = f"output_{uuid.uuid4()}.mp4"
63
-
64
- # Read input video and set up output video writer
65
- cap = cv2.VideoCapture(input_video)
66
- video_codec = cv2.VideoWriter_fourcc(*"mp4v") # MP4 codec
67
- fps = int(cap.get(cv2.CAP_PROP_FPS))
68
- width, height = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)), int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
69
- desired_fps = fps // SUBSAMPLE
70
- iterating, frame = cap.read()
71
-
72
- # Prepare video writer for output
73
- segment_file = cv2.VideoWriter(result_file_name, video_codec, desired_fps, (width, height))
74
- batch, frames, predict_index = [], [], []
75
- n_frames = 0
76
-
77
- while iterating:
78
- if n_frames % SUBSAMPLE == 0:
79
- predict_index.append(len(frames))
80
- batch.append(frame)
81
- frames.append(frame)
82
-
83
- # Process a batch of frames at once
84
- if len(batch) == desired_fps:
85
- classes_list = classes.strip().split(",")
86
- results, fps_value = query(batch, classes_list, threshold, (width, height))
87
-
88
- for i, frame in enumerate(frames):
89
- if i in predict_index:
90
- batch_idx = predict_index.index(i)
91
- detections = sv.Detections(
92
- xyxy=results[batch_idx]["boxes"].cpu().detach().numpy(),
93
- confidence=results[batch_idx]["scores"].cpu().detach().numpy(),
94
- class_id=np.array([classes_list.index(result_class) for result_class in results[batch_idx]["classes"]]),
95
- data={"class_name": results[batch_idx]["classes"]},
96
- )
97
- labels = results[batch_idx]["classes"]
98
-
99
- frame = annotate_image(input_image=frame, detections=detections, labels=labels)
100
- segment_file.write(frame)
101
-
102
- # Finalize and yield result
103
- segment_file.release()
104
- yield result_file_name, gr.Markdown(f'<h3 style="text-align: center;">Model inference FPS (batched): {fps_value * len(batch):.2f}</h3>')
105
- result_file_name = f"output_{uuid.uuid4()}.mp4"
106
- segment_file = cv2.VideoWriter(result_file_name, video_codec, desired_fps, (width, height))
107
- batch.clear()
108
- frames.clear()
109
- predict_index.clear()
110
-
111
- iterating, frame = cap.read()
112
- n_frames += 1
113
-
114
-
115
- def query(frame_batch, classes, confidence_threshold, size=(640, 480)):
116
- """Runs inference on a batch of frames and returns the results."""
117
- inputs = processor(images=frame_batch, text=[classes] * len(frame_batch), return_tensors="pt").to(device)
118
-
119
- with torch.no_grad():
120
- start_time = time.time()
121
- outputs = model(**inputs)
122
- fps_value = 1 / (time.time() - start_time)
123
-
124
- target_sizes = torch.tensor([size[::-1]] * len(frame_batch))
125
- results = processor.post_process_grounded_object_detection(
126
- outputs=outputs, classes=[classes] * len(frame_batch), score_threshold=confidence_threshold, target_sizes=target_sizes
127
- )
128
-
129
- return results, fps_value
130
-
131
-
132
- def set_classes(classes_input):
133
- """Updates the list of classes for detection."""
134
- global classes
135
- classes = classes_input
136
-
137
-
138
- def set_confidence_threshold(confidence_threshold_input):
139
- """Updates the confidence threshold for detection."""
140
- global threshold
141
- threshold = confidence_threshold_input
142
-
143
-
144
- # Custom footer for the Gradio interface
145
- footer = """
146
- <div style="text-align: center; margin-top: 20px;">
147
- <a href="https://www.linkedin.com/in/pejman-ebrahimi-4a60151a7/" target="_blank">LinkedIn</a> |
148
- <a href="https://github.com/arad1367" target="_blank">GitHub</a> |
149
- <a href="https://arad1367.pythonanywhere.com/" target="_blank">Live demo of my PhD defense</a> |
150
- <a href="https://huggingface.co/omlab/omdet-turbo-swin-tiny-hf" target="_blank">omdet-turbo-swin-tiny-hf repo in HF</a>
151
- <br>
152
- Made with 💖 by Pejman Ebrahimi
153
- </div>
154
- """
155
-
156
- # Gradio Interface with the customized theme and DuplicateButton
157
- with gr.Blocks(theme='ParityError/Anime', css=css) as demo:
158
- gr.Markdown("## Real Time Object Detection with OmDet-Turbo")
159
- gr.Markdown(
160
- """
161
- This is a demo for real-time open vocabulary object detection using OmDet-Turbo.<br>
162
- It utilizes ZeroGPU, which allocates GPU for the first inference.<br>
163
- The actual inference FPS is displayed after processing, providing an accurate assessment of performance.<br>
164
- """
165
- )
166
-
167
- with gr.Row():
168
- input_video = gr.Video(label="Upload Video")
169
- output_video = gr.Video(label="Processed Video", streaming=True, autoplay=True)
170
- actual_fps = gr.Markdown("", visible=False)
171
-
172
- with gr.Row():
173
- classes = gr.Textbox("person, university, class, Liectenstein", label="Objects to Detect (comma separated)", elem_classes="feedback", scale=3)
174
- conf = gr.Slider(label="Confidence Threshold", minimum=0.1, maximum=1.0, value=0.2, step=0.05)
175
-
176
- with gr.Row():
177
- submit = gr.Button("Run Detection", variant="primary")
178
- duplicate_space = gr.DuplicateButton(value="Duplicate Space for private use", elem_classes="duplicate-button")
179
-
180
- example_videos = gr.Examples(
181
- examples=[["./UNI-LI.mp4", 0.3, "person, university, class, Liectenstein"]],
182
- inputs=[input_video, conf, classes],
183
- outputs=[output_video, actual_fps]
184
- )
185
-
186
- classes.submit(set_classes, classes)
187
- conf.change(set_confidence_threshold, conf)
188
-
189
- submit.click(
190
- fn=process_video,
191
- inputs=[input_video, conf, classes],
192
- outputs=[output_video, actual_fps]
193
- )
194
-
195
- gr.HTML(footer)
196
-
197
- if __name__ == "__main__":
198
- demo.launch(show_error=True)
 
1
+ import time
2
+ import uuid
3
+
4
+ import cv2
5
+ import gradio as gr
6
+ import numpy as np
7
+ import spaces
8
+ import supervision as sv
9
+ import torch
10
+
11
+ from transformers import AutoModelForZeroShotObjectDetection, AutoProcessor
12
+
13
+ # Detect if CUDA is available and set the device accordingly
14
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
15
+
16
+ # Load the processor and model from Hugging Face
17
+ processor = AutoProcessor.from_pretrained("omlab/omdet-turbo-swin-tiny-hf")
18
+ model = AutoModelForZeroShotObjectDetection.from_pretrained("omlab/omdet-turbo-swin-tiny-hf").to(device)
19
+
20
+ # Custom CSS to enhance text area visibility
21
+ css = """
22
+ .feedback textarea {font-size: 24px !important}
23
+ """
24
+
25
+ # Initialize global variables
26
+ global classes
27
+ global detections
28
+ global labels
29
+ global threshold
30
+
31
+ # Set default values
32
+ classes = "person, university, class, Liectenstein"
33
+ detections = None
34
+ labels = None
35
+ threshold = 0.2
36
+
37
+ # Instantiate annotators for bounding boxes, masks, and labels
38
+ BOX_ANNOTATOR = sv.BoxAnnotator() # Updated from BoundingBoxAnnotator
39
+ MASK_ANNOTATOR = sv.MaskAnnotator()
40
+ LABEL_ANNOTATOR = sv.LabelAnnotator()
41
+
42
+ # Frame subsampling factor for video processing efficiency
43
+ SUBSAMPLE = 2
44
+
45
+ def annotate_image(input_image, detections, labels) -> np.ndarray:
46
+ """Applies mask, bounding box, and label annotations to a given image."""
47
+ output_image = MASK_ANNOTATOR.annotate(input_image, detections)
48
+ output_image = BOX_ANNOTATOR.annotate(output_image, detections) # Updated
49
+ output_image = LABEL_ANNOTATOR.annotate(output_image, detections, labels=labels)
50
+ return output_image
51
+
52
+
53
+ @spaces.GPU
54
+ def process_video(input_video, confidence_threshold, classes_new, progress=gr.Progress(track_tqdm=True)):
55
+ """Processes the input video frame by frame, performs object detection, and saves the output video."""
56
+ global detections, labels, classes, threshold
57
+ classes = classes_new
58
+ threshold = confidence_threshold
59
+
60
+ # Generate a unique file name for the output video
61
+ result_file_name = f"output_{uuid.uuid4()}.mp4"
62
+
63
+ # Read input video and set up output video writer
64
+ cap = cv2.VideoCapture(input_video)
65
+ video_codec = cv2.VideoWriter_fourcc(*"mp4v") # MP4 codec
66
+ fps = int(cap.get(cv2.CAP_PROP_FPS))
67
+ width, height = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)), int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
68
+ desired_fps = fps // SUBSAMPLE
69
+ iterating, frame = cap.read()
70
+
71
+ # Prepare video writer for output
72
+ segment_file = cv2.VideoWriter(result_file_name, video_codec, desired_fps, (width, height))
73
+ batch, frames, predict_index = [], [], []
74
+ n_frames = 0
75
+
76
+ while iterating:
77
+ if n_frames % SUBSAMPLE == 0:
78
+ predict_index.append(len(frames))
79
+ batch.append(frame)
80
+ frames.append(frame)
81
+
82
+ # Process a batch of frames at once
83
+ if len(batch) == desired_fps:
84
+ classes_list = classes.strip().split(",")
85
+ results, fps_value = query(batch, classes_list, threshold, (width, height))
86
+
87
+ for i, frame in enumerate(frames):
88
+ if i in predict_index:
89
+ batch_idx = predict_index.index(i)
90
+ detections = sv.Detections(
91
+ xyxy=results[batch_idx]["boxes"].cpu().detach().numpy(),
92
+ confidence=results[batch_idx]["scores"].cpu().detach().numpy(),
93
+ class_id=np.array([classes_list.index(result_class) for result_class in results[batch_idx]["classes"]]),
94
+ data={"class_name": results[batch_idx]["classes"]},
95
+ )
96
+ labels = results[batch_idx]["classes"]
97
+
98
+ frame = annotate_image(input_image=frame, detections=detections, labels=labels)
99
+ segment_file.write(frame)
100
+
101
+ # Finalize and yield result
102
+ segment_file.release()
103
+ yield result_file_name, gr.Markdown(f'<h3 style="text-align: center;">Model inference FPS (batched): {fps_value * len(batch):.2f}</h3>')
104
+ result_file_name = f"output_{uuid.uuid4()}.mp4"
105
+ segment_file = cv2.VideoWriter(result_file_name, video_codec, desired_fps, (width, height))
106
+ batch.clear()
107
+ frames.clear()
108
+ predict_index.clear()
109
+
110
+ iterating, frame = cap.read()
111
+ n_frames += 1
112
+
113
+
114
+ def query(frame_batch, classes, confidence_threshold, size=(640, 480)):
115
+ """Runs inference on a batch of frames and returns the results."""
116
+ inputs = processor(images=frame_batch, text=[classes] * len(frame_batch), return_tensors="pt").to(device)
117
+
118
+ with torch.no_grad():
119
+ start_time = time.time()
120
+ outputs = model(**inputs)
121
+ fps_value = 1 / (time.time() - start_time)
122
+
123
+ target_sizes = torch.tensor([size[::-1]] * len(frame_batch))
124
+ results = processor.post_process_grounded_object_detection(
125
+ outputs=outputs, classes=[classes] * len(frame_batch), score_threshold=confidence_threshold, target_sizes=target_sizes
126
+ )
127
+
128
+ return results, fps_value
129
+
130
+
131
+ def set_classes(classes_input):
132
+ """Updates the list of classes for detection."""
133
+ global classes
134
+ classes = classes_input
135
+
136
+
137
+ def set_confidence_threshold(confidence_threshold_input):
138
+ """Updates the confidence threshold for detection."""
139
+ global threshold
140
+ threshold = confidence_threshold_input
141
+
142
+
143
+ # Custom footer for the Gradio interface
144
+ footer = """
145
+ <div style="text-align: center; margin-top: 20px;">
146
+ <a href="https://www.linkedin.com/in/pejman-ebrahimi-4a60151a7/" target="_blank">LinkedIn</a> |
147
+ <a href="https://github.com/arad1367" target="_blank">GitHub</a> |
148
+ <a href="https://arad1367.pythonanywhere.com/" target="_blank">Live demo of my PhD defense</a> |
149
+ <a href="https://huggingface.co/omlab/omdet-turbo-swin-tiny-hf" target="_blank">omdet-turbo-swin-tiny-hf repo in HF</a>
150
+ <br>
151
+ Made with 💖 by Pejman Ebrahimi
152
+ </div>
153
+ """
154
+
155
+ # Gradio Interface with the customized theme and DuplicateButton
156
+ with gr.Blocks(theme='ParityError/Anime', css=css) as demo:
157
+ gr.Markdown("## Real Time Object Detection with OmDet-Turbo")
158
+ gr.Markdown(
159
+ """
160
+ This is a demo for real-time open vocabulary object detection using OmDet-Turbo.<br>
161
+ It utilizes ZeroGPU, which allocates GPU for the first inference.<br>
162
+ The actual inference FPS is displayed after processing, providing an accurate assessment of performance.<br>
163
+ """
164
+ )
165
+
166
+ with gr.Row():
167
+ input_video = gr.Video(label="Upload Video")
168
+ output_video = gr.Video(label="Processed Video", autoplay=True) # Removed 'streaming' argument
169
+ actual_fps = gr.Markdown("", visible=False)
170
+
171
+ with gr.Row():
172
+ classes = gr.Textbox("person, university, class, Liectenstein", label="Objects to Detect (comma separated)", elem_classes="feedback", scale=3)
173
+ conf = gr.Slider(label="Confidence Threshold", minimum=0.1, maximum=1.0, value=0.2, step=0.05)
174
+
175
+ with gr.Row():
176
+ submit = gr.Button("Run Detection", variant="primary")
177
+ duplicate_space = gr.DuplicateButton(value="Duplicate Space for private use", elem_classes="duplicate-button")
178
+
179
+ example_videos = gr.Examples(
180
+ examples=[["./UNI-LI.mp4", 0.3, "person, university, class, Liectenstein"]],
181
+ inputs=[input_video, conf, classes],
182
+ outputs=[output_video, actual_fps]
183
+ )
184
+
185
+ classes.submit(set_classes, classes)
186
+ conf.change(set_confidence_threshold, conf)
187
+
188
+ submit.click(
189
+ fn=process_video,
190
+ inputs=[input_video, conf, classes],
191
+ outputs=[output_video, actual_fps]
192
+ )
193
+
194
+ gr.HTML(footer)
195
+
196
+ if __name__ == "__main__":
197
+ demo.launch(show_error=True)