Abs6187 commited on
Commit
713bd0b
1 Parent(s): fddfd57

Create model.py

Browse files
Files changed (1) hide show
  1. model.py +54 -0
model.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # model.py
2
+ import tensorflow as tf
3
+ from ultralytics import YOLO
4
+ import numpy as np
5
+ from sklearn.preprocessing import StandardScaler
6
+
7
+ class SuspiciousActivityModel:
8
+ def __init__(self, lstm_model_path, yolo_model_path):
9
+ # Load YOLO model
10
+ self.yolo_model = YOLO(yolo_model_path)
11
+ # Load LSTM model
12
+ self.lstm_model = tf.keras.models.load_model(lstm_model_path)
13
+ self.scaler = StandardScaler()
14
+
15
+ def extract_keypoints(self, frame):
16
+ """
17
+ Extracts normalized keypoints from a frame using YOLO pose model.
18
+ """
19
+ results = self.yolo_model(frame, verbose=False)
20
+ for r in results:
21
+ if r.keypoints is not None and len(r.keypoints) > 0:
22
+ keypoints = r.keypoints.xyn.tolist()[0]
23
+ flattened_keypoints = [kp for keypoint in keypoints for kp in keypoint[:2]]
24
+ return flattened_keypoints
25
+ return None
26
+
27
+ def process_frame(self, frame):
28
+ results = self.yolo_model(frame, verbose=False)
29
+
30
+ for box in results[0].boxes:
31
+ cls = int(box.cls[0]) # Class ID
32
+ confidence = float(box.conf[0])
33
+
34
+ if cls == 0 and confidence > 0.5:
35
+ x1, y1, x2, y2 = map(int, box.xyxy[0])
36
+
37
+ # Extract ROI for classification
38
+ roi = frame[y1:y2, x1:x2]
39
+ if roi.size > 0:
40
+ keypoints = self.extract_keypoints(roi)
41
+ if keypoints is not None and len(keypoints) > 0:
42
+ # Standardize and reshape keypoints for LSTM input
43
+ keypoints_scaled = self.scaler.fit_transform([keypoints])
44
+ keypoints_reshaped = keypoints_scaled.reshape((1, 1, len(keypoints)))
45
+
46
+ # Predict with LSTM model
47
+ prediction = (self.lstm_model.predict(keypoints_reshaped) > 0.5).astype(int)[0][0]
48
+
49
+ # Return label
50
+ return 'Suspicious' if prediction == 1 else 'Normal'
51
+ return 'Normal'
52
+
53
+ def detect_activity(self, frame):
54
+ return self.process_frame(frame)