Spaces:
Running
on
Zero
Running
on
Zero
File size: 4,232 Bytes
b85fd0a b42c5db b85fd0a 18eea93 b85fd0a 18eea93 b85fd0a 18eea93 b42c5db 366a48b b42c5db b85fd0a 18eea93 b42c5db 18eea93 b42c5db 18eea93 b85fd0a 18eea93 b85fd0a 18eea93 b85fd0a 18eea93 b85fd0a 18eea93 b85fd0a 18eea93 b85fd0a 18eea93 b85fd0a 18eea93 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 |
#!/usr/bin/env python
from __future__ import annotations
import pathlib
import sys
import cv2
import gradio as gr
import numpy as np
import spaces
import torch
sys.path.insert(0, "face_detection")
sys.path.insert(0, "face_alignment")
sys.path.insert(0, "emotion_recognition")
from ibug.emotion_recognition import EmoNetPredictor
from ibug.face_alignment import FANPredictor
from ibug.face_detection import RetinaFacePredictor
DESCRIPTION = "# [ibug-group/emotion_recognition](https://github.com/ibug-group/emotion_recognition)"
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
face_detector = RetinaFacePredictor(threshold=0.8, device="cpu", model=RetinaFacePredictor.get_model("mobilenet0.25"))
face_detector.device = device
face_detector.net.to(device)
landmark_detector = FANPredictor(
device="cpu", model=FANPredictor.get_model("2dfan2"), config=FANPredictor.create_config(use_jit=False)
)
landmark_detector.device = device
landmark_detector.net.to(device)
def load_model(model_name: str, device: torch.device) -> EmoNetPredictor:
model = EmoNetPredictor(
device="cpu", model=EmoNetPredictor.get_model(model_name), config=EmoNetPredictor.create_config(use_jit=False)
)
model.device = device
model.net.to(device)
return model
model_names = [
"emonet248",
"emonet245",
"emonet248_alt",
"emonet245_alt",
]
models = {name: load_model(name, device) for name in model_names}
@spaces.GPU
def predict(image: np.ndarray, model_name: str, max_num_faces: int) -> np.ndarray:
model = models[model_name]
if len(model.config.emotion_labels) == 8:
colors: tuple[tuple[int, int, int], ...] = (
(192, 192, 192),
(0, 255, 0),
(255, 0, 0),
(0, 255, 255),
(0, 128, 255),
(255, 0, 128),
(0, 0, 255),
(128, 255, 0),
)
else:
colors = (
(192, 192, 192),
(0, 255, 0),
(255, 0, 0),
(0, 255, 255),
(0, 0, 255),
)
# RGB -> BGR
image = image[:, :, ::-1]
faces = face_detector(image, rgb=False)
if len(faces) == 0:
raise gr.Error("No face was found.")
faces = sorted(list(faces), key=lambda x: -x[4])[:max_num_faces]
faces = np.asarray(faces)
_, _, features = landmark_detector(image, faces, rgb=False, return_features=True)
emotions = model(features)
res = image.copy()
for index, face in enumerate(faces):
box = np.round(face[:4]).astype(int)
cv2.rectangle(res, tuple(box[:2]), tuple(box[2:]), (0, 255, 0), 2)
emotion = emotions["emotion"][index]
valence = emotions["valence"][index]
arousal = emotions["arousal"][index]
emotion_label = model.config.emotion_labels[emotion].title()
text_content = f"{emotion_label} ({valence: .01f}, {arousal: .01f})"
cv2.putText(
res, text_content, (box[0], box[1] - 10), cv2.FONT_HERSHEY_DUPLEX, 1, colors[emotion], lineType=cv2.LINE_AA
)
return res[:, :, ::-1]
with gr.Blocks(css="style.css") as demo:
gr.Markdown(DESCRIPTION)
with gr.Row():
with gr.Column():
image = gr.Image(label="Input", type="numpy")
model_name = gr.Radio(
label="Model",
choices=model_names,
value=model_names[0],
type="value",
)
max_num_of_faces = gr.Slider(
label="Max Number of Faces",
minimum=1,
maximum=30,
step=1,
value=30,
)
run_button = gr.Button()
with gr.Column():
result = gr.Image(label="Output")
gr.Examples(
examples=[[path.as_posix(), model_names[0], 30] for path in sorted(pathlib.Path("images").rglob("*.jpg"))],
inputs=[image, model_name, max_num_of_faces],
outputs=result,
fn=predict,
)
run_button.click(
fn=predict,
inputs=[image, model_name, max_num_of_faces],
outputs=result,
api_name="predict",
)
if __name__ == "__main__":
demo.queue(max_size=20).launch()
|