luminoussg commited on
Commit
a2f3593
·
verified ·
1 Parent(s): f0c8c69

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +325 -0
app.py ADDED
@@ -0,0 +1,325 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from ultralytics import YOLO
3
+ import tempfile
4
+ import os
5
+ import cv2
6
+ import numpy as np
7
+ import torch
8
+ import atexit
9
+ import uuid
10
+
11
+ # Load the YOLOv8 pose estimation model once at the start
12
+ model = YOLO("yolov8n-pose.pt")
13
+
14
+ # Define the skeleton connections based on COCO keypoints
15
+ COCO_KEYPOINTS = [
16
+ "nose", "left_eye", "right_eye", "left_ear", "right_ear",
17
+ "left_shoulder", "right_shoulder", "left_elbow", "right_elbow",
18
+ "left_wrist", "right_wrist", "left_hip", "right_hip",
19
+ "left_knee", "right_knee", "left_ankle", "right_ankle"
20
+ ]
21
+
22
+ # Define the skeleton as pairs of keypoints indices
23
+ SKELETON_CONNECTIONS = [
24
+ (0, 1), (0, 2), # Nose to eyes
25
+ (1, 3), (2, 4), # Eyes to ears
26
+ (0, 5), (0, 6), # Nose to shoulders
27
+ (5, 6), # Shoulders to each other
28
+ (5, 7), (6, 8), # Shoulders to elbows
29
+ (7, 9), (8, 10), # Elbows to wrists
30
+ (5, 11), (6, 12), # Shoulders to hips
31
+ (11, 12), # Hips to each other
32
+ (11, 13), (12, 14), # Hips to knees
33
+ (13, 15), (14, 16) # Knees to ankles
34
+ ]
35
+
36
+ def calculate_torso_angle(keypoints, frame_height):
37
+ """
38
+ Calculate the angle of the torso with respect to the vertical axis.
39
+
40
+ Args:
41
+ keypoints (numpy.ndarray): Array of shape (17, 3) representing COCO keypoints.
42
+ frame_height (int): Height of the video frame in pixels.
43
+
44
+ Returns:
45
+ float: Angle in degrees. Returns None if keypoints are not detected properly.
46
+ """
47
+ try:
48
+ # COCO keypoint indices
49
+ LEFT_SHOULDER = 5
50
+ RIGHT_SHOULDER = 6
51
+ LEFT_HIP = 11
52
+ RIGHT_HIP = 12
53
+
54
+ # Extract shoulder and hip coordinates
55
+ left_shoulder = keypoints[LEFT_SHOULDER][:2]
56
+ right_shoulder = keypoints[RIGHT_SHOULDER][:2]
57
+ left_hip = keypoints[LEFT_HIP][:2]
58
+ right_hip = keypoints[RIGHT_HIP][:2]
59
+
60
+ # Check visibility (visibility > 0.3)
61
+ if (keypoints[LEFT_SHOULDER][2] < 0.3 or keypoints[RIGHT_SHOULDER][2] < 0.3 or
62
+ keypoints[LEFT_HIP][2] < 0.3 or keypoints[RIGHT_HIP][2] < 0.3):
63
+ return None
64
+
65
+ # Calculate mid points
66
+ mid_shoulder = (left_shoulder + right_shoulder) / 2
67
+ mid_hip = (left_hip + right_hip) / 2
68
+
69
+ # Calculate the vector of the torso
70
+ vector = mid_hip - mid_shoulder
71
+
72
+ # Calculate angle with respect to the vertical axis
73
+ angle_rad = np.arctan2(vector[0], vector[1])
74
+ angle_deg = np.degrees(angle_rad)
75
+
76
+ return angle_deg
77
+ except Exception as e:
78
+ print(f"Error calculating torso angle: {e}")
79
+ return None
80
+
81
+ def draw_skeleton(frame, keypoints, show_labels=True):
82
+ """
83
+ Draws the skeleton on the frame based on keypoints.
84
+
85
+ Args:
86
+ frame (numpy.ndarray): The current video frame.
87
+ keypoints (numpy.ndarray): Array of shape (17, 3) representing COCO keypoints.
88
+ show_labels (bool): Whether to display keypoint indices.
89
+
90
+ Returns:
91
+ numpy.ndarray: Annotated frame with skeleton.
92
+ """
93
+ for connection in SKELETON_CONNECTIONS:
94
+ start_idx, end_idx = connection
95
+ x_start, y_start, conf_start = keypoints[start_idx]
96
+ x_end, y_end, conf_end = keypoints[end_idx]
97
+
98
+ # Only draw if both keypoints have sufficient confidence
99
+ if conf_start > 0.5 and conf_end > 0.5:
100
+ start_point = (int(x_start), int(y_start))
101
+ end_point = (int(x_end), int(y_end))
102
+ cv2.line(frame, start_point, end_point, (255, 0, 0), 2) # Blue lines
103
+
104
+ if show_labels:
105
+ # Draw keypoints indices
106
+ for idx, (x, y, conf) in enumerate(keypoints):
107
+ if conf > 0.5:
108
+ cv2.putText(frame, f"{idx}", (int(x), int(y)), cv2.FONT_HERSHEY_SIMPLEX, 0.3, (255, 0, 0), 1) # Blue labels
109
+
110
+ return frame
111
+
112
+ def detect_fall(video_path, angle_threshold=30, consecutive_frames=3, frame_sampling_rate=1, confidence_threshold=0.3, show_labels=True):
113
+ """
114
+ Detects falls in the uploaded video using pose estimation.
115
+
116
+ Args:
117
+ video_path (str): The path to the input video file uploaded by the user.
118
+ angle_threshold (float): Angle threshold to classify a fall (in degrees).
119
+ consecutive_frames (int): Number of consecutive frames to confirm a fall.
120
+ frame_sampling_rate (int): Process every nth frame.
121
+ confidence_threshold (float): Minimum confidence required for keypoint detection.
122
+ show_labels (bool): Whether to display keypoint indices.
123
+
124
+ Returns:
125
+ tuple: (annotated_video_path, notification_message)
126
+ """
127
+ try:
128
+ cap = cv2.VideoCapture(video_path)
129
+ if not cap.isOpened():
130
+ raise ValueError("Unable to open the video file.")
131
+
132
+ # Video properties
133
+ fps = cap.get(cv2.CAP_PROP_FPS)
134
+ width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
135
+ height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
136
+ fourcc = cv2.VideoWriter_fourcc(*'mp4v')
137
+
138
+ # Create a unique temporary file for the annotated video
139
+ unique_id = uuid.uuid4().hex
140
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".mp4", prefix=f"annotated_{unique_id}_") as tmp:
141
+ annotated_video_path = tmp.name
142
+
143
+ out = cv2.VideoWriter(annotated_video_path, fourcc, fps, (width, height))
144
+
145
+ frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
146
+ current_frame = 0
147
+ consecutive_fall_frames = 0
148
+ total_falls = 0
149
+ fall_frames = [] # To store frames where falls were detected
150
+
151
+ while True:
152
+ ret, frame = cap.read()
153
+ if not ret:
154
+ break # End of video
155
+
156
+ current_frame += 1
157
+
158
+ # Implement frame sampling
159
+ if current_frame % frame_sampling_rate != 0:
160
+ out.write(frame)
161
+ continue
162
+
163
+ print(f"Processing frame {current_frame}/{frame_count}")
164
+
165
+ # Run pose estimation
166
+ results = model.predict(source=frame, conf=confidence_threshold, save=False, stream=False)
167
+
168
+ # Iterate through detected persons
169
+ for result in results:
170
+ if not hasattr(result, 'keypoints') or result.keypoints is None:
171
+ continue
172
+ for keypoints in result.keypoints.data:
173
+ # keypoints should be a tensor of shape (17,3)
174
+ if keypoints is None or not hasattr(keypoints, 'cpu'):
175
+ continue
176
+ # Convert to NumPy array
177
+ if isinstance(keypoints, torch.Tensor):
178
+ kpts = keypoints.cpu().numpy()
179
+ elif isinstance(keypoints, np.ndarray):
180
+ kpts = keypoints
181
+ else:
182
+ print(f"Unexpected keypoints data type: {type(keypoints)}")
183
+ continue
184
+
185
+ if kpts.size == 0 or kpts.shape[0] < 17:
186
+ print(f"Insufficient keypoints for processing in frame {current_frame}")
187
+ continue
188
+
189
+ angle = calculate_torso_angle(kpts, height)
190
+ if angle is None:
191
+ continue
192
+
193
+ # Determine if it's a fall
194
+ if abs(angle) > angle_threshold:
195
+ consecutive_fall_frames += 1
196
+ label = "Fall Detected!"
197
+ color = (0, 0, 255) # Red
198
+ else:
199
+ if consecutive_fall_frames >= consecutive_frames:
200
+ total_falls += 1
201
+ fall_frames.append(current_frame)
202
+ consecutive_fall_frames = 0
203
+ label = "Normal"
204
+ color = (0, 255, 0) # Green
205
+
206
+ # If fall persists over consecutive frames, mark as fall
207
+ if consecutive_fall_frames >= consecutive_frames:
208
+ cv2.putText(frame, label, (50, 50), cv2.FONT_HERSHEY_SIMPLEX, 1, color, 2)
209
+
210
+ # Draw keypoints and skeleton
211
+ frame = draw_skeleton(frame, kpts, show_labels=show_labels)
212
+
213
+ # Write annotated frame
214
+ out.write(frame)
215
+
216
+ # Release resources
217
+ cap.release()
218
+ out.release()
219
+
220
+ # Final check for falls that persisted until the end of the video
221
+ if consecutive_fall_frames >= consecutive_frames:
222
+ total_falls += 1
223
+ fall_frames.append(current_frame)
224
+
225
+ # Generate notification message
226
+ if total_falls > 0:
227
+ if total_falls == 1:
228
+ notification = f"A fall was detected at frame {fall_frames[0]}."
229
+ else:
230
+ frames = ', '.join(map(str, fall_frames))
231
+ notification = f"{total_falls} falls were detected at frames: {frames}."
232
+ else:
233
+ notification = "No falls were detected in the video."
234
+
235
+ # Check if annotated video was created
236
+ if not os.path.exists(annotated_video_path):
237
+ raise FileNotFoundError("Annotated video was not found. Please check the model and processing steps.")
238
+
239
+ return annotated_video_path, notification
240
+
241
+ except Exception as e:
242
+ # Clean up in case of an error
243
+ print(f"Error during fall detection: {e}")
244
+ return None, f"An error occurred during fall detection: {e}"
245
+
246
+ def create_gradio_interface():
247
+ # Define the Gradio interface with adjustable parameters
248
+ iface = gr.Interface(
249
+ fn=detect_fall,
250
+ inputs=[
251
+ gr.Video(label="Upload Video"),
252
+ gr.Slider(
253
+ label="Angle Threshold (degrees)",
254
+ minimum=0,
255
+ maximum=90,
256
+ step=1,
257
+ value=30,
258
+ interactive=True,
259
+ info="Adjust the torso angle threshold to classify a fall. Lower values increase sensitivity."
260
+ ),
261
+ gr.Slider(
262
+ label="Consecutive Frames to Confirm Fall",
263
+ minimum=1,
264
+ maximum=10,
265
+ step=1,
266
+ value=3,
267
+ interactive=True,
268
+ info="Number of consecutive frames exceeding the angle threshold required to confirm a fall."
269
+ ),
270
+ gr.Slider(
271
+ label="Frame Sampling Rate",
272
+ minimum=1,
273
+ maximum=10,
274
+ step=1,
275
+ value=1,
276
+ interactive=True,
277
+ info="Process every nth frame to speed up detection. Higher values reduce processing time."
278
+ ),
279
+ gr.Slider(
280
+ label="Confidence Threshold",
281
+ minimum=0.0,
282
+ maximum=1.0,
283
+ step=0.05,
284
+ value=0.3, # Changed default value to 0.3
285
+ interactive=True,
286
+ info="Minimum confidence required for keypoint detection. Higher values reduce false positives."
287
+ ),
288
+ gr.Checkbox(
289
+ label="Show Keypoint Labels",
290
+ value=True,
291
+ interactive=True,
292
+ info="Toggle the display of keypoint indices on the video."
293
+ )
294
+ ],
295
+ outputs=[
296
+ gr.Video(label="Annotated Video"),
297
+ gr.Textbox(label="Fall Detection Notification")
298
+ ],
299
+ title="Fall Detection App 🚨",
300
+ description=(
301
+ "Upload a video of a person falling, and the app will detect and annotate the fall "
302
+ "using pose estimation. Adjust the angle threshold, consecutive frames, frame sampling rate, "
303
+ "and confidence threshold to fine-tune detection sensitivity and performance. "
304
+ "The annotated video will display keypoints, skeleton lines, and indicate when a fall is detected."
305
+ ),
306
+ examples=[
307
+ ["demo/person falling.mp4", 30, 3, 1, 0.3, True]
308
+ ], # Added example video with corresponding parameter values
309
+ flagging_mode="never", # Updated parameter name
310
+ )
311
+ return iface
312
+
313
+ # Create the Gradio interface
314
+ iface = create_gradio_interface()
315
+
316
+ # Ensure temporary directories are cleaned up on exit
317
+ def cleanup_temp_dirs():
318
+ temp_dir = tempfile.gettempdir()
319
+ # Implement additional cleanup logic if necessary
320
+
321
+ atexit.register(cleanup_temp_dirs)
322
+
323
+ # Launch the app
324
+ if __name__ == "__main__":
325
+ iface.launch()