ElenaRyumina commited on
Commit
c83f8fa
Β·
1 Parent(s): cf60969
Files changed (9) hide show
  1. .gitignore +2 -1
  2. app.css +12 -0
  3. app.py +14 -7
  4. app/app_utils.py +54 -19
  5. app/face_utils.py +6 -5
  6. app/model.py +13 -5
  7. app/model_architectures.py +150 -0
  8. config.toml +4 -4
  9. result.mp4 +0 -0
.gitignore CHANGED
@@ -168,4 +168,5 @@ dmypy.json
168
  .pyre/
169
 
170
  # Custom
171
- *.pth
 
 
168
  .pyre/
169
 
170
  # Custom
171
+ *.pth
172
+ *.pt
app.css CHANGED
@@ -20,6 +20,14 @@ div.dl2 img {
20
  max-height: 200px;
21
  }
22
 
 
 
 
 
 
 
 
 
23
  div.video1 div.video-container {
24
  height: 500px;
25
  }
@@ -32,6 +40,10 @@ div.video3 {
32
  height: 200px;
33
  }
34
 
 
 
 
 
35
  div.stat {
36
  height: 286px;
37
  }
 
20
  max-height: 200px;
21
  }
22
 
23
+ div.dl5 {
24
+ max-height: 200px;
25
+ }
26
+
27
+ div.dl5 img {
28
+ max-height: 200px;
29
+ }
30
+
31
  div.video1 div.video-container {
32
  height: 500px;
33
  }
 
40
  height: 200px;
41
  }
42
 
43
+ div.video4 {
44
+ height: 200px;
45
+ }
46
+
47
  div.stat {
48
  height: 286px;
49
  }
app.py CHANGED
@@ -18,6 +18,7 @@ from app.app_utils import preprocess_image_and_predict, preprocess_video_and_pre
18
  def clear_static_info():
19
  return (
20
  gr.Image(value=None, type="pil"),
 
21
  gr.Image(value=None, scale=1, elem_classes="dl2"),
22
  gr.Label(value=None, num_top_classes=3, scale=1, elem_classes="dl3"),
23
  )
@@ -27,6 +28,7 @@ def clear_dynamic_info():
27
  gr.Video(value=None),
28
  gr.Video(value=None),
29
  gr.Video(value=None),
 
30
  gr.Plot(value=None),
31
  )
32
 
@@ -45,8 +47,9 @@ with gr.Blocks(css="app.css") as demo:
45
  )
46
  with gr.Column(scale=2, elem_classes="dl4"):
47
  with gr.Row():
48
- output_video = gr.Video(label="Original video", scale=2, elem_classes="video2")
49
  output_face = gr.Video(label="Pre-processed video", scale=1, elem_classes="video3")
 
50
  output_statistics = gr.Plot(label="Statistics of emotions", elem_classes="stat")
51
  gr.Examples(
52
  ["videos/video1.mp4",
@@ -58,7 +61,7 @@ with gr.Blocks(css="app.css") as demo:
58
  gr.Markdown(value=DESCRIPTION_STATIC)
59
  with gr.Row():
60
  with gr.Column(scale=2, elem_classes="dl1"):
61
- input_image = gr.Image(type="pil")
62
  with gr.Row():
63
  clear_btn = gr.Button(
64
  value="Clear", interactive=True, scale=1, elem_classes="clear"
@@ -67,7 +70,9 @@ with gr.Blocks(css="app.css") as demo:
67
  value="Submit", interactive=True, scale=1, elem_classes="submit"
68
  )
69
  with gr.Column(scale=1, elem_classes="dl4"):
70
- output_image = gr.Image(scale=1, elem_classes="dl2")
 
 
71
  output_label = gr.Label(num_top_classes=3, scale=1, elem_classes="dl3")
72
  gr.Examples(
73
  [
@@ -87,13 +92,13 @@ with gr.Blocks(css="app.css") as demo:
87
  submit.click(
88
  fn=preprocess_image_and_predict,
89
  inputs=[input_image],
90
- outputs=[output_image, output_label],
91
  queue=True,
92
  )
93
  clear_btn.click(
94
  fn=clear_static_info,
95
  inputs=[],
96
- outputs=[input_image, output_image, output_label],
97
  queue=True,
98
  )
99
 
@@ -102,7 +107,8 @@ with gr.Blocks(css="app.css") as demo:
102
  inputs=input_video,
103
  outputs=[
104
  output_video,
105
- output_face,
 
106
  output_statistics
107
  ],
108
  queue=True,
@@ -113,7 +119,8 @@ with gr.Blocks(css="app.css") as demo:
113
  outputs=[
114
  input_video,
115
  output_video,
116
- output_face,
 
117
  output_statistics
118
  ],
119
  queue=True,
 
18
  def clear_static_info():
19
  return (
20
  gr.Image(value=None, type="pil"),
21
+ gr.Image(value=None, scale=1, elem_classes="dl5"),
22
  gr.Image(value=None, scale=1, elem_classes="dl2"),
23
  gr.Label(value=None, num_top_classes=3, scale=1, elem_classes="dl3"),
24
  )
 
28
  gr.Video(value=None),
29
  gr.Video(value=None),
30
  gr.Video(value=None),
31
+ gr.Video(value=None),
32
  gr.Plot(value=None),
33
  )
34
 
 
47
  )
48
  with gr.Column(scale=2, elem_classes="dl4"):
49
  with gr.Row():
50
+ output_video = gr.Video(label="Original video", scale=1, elem_classes="video2")
51
  output_face = gr.Video(label="Pre-processed video", scale=1, elem_classes="video3")
52
+ output_heatmaps = gr.Video(label="Heatmaps", scale=1, elem_classes="video4")
53
  output_statistics = gr.Plot(label="Statistics of emotions", elem_classes="stat")
54
  gr.Examples(
55
  ["videos/video1.mp4",
 
61
  gr.Markdown(value=DESCRIPTION_STATIC)
62
  with gr.Row():
63
  with gr.Column(scale=2, elem_classes="dl1"):
64
+ input_image = gr.Image(label="Original image", type="pil")
65
  with gr.Row():
66
  clear_btn = gr.Button(
67
  value="Clear", interactive=True, scale=1, elem_classes="clear"
 
70
  value="Submit", interactive=True, scale=1, elem_classes="submit"
71
  )
72
  with gr.Column(scale=1, elem_classes="dl4"):
73
+ with gr.Row():
74
+ output_image = gr.Image(label="Face", scale=1, elem_classes="dl5")
75
+ output_heatmap = gr.Image(label="Heatmap", scale=1, elem_classes="dl2")
76
  output_label = gr.Label(num_top_classes=3, scale=1, elem_classes="dl3")
77
  gr.Examples(
78
  [
 
92
  submit.click(
93
  fn=preprocess_image_and_predict,
94
  inputs=[input_image],
95
+ outputs=[output_image, output_heatmap, output_label],
96
  queue=True,
97
  )
98
  clear_btn.click(
99
  fn=clear_static_info,
100
  inputs=[],
101
+ outputs=[input_image, output_image, output_heatmap, output_label],
102
  queue=True,
103
  )
104
 
 
107
  inputs=input_video,
108
  outputs=[
109
  output_video,
110
+ output_face,
111
+ output_heatmaps,
112
  output_statistics
113
  ],
114
  queue=True,
 
119
  outputs=[
120
  input_video,
121
  output_video,
122
+ output_face,
123
+ output_heatmaps,
124
  output_statistics
125
  ],
126
  queue=True,
app/app_utils.py CHANGED
@@ -10,9 +10,10 @@ import numpy as np
10
  import mediapipe as mp
11
  from PIL import Image
12
  import cv2
 
13
 
14
  # Importing necessary components for the Gradio app
15
- from app.model import pth_model_static, pth_model_dynamic, pth_processing
16
  from app.face_utils import get_box, display_info
17
  from app.config import DICT_EMO, config_data
18
  from app.plot import statistics_plot
@@ -49,8 +50,13 @@ def preprocess_image_and_predict(inp):
49
  .numpy()[0]
50
  )
51
  confidences = {DICT_EMO[i]: float(prediction[i]) for i in range(7)}
 
 
 
 
 
52
 
53
- return cur_face, confidences
54
 
55
 
56
  def preprocess_video_and_predict(video):
@@ -60,14 +66,20 @@ def preprocess_video_and_predict(video):
60
  h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
61
  fps = np.round(cap.get(cv2.CAP_PROP_FPS))
62
 
63
- path_save_video = 'result.mp4'
64
- vid_writer = cv2.VideoWriter(path_save_video, cv2.VideoWriter_fourcc(*'mp4v'), fps, (224, 224))
 
 
 
65
 
66
  lstm_features = []
67
  count_frame = 1
 
68
  probs = []
69
  frames = []
70
- last_output = None
 
 
71
 
72
  with mp_face_mesh.FaceMesh(
73
  max_num_faces=1,
@@ -90,9 +102,16 @@ def preprocess_video_and_predict(video):
90
  startX, startY, endX, endY = get_box(fl, w, h)
91
  cur_face = frame_copy[startY:endY, startX: endX]
92
 
93
- if (count_frame-1)%config_data.FRAME_DOWNSAMPLING == 0:
94
  cur_face_copy = pth_processing(Image.fromarray(cur_face))
95
  features = torch.nn.functional.relu(pth_model_static.extract_features(cur_face_copy)).detach().numpy()
 
 
 
 
 
 
 
96
 
97
  if len(lstm_features) == 0:
98
  lstm_features = [features]*10
@@ -103,32 +122,48 @@ def preprocess_video_and_predict(video):
103
  lstm_f = torch.unsqueeze(lstm_f, 0)
104
  output = pth_model_dynamic(lstm_f).detach().numpy()
105
  last_output = output
 
 
 
 
106
  else:
107
  if last_output is not None:
108
  output = last_output
 
 
109
  elif last_output is None:
110
- output = np.zeros((7))
 
111
 
112
  probs.append(output[0])
113
  frames.append(count_frame)
114
  else:
115
- empty = np.empty((7))
116
- empty[:] = np.nan
117
- probs.append(empty)
118
- frames.append(count_frame)
119
-
120
- cur_face = cv2.cvtColor(cur_face, cv2.COLOR_RGB2BGR)
121
- cur_face = cv2.resize(cur_face, (224,224), interpolation = cv2.INTER_AREA)
 
 
 
 
 
 
 
 
122
 
123
- cur_face = display_info(cur_face, 'Frame: {}'.format(count_frame), box_scale=.3)
124
  count_frame += 1
125
- vid_writer.write(cur_face)
 
126
 
127
- vid_writer.release()
 
128
 
129
  stat = statistics_plot(frames, probs)
130
 
131
  if not stat:
132
- return None, None, None
133
 
134
- return video, path_save_video, stat
 
10
  import mediapipe as mp
11
  from PIL import Image
12
  import cv2
13
+ from pytorch_grad_cam.utils.image import show_cam_on_image
14
 
15
  # Importing necessary components for the Gradio app
16
+ from app.model import pth_model_static, pth_model_dynamic, cam, pth_processing
17
  from app.face_utils import get_box, display_info
18
  from app.config import DICT_EMO, config_data
19
  from app.plot import statistics_plot
 
50
  .numpy()[0]
51
  )
52
  confidences = {DICT_EMO[i]: float(prediction[i]) for i in range(7)}
53
+ grayscale_cam = cam(input_tensor=cur_face_n)
54
+ grayscale_cam = grayscale_cam[0, :]
55
+ cur_face_hm = cv2.resize(cur_face,(224,224))
56
+ cur_face_hm = np.float32(cur_face_hm) / 255
57
+ heatmap = show_cam_on_image(cur_face_hm, grayscale_cam, use_rgb=True)
58
 
59
+ return cur_face, heatmap, confidences
60
 
61
 
62
  def preprocess_video_and_predict(video):
 
66
  h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
67
  fps = np.round(cap.get(cv2.CAP_PROP_FPS))
68
 
69
+ path_save_video_face = 'result_face.mp4'
70
+ vid_writer_face = cv2.VideoWriter(path_save_video_face, cv2.VideoWriter_fourcc(*'mp4v'), fps, (224, 224))
71
+
72
+ path_save_video_hm = 'result_hm.mp4'
73
+ vid_writer_hm = cv2.VideoWriter(path_save_video_hm, cv2.VideoWriter_fourcc(*'mp4v'), fps, (224, 224))
74
 
75
  lstm_features = []
76
  count_frame = 1
77
+ count_face = 0
78
  probs = []
79
  frames = []
80
+ last_output = None
81
+ last_heatmap = None
82
+ cur_face = None
83
 
84
  with mp_face_mesh.FaceMesh(
85
  max_num_faces=1,
 
102
  startX, startY, endX, endY = get_box(fl, w, h)
103
  cur_face = frame_copy[startY:endY, startX: endX]
104
 
105
+ if count_face%config_data.FRAME_DOWNSAMPLING == 0:
106
  cur_face_copy = pth_processing(Image.fromarray(cur_face))
107
  features = torch.nn.functional.relu(pth_model_static.extract_features(cur_face_copy)).detach().numpy()
108
+
109
+ grayscale_cam = cam(input_tensor=cur_face_copy)
110
+ grayscale_cam = grayscale_cam[0, :]
111
+ cur_face_hm = cv2.resize(cur_face,(224,224), interpolation = cv2.INTER_AREA)
112
+ cur_face_hm = np.float32(cur_face_hm) / 255
113
+ heatmap = show_cam_on_image(cur_face_hm, grayscale_cam, use_rgb=False)
114
+ last_heatmap = heatmap
115
 
116
  if len(lstm_features) == 0:
117
  lstm_features = [features]*10
 
122
  lstm_f = torch.unsqueeze(lstm_f, 0)
123
  output = pth_model_dynamic(lstm_f).detach().numpy()
124
  last_output = output
125
+
126
+ if count_face == 0:
127
+ count_face += 1
128
+
129
  else:
130
  if last_output is not None:
131
  output = last_output
132
+ heatmap = last_heatmap
133
+
134
  elif last_output is None:
135
+ output = np.empty((1, 7))
136
+ output[:] = np.nan
137
 
138
  probs.append(output[0])
139
  frames.append(count_frame)
140
  else:
141
+ if last_output is not None:
142
+ lstm_features = []
143
+ empty = np.empty((7))
144
+ empty[:] = np.nan
145
+ probs.append(empty)
146
+ frames.append(count_frame)
147
+
148
+ if cur_face is not None:
149
+ heatmap_f = display_info(heatmap, 'Frame: {}'.format(count_frame), box_scale=.3)
150
+
151
+ cur_face = cv2.cvtColor(cur_face, cv2.COLOR_RGB2BGR)
152
+ cur_face = cv2.resize(cur_face, (224,224), interpolation = cv2.INTER_AREA)
153
+ cur_face = display_info(cur_face, 'Frame: {}'.format(count_frame), box_scale=.3)
154
+ vid_writer_face.write(cur_face)
155
+ vid_writer_hm.write(heatmap_f)
156
 
 
157
  count_frame += 1
158
+ if count_face != 0:
159
+ count_face += 1
160
 
161
+ vid_writer_face.release()
162
+ vid_writer_hm.release()
163
 
164
  stat = statistics_plot(frames, probs)
165
 
166
  if not stat:
167
+ return None, None, None, None
168
 
169
+ return video, path_save_video_face, path_save_video_hm, stat
app/face_utils.py CHANGED
@@ -34,7 +34,8 @@ def get_box(fl, w, h):
34
  return startX, startY, endX, endY
35
 
36
  def display_info(img, text, margin=1.0, box_scale=1.0):
37
- img_h, img_w, _ = img.shape
 
38
  line_width = int(min(img_h, img_w) * 0.001)
39
  thickness = max(int(line_width / 3), 1)
40
 
@@ -45,15 +46,15 @@ def display_info(img, text, margin=1.0, box_scale=1.0):
45
  t_w, t_h = cv2.getTextSize(text, font_face, font_scale, None)[0]
46
 
47
  margin_n = int(t_h * margin)
48
- sub_img = img[0 + margin_n: 0 + margin_n + t_h + int(2 * t_h * box_scale),
49
  img_w - t_w - margin_n - int(2 * t_h * box_scale): img_w - margin_n]
50
 
51
  white_rect = np.ones(sub_img.shape, dtype=np.uint8) * 255
52
 
53
- img[0 + margin_n: 0 + margin_n + t_h + int(2 * t_h * box_scale),
54
  img_w - t_w - margin_n - int(2 * t_h * box_scale):img_w - margin_n] = cv2.addWeighted(sub_img, 0.5, white_rect, .5, 1.0)
55
 
56
- cv2.putText(img=img,
57
  text=text,
58
  org=(img_w - t_w - margin_n - int(2 * t_h * box_scale) // 2,
59
  0 + margin_n + t_h + int(2 * t_h * box_scale) // 2),
@@ -64,4 +65,4 @@ def display_info(img, text, margin=1.0, box_scale=1.0):
64
  lineType=cv2.LINE_AA,
65
  bottomLeftOrigin=False)
66
 
67
- return img
 
34
  return startX, startY, endX, endY
35
 
36
  def display_info(img, text, margin=1.0, box_scale=1.0):
37
+ img_copy = img.copy()
38
+ img_h, img_w, _ = img_copy.shape
39
  line_width = int(min(img_h, img_w) * 0.001)
40
  thickness = max(int(line_width / 3), 1)
41
 
 
46
  t_w, t_h = cv2.getTextSize(text, font_face, font_scale, None)[0]
47
 
48
  margin_n = int(t_h * margin)
49
+ sub_img = img_copy[0 + margin_n: 0 + margin_n + t_h + int(2 * t_h * box_scale),
50
  img_w - t_w - margin_n - int(2 * t_h * box_scale): img_w - margin_n]
51
 
52
  white_rect = np.ones(sub_img.shape, dtype=np.uint8) * 255
53
 
54
+ img_copy[0 + margin_n: 0 + margin_n + t_h + int(2 * t_h * box_scale),
55
  img_w - t_w - margin_n - int(2 * t_h * box_scale):img_w - margin_n] = cv2.addWeighted(sub_img, 0.5, white_rect, .5, 1.0)
56
 
57
+ cv2.putText(img=img_copy,
58
  text=text,
59
  org=(img_w - t_w - margin_n - int(2 * t_h * box_scale) // 2,
60
  0 + margin_n + t_h + int(2 * t_h * box_scale) // 2),
 
65
  lineType=cv2.LINE_AA,
66
  bottomLeftOrigin=False)
67
 
68
+ return img_copy
app/model.py CHANGED
@@ -10,9 +10,11 @@ import torch
10
  import requests
11
  from PIL import Image
12
  from torchvision import transforms
 
13
 
14
  # Importing necessary components for the Gradio app
15
  from app.config import config_data
 
16
 
17
 
18
  def load_model(model_url, model_path):
@@ -21,17 +23,23 @@ def load_model(model_url, model_path):
21
  with open(model_path, "wb") as file:
22
  for chunk in response.iter_content(chunk_size=8192):
23
  file.write(chunk)
24
- return torch.jit.load(model_path).eval()
25
  except Exception as e:
26
  print(f"Error loading model: {e}")
27
  return None
28
 
 
 
 
 
29
 
30
- pth_model_static = load_model(config_data.model_static_url, config_data.model_static_path)
31
-
32
- pth_model_dynamic = load_model(config_data.model_dynamic_url, config_data.model_dynamic_path)
33
-
34
 
 
 
35
 
36
  def pth_processing(fp):
37
  class PreprocessInput(torch.nn.Module):
 
10
  import requests
11
  from PIL import Image
12
  from torchvision import transforms
13
+ from pytorch_grad_cam import GradCAM
14
 
15
  # Importing necessary components for the Gradio app
16
  from app.config import config_data
17
+ from app.model_architectures import ResNet50, LSTMPyTorch
18
 
19
 
20
  def load_model(model_url, model_path):
 
23
  with open(model_path, "wb") as file:
24
  for chunk in response.iter_content(chunk_size=8192):
25
  file.write(chunk)
26
+ return model_path
27
  except Exception as e:
28
  print(f"Error loading model: {e}")
29
  return None
30
 
31
+ path_static = load_model(config_data.model_static_url, config_data.model_static_path)
32
+ pth_model_static = ResNet50(7, channels=3)
33
+ pth_model_static.load_state_dict(torch.load(path_static))
34
+ pth_model_static.eval()
35
 
36
+ path_dynamic = load_model(config_data.model_dynamic_url, config_data.model_dynamic_path)
37
+ pth_model_dynamic = LSTMPyTorch()
38
+ pth_model_dynamic.load_state_dict(torch.load(path_dynamic))
39
+ pth_model_dynamic.eval()
40
 
41
+ target_layers = [pth_model_static.layer4]
42
+ cam = GradCAM(model=pth_model_static, target_layers=target_layers)
43
 
44
  def pth_processing(fp):
45
  class PreprocessInput(torch.nn.Module):
app/model_architectures.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ File: model.py
3
+ Author: Elena Ryumina and Dmitry Ryumin
4
+ Description: This module provides model architectures.
5
+ License: MIT License
6
+ """
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+ import math
12
+
13
+ class Bottleneck(nn.Module):
14
+ expansion = 4
15
+ def __init__(self, in_channels, out_channels, i_downsample=None, stride=1):
16
+ super(Bottleneck, self).__init__()
17
+
18
+ self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, padding=0, bias=False)
19
+ self.batch_norm1 = nn.BatchNorm2d(out_channels, eps=0.001, momentum=0.99)
20
+
21
+ self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding='same', bias=False)
22
+ self.batch_norm2 = nn.BatchNorm2d(out_channels, eps=0.001, momentum=0.99)
23
+
24
+ self.conv3 = nn.Conv2d(out_channels, out_channels*self.expansion, kernel_size=1, stride=1, padding=0, bias=False)
25
+ self.batch_norm3 = nn.BatchNorm2d(out_channels*self.expansion, eps=0.001, momentum=0.99)
26
+
27
+ self.i_downsample = i_downsample
28
+ self.stride = stride
29
+ self.relu = nn.ReLU()
30
+
31
+ def forward(self, x):
32
+ identity = x.clone()
33
+ x = self.relu(self.batch_norm1(self.conv1(x)))
34
+
35
+ x = self.relu(self.batch_norm2(self.conv2(x)))
36
+
37
+ x = self.conv3(x)
38
+ x = self.batch_norm3(x)
39
+
40
+ #downsample if needed
41
+ if self.i_downsample is not None:
42
+ identity = self.i_downsample(identity)
43
+ #add identity
44
+ x+=identity
45
+ x=self.relu(x)
46
+
47
+ return x
48
+
49
+ class Conv2dSame(torch.nn.Conv2d):
50
+
51
+ def calc_same_pad(self, i: int, k: int, s: int, d: int) -> int:
52
+ return max((math.ceil(i / s) - 1) * s + (k - 1) * d + 1 - i, 0)
53
+
54
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
55
+ ih, iw = x.size()[-2:]
56
+
57
+ pad_h = self.calc_same_pad(i=ih, k=self.kernel_size[0], s=self.stride[0], d=self.dilation[0])
58
+ pad_w = self.calc_same_pad(i=iw, k=self.kernel_size[1], s=self.stride[1], d=self.dilation[1])
59
+
60
+ if pad_h > 0 or pad_w > 0:
61
+ x = F.pad(
62
+ x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2]
63
+ )
64
+ return F.conv2d(
65
+ x,
66
+ self.weight,
67
+ self.bias,
68
+ self.stride,
69
+ self.padding,
70
+ self.dilation,
71
+ self.groups,
72
+ )
73
+
74
+ class ResNet(nn.Module):
75
+ def __init__(self, ResBlock, layer_list, num_classes, num_channels=3):
76
+ super(ResNet, self).__init__()
77
+ self.in_channels = 64
78
+
79
+ self.conv_layer_s2_same = Conv2dSame(num_channels, 64, 7, stride=2, groups=1, bias=False)
80
+ self.batch_norm1 = nn.BatchNorm2d(64, eps=0.001, momentum=0.99)
81
+ self.relu = nn.ReLU()
82
+ self.max_pool = nn.MaxPool2d(kernel_size = 3, stride=2)
83
+
84
+ self.layer1 = self._make_layer(ResBlock, layer_list[0], planes=64, stride=1)
85
+ self.layer2 = self._make_layer(ResBlock, layer_list[1], planes=128, stride=2)
86
+ self.layer3 = self._make_layer(ResBlock, layer_list[2], planes=256, stride=2)
87
+ self.layer4 = self._make_layer(ResBlock, layer_list[3], planes=512, stride=2)
88
+
89
+ self.avgpool = nn.AdaptiveAvgPool2d((1,1))
90
+ self.fc1 = nn.Linear(512*ResBlock.expansion, 512)
91
+ self.relu1 = nn.ReLU()
92
+ self.fc2 = nn.Linear(512, num_classes)
93
+
94
+ def extract_features(self, x):
95
+ x = self.relu(self.batch_norm1(self.conv_layer_s2_same(x)))
96
+ x = self.max_pool(x)
97
+ # print(x.shape)
98
+ x = self.layer1(x)
99
+ x = self.layer2(x)
100
+ x = self.layer3(x)
101
+ x = self.layer4(x)
102
+
103
+ x = self.avgpool(x)
104
+ x = x.reshape(x.shape[0], -1)
105
+ x = self.fc1(x)
106
+ return x
107
+
108
+ def forward(self, x):
109
+ x = self.extract_features(x)
110
+ x = self.relu1(x)
111
+ x = self.fc2(x)
112
+ return x
113
+
114
+ def _make_layer(self, ResBlock, blocks, planes, stride=1):
115
+ ii_downsample = None
116
+ layers = []
117
+
118
+ if stride != 1 or self.in_channels != planes*ResBlock.expansion:
119
+ ii_downsample = nn.Sequential(
120
+ nn.Conv2d(self.in_channels, planes*ResBlock.expansion, kernel_size=1, stride=stride, bias=False, padding=0),
121
+ nn.BatchNorm2d(planes*ResBlock.expansion, eps=0.001, momentum=0.99)
122
+ )
123
+
124
+ layers.append(ResBlock(self.in_channels, planes, i_downsample=ii_downsample, stride=stride))
125
+ self.in_channels = planes*ResBlock.expansion
126
+
127
+ for i in range(blocks-1):
128
+ layers.append(ResBlock(self.in_channels, planes))
129
+
130
+ return nn.Sequential(*layers)
131
+
132
+ def ResNet50(num_classes, channels=3):
133
+ return ResNet(Bottleneck, [3,4,6,3], num_classes, channels)
134
+
135
+
136
+ class LSTMPyTorch(nn.Module):
137
+ def __init__(self):
138
+ super(LSTMPyTorch, self).__init__()
139
+
140
+ self.lstm1 = nn.LSTM(input_size=512, hidden_size=512, batch_first=True, bidirectional=False)
141
+ self.lstm2 = nn.LSTM(input_size=512, hidden_size=256, batch_first=True, bidirectional=False)
142
+ self.fc = nn.Linear(256, 7)
143
+ self.softmax = nn.Softmax(dim=1)
144
+
145
+ def forward(self, x):
146
+ x, _ = self.lstm1(x)
147
+ x, _ = self.lstm2(x)
148
+ x = self.fc(x[:, -1, :])
149
+ x = self.softmax(x)
150
+ return x
config.toml CHANGED
@@ -2,9 +2,9 @@ APP_VERSION = "0.2.0"
2
  FRAME_DOWNSAMPLING = 5
3
 
4
  [model_static]
5
- url = "https://huggingface.co/ElenaRyumina/face_emotion_recognition/resolve/main/FER_static_ResNet50_AffectNet.pth"
6
- path = "FER_static_ResNet50_AffectNet.pth"
7
 
8
  [model_dynamic]
9
- url = "https://huggingface.co/ElenaRyumina/face_emotion_recognition/resolve/main/FER_dinamic_LSTM_IEMOCAP.pth"
10
- path = "FER_dinamic_LSTM_IEMOCAP.pth"
 
2
  FRAME_DOWNSAMPLING = 5
3
 
4
  [model_static]
5
+ url = "https://huggingface.co/ElenaRyumina/face_emotion_recognition/resolve/main/FER_static_ResNet50_AffectNet.pt"
6
+ path = "FER_static_ResNet50_AffectNet.pt"
7
 
8
  [model_dynamic]
9
+ url = "https://huggingface.co/ElenaRyumina/face_emotion_recognition/resolve/main/FER_dinamic_LSTM_IEMOCAP.pt"
10
+ path = "FER_dinamic_LSTM_IEMOCAP.pt"
result.mp4 DELETED
Binary file (108 kB)