ElenaRyumina commited on
Commit
c0f6432
·
1 Parent(s): 6e3fd26
Files changed (12) hide show
  1. app.css +18 -2
  2. app.py +60 -7
  3. app/app_utils.py +87 -4
  4. app/config.py +10 -0
  5. app/description.py +12 -2
  6. app/face_utils.py +34 -0
  7. app/model.py +4 -1
  8. app/plot.py +29 -0
  9. config.toml +6 -2
  10. result.mp4 +0 -0
  11. videos/video1.mp4 +0 -0
  12. videos/video2.mp4 +0 -0
app.css CHANGED
@@ -3,8 +3,8 @@ div.app-flex-container {
3
  align-items: left;
4
  }
5
 
6
- div.app-flex-container > img {
7
- margin-right: 6px;
8
  }
9
 
10
  div.dl1 div.upload-container {
@@ -20,6 +20,22 @@ div.dl2 img {
20
  max-height: 200px;
21
  }
22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  .submit {
24
  display: inline-block;
25
  padding: 10px 20px;
 
3
  align-items: left;
4
  }
5
 
6
+ div.app-flex-container > a {
7
+ margin-left: 6px;
8
  }
9
 
10
  div.dl1 div.upload-container {
 
20
  max-height: 200px;
21
  }
22
 
23
+ div.video1 div.video-container {
24
+ height: 500px;
25
+ }
26
+
27
+ div.video2 {
28
+ height: 200px;
29
+ }
30
+
31
+ div.video3 {
32
+ height: 200px;
33
+ }
34
+
35
+ div.stat {
36
+ height: 286px;
37
+ }
38
+
39
  .submit {
40
  display: inline-block;
41
  padding: 10px 20px;
app.py CHANGED
@@ -10,21 +10,52 @@ License: MIT License
10
  import gradio as gr
11
 
12
  # Importing necessary components for the Gradio app
13
- from app.description import DESCRIPTION
14
  from app.authors import AUTHORS
15
- from app.app_utils import preprocess_and_predict
16
 
17
 
18
- def clear():
19
  return (
20
  gr.Image(value=None, type="pil"),
21
  gr.Image(value=None, scale=1, elem_classes="dl2"),
22
  gr.Label(value=None, num_top_classes=3, scale=1, elem_classes="dl3"),
23
  )
24
 
 
 
 
 
 
 
 
 
25
  with gr.Blocks(css="app.css") as demo:
26
- with gr.Tab("App"):
27
- gr.Markdown(value=DESCRIPTION)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  with gr.Row():
29
  with gr.Column(scale=2, elem_classes="dl1"):
30
  input_image = gr.Image(type="pil")
@@ -54,17 +85,39 @@ with gr.Blocks(css="app.css") as demo:
54
  gr.Markdown(value=AUTHORS)
55
 
56
  submit.click(
57
- fn=preprocess_and_predict,
58
  inputs=[input_image],
59
  outputs=[output_image, output_label],
60
  queue=True,
61
  )
62
  clear_btn.click(
63
- fn=clear,
64
  inputs=[],
65
  outputs=[input_image, output_image, output_label],
66
  queue=True,
67
  )
68
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
  if __name__ == "__main__":
70
  demo.queue(api_open=False).launch(share=False)
 
10
  import gradio as gr
11
 
12
  # Importing necessary components for the Gradio app
13
+ from app.description import DESCRIPTION_STATIC, DESCRIPTION_DYNAMIC
14
  from app.authors import AUTHORS
15
+ from app.app_utils import preprocess_image_and_predict, preprocess_video_and_predict
16
 
17
 
18
+ def clear_static_info():
19
  return (
20
  gr.Image(value=None, type="pil"),
21
  gr.Image(value=None, scale=1, elem_classes="dl2"),
22
  gr.Label(value=None, num_top_classes=3, scale=1, elem_classes="dl3"),
23
  )
24
 
25
+ def clear_dynamic_info():
26
+ return (
27
+ gr.Video(value=None),
28
+ gr.Video(value=None),
29
+ gr.Video(value=None),
30
+ gr.Plot(value=None),
31
+ )
32
+
33
  with gr.Blocks(css="app.css") as demo:
34
+ with gr.Tab("Dynamic App"):
35
+ gr.Markdown(value=DESCRIPTION_DYNAMIC)
36
+ with gr.Row():
37
+ with gr.Column(scale=2):
38
+ input_video = gr.Video(elem_classes="video1")
39
+ with gr.Row():
40
+ clear_btn_dynamic = gr.Button(
41
+ value="Clear", interactive=True, scale=1
42
+ )
43
+ submit_dynamic = gr.Button(
44
+ value="Submit", interactive=True, scale=1, elem_classes="submit"
45
+ )
46
+ with gr.Column(scale=2, elem_classes="dl4"):
47
+ with gr.Row():
48
+ output_video = gr.Video(label="Original video", scale=2, elem_classes="video2")
49
+ output_face = gr.Video(label="Pre-processed video", scale=1, elem_classes="video3")
50
+ output_statistics = gr.Plot(label="Statistics of emotions", elem_classes="stat")
51
+ gr.Examples(
52
+ ["videos/video1.mp4",
53
+ "videos/video2.mp4"],
54
+ [input_video],
55
+ )
56
+
57
+ with gr.Tab("Static App"):
58
+ gr.Markdown(value=DESCRIPTION_STATIC)
59
  with gr.Row():
60
  with gr.Column(scale=2, elem_classes="dl1"):
61
  input_image = gr.Image(type="pil")
 
85
  gr.Markdown(value=AUTHORS)
86
 
87
  submit.click(
88
+ fn=preprocess_image_and_predict,
89
  inputs=[input_image],
90
  outputs=[output_image, output_label],
91
  queue=True,
92
  )
93
  clear_btn.click(
94
+ fn=clear_static_info,
95
  inputs=[],
96
  outputs=[input_image, output_image, output_label],
97
  queue=True,
98
  )
99
 
100
+ submit_dynamic.click(
101
+ fn=preprocess_video_and_predict,
102
+ inputs=input_video,
103
+ outputs=[
104
+ output_video,
105
+ output_face,
106
+ output_statistics
107
+ ],
108
+ queue=True,
109
+ )
110
+ clear_btn_dynamic.click(
111
+ fn=clear_dynamic_info,
112
+ inputs=[],
113
+ outputs=[
114
+ input_video,
115
+ output_video,
116
+ output_face,
117
+ output_statistics
118
+ ],
119
+ queue=True,
120
+ )
121
+
122
  if __name__ == "__main__":
123
  demo.queue(api_open=False).launch(share=False)
app/app_utils.py CHANGED
@@ -9,17 +9,19 @@ import torch
9
  import numpy as np
10
  import mediapipe as mp
11
  from PIL import Image
 
12
 
13
  # Importing necessary components for the Gradio app
14
- from app.model import pth_model, pth_processing
15
- from app.face_utils import get_box
16
  from app.config import DICT_EMO
 
17
 
18
 
19
  mp_face_mesh = mp.solutions.face_mesh
20
 
21
 
22
- def preprocess_and_predict(inp):
23
  inp = np.array(inp)
24
 
25
  if inp is None:
@@ -43,10 +45,91 @@ def preprocess_and_predict(inp):
43
  cur_face = inp[startY:endY, startX:endX]
44
  cur_face_n = pth_processing(Image.fromarray(cur_face))
45
  prediction = (
46
- torch.nn.functional.softmax(pth_model(cur_face_n), dim=1)
47
  .detach()
48
  .numpy()[0]
49
  )
50
  confidences = {DICT_EMO[i]: float(prediction[i]) for i in range(7)}
51
 
52
  return cur_face, confidences
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  import numpy as np
10
  import mediapipe as mp
11
  from PIL import Image
12
+ import cv2
13
 
14
  # Importing necessary components for the Gradio app
15
+ from app.model import pth_model_static, pth_model_dynamic, pth_processing
16
+ from app.face_utils import get_box, display_info
17
  from app.config import DICT_EMO
18
+ from app.plot import statistics_plot
19
 
20
 
21
  mp_face_mesh = mp.solutions.face_mesh
22
 
23
 
24
+ def preprocess_image_and_predict(inp):
25
  inp = np.array(inp)
26
 
27
  if inp is None:
 
45
  cur_face = inp[startY:endY, startX:endX]
46
  cur_face_n = pth_processing(Image.fromarray(cur_face))
47
  prediction = (
48
+ torch.nn.functional.softmax(pth_model_static(cur_face_n), dim=1)
49
  .detach()
50
  .numpy()[0]
51
  )
52
  confidences = {DICT_EMO[i]: float(prediction[i]) for i in range(7)}
53
 
54
  return cur_face, confidences
55
+
56
+
57
+ def preprocess_video_and_predict(video):
58
+
59
+ cap = cv2.VideoCapture(video)
60
+ w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
61
+ h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
62
+ fps = np.round(cap.get(cv2.CAP_PROP_FPS))
63
+
64
+ path_save_video = 'result.mp4'
65
+ vid_writer = cv2.VideoWriter(path_save_video, cv2.VideoWriter_fourcc(*'mp4v'), fps, (224, 224))
66
+
67
+ lstm_features = []
68
+ count_frame = 1
69
+ probs = []
70
+ frames = []
71
+ last_output = None
72
+
73
+ with mp_face_mesh.FaceMesh(
74
+ max_num_faces=1,
75
+ refine_landmarks=False,
76
+ min_detection_confidence=0.5,
77
+ min_tracking_confidence=0.5) as face_mesh:
78
+
79
+ while cap.isOpened():
80
+ _, frame = cap.read()
81
+ if frame is None: break
82
+
83
+ frame_copy = frame.copy()
84
+ frame_copy.flags.writeable = False
85
+ frame_copy = cv2.cvtColor(frame_copy, cv2.COLOR_BGR2RGB)
86
+ results = face_mesh.process(frame_copy)
87
+ frame_copy.flags.writeable = True
88
+
89
+ if results.multi_face_landmarks:
90
+ for fl in results.multi_face_landmarks:
91
+ startX, startY, endX, endY = get_box(fl, w, h)
92
+ cur_face = frame_copy[startY:endY, startX: endX]
93
+
94
+ if (count_frame-1)%5 == 0:
95
+ cur_face_copy = pth_processing(Image.fromarray(cur_face))
96
+ features = torch.nn.functional.relu(pth_model_static.extract_features(cur_face_copy)).detach().numpy()
97
+
98
+ if len(lstm_features) == 0:
99
+ lstm_features = [features]*10
100
+ else:
101
+ lstm_features = lstm_features[1:] + [features]
102
+
103
+ lstm_f = torch.from_numpy(np.vstack(lstm_features))
104
+ lstm_f = torch.unsqueeze(lstm_f, 0)
105
+ output = pth_model_dynamic(lstm_f).detach().numpy()
106
+ last_output = output
107
+ else:
108
+ if last_output is not None:
109
+ output = last_output
110
+ elif last_output is None:
111
+ output = np.zeros((7))
112
+
113
+ probs.append(output[0])
114
+ frames.append(count_frame)
115
+ else:
116
+ empty = np.empty((7))
117
+ empty[:] = np.nan
118
+ probs.append(empty)
119
+ frames.append(count_frame)
120
+
121
+ cur_face = cv2.cvtColor(cur_face, cv2.COLOR_RGB2BGR)
122
+ cur_face = cv2.resize(cur_face, (224,224), interpolation = cv2.INTER_AREA)
123
+
124
+ cur_face = display_info(cur_face, 'Frame: {}'.format(count_frame), box_scale=.3)
125
+ count_frame += 1
126
+ vid_writer.write(cur_face)
127
+
128
+ vid_writer.release()
129
+
130
+ stat = statistics_plot(frames, probs)
131
+
132
+ if not stat:
133
+ return None, None, None
134
+
135
+ return video, path_save_video, stat
app/config.py CHANGED
@@ -37,3 +37,13 @@ DICT_EMO = {
37
  5: "Disgust",
38
  6: "Anger",
39
  }
 
 
 
 
 
 
 
 
 
 
 
37
  5: "Disgust",
38
  6: "Anger",
39
  }
40
+
41
+ COLORS = {
42
+ 0: 'blue',
43
+ 1: 'orange',
44
+ 2: 'green',
45
+ 3: 'red',
46
+ 4: 'purple',
47
+ 5: 'brown',
48
+ 6: 'pink'
49
+ }
app/description.py CHANGED
@@ -8,10 +8,20 @@ License: MIT License
8
  # Importing necessary components for the Gradio app
9
  from app.config import config_data
10
 
11
- DESCRIPTION = f"""\
12
  # Static Facial Expression Recognition
13
  <div class="app-flex-container">
14
  <img src="https://img.shields.io/badge/version-v{config_data.APP_VERSION}-rc0" alt="Version">
15
  <a href="https://visitorbadge.io/status?path=https%3A%2F%2Fhuggingface.co%2Fspaces%2FElenaRyumina%2FFacial_Expression_Recognition"><img src="https://api.visitorbadge.io/api/combined?path=https%3A%2F%2Fhuggingface.co%2Fspaces%2FElenaRyumina%2FFacial_Expression_Recognition&countColor=%23263759&style=flat" /></a>
16
- </div>
 
 
 
 
 
 
 
 
 
 
17
  """
 
8
  # Importing necessary components for the Gradio app
9
  from app.config import config_data
10
 
11
+ DESCRIPTION_STATIC = f"""\
12
  # Static Facial Expression Recognition
13
  <div class="app-flex-container">
14
  <img src="https://img.shields.io/badge/version-v{config_data.APP_VERSION}-rc0" alt="Version">
15
  <a href="https://visitorbadge.io/status?path=https%3A%2F%2Fhuggingface.co%2Fspaces%2FElenaRyumina%2FFacial_Expression_Recognition"><img src="https://api.visitorbadge.io/api/combined?path=https%3A%2F%2Fhuggingface.co%2Fspaces%2FElenaRyumina%2FFacial_Expression_Recognition&countColor=%23263759&style=flat" /></a>
16
+ <a href="https://paperswithcode.com/paper/in-search-of-a-robust-facial-expressions"><img src="https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/in-search-of-a-robust-facial-expressions/facial-expression-recognition-on-affectnet" /></a>
17
+ </div>
18
+ """
19
+
20
+ DESCRIPTION_DYNAMIC = f"""\
21
+ # Dynamic Facial Expression Recognition
22
+ <div class="app-flex-container">
23
+ <img src="https://img.shields.io/badge/version-v{config_data.APP_VERSION}-rc0" alt="Version">
24
+ <a href="https://visitorbadge.io/status?path=https%3A%2F%2Fhuggingface.co%2Fspaces%2FElenaRyumina%2FFacial_Expression_Recognition"><img src="https://api.visitorbadge.io/api/combined?path=https%3A%2F%2Fhuggingface.co%2Fspaces%2FElenaRyumina%2FFacial_Expression_Recognition&countColor=%23263759&style=flat" /></a>
25
+ <a href="https://paperswithcode.com/paper/in-search-of-a-robust-facial-expressions"><img src="https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/in-search-of-a-robust-facial-expressions/facial-expression-recognition-on-affectnet" /></a>
26
+ </div>
27
  """
app/face_utils.py CHANGED
@@ -7,6 +7,7 @@ License: MIT License
7
 
8
  import numpy as np
9
  import math
 
10
 
11
 
12
  def norm_coordinates(normalized_x, normalized_y, image_width, image_height):
@@ -31,3 +32,36 @@ def get_box(fl, w, h):
31
  (endX, endY) = (min(w - 1, endX), min(h - 1, endY))
32
 
33
  return startX, startY, endX, endY
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
  import numpy as np
9
  import math
10
+ import cv2
11
 
12
 
13
  def norm_coordinates(normalized_x, normalized_y, image_width, image_height):
 
32
  (endX, endY) = (min(w - 1, endX), min(h - 1, endY))
33
 
34
  return startX, startY, endX, endY
35
+
36
+ def display_info(img, text, margin=1.0, box_scale=1.0):
37
+ img_h, img_w, _ = img.shape
38
+ line_width = int(min(img_h, img_w) * 0.001)
39
+ thickness = max(int(line_width / 3), 1)
40
+
41
+ font_face = cv2.FONT_HERSHEY_SIMPLEX
42
+ font_color = (0, 0, 0)
43
+ font_scale = thickness / 1.5
44
+
45
+ t_w, t_h = cv2.getTextSize(text, font_face, font_scale, None)[0]
46
+
47
+ margin_n = int(t_h * margin)
48
+ sub_img = img[0 + margin_n: 0 + margin_n + t_h + int(2 * t_h * box_scale),
49
+ img_w - t_w - margin_n - int(2 * t_h * box_scale): img_w - margin_n]
50
+
51
+ white_rect = np.ones(sub_img.shape, dtype=np.uint8) * 255
52
+
53
+ img[0 + margin_n: 0 + margin_n + t_h + int(2 * t_h * box_scale),
54
+ img_w - t_w - margin_n - int(2 * t_h * box_scale):img_w - margin_n] = cv2.addWeighted(sub_img, 0.5, white_rect, .5, 1.0)
55
+
56
+ cv2.putText(img=img,
57
+ text=text,
58
+ org=(img_w - t_w - margin_n - int(2 * t_h * box_scale) // 2,
59
+ 0 + margin_n + t_h + int(2 * t_h * box_scale) // 2),
60
+ fontFace=font_face,
61
+ fontScale=font_scale,
62
+ color=font_color,
63
+ thickness=thickness,
64
+ lineType=cv2.LINE_AA,
65
+ bottomLeftOrigin=False)
66
+
67
+ return img
app/model.py CHANGED
@@ -27,7 +27,10 @@ def load_model(model_url, model_path):
27
  return None
28
 
29
 
30
- pth_model = load_model(config_data.model_url, config_data.model_path)
 
 
 
31
 
32
 
33
  def pth_processing(fp):
 
27
  return None
28
 
29
 
30
+ pth_model_static = load_model(config_data.model_static_url, config_data.model_static_path)
31
+
32
+ pth_model_dynamic = load_model(config_data.model_dynamic_url, config_data.model_dynamic_path)
33
+
34
 
35
 
36
  def pth_processing(fp):
app/plot.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ File: config.py
3
+ Author: Elena Ryumina and Dmitry Ryumin
4
+ Description: Plotting statistical information.
5
+ License: MIT License
6
+ """
7
+ import matplotlib.pyplot as plt
8
+ import numpy as np
9
+
10
+ # Importing necessary components for the Gradio app
11
+ from app.config import DICT_EMO, COLORS
12
+
13
+
14
+ def statistics_plot(frames, probs):
15
+ fig, ax = plt.subplots(figsize=(10, 4))
16
+ fig.subplots_adjust(left=0.07, bottom=0.14, right=0.98, top=0.8, wspace=0, hspace=0)
17
+ # Установка параметров left, bottom, right, top, чтобы выделить место для легенды и названий осей
18
+ probs = np.array(probs)
19
+ for i in range(7):
20
+ try:
21
+ ax.plot(frames, probs[:, i], label=DICT_EMO[i], c=COLORS[i])
22
+ except Exception:
23
+ return None
24
+
25
+ ax.legend(loc='upper center', bbox_to_anchor=(0.47, 1.2), ncol=7, fontsize=12)
26
+ ax.set_xlabel('Frames', fontsize=12) # Добавляем подпись к оси X
27
+ ax.set_ylabel('Probability', fontsize=12) # Добавляем подпись к оси Y
28
+ ax.grid(True)
29
+ return plt
config.toml CHANGED
@@ -1,5 +1,9 @@
1
- APP_VERSION = "0.1.0"
2
 
3
- [model]
4
  url = "https://huggingface.co/ElenaRyumina/face_emotion_recognition/resolve/main/FER_static_ResNet50_AffectNet.pth"
5
  path = "FER_static_ResNet50_AffectNet.pth"
 
 
 
 
 
1
+ APP_VERSION = "0.2.0"
2
 
3
+ [model_static]
4
  url = "https://huggingface.co/ElenaRyumina/face_emotion_recognition/resolve/main/FER_static_ResNet50_AffectNet.pth"
5
  path = "FER_static_ResNet50_AffectNet.pth"
6
+
7
+ [model_dynamic]
8
+ url = "https://huggingface.co/ElenaRyumina/face_emotion_recognition/resolve/main/FER_dinamic_LSTM_IEMOCAP.pth"
9
+ path = "FER_dinamic_LSTM_IEMOCAP.pth"
result.mp4 ADDED
File without changes
videos/video1.mp4 ADDED
Binary file (680 kB). View file
 
videos/video2.mp4 ADDED
Binary file (182 kB). View file