wissemkarous commited on
Commit
8e606bb
Β·
verified Β·
1 Parent(s): 8c79f36
Files changed (2) hide show
  1. demo.py +242 -0
  2. two_stream_infer.py +38 -0
demo.py ADDED
@@ -0,0 +1,242 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import os
3
+ from dataset import MyDataset
4
+ import numpy as np
5
+ import cv2
6
+ import face_alignment
7
+ import streamlit as st
8
+
9
+
10
+ def get_position(size, padding=0.25):
11
+ x = [
12
+ 0.000213256,
13
+ 0.0752622,
14
+ 0.18113,
15
+ 0.29077,
16
+ 0.393397,
17
+ 0.586856,
18
+ 0.689483,
19
+ 0.799124,
20
+ 0.904991,
21
+ 0.98004,
22
+ 0.490127,
23
+ 0.490127,
24
+ 0.490127,
25
+ 0.490127,
26
+ 0.36688,
27
+ 0.426036,
28
+ 0.490127,
29
+ 0.554217,
30
+ 0.613373,
31
+ 0.121737,
32
+ 0.187122,
33
+ 0.265825,
34
+ 0.334606,
35
+ 0.260918,
36
+ 0.182743,
37
+ 0.645647,
38
+ 0.714428,
39
+ 0.793132,
40
+ 0.858516,
41
+ 0.79751,
42
+ 0.719335,
43
+ 0.254149,
44
+ 0.340985,
45
+ 0.428858,
46
+ 0.490127,
47
+ 0.551395,
48
+ 0.639268,
49
+ 0.726104,
50
+ 0.642159,
51
+ 0.556721,
52
+ 0.490127,
53
+ 0.423532,
54
+ 0.338094,
55
+ 0.290379,
56
+ 0.428096,
57
+ 0.490127,
58
+ 0.552157,
59
+ 0.689874,
60
+ 0.553364,
61
+ 0.490127,
62
+ 0.42689,
63
+ ]
64
+
65
+ y = [
66
+ 0.106454,
67
+ 0.038915,
68
+ 0.0187482,
69
+ 0.0344891,
70
+ 0.0773906,
71
+ 0.0773906,
72
+ 0.0344891,
73
+ 0.0187482,
74
+ 0.038915,
75
+ 0.106454,
76
+ 0.203352,
77
+ 0.307009,
78
+ 0.409805,
79
+ 0.515625,
80
+ 0.587326,
81
+ 0.609345,
82
+ 0.628106,
83
+ 0.609345,
84
+ 0.587326,
85
+ 0.216423,
86
+ 0.178758,
87
+ 0.179852,
88
+ 0.231733,
89
+ 0.245099,
90
+ 0.244077,
91
+ 0.231733,
92
+ 0.179852,
93
+ 0.178758,
94
+ 0.216423,
95
+ 0.244077,
96
+ 0.245099,
97
+ 0.780233,
98
+ 0.745405,
99
+ 0.727388,
100
+ 0.742578,
101
+ 0.727388,
102
+ 0.745405,
103
+ 0.780233,
104
+ 0.864805,
105
+ 0.902192,
106
+ 0.909281,
107
+ 0.902192,
108
+ 0.864805,
109
+ 0.784792,
110
+ 0.778746,
111
+ 0.785343,
112
+ 0.778746,
113
+ 0.784792,
114
+ 0.824182,
115
+ 0.831803,
116
+ 0.824182,
117
+ ]
118
+
119
+ x, y = np.array(x), np.array(y)
120
+
121
+ x = (x + padding) / (2 * padding + 1)
122
+ y = (y + padding) / (2 * padding + 1)
123
+ x = x * size
124
+ y = y * size
125
+ return np.array(list(zip(x, y)))
126
+
127
+
128
+ def output_video(p, txt, output_path):
129
+ files = os.listdir(p)
130
+ files = sorted(files, key=lambda x: int(os.path.splitext(x)[0]))
131
+
132
+ font = cv2.FONT_HERSHEY_SIMPLEX
133
+
134
+ for file, line in zip(files, txt):
135
+ img = cv2.imread(os.path.join(p, file))
136
+ h, w, _ = img.shape
137
+ img = cv2.putText(
138
+ img, line, (w // 8, 11 * h // 12), font, 1.2, (0, 0, 0), 3, cv2.LINE_AA
139
+ )
140
+ img = cv2.putText(
141
+ img,
142
+ line,
143
+ (w // 8, 11 * h // 12),
144
+ font,
145
+ 1.2,
146
+ (255, 255, 255),
147
+ 0,
148
+ cv2.LINE_AA,
149
+ )
150
+ h = h // 2
151
+ w = w // 2
152
+ img = cv2.resize(img, (w, h))
153
+ cv2.imwrite(os.path.join(p, file), img)
154
+
155
+ # create the output_videos directory if it doesn't exist
156
+ if not os.path.exists(output_path):
157
+ os.makedirs(output_path)
158
+
159
+ output = os.path.join(output_path, "output.mp4")
160
+ cmd = "ffmpeg -hide_banner -loglevel error -y -i {}/%04d.jpg -r 25 {}".format(
161
+ p, output
162
+ )
163
+ os.system(cmd)
164
+
165
+
166
+ def transformation_from_points(points1, points2):
167
+ points1 = points1.astype(np.float64)
168
+ points2 = points2.astype(np.float64)
169
+
170
+ c1 = np.mean(points1, axis=0)
171
+ c2 = np.mean(points2, axis=0)
172
+ points1 -= c1
173
+ points2 -= c2
174
+ s1 = np.std(points1)
175
+ s2 = np.std(points2)
176
+ points1 /= s1
177
+ points2 /= s2
178
+
179
+ U, S, Vt = np.linalg.svd(points1.T * points2)
180
+ R = (U * Vt).T
181
+ return np.vstack(
182
+ [
183
+ np.hstack(((s2 / s1) * R, c2.T - (s2 / s1) * R * c1.T)),
184
+ np.matrix([0.0, 0.0, 1.0]),
185
+ ]
186
+ )
187
+
188
+
189
+ @st.cache_data(show_spinner=False, persist=True)
190
+ def load_video(file, device: str):
191
+ video_name = file.split(".")[0]
192
+ # create the samples directory if it doesn't exist
193
+ if not os.path.exists(f"{video_name}_samples"):
194
+ os.makedirs(f"{video_name}_samples")
195
+
196
+ p = os.path.join(f"{video_name}_samples")
197
+ output = os.path.join(f"{video_name}_samples", "%04d.jpg")
198
+ cmd = "ffmpeg -hide_banner -loglevel error -i {} -qscale:v 2 -r 25 {}".format(
199
+ file, output
200
+ )
201
+ os.system(cmd)
202
+
203
+ files = os.listdir(p)
204
+ files = sorted(files, key=lambda x: int(os.path.splitext(x)[0]))
205
+
206
+ array = [cv2.imread(os.path.join(p, file)) for file in files]
207
+
208
+ array = list(filter(lambda im: not im is None, array))
209
+
210
+ fa = face_alignment.FaceAlignment(
211
+ face_alignment.LandmarksType._2D, flip_input=False, device=device
212
+ )
213
+ points = [fa.get_landmarks(I) for I in array]
214
+
215
+ front256 = get_position(256)
216
+ video = []
217
+ for point, scene in zip(points, array):
218
+ if point is not None:
219
+ shape = np.array(point[0])
220
+ shape = shape[17:]
221
+ M = transformation_from_points(np.matrix(shape), np.matrix(front256))
222
+
223
+ img = cv2.warpAffine(scene, M[:2], (256, 256))
224
+ (x, y) = front256[-20:].mean(0).astype(np.int32)
225
+ w = 160 // 2
226
+ img = img[y - w // 2 : y + w // 2, x - w : x + w, ...]
227
+ img = cv2.resize(img, (128, 64))
228
+ video.append(img)
229
+
230
+ video = np.stack(video, axis=0).astype(np.float32)
231
+ video = torch.FloatTensor(video.transpose(3, 0, 1, 2)) / 255.0
232
+
233
+ return video, p, files
234
+
235
+
236
+ def ctc_decode(y):
237
+ y = y.argmax(-1)
238
+ t = y.size(0)
239
+ result = []
240
+ for i in range(t + 1):
241
+ result.append(MyDataset.ctc_arr2txt(y[:i], start=1))
242
+ return result
two_stream_infer.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from models.two_stream_lipnet import TwoStreamLipNet
2
+ import options as opt
3
+ import os
4
+ import torch
5
+ import streamlit as st
6
+
7
+ os.environ["CUDA_VISIBLE_DEVICES"] = opt.gpu
8
+
9
+
10
+ @st.cache_resource
11
+ def load_model():
12
+ model = TwoStreamLipNet()
13
+ model = model.to(opt.device)
14
+
15
+ # load the pretrained weights
16
+ if hasattr(opt, "two_stream_weights"):
17
+ pretrained_dict = torch.load(
18
+ opt.two_stream_weights, map_location=torch.device(opt.device)
19
+ )
20
+ model_dict = model.state_dict()
21
+ pretrained_dict = {
22
+ k: v
23
+ for k, v in pretrained_dict.items()
24
+ if k in model_dict.keys() and v.size() == model_dict[k].size()
25
+ }
26
+ missed_params = [
27
+ k for k, v in model_dict.items() if not k in pretrained_dict.keys()
28
+ ]
29
+ print(
30
+ "loaded params/tot params:{}/{}".format(
31
+ len(pretrained_dict), len(model_dict)
32
+ )
33
+ )
34
+ print("miss matched params:{}".format(missed_params))
35
+ model_dict.update(pretrained_dict)
36
+ model.load_state_dict(model_dict)
37
+
38
+ return model