Abs6187 commited on
Commit
83b09db
1 Parent(s): 06372a9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +53 -22
app.py CHANGED
@@ -1,32 +1,63 @@
1
- import cv2
2
  import gradio as gr
3
- from transformers import pipeline
 
 
 
 
4
 
5
- # Load YOLOv8 model and suspicious activity classification model
6
- pose_detection = pipeline("object-detection", model="yolov8-pose") # Correct path if it's inside same folder.
7
- suspicious_activity_detection = pipeline("text-classification", model="suspicious_activity_model")
 
8
 
 
 
 
 
 
 
 
 
 
 
 
9
  def process_frame(frame):
10
- results = pose_detection(frame)
 
 
 
 
11
 
12
- for person in results:
13
- if person['label'] == 'person':
14
- x1, y1, x2, y2 = map(int, person['box'].values())
15
- cv2.rectangle(frame, (x1, y1), (x2, y2), (0, 255, 0), 2)
16
- cv2.putText(frame, 'Detected', (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 2)
 
 
 
 
 
 
 
 
 
 
 
 
17
  return frame
18
 
19
- def live_detection(frame):
20
- frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
 
21
  processed_frame = process_frame(frame)
22
  return processed_frame
23
 
24
- interface = gr.Interface(
25
- fn=live_detection,
26
- inputs=gr.Image(source="webcam", tool="editor", type="numpy"),
27
- outputs=gr.Image(type="numpy"),
28
- live=True
29
- )
30
-
31
- if __name__ == "__main__":
32
- interface.launch(server_name="0.0.0.0", server_port=7860)
 
 
1
  import gradio as gr
2
+ import cv2
3
+ import numpy as np
4
+ from tensorflow.keras.models import load_model
5
+ from sklearn.preprocessing import StandardScaler
6
+ from ultralytics import YOLO
7
 
8
+ # Load models
9
+ lstm_model = load_model('suspicious_activity_model.h5')
10
+ yolo_model = YOLO('yolov8n-pose.pt') # Ensure this model supports keypoint detection
11
+ scaler = StandardScaler()
12
 
13
+ # Function to extract keypoints from a frame
14
+ def extract_keypoints(frame):
15
+ results = yolo_model(frame, verbose=False)
16
+ for r in results:
17
+ if r.keypoints is not None and len(r.keypoints) > 0:
18
+ keypoints = r.keypoints.xyn.tolist()[0] # Use the first person's keypoints
19
+ flattened_keypoints = [kp for keypoint in keypoints for kp in keypoint[:2]] # Flatten x, y values
20
+ return flattened_keypoints
21
+ return None # Return None if no keypoints are detected
22
+
23
+ # Function to process each frame
24
  def process_frame(frame):
25
+ results = yolo_model(frame, verbose=False)
26
+
27
+ for box in results[0].boxes:
28
+ cls = int(box.cls[0]) # Class ID
29
+ confidence = float(box.conf[0])
30
 
31
+ if cls == 0 and confidence > 0.5: # Detect persons only
32
+ x1, y1, x2, y2 = map(int, box.xyxy[0]) # Bounding box coordinates
33
+
34
+ # Extract ROI for classification
35
+ roi = frame[y1:y2, x1:x2]
36
+ if roi.size > 0:
37
+ keypoints = extract_keypoints(roi)
38
+ if keypoints is not None and len(keypoints) > 0:
39
+ keypoints_scaled = scaler.fit_transform([keypoints])
40
+ keypoints_reshaped = keypoints_scaled.reshape((1, 1, len(keypoints)))
41
+
42
+ prediction = (lstm_model.predict(keypoints_reshaped) > 0.5).astype(int)[0][0]
43
+
44
+ color = (0, 0, 255) if prediction == 1 else (0, 255, 0)
45
+ label = 'Suspicious' if prediction == 1 else 'Normal'
46
+ cv2.rectangle(frame, (x1, y1), (x2, y2), color, 2)
47
+ cv2.putText(frame, label, (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)
48
  return frame
49
 
50
+ # Gradio video streaming function
51
+ def video_processing(video_frame):
52
+ frame = cv2.cvtColor(video_frame, cv2.COLOR_BGR2RGB) # Convert to RGB
53
  processed_frame = process_frame(frame)
54
  return processed_frame
55
 
56
+ # Launch Gradio app
57
+ gr.Interface(
58
+ fn=video_processing,
59
+ inputs=gr.Video(source="webcam", streaming=True),
60
+ outputs="video",
61
+ live=True,
62
+ title="Suspicious Activity Detection"
63
+ ).launch(debug=True)