File size: 4,232 Bytes
b85fd0a
 
 
 
 
 
 
 
 
 
b42c5db
b85fd0a
 
18eea93
 
 
b85fd0a
 
 
 
 
18eea93
b85fd0a
18eea93
 
b42c5db
 
 
 
366a48b
 
 
b42c5db
 
 
 
 
 
 
 
 
 
 
 
b85fd0a
18eea93
 
 
 
 
 
b42c5db
18eea93
 
b42c5db
18eea93
b85fd0a
 
18eea93
b85fd0a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18eea93
b85fd0a
 
18eea93
b85fd0a
 
 
 
 
 
 
18eea93
 
 
b85fd0a
 
18eea93
 
 
 
b85fd0a
 
 
 
18eea93
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b85fd0a
 
18eea93
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
#!/usr/bin/env python

from __future__ import annotations

import pathlib
import sys

import cv2
import gradio as gr
import numpy as np
import spaces
import torch

sys.path.insert(0, "face_detection")
sys.path.insert(0, "face_alignment")
sys.path.insert(0, "emotion_recognition")

from ibug.emotion_recognition import EmoNetPredictor
from ibug.face_alignment import FANPredictor
from ibug.face_detection import RetinaFacePredictor

DESCRIPTION = "# [ibug-group/emotion_recognition](https://github.com/ibug-group/emotion_recognition)"

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

face_detector = RetinaFacePredictor(threshold=0.8, device="cpu", model=RetinaFacePredictor.get_model("mobilenet0.25"))
face_detector.device = device
face_detector.net.to(device)

landmark_detector = FANPredictor(
    device="cpu", model=FANPredictor.get_model("2dfan2"), config=FANPredictor.create_config(use_jit=False)
)
landmark_detector.device = device
landmark_detector.net.to(device)


def load_model(model_name: str, device: torch.device) -> EmoNetPredictor:
    model = EmoNetPredictor(
        device="cpu", model=EmoNetPredictor.get_model(model_name), config=EmoNetPredictor.create_config(use_jit=False)
    )
    model.device = device
    model.net.to(device)
    return model


model_names = [
    "emonet248",
    "emonet245",
    "emonet248_alt",
    "emonet245_alt",
]
models = {name: load_model(name, device) for name in model_names}


@spaces.GPU
def predict(image: np.ndarray, model_name: str, max_num_faces: int) -> np.ndarray:
    model = models[model_name]
    if len(model.config.emotion_labels) == 8:
        colors: tuple[tuple[int, int, int], ...] = (
            (192, 192, 192),
            (0, 255, 0),
            (255, 0, 0),
            (0, 255, 255),
            (0, 128, 255),
            (255, 0, 128),
            (0, 0, 255),
            (128, 255, 0),
        )
    else:
        colors = (
            (192, 192, 192),
            (0, 255, 0),
            (255, 0, 0),
            (0, 255, 255),
            (0, 0, 255),
        )

    # RGB -> BGR
    image = image[:, :, ::-1]

    faces = face_detector(image, rgb=False)
    if len(faces) == 0:
        raise gr.Error("No face was found.")
    faces = sorted(list(faces), key=lambda x: -x[4])[:max_num_faces]
    faces = np.asarray(faces)
    _, _, features = landmark_detector(image, faces, rgb=False, return_features=True)
    emotions = model(features)

    res = image.copy()
    for index, face in enumerate(faces):
        box = np.round(face[:4]).astype(int)
        cv2.rectangle(res, tuple(box[:2]), tuple(box[2:]), (0, 255, 0), 2)

        emotion = emotions["emotion"][index]
        valence = emotions["valence"][index]
        arousal = emotions["arousal"][index]
        emotion_label = model.config.emotion_labels[emotion].title()

        text_content = f"{emotion_label} ({valence: .01f}, {arousal: .01f})"
        cv2.putText(
            res, text_content, (box[0], box[1] - 10), cv2.FONT_HERSHEY_DUPLEX, 1, colors[emotion], lineType=cv2.LINE_AA
        )

    return res[:, :, ::-1]


with gr.Blocks(css="style.css") as demo:
    gr.Markdown(DESCRIPTION)
    with gr.Row():
        with gr.Column():
            image = gr.Image(label="Input", type="numpy")
            model_name = gr.Radio(
                label="Model",
                choices=model_names,
                value=model_names[0],
                type="value",
            )
            max_num_of_faces = gr.Slider(
                label="Max Number of Faces",
                minimum=1,
                maximum=30,
                step=1,
                value=30,
            )
            run_button = gr.Button()
        with gr.Column():
            result = gr.Image(label="Output")
    gr.Examples(
        examples=[[path.as_posix(), model_names[0], 30] for path in sorted(pathlib.Path("images").rglob("*.jpg"))],
        inputs=[image, model_name, max_num_of_faces],
        outputs=result,
        fn=predict,
    )
    run_button.click(
        fn=predict,
        inputs=[image, model_name, max_num_of_faces],
        outputs=result,
        api_name="predict",
    )

if __name__ == "__main__":
    demo.queue(max_size=20).launch()