Abs6187's picture
Update app.py
b488bef verified
raw
history blame
6.79 kB
import time
import torch
import cv2
import numpy as np
import gradio as gr
from ultralytics import YOLO
from deep_sort.utils.parser import get_config
from deep_sort.deep_sort import DeepSort
# Initialize YOLO and DeepSort
deep_sort_weights = 'ckpt.t7'
tracker = DeepSort(model_path=deep_sort_weights, max_age=80)
model = YOLO("person_gun.pt")
class ObjectDetector:
def __init__(self):
# Tracking variables
self.unique_track_ids = set()
self.track_labels = {}
self.track_times = {}
self.track_positions = {}
self.running_counters = {}
self.alert_person_ids = []
# Detection parameters
self.running_threshold = 0.5
self.fps = 30 # Default FPS
def process_frame(self, frame):
"""
Process a single frame for object detection and tracking
"""
# Reset alert tracking for this frame
self.alert_person_ids.clear()
og_frame = frame.copy()
# Detect persons
results = model(frame, device=0, classes=0, conf=0.75)
for result in results:
boxes = result.boxes
cls = boxes.cls.tolist()
conf = boxes.conf
xywh = boxes.xywh
pred_cls = np.array(cls)
conf = conf.detach().cpu().numpy()
bboxes_xywh = xywh.cpu().numpy()
# Update tracking
tracks = tracker.update(bboxes_xywh, conf, og_frame)
active_track_ids = set()
# Reset running status
new_running_status = "No Running"
for track in tracker.tracker.tracks:
track_id = track.track_id
x1, y1, x2, y2 = track.to_tlbr()
w = x2 - x1
h = y2 - y1
# Define color for bounding box
red_color = (0, 0, 255)
blue_color = (255, 0, 0)
green_color = (0, 255, 0)
color_id = track_id % 3
color = red_color if color_id == 0 else blue_color if color_id == 1 else green_color
cv2.rectangle(og_frame, (int(x1), int(y1)), (int(x1 + w), int(y1 + h)), color, 2)
# Initialize tracking for new tracks
if track_id not in self.track_labels:
self.track_labels[track_id] = "Person"
self.track_times[track_id] = 0
self.track_positions[track_id] = (x1, y1)
self.running_counters[track_id] = 0
self.track_times[track_id] += 1
prev_x1, prev_y1 = self.track_positions[track_id]
displacement = np.sqrt((x1 - prev_x1) ** 2 + (y1 - prev_y1) ** 2)
# Calculate speed
speed = displacement / self.fps if self.fps > 0 else 0
self.track_positions[track_id] = (x1, y1)
# Detect running
if speed > self.running_threshold and w * h > 5000:
self.running_counters[track_id] += 1
if self.running_counters[track_id] > self.fps/2:
self.track_labels[track_id] = "Running"
new_running_status = "Running Detected"
else:
self.running_counters[track_id] = 0
self.track_labels[track_id] = "Person"
# Track time and potential alerts
total_seconds = self.track_times[track_id] / self.fps if self.fps > 0 else 0
minutes = int(total_seconds // 60)
seconds = int(total_seconds % 60)
# Trigger alert for prolonged stay
if total_seconds > 60 and track_id not in self.alert_person_ids:
self.alert_person_ids.append(track_id)
# Add label to frame
cv2.putText(og_frame, f"{self.track_labels[track_id]} {minutes:02}:{seconds:02}",
(int(x1) + 10, int(y1) - 5), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 0), 1, cv2.LINE_AA)
active_track_ids.add(track_id)
# Update unique track IDs
self.unique_track_ids.intersection_update(active_track_ids)
self.unique_track_ids.update(active_track_ids)
# Prepare result dictionary
result_info = {
'person_count': len(self.unique_track_ids),
'running_status': new_running_status,
'prolonged_stay_ids': list(self.alert_person_ids)
}
return og_frame, result_info
def process_input(self, input_media):
"""
Process either video or webcam input
"""
# Determine input type
if isinstance(input_media, str): # Video file path
cap = cv2.VideoCapture(input_media)
else: # Webcam input
cap = cv2.VideoCapture(0)
# Prepare output video
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
fps = cap.get(cv2.CAP_PROP_FPS) or 30
self.fps = fps
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
out = cv2.VideoWriter('output_detection.mp4', fourcc, fps, (width, height))
# Processing loop
frame_info_list = []
while cap.isOpened():
ret, frame = cap.read()
if not ret:
break
# Process frame
processed_frame, frame_info = self.process_frame(frame)
out.write(processed_frame)
frame_info_list.append(frame_info)
# Release resources
cap.release()
out.release()
return 'output_detection.mp4', frame_info_list
# Create Gradio interface
detector = ObjectDetector()
def detect_interface(input_media):
"""
Gradio interface function for detection
"""
output_video, frame_info_list = detector.process_input(input_media)
# Generate text summary
summary = "Detection Summary:\n"
if frame_info_list:
# Take the last frame's information
last_frame_info = frame_info_list[-1]
summary += f"Total Persons Detected: {last_frame_info['person_count']}\n"
summary += f"Running Status: {last_frame_info['running_status']}\n"
if last_frame_info['prolonged_stay_ids']:
summary += f"Prolonged Stay Detected - Person IDs: {last_frame_info['prolonged_stay_ids']}"
else:
summary += "No Prolonged Stay Detected"
return output_video, summary
# Gradio Interface
iface = gr.Interface(
fn=detect_interface,
inputs=[
gr.File(label="Upload Video", type="filepath"),
gr.Webcam(label="Or Use Webcam")
],
outputs=[
gr.Video(label="Processed Video"),
gr.Textbox(label="Detection Summary")
],
title="Object Detection with Tracking",
description="Upload a video or use webcam for real-time object detection and tracking"
)
# Launch the interface
iface.launch()