|
import time |
|
import torch |
|
import cv2 |
|
import numpy as np |
|
import gradio as gr |
|
from ultralytics import YOLO |
|
from deep_sort.utils.parser import get_config |
|
from deep_sort.deep_sort import DeepSort |
|
|
|
|
|
deep_sort_weights = 'ckpt.t7' |
|
tracker = DeepSort(model_path=deep_sort_weights, max_age=80) |
|
model = YOLO("person_gun.pt") |
|
|
|
class ObjectDetector: |
|
def __init__(self): |
|
|
|
self.unique_track_ids = set() |
|
self.track_labels = {} |
|
self.track_times = {} |
|
self.track_positions = {} |
|
self.running_counters = {} |
|
self.alert_person_ids = [] |
|
|
|
|
|
self.running_threshold = 0.5 |
|
self.fps = 30 |
|
|
|
def process_frame(self, frame): |
|
""" |
|
Process a single frame for object detection and tracking |
|
""" |
|
|
|
self.alert_person_ids.clear() |
|
og_frame = frame.copy() |
|
|
|
|
|
results = model(frame, device=0, classes=0, conf=0.75) |
|
|
|
for result in results: |
|
boxes = result.boxes |
|
cls = boxes.cls.tolist() |
|
conf = boxes.conf |
|
xywh = boxes.xywh |
|
|
|
pred_cls = np.array(cls) |
|
conf = conf.detach().cpu().numpy() |
|
bboxes_xywh = xywh.cpu().numpy() |
|
|
|
|
|
tracks = tracker.update(bboxes_xywh, conf, og_frame) |
|
active_track_ids = set() |
|
|
|
|
|
new_running_status = "No Running" |
|
|
|
for track in tracker.tracker.tracks: |
|
track_id = track.track_id |
|
x1, y1, x2, y2 = track.to_tlbr() |
|
w = x2 - x1 |
|
h = y2 - y1 |
|
|
|
|
|
red_color = (0, 0, 255) |
|
blue_color = (255, 0, 0) |
|
green_color = (0, 255, 0) |
|
color_id = track_id % 3 |
|
color = red_color if color_id == 0 else blue_color if color_id == 1 else green_color |
|
cv2.rectangle(og_frame, (int(x1), int(y1)), (int(x1 + w), int(y1 + h)), color, 2) |
|
|
|
|
|
if track_id not in self.track_labels: |
|
self.track_labels[track_id] = "Person" |
|
self.track_times[track_id] = 0 |
|
self.track_positions[track_id] = (x1, y1) |
|
self.running_counters[track_id] = 0 |
|
|
|
self.track_times[track_id] += 1 |
|
prev_x1, prev_y1 = self.track_positions[track_id] |
|
displacement = np.sqrt((x1 - prev_x1) ** 2 + (y1 - prev_y1) ** 2) |
|
|
|
|
|
speed = displacement / self.fps if self.fps > 0 else 0 |
|
|
|
self.track_positions[track_id] = (x1, y1) |
|
|
|
|
|
if speed > self.running_threshold and w * h > 5000: |
|
self.running_counters[track_id] += 1 |
|
if self.running_counters[track_id] > self.fps/2: |
|
self.track_labels[track_id] = "Running" |
|
new_running_status = "Running Detected" |
|
else: |
|
self.running_counters[track_id] = 0 |
|
self.track_labels[track_id] = "Person" |
|
|
|
|
|
total_seconds = self.track_times[track_id] / self.fps if self.fps > 0 else 0 |
|
minutes = int(total_seconds // 60) |
|
seconds = int(total_seconds % 60) |
|
|
|
|
|
if total_seconds > 60 and track_id not in self.alert_person_ids: |
|
self.alert_person_ids.append(track_id) |
|
|
|
|
|
cv2.putText(og_frame, f"{self.track_labels[track_id]} {minutes:02}:{seconds:02}", |
|
(int(x1) + 10, int(y1) - 5), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 0), 1, cv2.LINE_AA) |
|
|
|
active_track_ids.add(track_id) |
|
|
|
|
|
self.unique_track_ids.intersection_update(active_track_ids) |
|
self.unique_track_ids.update(active_track_ids) |
|
|
|
|
|
result_info = { |
|
'person_count': len(self.unique_track_ids), |
|
'running_status': new_running_status, |
|
'prolonged_stay_ids': list(self.alert_person_ids) |
|
} |
|
|
|
return og_frame, result_info |
|
|
|
def process_input(self, input_media): |
|
""" |
|
Process either video or webcam input |
|
""" |
|
|
|
if isinstance(input_media, str): |
|
cap = cv2.VideoCapture(input_media) |
|
else: |
|
cap = cv2.VideoCapture(0) |
|
|
|
|
|
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) |
|
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) |
|
fps = cap.get(cv2.CAP_PROP_FPS) or 30 |
|
self.fps = fps |
|
fourcc = cv2.VideoWriter_fourcc(*'mp4v') |
|
out = cv2.VideoWriter('output_detection.mp4', fourcc, fps, (width, height)) |
|
|
|
|
|
frame_info_list = [] |
|
while cap.isOpened(): |
|
ret, frame = cap.read() |
|
if not ret: |
|
break |
|
|
|
|
|
processed_frame, frame_info = self.process_frame(frame) |
|
out.write(processed_frame) |
|
frame_info_list.append(frame_info) |
|
|
|
|
|
cap.release() |
|
out.release() |
|
|
|
return 'output_detection.mp4', frame_info_list |
|
|
|
|
|
detector = ObjectDetector() |
|
|
|
def detect_interface(input_media): |
|
""" |
|
Gradio interface function for detection |
|
""" |
|
output_video, frame_info_list = detector.process_input(input_media) |
|
|
|
|
|
summary = "Detection Summary:\n" |
|
if frame_info_list: |
|
|
|
last_frame_info = frame_info_list[-1] |
|
summary += f"Total Persons Detected: {last_frame_info['person_count']}\n" |
|
summary += f"Running Status: {last_frame_info['running_status']}\n" |
|
if last_frame_info['prolonged_stay_ids']: |
|
summary += f"Prolonged Stay Detected - Person IDs: {last_frame_info['prolonged_stay_ids']}" |
|
else: |
|
summary += "No Prolonged Stay Detected" |
|
|
|
return output_video, summary |
|
|
|
|
|
iface = gr.Interface( |
|
fn=detect_interface, |
|
inputs=[ |
|
gr.File(label="Upload Video", type="filepath"), |
|
gr.Webcam(label="Or Use Webcam") |
|
], |
|
outputs=[ |
|
gr.Video(label="Processed Video"), |
|
gr.Textbox(label="Detection Summary") |
|
], |
|
title="Object Detection with Tracking", |
|
description="Upload a video or use webcam for real-time object detection and tracking" |
|
) |
|
|
|
|
|
iface.launch() |