hysts HF staff commited on
Commit
b85fd0a
1 Parent(s): a5542d6
Files changed (7) hide show
  1. .gitignore +1 -0
  2. .gitmodules +9 -0
  3. app.py +198 -0
  4. emotion_recognition +1 -0
  5. face_alignment +1 -0
  6. face_detection +1 -0
  7. requirements.txt +4 -0
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ images
.gitmodules ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ [submodule "face_detection"]
2
+ path = face_detection
3
+ url = https://github.com/ibug-group/face_detection
4
+ [submodule "face_alignment"]
5
+ path = face_alignment
6
+ url = https://github.com/ibug-group/face_alignment
7
+ [submodule "emotion_recognition"]
8
+ path = emotion_recognition
9
+ url = https://github.com/ibug-group/emotion_recognition
app.py ADDED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ from __future__ import annotations
4
+
5
+ import argparse
6
+ import functools
7
+ import os
8
+ import pathlib
9
+ import sys
10
+ import tarfile
11
+
12
+ import cv2
13
+ import gradio as gr
14
+ import huggingface_hub
15
+ import numpy as np
16
+ import torch
17
+
18
+ sys.path.insert(0, 'face_detection')
19
+ sys.path.insert(0, 'face_alignment')
20
+ sys.path.insert(0, 'emotion_recognition')
21
+
22
+ from ibug.emotion_recognition import EmoNetPredictor
23
+ from ibug.face_alignment import FANPredictor
24
+ from ibug.face_detection import RetinaFacePredictor
25
+
26
+ REPO_URL = 'https://github.com/ibug-group/emotion_recognition'
27
+ TITLE = 'ibug-group/emotion_recognition'
28
+ DESCRIPTION = f'This is a demo for {REPO_URL}.'
29
+ ARTICLE = None
30
+
31
+ TOKEN = os.environ['TOKEN']
32
+
33
+
34
+ def parse_args() -> argparse.Namespace:
35
+ parser = argparse.ArgumentParser()
36
+ parser.add_argument('--device', type=str, default='cpu')
37
+ parser.add_argument('--theme', type=str)
38
+ parser.add_argument('--live', action='store_true')
39
+ parser.add_argument('--share', action='store_true')
40
+ parser.add_argument('--port', type=int)
41
+ parser.add_argument('--disable-queue',
42
+ dest='enable_queue',
43
+ action='store_false')
44
+ parser.add_argument('--allow-flagging', type=str, default='never')
45
+ parser.add_argument('--allow-screenshot', action='store_true')
46
+ return parser.parse_args()
47
+
48
+
49
+ def load_sample_images() -> list[pathlib.Path]:
50
+ image_dir = pathlib.Path('images')
51
+ if not image_dir.exists():
52
+ image_dir.mkdir()
53
+ dataset_repo = 'hysts/input-images'
54
+ filenames = ['004.tar']
55
+ for name in filenames:
56
+ path = huggingface_hub.hf_hub_download(dataset_repo,
57
+ name,
58
+ repo_type='dataset',
59
+ use_auth_token=TOKEN)
60
+ with tarfile.open(path) as f:
61
+ f.extractall(image_dir.as_posix())
62
+ return sorted(image_dir.rglob('*.jpg'))
63
+
64
+
65
+ def load_face_detector(device: torch.device) -> RetinaFacePredictor:
66
+ model = RetinaFacePredictor(
67
+ threshold=0.8,
68
+ device=device,
69
+ model=RetinaFacePredictor.get_model('mobilenet0.25'))
70
+ return model
71
+
72
+
73
+ def load_landmark_detector(device: torch.device) -> FANPredictor:
74
+ model = FANPredictor(device=device, model=FANPredictor.get_model('2dfan2'))
75
+ return model
76
+
77
+
78
+ def load_model(model_name: str, device: torch.device) -> EmoNetPredictor:
79
+ model = EmoNetPredictor(device=device,
80
+ model=EmoNetPredictor.get_model(model_name))
81
+ return model
82
+
83
+
84
+ def predict(image: np.ndarray, model_name: str, max_num_faces: int,
85
+ face_detector: RetinaFacePredictor,
86
+ landmark_detector: FANPredictor,
87
+ models: dict[str, EmoNetPredictor]) -> np.ndarray:
88
+ model = models[model_name]
89
+ if len(model.config.emotion_labels) == 8:
90
+ colors = (
91
+ (192, 192, 192),
92
+ (0, 255, 0),
93
+ (255, 0, 0),
94
+ (0, 255, 255),
95
+ (0, 128, 255),
96
+ (255, 0, 128),
97
+ (0, 0, 255),
98
+ (128, 255, 0),
99
+ )
100
+ else:
101
+ colors = (
102
+ (192, 192, 192),
103
+ (0, 255, 0),
104
+ (255, 0, 0),
105
+ (0, 255, 255),
106
+ (0, 0, 255),
107
+ )
108
+
109
+ # RGB -> BGR
110
+ image = image[:, :, ::-1]
111
+
112
+ faces = face_detector(image, rgb=False)
113
+ if len(faces) == 0:
114
+ raise RuntimeError('No face was found.')
115
+ faces = sorted(list(faces), key=lambda x: -x[4])[:max_num_faces]
116
+ faces = np.asarray(faces)
117
+ _, _, features = landmark_detector(image,
118
+ faces,
119
+ rgb=False,
120
+ return_features=True)
121
+ emotions = model(features)
122
+
123
+ res = image.copy()
124
+ for index, face in enumerate(faces):
125
+ box = np.round(face[:4]).astype(int)
126
+ cv2.rectangle(res, tuple(box[:2]), tuple(box[2:]), (0, 255, 0), 2)
127
+
128
+ emotion = emotions['emotion'][index]
129
+ valence = emotions['valence'][index]
130
+ arousal = emotions['arousal'][index]
131
+ emotion_label = model.config.emotion_labels[emotion].title()
132
+
133
+ text_content = f'{emotion_label} ({valence: .01f}, {arousal: .01f})'
134
+ cv2.putText(res,
135
+ text_content, (box[0], box[1] - 10),
136
+ cv2.FONT_HERSHEY_DUPLEX,
137
+ 1,
138
+ colors[emotion],
139
+ lineType=cv2.LINE_AA)
140
+
141
+ return res[:, :, ::-1]
142
+
143
+
144
+ def main():
145
+ gr.close_all()
146
+
147
+ args = parse_args()
148
+ device = torch.device(args.device)
149
+
150
+ face_detector = load_face_detector(device)
151
+ landmark_detector = load_landmark_detector(device)
152
+
153
+ model_names = [
154
+ 'emonet248',
155
+ 'emonet245',
156
+ 'emonet248_alt',
157
+ 'emonet245_alt',
158
+ ]
159
+ models = {name: load_model(name, device=device) for name in model_names}
160
+
161
+ func = functools.partial(predict,
162
+ face_detector=face_detector,
163
+ landmark_detector=landmark_detector,
164
+ models=models)
165
+ func = functools.update_wrapper(func, predict)
166
+
167
+ image_paths = load_sample_images()
168
+ examples = [[path.as_posix(), model_names[0], 30] for path in image_paths]
169
+
170
+ gr.Interface(
171
+ func,
172
+ [
173
+ gr.inputs.Image(type='numpy', label='Input'),
174
+ gr.inputs.Radio(model_names,
175
+ type='value',
176
+ default=model_names[0],
177
+ label='Model'),
178
+ gr.inputs.Slider(
179
+ 1, 30, step=1, default=30, label='Max Number of Faces'),
180
+ ],
181
+ gr.outputs.Image(type='numpy', label='Output'),
182
+ examples=examples,
183
+ title=TITLE,
184
+ description=DESCRIPTION,
185
+ article=ARTICLE,
186
+ theme=args.theme,
187
+ allow_screenshot=args.allow_screenshot,
188
+ allow_flagging=args.allow_flagging,
189
+ live=args.live,
190
+ ).launch(
191
+ enable_queue=args.enable_queue,
192
+ server_port=args.port,
193
+ share=args.share,
194
+ )
195
+
196
+
197
+ if __name__ == '__main__':
198
+ main()
emotion_recognition ADDED
@@ -0,0 +1 @@
 
 
1
+ Subproject commit 29e09ae91dcdc145a153eb793b9f451774e191ef
face_alignment ADDED
@@ -0,0 +1 @@
 
 
1
+ Subproject commit aef843c05be718fbd87ee2cb25fa3a015b7e59b0
face_detection ADDED
@@ -0,0 +1 @@
 
 
1
+ Subproject commit bc1e392b11d731fa20b1397c8ff3faed5e7fc76e
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ numpy==1.22.3
2
+ opencv-python-headless==4.5.5.64
3
+ torch==1.11.0
4
+ torchvision==0.12.0