Non-playing-Character commited on
Commit
3d83ea0
·
verified ·
1 Parent(s): 6a05650

Update inference.py

Browse files
Files changed (1) hide show
  1. inference.py +359 -359
inference.py CHANGED
@@ -1,359 +1,359 @@
1
- from os import listdir, path
2
- import numpy as np
3
- import scipy, cv2, os, sys, argparse, audio
4
- import json, subprocess, random, string
5
- from tqdm import tqdm
6
- from glob import glob
7
- import torch, face_detection
8
- from wav2lip_models import Wav2Lip
9
- import platform
10
- from face_parsing import init_parser, swap_regions
11
- from basicsr.apply_sr import init_sr_model, enhance
12
-
13
- parser = argparse.ArgumentParser(description='Inference code to lip-sync videos in the wild using Wav2Lip models')
14
-
15
- parser.add_argument('--checkpoint_path', type=str,
16
- help='Name of saved checkpoint to load weights from', required=True)
17
-
18
- parser.add_argument('--segmentation_path', type=str,
19
- help='Name of saved checkpoint of segmentation network', required=True)
20
-
21
- parser.add_argument('--sr_path', type=str,
22
- help='Name of saved checkpoint of super-resolution network', required=True)
23
-
24
- parser.add_argument('--face', type=str,
25
- help='Filepath of video/image that contains faces to use', required=True)
26
- parser.add_argument('--audio', type=str,
27
- help='Filepath of video/audio file to use as raw audio source', required=True)
28
- parser.add_argument('--outfile', type=str, help='Video path to save result. See default for an e.g.',
29
- default='results/result_voice.mp4')
30
-
31
-
32
- parser.add_argument('--static', type=bool,
33
- help='If True, then use only first video frame for inference', default=False)
34
- parser.add_argument('--fps', type=float, help='Can be specified only if input is a static image (default: 25)',
35
- default=25., required=False)
36
-
37
- parser.add_argument('--pads', nargs='+', type=int, default=[0, 10, 0, 0],
38
- help='Padding (top, bottom, left, right). Please adjust to include chin at least')
39
-
40
- parser.add_argument('--face_det_batch_size', type=int,
41
- help='Batch size for face detection', default=16)
42
- parser.add_argument('--wav2lip_batch_size', type=int, help='Batch size for Wav2Lip model(s)', default=128)
43
-
44
- parser.add_argument('--resize_factor', default=1, type=int,
45
- help='Reduce the resolution by this factor. Sometimes, best results are obtained at 480p or 720p')
46
-
47
- parser.add_argument('--crop', nargs='+', type=int, default=[0, -1, 0, -1],
48
- help='Crop video to a smaller region (top, bottom, left, right). Applied after resize_factor and rotate arg. '
49
- 'Useful if multiple face present. -1 implies the value will be auto-inferred based on height, width')
50
-
51
- parser.add_argument('--box', nargs='+', type=int, default=[-1, -1, -1, -1],
52
- help='Specify a constant bounding box for the face. Use only as a last resort if the face is not detected.'
53
- 'Also, might work only if the face is not moving around much. Syntax: (top, bottom, left, right).')
54
-
55
- parser.add_argument('--rotate', default=False, action='store_true',
56
- help='Sometimes videos taken from a phone can be flipped 90deg. If true, will flip video right by 90deg.'
57
- 'Use if you get a flipped result, despite feeding a normal looking video')
58
-
59
- parser.add_argument('--nosmooth', default=False, action='store_true',
60
- help='Prevent smoothing face detections over a short temporal window')
61
- parser.add_argument('--no_segmentation', default=False, action='store_true',
62
- help='Prevent using face segmentation')
63
- parser.add_argument('--no_sr', default=False, action='store_true',
64
- help='Prevent using super resolution')
65
-
66
- parser.add_argument('--save_frames', default=False, action='store_true',
67
- help='Save each frame as an image. Use with caution')
68
- parser.add_argument('--gt_path', type=str,
69
- help='Where to store saved ground truth frames', required=False)
70
- parser.add_argument('--pred_path', type=str,
71
- help='Where to store frames produced by algorithm', required=False)
72
- parser.add_argument('--save_as_video', action="store_true", default=False,
73
- help='Whether to save frames as video', required=False)
74
- parser.add_argument('--image_prefix', type=str, default="",
75
- help='Prefix to save frames with', required=False)
76
-
77
- args = parser.parse_args()
78
- args.img_size = 96
79
-
80
- if os.path.isfile(args.face) and args.face.split('.')[1] in ['jpg', 'png', 'jpeg']:
81
- args.static = True
82
-
83
- def get_smoothened_boxes(boxes, T):
84
- for i in range(len(boxes)):
85
- if i + T > len(boxes):
86
- window = boxes[len(boxes) - T:]
87
- else:
88
- window = boxes[i : i + T]
89
- boxes[i] = np.mean(window, axis=0)
90
- return boxes
91
-
92
- def face_detect(images):
93
- detector = face_detection.FaceAlignment(face_detection.LandmarksType._2D,
94
- flip_input=False, device=device)
95
-
96
- batch_size = args.face_det_batch_size
97
-
98
- while 1:
99
- predictions = []
100
- try:
101
- for i in range(0, len(images), batch_size):
102
- predictions.extend(detector.get_detections_for_batch(np.array(images[i:i + batch_size])))
103
- except RuntimeError:
104
- if batch_size == 1:
105
- raise RuntimeError('Image too big to run face detection on GPU. Please use the --resize_factor argument')
106
- batch_size //= 2
107
- print('Recovering from OOM error; New batch size: {}'.format(batch_size))
108
- continue
109
- break
110
-
111
- results = []
112
- pady1, pady2, padx1, padx2 = args.pads
113
- for rect, image in zip(predictions, images):
114
- if rect is None:
115
- cv2.imwrite('temp/faulty_frame.jpg', image) # check this frame where the face was not detected.
116
- raise ValueError('Face not detected! Ensure the video contains a face in all the frames.')
117
-
118
- y1 = max(0, rect[1] - pady1)
119
- y2 = min(image.shape[0], rect[3] + pady2)
120
- x1 = max(0, rect[0] - padx1)
121
- x2 = min(image.shape[1], rect[2] + padx2)
122
-
123
- results.append([x1, y1, x2, y2])
124
-
125
- boxes = np.array(results)
126
- if not args.nosmooth: boxes = get_smoothened_boxes(boxes, T=5)
127
- results = [[image[y1: y2, x1:x2], (y1, y2, x1, x2)] for image, (x1, y1, x2, y2) in zip(images, boxes)]
128
-
129
- del detector
130
- return results
131
-
132
- def datagen(mels):
133
- img_batch, mel_batch, frame_batch, coords_batch = [], [], [], []
134
-
135
- """
136
- if args.box[0] == -1:
137
- if not args.static:
138
- face_det_results = face_detect(frames) # BGR2RGB for CNN face detection
139
- else:
140
- face_det_results = face_detect([frames[0]])
141
- else:
142
- print('Using the specified bounding box instead of face detection...')
143
- y1, y2, x1, x2 = args.box
144
- face_det_results = [[f[y1: y2, x1:x2], (y1, y2, x1, x2)] for f in frames]
145
- """
146
-
147
- reader = read_frames()
148
-
149
- for i, m in enumerate(mels):
150
- try:
151
- frame_to_save = next(reader)
152
- except StopIteration:
153
- reader = read_frames()
154
- frame_to_save = next(reader)
155
-
156
- face, coords = face_detect([frame_to_save])[0]
157
-
158
- face = cv2.resize(face, (args.img_size, args.img_size))
159
-
160
- img_batch.append(face)
161
- mel_batch.append(m)
162
- frame_batch.append(frame_to_save)
163
- coords_batch.append(coords)
164
-
165
- if len(img_batch) >= args.wav2lip_batch_size:
166
- img_batch, mel_batch = np.asarray(img_batch), np.asarray(mel_batch)
167
-
168
- img_masked = img_batch.copy()
169
- img_masked[:, args.img_size//2:] = 0
170
-
171
- img_batch = np.concatenate((img_masked, img_batch), axis=3) / 255.
172
- mel_batch = np.reshape(mel_batch, [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1])
173
-
174
- yield img_batch, mel_batch, frame_batch, coords_batch
175
- img_batch, mel_batch, frame_batch, coords_batch = [], [], [], []
176
-
177
- if len(img_batch) > 0:
178
- img_batch, mel_batch = np.asarray(img_batch), np.asarray(mel_batch)
179
-
180
- img_masked = img_batch.copy()
181
- img_masked[:, args.img_size//2:] = 0
182
-
183
- img_batch = np.concatenate((img_masked, img_batch), axis=3) / 255.
184
- mel_batch = np.reshape(mel_batch, [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1])
185
-
186
- yield img_batch, mel_batch, frame_batch, coords_batch
187
-
188
- mel_step_size = 16
189
- device = 'cuda' if torch.cuda.is_available() else 'cpu'
190
- print('Using {} for inference.'.format(device))
191
-
192
- def _load(checkpoint_path):
193
- if device == 'cuda':
194
- checkpoint = torch.load(checkpoint_path)
195
- else:
196
- checkpoint = torch.load(checkpoint_path,
197
- map_location=lambda storage, loc: storage)
198
- return checkpoint
199
-
200
- def load_model(path):
201
- model = Wav2Lip()
202
- print("Load checkpoint from: {}".format(path))
203
- checkpoint = _load(path)
204
- s = checkpoint["state_dict"]
205
- new_s = {}
206
- for k, v in s.items():
207
- new_s[k.replace('module.', '')] = v
208
- model.load_state_dict(new_s)
209
-
210
- model = model.to(device)
211
- return model.eval()
212
-
213
- def read_frames():
214
- if args.face.split('.')[1] in ['jpg', 'png', 'jpeg']:
215
- face = cv2.imread(args.face)
216
- while 1:
217
- yield face
218
-
219
- video_stream = cv2.VideoCapture(args.face)
220
- fps = video_stream.get(cv2.CAP_PROP_FPS)
221
-
222
- print('Reading video frames from start...')
223
-
224
- while 1:
225
- still_reading, frame = video_stream.read()
226
- if not still_reading:
227
- video_stream.release()
228
- break
229
- if args.resize_factor > 1:
230
- frame = cv2.resize(frame, (frame.shape[1]//args.resize_factor, frame.shape[0]//args.resize_factor))
231
-
232
- if args.rotate:
233
- frame = cv2.rotate(frame, cv2.cv2.ROTATE_90_CLOCKWISE)
234
-
235
- y1, y2, x1, x2 = args.crop
236
- if x2 == -1: x2 = frame.shape[1]
237
- if y2 == -1: y2 = frame.shape[0]
238
-
239
- frame = frame[y1:y2, x1:x2]
240
-
241
- yield frame
242
-
243
- def main():
244
- if not os.path.isfile(args.face):
245
- raise ValueError('--face argument must be a valid path to video/image file')
246
-
247
- elif args.face.split('.')[1] in ['jpg', 'png', 'jpeg']:
248
- fps = args.fps
249
- else:
250
- video_stream = cv2.VideoCapture(args.face)
251
- fps = video_stream.get(cv2.CAP_PROP_FPS)
252
- video_stream.release()
253
-
254
-
255
- if not args.audio.endswith('.wav'):
256
- print('Extracting raw audio...')
257
- command = 'ffmpeg -y -i {} -strict -2 {}'.format(args.audio, 'temp/temp.wav')
258
-
259
- subprocess.call(command, shell=True)
260
- args.audio = 'temp/temp.wav'
261
-
262
- wav = audio.load_wav(args.audio, 16000)
263
- mel = audio.melspectrogram(wav)
264
- print(mel.shape)
265
-
266
- if np.isnan(mel.reshape(-1)).sum() > 0:
267
- raise ValueError('Mel contains nan! Using a TTS voice? Add a small epsilon noise to the wav file and try again')
268
-
269
- mel_chunks = []
270
- mel_idx_multiplier = 80./fps
271
- i = 0
272
- while 1:
273
- start_idx = int(i * mel_idx_multiplier)
274
- if start_idx + mel_step_size > len(mel[0]):
275
- mel_chunks.append(mel[:, len(mel[0]) - mel_step_size:])
276
- break
277
- mel_chunks.append(mel[:, start_idx : start_idx + mel_step_size])
278
- i += 1
279
-
280
- print("Length of mel chunks: {}".format(len(mel_chunks)))
281
-
282
- batch_size = args.wav2lip_batch_size
283
- gen = datagen(mel_chunks)
284
-
285
-
286
-
287
- if args.save_as_video:
288
- gt_out = cv2.VideoWriter("temp/gt.avi", cv2.VideoWriter_fourcc(*'DIVX'), fps, (384, 384))
289
- pred_out = cv2.VideoWriter("temp/pred.avi", cv2.VideoWriter_fourcc(*'DIVX'), fps, (96, 96))
290
-
291
- abs_idx = 0
292
- for i, (img_batch, mel_batch, frames, coords) in enumerate(tqdm(gen,
293
- total=int(np.ceil(float(len(mel_chunks))/batch_size)))):
294
- if i == 0:
295
- print("Loading segmentation network...")
296
- seg_net = init_parser(args.segmentation_path)
297
-
298
- print("Loading super resolution model...")
299
- sr_net = init_sr_model(args.sr_path)
300
-
301
- model = load_model(args.checkpoint_path)
302
- print ("Model loaded")
303
-
304
- frame_h, frame_w = next(read_frames()).shape[:-1]
305
- out = cv2.VideoWriter('temp/result.avi',
306
- cv2.VideoWriter_fourcc(*'DIVX'), fps, (frame_w, frame_h))
307
-
308
- img_batch = torch.FloatTensor(np.transpose(img_batch, (0, 3, 1, 2))).to(device)
309
- mel_batch = torch.FloatTensor(np.transpose(mel_batch, (0, 3, 1, 2))).to(device)
310
-
311
- with torch.no_grad():
312
- pred = model(mel_batch, img_batch)
313
-
314
- pred = pred.cpu().numpy().transpose(0, 2, 3, 1) * 255.
315
-
316
- for p, f, c in zip(pred, frames, coords):
317
- y1, y2, x1, x2 = c
318
-
319
- if args.save_frames:
320
- print("saving frames or video...")
321
- if args.save_as_video:
322
- print("videos...")
323
- pred_out.write(p.astype(np.uint8))
324
- gt_out.write(cv2.resize(f[y1:y2, x1:x2], (384, 384)))
325
- else:
326
- print("frames...")
327
- print(f"{args.gt_path}/{args.image_prefix}{abs_idx}.png")
328
- cv2.imwrite(f"{args.gt_path}/{args.image_prefix}{abs_idx}.png", f[y1:y2, x1:x2])
329
- cv2.imwrite(f"{args.pred_path}/{args.image_prefix}{abs_idx}.png", p)
330
- abs_idx += 1
331
-
332
- if not args.no_sr:
333
- p = enhance(sr_net, p)
334
- p = cv2.resize(p.astype(np.uint8), (x2 - x1, y2 - y1))
335
-
336
- if not args.no_segmentation:
337
- p = swap_regions(f[y1:y2, x1:x2], p, seg_net)
338
-
339
- f[y1:y2, x1:x2] = p
340
- out.write(f)
341
-
342
- out.release()
343
-
344
- command = 'ffmpeg -y -i {} -i {} -strict -2 -q:v 1 {}'.format(args.audio, 'temp/result.avi', args.outfile)
345
- subprocess.call(command, shell=platform.system() != 'Windows')
346
-
347
- if args.save_frames and args.save_as_video:
348
- gt_out.release()
349
- pred_out.release()
350
-
351
- command = 'ffmpeg -y -i {} -i {} -strict -2 -q:v 1 {}'.format(args.audio, 'temp/gt.avi', args.gt_path)
352
- subprocess.call(command, shell=platform.system() != 'Windows')
353
-
354
- command = 'ffmpeg -y -i {} -i {} -strict -2 -q:v 1 {}'.format(args.audio, 'temp/pred.avi', args.pred_path)
355
- subprocess.call(command, shell=platform.system() != 'Windows')
356
-
357
-
358
- if __name__ == '__main__':
359
- main()
 
1
+ from os import listdir, path
2
+ import numpy as np
3
+ import scipy, cv2, os, sys, argparse, audio
4
+ import json, subprocess, random, string
5
+ from tqdm import tqdm
6
+ from glob import glob
7
+ import torch, face_detection
8
+ from wav2lip_models import Wav2Lip
9
+ import platform
10
+ from face_parsing import init_parser, swap_regions
11
+ from basicsr.apply_sr import init_sr_model, enhance
12
+
13
+ parser = argparse.ArgumentParser(description='Inference code to lip-sync videos in the wild using Wav2Lip models')
14
+
15
+ parser.add_argument('--checkpoint_path', type=str,
16
+ help='Name of saved checkpoint to load weights from', required=True)
17
+
18
+ parser.add_argument('--segmentation_path', type=str,
19
+ help='Name of saved checkpoint of segmentation network', required=True)
20
+
21
+ parser.add_argument('--sr_path', type=str,
22
+ help='Name of saved checkpoint of super-resolution network', required=True)
23
+
24
+ parser.add_argument('--face', type=str,
25
+ help='Filepath of video/image that contains faces to use', required=True)
26
+ parser.add_argument('--audio', type=str,
27
+ help='Filepath of video/audio file to use as raw audio source', required=True)
28
+ parser.add_argument('--outfile', type=str, help='Video path to save result. See default for an e.g.',
29
+ default='results/result_voice.mp4')
30
+
31
+
32
+ parser.add_argument('--static', type=bool,
33
+ help='If True, then use only first video frame for inference', default=False)
34
+ parser.add_argument('--fps', type=float, help='Can be specified only if input is a static image (default: 25)',
35
+ default=25., required=False)
36
+
37
+ parser.add_argument('--pads', nargs='+', type=int, default=[0, 10, 0, 0],
38
+ help='Padding (top, bottom, left, right). Please adjust to include chin at least')
39
+
40
+ parser.add_argument('--face_det_batch_size', type=int,
41
+ help='Batch size for face detection', default=16)
42
+ parser.add_argument('--wav2lip_batch_size', type=int, help='Batch size for Wav2Lip model(s)', default=128)
43
+
44
+ parser.add_argument('--resize_factor', default=1, type=int,
45
+ help='Reduce the resolution by this factor. Sometimes, best results are obtained at 480p or 720p')
46
+
47
+ parser.add_argument('--crop', nargs='+', type=int, default=[0, -1, 0, -1],
48
+ help='Crop video to a smaller region (top, bottom, left, right). Applied after resize_factor and rotate arg. '
49
+ 'Useful if multiple face present. -1 implies the value will be auto-inferred based on height, width')
50
+
51
+ parser.add_argument('--box', nargs='+', type=int, default=[-1, -1, -1, -1],
52
+ help='Specify a constant bounding box for the face. Use only as a last resort if the face is not detected.'
53
+ 'Also, might work only if the face is not moving around much. Syntax: (top, bottom, left, right).')
54
+
55
+ parser.add_argument('--rotate', default=False, action='store_true',
56
+ help='Sometimes videos taken from a phone can be flipped 90deg. If true, will flip video right by 90deg.'
57
+ 'Use if you get a flipped result, despite feeding a normal looking video')
58
+
59
+ parser.add_argument('--nosmooth', default=False, action='store_true',
60
+ help='Prevent smoothing face detections over a short temporal window')
61
+ parser.add_argument('--no_segmentation', default=False, action='store_true',
62
+ help='Prevent using face segmentation')
63
+ parser.add_argument('--no_sr', default=False, action='store_true',
64
+ help='Prevent using super resolution')
65
+
66
+ parser.add_argument('--save_frames', default=True, action='store_true',
67
+ help='Save each frame as an image. Use with caution')
68
+ parser.add_argument('--gt_path', type=str,
69
+ help='Where to store saved ground truth frames', required=False)
70
+ parser.add_argument('--pred_path', type=str,
71
+ help='Where to store frames produced by algorithm', required=False)
72
+ parser.add_argument('--save_as_video', action="store_true", default=False,
73
+ help='Whether to save frames as video', required=False)
74
+ parser.add_argument('--image_prefix', type=str, default="",
75
+ help='Prefix to save frames with', required=False)
76
+
77
+ args = parser.parse_args()
78
+ args.img_size = 96
79
+
80
+ if os.path.isfile(args.face) and args.face.split('.')[1] in ['jpg', 'png', 'jpeg']:
81
+ args.static = True
82
+
83
+ def get_smoothened_boxes(boxes, T):
84
+ for i in range(len(boxes)):
85
+ if i + T > len(boxes):
86
+ window = boxes[len(boxes) - T:]
87
+ else:
88
+ window = boxes[i : i + T]
89
+ boxes[i] = np.mean(window, axis=0)
90
+ return boxes
91
+
92
+ def face_detect(images):
93
+ detector = face_detection.FaceAlignment(face_detection.LandmarksType._2D,
94
+ flip_input=False, device=device)
95
+
96
+ batch_size = args.face_det_batch_size
97
+
98
+ while 1:
99
+ predictions = []
100
+ try:
101
+ for i in range(0, len(images), batch_size):
102
+ predictions.extend(detector.get_detections_for_batch(np.array(images[i:i + batch_size])))
103
+ except RuntimeError:
104
+ if batch_size == 1:
105
+ raise RuntimeError('Image too big to run face detection on GPU. Please use the --resize_factor argument')
106
+ batch_size //= 2
107
+ print('Recovering from OOM error; New batch size: {}'.format(batch_size))
108
+ continue
109
+ break
110
+
111
+ results = []
112
+ pady1, pady2, padx1, padx2 = args.pads
113
+ for rect, image in zip(predictions, images):
114
+ if rect is None:
115
+ cv2.imwrite('temp/faulty_frame.jpg', image) # check this frame where the face was not detected.
116
+ raise ValueError('Face not detected! Ensure the video contains a face in all the frames.')
117
+
118
+ y1 = max(0, rect[1] - pady1)
119
+ y2 = min(image.shape[0], rect[3] + pady2)
120
+ x1 = max(0, rect[0] - padx1)
121
+ x2 = min(image.shape[1], rect[2] + padx2)
122
+
123
+ results.append([x1, y1, x2, y2])
124
+
125
+ boxes = np.array(results)
126
+ if not args.nosmooth: boxes = get_smoothened_boxes(boxes, T=5)
127
+ results = [[image[y1: y2, x1:x2], (y1, y2, x1, x2)] for image, (x1, y1, x2, y2) in zip(images, boxes)]
128
+
129
+ del detector
130
+ return results
131
+
132
+ def datagen(mels):
133
+ img_batch, mel_batch, frame_batch, coords_batch = [], [], [], []
134
+
135
+ """
136
+ if args.box[0] == -1:
137
+ if not args.static:
138
+ face_det_results = face_detect(frames) # BGR2RGB for CNN face detection
139
+ else:
140
+ face_det_results = face_detect([frames[0]])
141
+ else:
142
+ print('Using the specified bounding box instead of face detection...')
143
+ y1, y2, x1, x2 = args.box
144
+ face_det_results = [[f[y1: y2, x1:x2], (y1, y2, x1, x2)] for f in frames]
145
+ """
146
+
147
+ reader = read_frames()
148
+
149
+ for i, m in enumerate(mels):
150
+ try:
151
+ frame_to_save = next(reader)
152
+ except StopIteration:
153
+ reader = read_frames()
154
+ frame_to_save = next(reader)
155
+
156
+ face, coords = face_detect([frame_to_save])[0]
157
+
158
+ face = cv2.resize(face, (args.img_size, args.img_size))
159
+
160
+ img_batch.append(face)
161
+ mel_batch.append(m)
162
+ frame_batch.append(frame_to_save)
163
+ coords_batch.append(coords)
164
+
165
+ if len(img_batch) >= args.wav2lip_batch_size:
166
+ img_batch, mel_batch = np.asarray(img_batch), np.asarray(mel_batch)
167
+
168
+ img_masked = img_batch.copy()
169
+ img_masked[:, args.img_size//2:] = 0
170
+
171
+ img_batch = np.concatenate((img_masked, img_batch), axis=3) / 255.
172
+ mel_batch = np.reshape(mel_batch, [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1])
173
+
174
+ yield img_batch, mel_batch, frame_batch, coords_batch
175
+ img_batch, mel_batch, frame_batch, coords_batch = [], [], [], []
176
+
177
+ if len(img_batch) > 0:
178
+ img_batch, mel_batch = np.asarray(img_batch), np.asarray(mel_batch)
179
+
180
+ img_masked = img_batch.copy()
181
+ img_masked[:, args.img_size//2:] = 0
182
+
183
+ img_batch = np.concatenate((img_masked, img_batch), axis=3) / 255.
184
+ mel_batch = np.reshape(mel_batch, [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1])
185
+
186
+ yield img_batch, mel_batch, frame_batch, coords_batch
187
+
188
+ mel_step_size = 16
189
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
190
+ print('Using {} for inference.'.format(device))
191
+
192
+ def _load(checkpoint_path):
193
+ if device == 'cuda':
194
+ checkpoint = torch.load(checkpoint_path)
195
+ else:
196
+ checkpoint = torch.load(checkpoint_path,
197
+ map_location=lambda storage, loc: storage)
198
+ return checkpoint
199
+
200
+ def load_model(path):
201
+ model = Wav2Lip()
202
+ print("Load checkpoint from: {}".format(path))
203
+ checkpoint = _load(path)
204
+ s = checkpoint["state_dict"]
205
+ new_s = {}
206
+ for k, v in s.items():
207
+ new_s[k.replace('module.', '')] = v
208
+ model.load_state_dict(new_s)
209
+
210
+ model = model.to(device)
211
+ return model.eval()
212
+
213
+ def read_frames():
214
+ if args.face.split('.')[1] in ['jpg', 'png', 'jpeg']:
215
+ face = cv2.imread(args.face)
216
+ while 1:
217
+ yield face
218
+
219
+ video_stream = cv2.VideoCapture(args.face)
220
+ fps = video_stream.get(cv2.CAP_PROP_FPS)
221
+
222
+ print('Reading video frames from start...')
223
+
224
+ while 1:
225
+ still_reading, frame = video_stream.read()
226
+ if not still_reading:
227
+ video_stream.release()
228
+ break
229
+ if args.resize_factor > 1:
230
+ frame = cv2.resize(frame, (frame.shape[1]//args.resize_factor, frame.shape[0]//args.resize_factor))
231
+
232
+ if args.rotate:
233
+ frame = cv2.rotate(frame, cv2.cv2.ROTATE_90_CLOCKWISE)
234
+
235
+ y1, y2, x1, x2 = args.crop
236
+ if x2 == -1: x2 = frame.shape[1]
237
+ if y2 == -1: y2 = frame.shape[0]
238
+
239
+ frame = frame[y1:y2, x1:x2]
240
+
241
+ yield frame
242
+
243
+ def main():
244
+ if not os.path.isfile(args.face):
245
+ raise ValueError('--face argument must be a valid path to video/image file')
246
+
247
+ elif args.face.split('.')[1] in ['jpg', 'png', 'jpeg']:
248
+ fps = args.fps
249
+ else:
250
+ video_stream = cv2.VideoCapture(args.face)
251
+ fps = video_stream.get(cv2.CAP_PROP_FPS)
252
+ video_stream.release()
253
+
254
+
255
+ if not args.audio.endswith('.wav'):
256
+ print('Extracting raw audio...')
257
+ command = 'ffmpeg -y -i {} -strict -2 {}'.format(args.audio, 'temp/temp.wav')
258
+
259
+ subprocess.call(command, shell=True)
260
+ args.audio = 'temp/temp.wav'
261
+
262
+ wav = audio.load_wav(args.audio, 16000)
263
+ mel = audio.melspectrogram(wav)
264
+ print(mel.shape)
265
+
266
+ if np.isnan(mel.reshape(-1)).sum() > 0:
267
+ raise ValueError('Mel contains nan! Using a TTS voice? Add a small epsilon noise to the wav file and try again')
268
+
269
+ mel_chunks = []
270
+ mel_idx_multiplier = 80./fps
271
+ i = 0
272
+ while 1:
273
+ start_idx = int(i * mel_idx_multiplier)
274
+ if start_idx + mel_step_size > len(mel[0]):
275
+ mel_chunks.append(mel[:, len(mel[0]) - mel_step_size:])
276
+ break
277
+ mel_chunks.append(mel[:, start_idx : start_idx + mel_step_size])
278
+ i += 1
279
+
280
+ print("Length of mel chunks: {}".format(len(mel_chunks)))
281
+
282
+ batch_size = args.wav2lip_batch_size
283
+ gen = datagen(mel_chunks)
284
+
285
+
286
+
287
+ if args.save_as_video:
288
+ gt_out = cv2.VideoWriter("temp/gt.avi", cv2.VideoWriter_fourcc(*'DIVX'), fps, (384, 384))
289
+ pred_out = cv2.VideoWriter("temp/pred.avi", cv2.VideoWriter_fourcc(*'DIVX'), fps, (96, 96))
290
+
291
+ abs_idx = 0
292
+ for i, (img_batch, mel_batch, frames, coords) in enumerate(tqdm(gen,
293
+ total=int(np.ceil(float(len(mel_chunks))/batch_size)))):
294
+ if i == 0:
295
+ print("Loading segmentation network...")
296
+ seg_net = init_parser(args.segmentation_path)
297
+
298
+ print("Loading super resolution model...")
299
+ sr_net = init_sr_model(args.sr_path)
300
+
301
+ model = load_model(args.checkpoint_path)
302
+ print ("Model loaded")
303
+
304
+ frame_h, frame_w = next(read_frames()).shape[:-1]
305
+ out = cv2.VideoWriter('temp/result.avi',
306
+ cv2.VideoWriter_fourcc(*'DIVX'), fps, (frame_w, frame_h))
307
+
308
+ img_batch = torch.FloatTensor(np.transpose(img_batch, (0, 3, 1, 2))).to(device)
309
+ mel_batch = torch.FloatTensor(np.transpose(mel_batch, (0, 3, 1, 2))).to(device)
310
+
311
+ with torch.no_grad():
312
+ pred = model(mel_batch, img_batch)
313
+
314
+ pred = pred.cpu().numpy().transpose(0, 2, 3, 1) * 255.
315
+
316
+ for p, f, c in zip(pred, frames, coords):
317
+ y1, y2, x1, x2 = c
318
+
319
+ if args.save_frames:
320
+ print("saving frames or video...")
321
+ if args.save_as_video:
322
+ print("videos...")
323
+ pred_out.write(p.astype(np.uint8))
324
+ gt_out.write(cv2.resize(f[y1:y2, x1:x2], (384, 384)))
325
+ else:
326
+ print("frames...")
327
+ print(f"{args.pred_path}/{args.image_prefix}{abs_idx}.png")
328
+ cv2.imwrite(f"{args.pred_path}/{args.image_prefix}{abs_idx:05d}.png", p)
329
+ cv2.imwrite(f"{args.gt_path}/{args.image_prefix}{abs_idx}.png", f[y1:y2, x1:x2])
330
+ abs_idx += 1
331
+
332
+ if not args.no_sr:
333
+ p = enhance(sr_net, p)
334
+ p = cv2.resize(p.astype(np.uint8), (x2 - x1, y2 - y1))
335
+
336
+ if not args.no_segmentation:
337
+ p = swap_regions(f[y1:y2, x1:x2], p, seg_net)
338
+
339
+ f[y1:y2, x1:x2] = p
340
+ out.write(f)
341
+
342
+ out.release()
343
+
344
+ command = 'ffmpeg -y -i {} -i {} -strict -2 -q:v 1 {}'.format(args.audio, 'temp/result.avi', args.outfile)
345
+ subprocess.call(command, shell=platform.system() != 'Windows')
346
+
347
+ if args.save_frames and args.save_as_video:
348
+ gt_out.release()
349
+ pred_out.release()
350
+
351
+ command = 'ffmpeg -y -i {} -i {} -strict -2 -q:v 1 {}'.format(args.audio, 'temp/gt.avi', args.gt_path)
352
+ subprocess.call(command, shell=platform.system() != 'Windows')
353
+
354
+ command = 'ffmpeg -y -i {} -i {} -strict -2 -q:v 1 {}'.format(args.audio, 'temp/pred.avi', args.pred_path)
355
+ subprocess.call(command, shell=platform.system() != 'Windows')
356
+
357
+
358
+ if __name__ == '__main__':
359
+ main()