Spaces:
Sleeping
Sleeping
import gradio as gr | |
from ultralytics import YOLO | |
import tempfile | |
import os | |
import cv2 | |
import numpy as np | |
import torch | |
import atexit | |
import uuid | |
# Load the YOLOv8 pose estimation model once at the start | |
model = YOLO("yolov8n-pose.pt") | |
# Define the skeleton connections based on COCO keypoints | |
COCO_KEYPOINTS = [ | |
"nose", "left_eye", "right_eye", "left_ear", "right_ear", | |
"left_shoulder", "right_shoulder", "left_elbow", "right_elbow", | |
"left_wrist", "right_wrist", "left_hip", "right_hip", | |
"left_knee", "right_knee", "left_ankle", "right_ankle" | |
] | |
# Define the skeleton as pairs of keypoints indices | |
SKELETON_CONNECTIONS = [ | |
(0, 1), (0, 2), # Nose to eyes | |
(1, 3), (2, 4), # Eyes to ears | |
(0, 5), (0, 6), # Nose to shoulders | |
(5, 6), # Shoulders to each other | |
(5, 7), (6, 8), # Shoulders to elbows | |
(7, 9), (8, 10), # Elbows to wrists | |
(5, 11), (6, 12), # Shoulders to hips | |
(11, 12), # Hips to each other | |
(11, 13), (12, 14), # Hips to knees | |
(13, 15), (14, 16) # Knees to ankles | |
] | |
def calculate_torso_angle(keypoints, frame_height): | |
""" | |
Calculate the angle of the torso with respect to the vertical axis. | |
Args: | |
keypoints (numpy.ndarray): Array of shape (17, 3) representing COCO keypoints. | |
frame_height (int): Height of the video frame in pixels. | |
Returns: | |
float: Angle in degrees. Returns None if keypoints are not detected properly. | |
""" | |
try: | |
# COCO keypoint indices | |
LEFT_SHOULDER = 5 | |
RIGHT_SHOULDER = 6 | |
LEFT_HIP = 11 | |
RIGHT_HIP = 12 | |
# Extract shoulder and hip coordinates | |
left_shoulder = keypoints[LEFT_SHOULDER][:2] | |
right_shoulder = keypoints[RIGHT_SHOULDER][:2] | |
left_hip = keypoints[LEFT_HIP][:2] | |
right_hip = keypoints[RIGHT_HIP][:2] | |
# Check visibility (visibility > 0.3) | |
if (keypoints[LEFT_SHOULDER][2] < 0.3 or keypoints[RIGHT_SHOULDER][2] < 0.3 or | |
keypoints[LEFT_HIP][2] < 0.3 or keypoints[RIGHT_HIP][2] < 0.3): | |
return None | |
# Calculate mid points | |
mid_shoulder = (left_shoulder + right_shoulder) / 2 | |
mid_hip = (left_hip + right_hip) / 2 | |
# Calculate the vector of the torso | |
vector = mid_hip - mid_shoulder | |
# Calculate angle with respect to the vertical axis | |
angle_rad = np.arctan2(vector[0], vector[1]) | |
angle_deg = np.degrees(angle_rad) | |
return angle_deg | |
except Exception as e: | |
print(f"Error calculating torso angle: {e}") | |
return None | |
def draw_skeleton(frame, keypoints, show_labels=True): | |
""" | |
Draws the skeleton on the frame based on keypoints. | |
Args: | |
frame (numpy.ndarray): The current video frame. | |
keypoints (numpy.ndarray): Array of shape (17, 3) representing COCO keypoints. | |
show_labels (bool): Whether to display keypoint indices. | |
Returns: | |
numpy.ndarray: Annotated frame with skeleton. | |
""" | |
for connection in SKELETON_CONNECTIONS: | |
start_idx, end_idx = connection | |
x_start, y_start, conf_start = keypoints[start_idx] | |
x_end, y_end, conf_end = keypoints[end_idx] | |
# Only draw if both keypoints have sufficient confidence | |
if conf_start > 0.5 and conf_end > 0.5: | |
start_point = (int(x_start), int(y_start)) | |
end_point = (int(x_end), int(y_end)) | |
cv2.line(frame, start_point, end_point, (255, 0, 0), 2) # Blue lines | |
if show_labels: | |
# Draw keypoints indices | |
for idx, (x, y, conf) in enumerate(keypoints): | |
if conf > 0.5: | |
cv2.putText(frame, f"{idx}", (int(x), int(y)), cv2.FONT_HERSHEY_SIMPLEX, 0.3, (255, 0, 0), 1) # Blue labels | |
return frame | |
def detect_fall(video_path, angle_threshold=30, consecutive_frames=3, frame_sampling_rate=1, confidence_threshold=0.3, show_labels=True): | |
""" | |
Detects falls in the uploaded video using pose estimation. | |
Args: | |
video_path (str): The path to the input video file uploaded by the user. | |
angle_threshold (float): Angle threshold to classify a fall (in degrees). | |
consecutive_frames (int): Number of consecutive frames to confirm a fall. | |
frame_sampling_rate (int): Process every nth frame. | |
confidence_threshold (float): Minimum confidence required for keypoint detection. | |
show_labels (bool): Whether to display keypoint indices. | |
Returns: | |
tuple: (annotated_video_path, notification_message) | |
""" | |
try: | |
cap = cv2.VideoCapture(video_path) | |
if not cap.isOpened(): | |
raise ValueError("Unable to open the video file.") | |
# Video properties | |
fps = cap.get(cv2.CAP_PROP_FPS) | |
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) | |
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) | |
fourcc = cv2.VideoWriter_fourcc(*'mp4v') | |
# Create a unique temporary file for the annotated video | |
unique_id = uuid.uuid4().hex | |
with tempfile.NamedTemporaryFile(delete=False, suffix=".mp4", prefix=f"annotated_{unique_id}_") as tmp: | |
annotated_video_path = tmp.name | |
out = cv2.VideoWriter(annotated_video_path, fourcc, fps, (width, height)) | |
frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) | |
current_frame = 0 | |
consecutive_fall_frames = 0 | |
total_falls = 0 | |
fall_frames = [] # To store frames where falls were detected | |
while True: | |
ret, frame = cap.read() | |
if not ret: | |
break # End of video | |
current_frame += 1 | |
# Implement frame sampling | |
if current_frame % frame_sampling_rate != 0: | |
out.write(frame) | |
continue | |
print(f"Processing frame {current_frame}/{frame_count}") | |
# Run pose estimation | |
results = model.predict(source=frame, conf=confidence_threshold, save=False, stream=False) | |
# Iterate through detected persons | |
for result in results: | |
if not hasattr(result, 'keypoints') or result.keypoints is None: | |
continue | |
for keypoints in result.keypoints.data: | |
# keypoints should be a tensor of shape (17,3) | |
if keypoints is None or not hasattr(keypoints, 'cpu'): | |
continue | |
# Convert to NumPy array | |
if isinstance(keypoints, torch.Tensor): | |
kpts = keypoints.cpu().numpy() | |
elif isinstance(keypoints, np.ndarray): | |
kpts = keypoints | |
else: | |
print(f"Unexpected keypoints data type: {type(keypoints)}") | |
continue | |
if kpts.size == 0 or kpts.shape[0] < 17: | |
print(f"Insufficient keypoints for processing in frame {current_frame}") | |
continue | |
angle = calculate_torso_angle(kpts, height) | |
if angle is None: | |
continue | |
# Determine if it's a fall | |
if abs(angle) > angle_threshold: | |
consecutive_fall_frames += 1 | |
label = "Fall Detected!" | |
color = (0, 0, 255) # Red | |
else: | |
if consecutive_fall_frames >= consecutive_frames: | |
total_falls += 1 | |
fall_frames.append(current_frame) | |
consecutive_fall_frames = 0 | |
label = "Normal" | |
color = (0, 255, 0) # Green | |
# If fall persists over consecutive frames, mark as fall | |
if consecutive_fall_frames >= consecutive_frames: | |
cv2.putText(frame, label, (50, 50), cv2.FONT_HERSHEY_SIMPLEX, 1, color, 2) | |
# Draw keypoints and skeleton | |
frame = draw_skeleton(frame, kpts, show_labels=show_labels) | |
# Write annotated frame | |
out.write(frame) | |
# Release resources | |
cap.release() | |
out.release() | |
# Final check for falls that persisted until the end of the video | |
if consecutive_fall_frames >= consecutive_frames: | |
total_falls += 1 | |
fall_frames.append(current_frame) | |
# Generate notification message | |
if total_falls > 0: | |
if total_falls == 1: | |
notification = f"A fall was detected at frame {fall_frames[0]}." | |
else: | |
frames = ', '.join(map(str, fall_frames)) | |
notification = f"{total_falls} falls were detected at frames: {frames}." | |
else: | |
notification = "No falls were detected in the video." | |
# Check if annotated video was created | |
if not os.path.exists(annotated_video_path): | |
raise FileNotFoundError("Annotated video was not found. Please check the model and processing steps.") | |
return annotated_video_path, notification | |
except Exception as e: | |
# Clean up in case of an error | |
print(f"Error during fall detection: {e}") | |
return None, f"An error occurred during fall detection: {e}" | |
def create_gradio_interface(): | |
# Define the Gradio interface with adjustable parameters | |
iface = gr.Interface( | |
fn=detect_fall, | |
inputs=[ | |
gr.Video(label="Upload Video"), | |
gr.Slider( | |
label="Angle Threshold (degrees)", | |
minimum=0, | |
maximum=90, | |
step=1, | |
value=30, | |
interactive=True, | |
info="Adjust the torso angle threshold to classify a fall. Lower values increase sensitivity." | |
), | |
gr.Slider( | |
label="Consecutive Frames to Confirm Fall", | |
minimum=1, | |
maximum=10, | |
step=1, | |
value=3, | |
interactive=True, | |
info="Number of consecutive frames exceeding the angle threshold required to confirm a fall." | |
), | |
gr.Slider( | |
label="Frame Sampling Rate", | |
minimum=1, | |
maximum=10, | |
step=1, | |
value=1, | |
interactive=True, | |
info="Process every nth frame to speed up detection. Higher values reduce processing time." | |
), | |
gr.Slider( | |
label="Confidence Threshold", | |
minimum=0.0, | |
maximum=1.0, | |
step=0.05, | |
value=0.3, # Changed default value to 0.3 | |
interactive=True, | |
info="Minimum confidence required for keypoint detection. Higher values reduce false positives." | |
), | |
gr.Checkbox( | |
label="Show Keypoint Labels", | |
value=True, | |
interactive=True, | |
info="Toggle the display of keypoint indices on the video." | |
) | |
], | |
outputs=[ | |
gr.Video(label="Annotated Video"), | |
gr.Textbox(label="Fall Detection Notification") | |
], | |
title="Fall Detection App 🚨", | |
description=( | |
"Upload a video of a person falling, and the app will detect and annotate the fall " | |
"using pose estimation. Adjust the angle threshold, consecutive frames, frame sampling rate, " | |
"and confidence threshold to fine-tune detection sensitivity and performance. " | |
"The annotated video will display keypoints, skeleton lines, and indicate when a fall is detected." | |
), | |
examples=[ | |
["demo/person falling.mp4", 30, 3, 1, 0.3, True] | |
], # Added example video with corresponding parameter values | |
flagging_mode="never", # Updated parameter name | |
) | |
return iface | |
# Create the Gradio interface | |
iface = create_gradio_interface() | |
# Ensure temporary directories are cleaned up on exit | |
def cleanup_temp_dirs(): | |
temp_dir = tempfile.gettempdir() | |
# Implement additional cleanup logic if necessary | |
atexit.register(cleanup_temp_dirs) | |
# Launch the app | |
if __name__ == "__main__": | |
iface.launch() | |