luminoussg's picture
Create app.py
a2f3593 verified
import gradio as gr
from ultralytics import YOLO
import tempfile
import os
import cv2
import numpy as np
import torch
import atexit
import uuid
# Load the YOLOv8 pose estimation model once at the start
model = YOLO("yolov8n-pose.pt")
# Define the skeleton connections based on COCO keypoints
COCO_KEYPOINTS = [
"nose", "left_eye", "right_eye", "left_ear", "right_ear",
"left_shoulder", "right_shoulder", "left_elbow", "right_elbow",
"left_wrist", "right_wrist", "left_hip", "right_hip",
"left_knee", "right_knee", "left_ankle", "right_ankle"
]
# Define the skeleton as pairs of keypoints indices
SKELETON_CONNECTIONS = [
(0, 1), (0, 2), # Nose to eyes
(1, 3), (2, 4), # Eyes to ears
(0, 5), (0, 6), # Nose to shoulders
(5, 6), # Shoulders to each other
(5, 7), (6, 8), # Shoulders to elbows
(7, 9), (8, 10), # Elbows to wrists
(5, 11), (6, 12), # Shoulders to hips
(11, 12), # Hips to each other
(11, 13), (12, 14), # Hips to knees
(13, 15), (14, 16) # Knees to ankles
]
def calculate_torso_angle(keypoints, frame_height):
"""
Calculate the angle of the torso with respect to the vertical axis.
Args:
keypoints (numpy.ndarray): Array of shape (17, 3) representing COCO keypoints.
frame_height (int): Height of the video frame in pixels.
Returns:
float: Angle in degrees. Returns None if keypoints are not detected properly.
"""
try:
# COCO keypoint indices
LEFT_SHOULDER = 5
RIGHT_SHOULDER = 6
LEFT_HIP = 11
RIGHT_HIP = 12
# Extract shoulder and hip coordinates
left_shoulder = keypoints[LEFT_SHOULDER][:2]
right_shoulder = keypoints[RIGHT_SHOULDER][:2]
left_hip = keypoints[LEFT_HIP][:2]
right_hip = keypoints[RIGHT_HIP][:2]
# Check visibility (visibility > 0.3)
if (keypoints[LEFT_SHOULDER][2] < 0.3 or keypoints[RIGHT_SHOULDER][2] < 0.3 or
keypoints[LEFT_HIP][2] < 0.3 or keypoints[RIGHT_HIP][2] < 0.3):
return None
# Calculate mid points
mid_shoulder = (left_shoulder + right_shoulder) / 2
mid_hip = (left_hip + right_hip) / 2
# Calculate the vector of the torso
vector = mid_hip - mid_shoulder
# Calculate angle with respect to the vertical axis
angle_rad = np.arctan2(vector[0], vector[1])
angle_deg = np.degrees(angle_rad)
return angle_deg
except Exception as e:
print(f"Error calculating torso angle: {e}")
return None
def draw_skeleton(frame, keypoints, show_labels=True):
"""
Draws the skeleton on the frame based on keypoints.
Args:
frame (numpy.ndarray): The current video frame.
keypoints (numpy.ndarray): Array of shape (17, 3) representing COCO keypoints.
show_labels (bool): Whether to display keypoint indices.
Returns:
numpy.ndarray: Annotated frame with skeleton.
"""
for connection in SKELETON_CONNECTIONS:
start_idx, end_idx = connection
x_start, y_start, conf_start = keypoints[start_idx]
x_end, y_end, conf_end = keypoints[end_idx]
# Only draw if both keypoints have sufficient confidence
if conf_start > 0.5 and conf_end > 0.5:
start_point = (int(x_start), int(y_start))
end_point = (int(x_end), int(y_end))
cv2.line(frame, start_point, end_point, (255, 0, 0), 2) # Blue lines
if show_labels:
# Draw keypoints indices
for idx, (x, y, conf) in enumerate(keypoints):
if conf > 0.5:
cv2.putText(frame, f"{idx}", (int(x), int(y)), cv2.FONT_HERSHEY_SIMPLEX, 0.3, (255, 0, 0), 1) # Blue labels
return frame
def detect_fall(video_path, angle_threshold=30, consecutive_frames=3, frame_sampling_rate=1, confidence_threshold=0.3, show_labels=True):
"""
Detects falls in the uploaded video using pose estimation.
Args:
video_path (str): The path to the input video file uploaded by the user.
angle_threshold (float): Angle threshold to classify a fall (in degrees).
consecutive_frames (int): Number of consecutive frames to confirm a fall.
frame_sampling_rate (int): Process every nth frame.
confidence_threshold (float): Minimum confidence required for keypoint detection.
show_labels (bool): Whether to display keypoint indices.
Returns:
tuple: (annotated_video_path, notification_message)
"""
try:
cap = cv2.VideoCapture(video_path)
if not cap.isOpened():
raise ValueError("Unable to open the video file.")
# Video properties
fps = cap.get(cv2.CAP_PROP_FPS)
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
# Create a unique temporary file for the annotated video
unique_id = uuid.uuid4().hex
with tempfile.NamedTemporaryFile(delete=False, suffix=".mp4", prefix=f"annotated_{unique_id}_") as tmp:
annotated_video_path = tmp.name
out = cv2.VideoWriter(annotated_video_path, fourcc, fps, (width, height))
frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
current_frame = 0
consecutive_fall_frames = 0
total_falls = 0
fall_frames = [] # To store frames where falls were detected
while True:
ret, frame = cap.read()
if not ret:
break # End of video
current_frame += 1
# Implement frame sampling
if current_frame % frame_sampling_rate != 0:
out.write(frame)
continue
print(f"Processing frame {current_frame}/{frame_count}")
# Run pose estimation
results = model.predict(source=frame, conf=confidence_threshold, save=False, stream=False)
# Iterate through detected persons
for result in results:
if not hasattr(result, 'keypoints') or result.keypoints is None:
continue
for keypoints in result.keypoints.data:
# keypoints should be a tensor of shape (17,3)
if keypoints is None or not hasattr(keypoints, 'cpu'):
continue
# Convert to NumPy array
if isinstance(keypoints, torch.Tensor):
kpts = keypoints.cpu().numpy()
elif isinstance(keypoints, np.ndarray):
kpts = keypoints
else:
print(f"Unexpected keypoints data type: {type(keypoints)}")
continue
if kpts.size == 0 or kpts.shape[0] < 17:
print(f"Insufficient keypoints for processing in frame {current_frame}")
continue
angle = calculate_torso_angle(kpts, height)
if angle is None:
continue
# Determine if it's a fall
if abs(angle) > angle_threshold:
consecutive_fall_frames += 1
label = "Fall Detected!"
color = (0, 0, 255) # Red
else:
if consecutive_fall_frames >= consecutive_frames:
total_falls += 1
fall_frames.append(current_frame)
consecutive_fall_frames = 0
label = "Normal"
color = (0, 255, 0) # Green
# If fall persists over consecutive frames, mark as fall
if consecutive_fall_frames >= consecutive_frames:
cv2.putText(frame, label, (50, 50), cv2.FONT_HERSHEY_SIMPLEX, 1, color, 2)
# Draw keypoints and skeleton
frame = draw_skeleton(frame, kpts, show_labels=show_labels)
# Write annotated frame
out.write(frame)
# Release resources
cap.release()
out.release()
# Final check for falls that persisted until the end of the video
if consecutive_fall_frames >= consecutive_frames:
total_falls += 1
fall_frames.append(current_frame)
# Generate notification message
if total_falls > 0:
if total_falls == 1:
notification = f"A fall was detected at frame {fall_frames[0]}."
else:
frames = ', '.join(map(str, fall_frames))
notification = f"{total_falls} falls were detected at frames: {frames}."
else:
notification = "No falls were detected in the video."
# Check if annotated video was created
if not os.path.exists(annotated_video_path):
raise FileNotFoundError("Annotated video was not found. Please check the model and processing steps.")
return annotated_video_path, notification
except Exception as e:
# Clean up in case of an error
print(f"Error during fall detection: {e}")
return None, f"An error occurred during fall detection: {e}"
def create_gradio_interface():
# Define the Gradio interface with adjustable parameters
iface = gr.Interface(
fn=detect_fall,
inputs=[
gr.Video(label="Upload Video"),
gr.Slider(
label="Angle Threshold (degrees)",
minimum=0,
maximum=90,
step=1,
value=30,
interactive=True,
info="Adjust the torso angle threshold to classify a fall. Lower values increase sensitivity."
),
gr.Slider(
label="Consecutive Frames to Confirm Fall",
minimum=1,
maximum=10,
step=1,
value=3,
interactive=True,
info="Number of consecutive frames exceeding the angle threshold required to confirm a fall."
),
gr.Slider(
label="Frame Sampling Rate",
minimum=1,
maximum=10,
step=1,
value=1,
interactive=True,
info="Process every nth frame to speed up detection. Higher values reduce processing time."
),
gr.Slider(
label="Confidence Threshold",
minimum=0.0,
maximum=1.0,
step=0.05,
value=0.3, # Changed default value to 0.3
interactive=True,
info="Minimum confidence required for keypoint detection. Higher values reduce false positives."
),
gr.Checkbox(
label="Show Keypoint Labels",
value=True,
interactive=True,
info="Toggle the display of keypoint indices on the video."
)
],
outputs=[
gr.Video(label="Annotated Video"),
gr.Textbox(label="Fall Detection Notification")
],
title="Fall Detection App 🚨",
description=(
"Upload a video of a person falling, and the app will detect and annotate the fall "
"using pose estimation. Adjust the angle threshold, consecutive frames, frame sampling rate, "
"and confidence threshold to fine-tune detection sensitivity and performance. "
"The annotated video will display keypoints, skeleton lines, and indicate when a fall is detected."
),
examples=[
["demo/person falling.mp4", 30, 3, 1, 0.3, True]
], # Added example video with corresponding parameter values
flagging_mode="never", # Updated parameter name
)
return iface
# Create the Gradio interface
iface = create_gradio_interface()
# Ensure temporary directories are cleaned up on exit
def cleanup_temp_dirs():
temp_dir = tempfile.gettempdir()
# Implement additional cleanup logic if necessary
atexit.register(cleanup_temp_dirs)
# Launch the app
if __name__ == "__main__":
iface.launch()