File size: 4,241 Bytes
8cc5e9d
18a9dce
8cc5e9d
 
18a9dce
8cc5e9d
8e0d21f
18a9dce
a583b5e
 
 
 
8cc5e9d
 
18a9dce
 
 
 
 
 
 
8cc5e9d
18a9dce
 
8cc5e9d
18a9dce
8cc5e9d
 
18a9dce
8cc5e9d
 
18a9dce
 
 
 
 
8cc5e9d
18a9dce
 
 
499f0dc
18a9dce
 
 
 
 
 
 
 
 
 
499f0dc
 
 
8cc5e9d
18a9dce
 
 
 
 
8cc5e9d
18a9dce
8cc5e9d
18a9dce
8cc5e9d
 
 
18a9dce
8cc5e9d
 
 
 
18a9dce
 
8cc5e9d
18a9dce
 
8cc5e9d
18a9dce
 
 
 
8cc5e9d
18a9dce
 
8cc5e9d
 
 
 
18a9dce
8cc5e9d
18a9dce
499f0dc
18a9dce
8cc5e9d
18a9dce
 
 
8cc5e9d
 
18a9dce
 
8cc5e9d
18a9dce
8cc5e9d
499f0dc
8cc5e9d
 
7cf8e9f
8cc5e9d
499f0dc
8cc5e9d
 
 
18a9dce
 
8cc5e9d
18a9dce
 
 
8cc5e9d
18a9dce
 
 
 
8cc5e9d
18a9dce
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
from models.PosterV2_7cls import pyramid_trans_expr2
import cv2
import torch
import os
import time
from PIL import Image
from main import RecorderMeter1, RecorderMeter  # noqa: F401

script_dir = os.path.dirname(os.path.abspath(__file__))

# Construct the full path to the model file
model_path = os.path.join(script_dir,"models","checkpoints","raf-db-model_best.pth")

# Determine the available device for model execution
if torch.backends.mps.is_available():
    device = "mps"
elif torch.cuda.is_available():
    device = "cuda"
else:
    device = "cpu"

# Initialize the model with specified image size and number of classes
model = pyramid_trans_expr2(img_size=224, num_classes=7)

# Wrap the model with DataParallel for potential multi-GPU usage
model = torch.nn.DataParallel(model)

# Move the model to the chosen device
model = model.to(device)

# Print the current time
currtime = time.strftime("%H:%M:%S")
print(currtime)


def main():
    # Load the model checkpoint if it exists
    if model_path is not None:
        if os.path.isfile(model_path):
            print("=> loading checkpoint '{}'".format(model_path))
            checkpoint = torch.load(model_path, map_location=device, weights_only=False)
            best_acc = checkpoint["best_acc"]
            best_acc = best_acc.to()
            print(f"best_acc:{best_acc}")
            model.load_state_dict(checkpoint["state_dict"])
            print(
                "=> loaded checkpoint '{}' (epoch {})".format(
                    model_path, checkpoint["epoch"]
                )
            )
        else:
            print(
                "[!] detectfaces.py => no checkpoint found at '{}'".format(model_path)
            )
        # Start webcam capture and prediction
        imagecapture(model)
        return


def imagecapture(model):
    # Initialize webcam capture
    cap = cv2.VideoCapture(0)
    time.sleep(5)  # Wait for 5 seconds to allow the camera to initialize

    # Keep trying to open the webcam until successful
    while not cap.isOpened():
        time.sleep(2)  # Wait for 2 seconds before retrying

    # Flag to control webcam capture
    capturing = True
    while capturing:
        # Import the predict function from the prediction module
        from prediction import predict

        # Read a frame from the webcam
        ret, frame = cap.read()

        # Handle potential error reading the frame
        if not ret:
            print("Error: Could not read frame.")
            break

        # Convert the frame to grayscale
        gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)

        # Detect faces using Haar Cascades
        faces = cv2.CascadeClassifier(
            cv2.data.haarcascades + "haarcascade_frontalface_default.xml"
        ).detectMultiScale(gray, scaleFactor=1.3, minNeighbors=5, minSize=(30, 30))

        # If faces are detected, proceed with prediction
        if len(faces) > 0:
            currtimeimg = time.strftime("%H:%M:%S")
            print(f"[!]Face detected at {currtimeimg}")
            # Crop the face region
            face_region = frame[
                faces[0][1] : faces[0][1] + faces[0][3],
                faces[0][0] : faces[0][0] + faces[0][2],
            ]
            # Convert the face region to a PIL image
            face_pil_image = Image.fromarray(
                cv2.cvtColor(face_region, cv2.COLOR_BGR2RGB)
            )
            print("[!]Start Expressions")
            # Record the prediction start time
            starttime = time.strftime("%H:%M:%S")
            print(f"-->Prediction starting at {starttime}")
            # Perform emotion prediction
            predict(model, image_path=face_pil_image)
            # Record the prediction end time
            endtime = time.strftime("%H:%M:%S")
            print(f"-->Done prediction at {endtime}")

            # Stop capturing once prediction is complete
            capturing = False

        # Exit the loop if the 'q' key is pressed
        if cv2.waitKey(1) & 0xFF == ord("q"):
            break

    # Release webcam resources and close OpenCV windows
    cap.release()
    cv2.destroyAllWindows()


# Execute the main function if the script is run directly
if __name__ == "__main__":
    main()