Spanicin commited on
Commit
a57726b
·
verified ·
1 Parent(s): 2f2399f

Update videoretalking/inference_function.py

Browse files
Files changed (1) hide show
  1. videoretalking/inference_function.py +346 -368
videoretalking/inference_function.py CHANGED
@@ -1,368 +1,346 @@
1
- import numpy as np
2
- import cv2, os, sys, subprocess, platform, torch
3
- from tqdm import tqdm
4
- from PIL import Image
5
- from scipy.io import loadmat
6
- from moviepy.editor import AudioFileClip, VideoFileClip
7
-
8
- sys.path.insert(0, 'third_part')
9
- sys.path.insert(0, 'third_part/GPEN')
10
-
11
- # 3dmm extraction
12
- from third_part.face3d.util.preprocess import align_img
13
- from third_part.face3d.util.load_mats import load_lm3d
14
- from third_part.face3d.extract_kp_videos import KeypointExtractor
15
- # face enhancement
16
- from third_part.GPEN.gpen_face_enhancer import FaceEnhancement
17
- # expression control
18
- from third_part.ganimation_replicate.model.ganimation import GANimationModel
19
-
20
- from utils import audio
21
- from utils.ffhq_preprocess import Croper
22
- from utils.alignment_stit import crop_faces, calc_alignment_coefficients, paste_image
23
- from utils.inference_utils import Laplacian_Pyramid_Blending_with_mask, face_detect, load_model, options, split_coeff, \
24
- trans_image, transform_semantic, find_crop_norm_ratio, load_face3d_net, exp_aus_dict
25
- import warnings
26
- warnings.filterwarnings("ignore")
27
-
28
- def video_lipsync_correctness(face, audio_path, outfile=None, tmp_dir="temp", crop=[0, -1, 0, -1], re_preprocess=False, exp_img="neutral", face3d_net_path="checkpoints/face3d_pretrain_epoch_20.pth", one_shot=False, up_face="original", LNet_batch_size=16, without_rl1=False, static=False):
29
- device = 'cuda' if torch.cuda.is_available() else 'cpu'
30
- print('[Info] Using {} for inference.'.format(device))
31
- os.makedirs(os.path.join('temp', tmp_dir), exist_ok=True)
32
-
33
- enhancer = FaceEnhancement(base_dir='checkpoints', size=512, model='GPEN-BFR-512', use_sr=False, \
34
- sr_model='rrdb_realesrnet_psnr', channel_multiplier=2, narrow=1, device=device)
35
-
36
- base_name = face.split('/')[-1]
37
- print('base_name',base_name)
38
- if os.path.isfile(face) and face.split('.')[1] in ['jpg', 'png', 'jpeg']:
39
- static = True
40
- if not os.path.isfile(face):
41
- raise ValueError('--face argument must be a valid path to video/image file')
42
- elif face.split('.')[1] in ['jpg', 'png', 'jpeg']:
43
- full_frames = [cv2.imread(face)]
44
- fps = fps
45
- else:
46
- video_stream = cv2.VideoCapture(face)
47
- fps = video_stream.get(cv2.CAP_PROP_FPS)
48
-
49
- full_frames = []
50
- while True:
51
- still_reading, frame = video_stream.read()
52
- if not still_reading:
53
- video_stream.release()
54
- break
55
- y1, y2, x1, x2 = crop
56
- if x2 == -1: x2 = frame.shape[1]
57
- if y2 == -1: y2 = frame.shape[0]
58
- frame = frame[y1:y2, x1:x2]
59
- full_frames.append(frame)
60
-
61
- print ("[Step 0] Number of frames available for inference: "+str(len(full_frames)))
62
- # face detection & cropping, cropping the first frame as the style of FFHQ
63
- croper = Croper('checkpoints/shape_predictor_68_face_landmarks.dat')
64
- full_frames_RGB = [cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) for frame in full_frames]
65
- full_frames_RGB, crop, quad = croper.crop(full_frames_RGB, xsize=512)
66
-
67
- clx, cly, crx, cry = crop
68
- lx, ly, rx, ry = quad
69
- lx, ly, rx, ry = int(lx), int(ly), int(rx), int(ry)
70
- oy1, oy2, ox1, ox2 = cly+ly, min(cly+ry, full_frames[0].shape[0]), clx+lx, min(clx+rx, full_frames[0].shape[1])
71
- # original_size = (ox2 - ox1, oy2 - oy1)
72
- frames_pil = [Image.fromarray(cv2.resize(frame,(256,256))) for frame in full_frames_RGB]
73
-
74
- # get the landmark according to the detected face.
75
- if not os.path.isfile('temp/'+base_name+'_landmarks.txt') or re_preprocess:
76
- print('[Step 1] Landmarks Extraction in Video.')
77
- kp_extractor = KeypointExtractor()
78
- lm = kp_extractor.extract_keypoint(frames_pil, 'temp/'+base_name+'_landmarks.txt')
79
- else:
80
- print('[Step 1] Using saved landmarks.')
81
- lm = np.loadtxt('temp/'+base_name+'_landmarks.txt').astype(np.float32)
82
- lm = lm.reshape([len(full_frames), -1, 2])
83
-
84
- if not os.path.isfile('temp/'+base_name+'_coeffs.npy') or exp_img is not None or re_preprocess:
85
- net_recon = load_face3d_net(face3d_net_path, device)
86
- lm3d_std = load_lm3d('checkpoints/BFM_Fitting')
87
-
88
- video_coeffs = []
89
- for idx in tqdm(range(len(frames_pil)), desc="[Step 2] 3DMM Extraction In Video:"):
90
- frame = frames_pil[idx]
91
- W, H = frame.size
92
- lm_idx = lm[idx].reshape([-1, 2])
93
- if np.mean(lm_idx) == -1:
94
- lm_idx = (lm3d_std[:, :2]+1) / 2.
95
- lm_idx = np.concatenate([lm_idx[:, :1] * W, lm_idx[:, 1:2] * H], 1)
96
- else:
97
- lm_idx[:, -1] = H - 1 - lm_idx[:, -1]
98
-
99
- trans_params, im_idx, lm_idx, _ = align_img(frame, lm_idx, lm3d_std)
100
- trans_params = np.array([float(item) for item in np.hsplit(trans_params, 5)]).astype(np.float32)
101
- im_idx_tensor = torch.tensor(np.array(im_idx)/255., dtype=torch.float32).permute(2, 0, 1).to(device).unsqueeze(0)
102
- with torch.no_grad():
103
- coeffs = split_coeff(net_recon(im_idx_tensor))
104
-
105
- pred_coeff = {key:coeffs[key].cpu().numpy() for key in coeffs}
106
- pred_coeff = np.concatenate([pred_coeff['id'], pred_coeff['exp'], pred_coeff['tex'], pred_coeff['angle'],\
107
- pred_coeff['gamma'], pred_coeff['trans'], trans_params[None]], 1)
108
- video_coeffs.append(pred_coeff)
109
- semantic_npy = np.array(video_coeffs)[:,0]
110
- np.save('temp/'+base_name+'_coeffs.npy', semantic_npy)
111
- else:
112
- print('[Step 2] Using saved coeffs.')
113
- semantic_npy = np.load('temp/'+base_name+'_coeffs.npy').astype(np.float32)
114
-
115
- # generate the 3dmm coeff from a single image
116
- if exp_img is not None and ('.png' in exp_img or '.jpg' in exp_img):
117
- print('extract the exp from',exp_img)
118
- exp_pil = Image.open(exp_img).convert('RGB')
119
- lm3d_std = load_lm3d('third_part/face3d/BFM')
120
-
121
- W, H = exp_pil.size
122
- kp_extractor = KeypointExtractor()
123
- lm_exp = kp_extractor.extract_keypoint([exp_pil], 'temp/'+base_name+'_temp.txt')[0]
124
- if np.mean(lm_exp) == -1:
125
- lm_exp = (lm3d_std[:, :2] + 1) / 2.
126
- lm_exp = np.concatenate(
127
- [lm_exp[:, :1] * W, lm_exp[:, 1:2] * H], 1)
128
- else:
129
- lm_exp[:, -1] = H - 1 - lm_exp[:, -1]
130
-
131
- trans_params, im_exp, lm_exp, _ = align_img(exp_pil, lm_exp, lm3d_std)
132
- trans_params = np.array([float(item) for item in np.hsplit(trans_params, 5)]).astype(np.float32)
133
- im_exp_tensor = torch.tensor(np.array(im_exp)/255., dtype=torch.float32).permute(2, 0, 1).to(device).unsqueeze(0)
134
- with torch.no_grad():
135
- expression = split_coeff(net_recon(im_exp_tensor))['exp'][0]
136
- del net_recon
137
- elif exp_img == 'smile':
138
- expression = torch.tensor(loadmat('checkpoints/expression.mat')['expression_mouth'])[0]
139
- else:
140
- print('using expression center')
141
- expression = torch.tensor(loadmat('checkpoints/expression.mat')['expression_center'])[0]
142
-
143
- # load DNet, model(LNet and ENet)
144
- D_Net, model = load_model(device,DNet_path='checkpoints/DNet.pt',LNet_path='checkpoints/LNet.pth',ENet_path='checkpoints/ENet.pth')
145
-
146
- if not os.path.isfile('temp/'+base_name+'_stablized.npy') or re_preprocess:
147
- imgs = []
148
- for idx in tqdm(range(len(frames_pil)), desc="[Step 3] Stabilize the expression In Video:"):
149
- if one_shot:
150
- source_img = trans_image(frames_pil[0]).unsqueeze(0).to(device)
151
- semantic_source_numpy = semantic_npy[0:1]
152
- else:
153
- source_img = trans_image(frames_pil[idx]).unsqueeze(0).to(device)
154
- semantic_source_numpy = semantic_npy[idx:idx+1]
155
- ratio = find_crop_norm_ratio(semantic_source_numpy, semantic_npy)
156
- coeff = transform_semantic(semantic_npy, idx, ratio).unsqueeze(0).to(device)
157
-
158
- # hacking the new expression
159
- coeff[:, :64, :] = expression[None, :64, None].to(device)
160
- with torch.no_grad():
161
- output = D_Net(source_img, coeff)
162
- img_stablized = np.uint8((output['fake_image'].squeeze(0).permute(1,2,0).cpu().clamp_(-1, 1).numpy() + 1 )/2. * 255)
163
- imgs.append(cv2.cvtColor(img_stablized,cv2.COLOR_RGB2BGR))
164
- np.save('temp/'+base_name+'_stablized.npy',imgs)
165
- del D_Net
166
- else:
167
- print('[Step 3] Using saved stabilized video.')
168
- imgs = np.load('temp/'+base_name+'_stablized.npy')
169
- torch.cuda.empty_cache()
170
-
171
- if not audio_path.endswith('.wav'):
172
- # command = 'ffmpeg -loglevel error -y -i {} -strict -2 {}'.format(audio_path, 'temp/{}/temp.wav'.format(tmp_dir))
173
- # subprocess.call(command, shell=True)
174
- converted_audio_path = os.path.join('temp', tmp_dir, 'temp.wav')
175
- audio_clip = AudioFileClip(audio_path)
176
- audio_clip.write_audiofile(converted_audio_path, codec='pcm_s16le')
177
- audio_clip.close()
178
- audio_path = converted_audio_path
179
- # audio_path = 'temp/{}/temp.wav'.format(tmp_dir)
180
- wav = audio.load_wav(audio_path, 16000)
181
- mel = audio.melspectrogram(wav)
182
- if np.isnan(mel.reshape(-1)).sum() > 0:
183
- raise ValueError('Mel contains nan! Using a TTS voice? Add a small epsilon noise to the wav file and try again')
184
-
185
- mel_step_size, mel_idx_multiplier, i, mel_chunks = 16, 80./fps, 0, []
186
- while True:
187
- start_idx = int(i * mel_idx_multiplier)
188
- if start_idx + mel_step_size > len(mel[0]):
189
- mel_chunks.append(mel[:, len(mel[0]) - mel_step_size:])
190
- break
191
- mel_chunks.append(mel[:, start_idx : start_idx + mel_step_size])
192
- i += 1
193
-
194
- print("[Step 4] Load audio; Length of mel chunks: {}".format(len(mel_chunks)))
195
- imgs = imgs[:len(mel_chunks)]
196
- full_frames = full_frames[:len(mel_chunks)]
197
- lm = lm[:len(mel_chunks)]
198
-
199
- imgs_enhanced = []
200
- for idx in tqdm(range(len(imgs)), desc='[Step 5] Reference Enhancement'):
201
- img = imgs[idx]
202
- pred, _, _ = enhancer.process(img, img, face_enhance=True, possion_blending=False)
203
- imgs_enhanced.append(pred)
204
- gen = datagen(imgs_enhanced.copy(), mel_chunks, full_frames, None, (oy1,oy2,ox1,ox2), face, static, LNet_batch_size, img_size=384)
205
-
206
- frame_h, frame_w = full_frames[0].shape[:-1]
207
- out = cv2.VideoWriter('temp/{}/result.mp4'.format(tmp_dir), cv2.VideoWriter_fourcc(*'mp4v'), fps, (frame_w, frame_h))
208
-
209
- if up_face != 'original':
210
- instance = GANimationModel()
211
- instance.initialize()
212
- instance.setup()
213
-
214
- kp_extractor = KeypointExtractor()
215
- for i, (img_batch, mel_batch, frames, coords, img_original, f_frames) in enumerate(tqdm(gen, desc='[Step 6] Lip Synthesis:', total=int(np.ceil(float(len(mel_chunks)) / LNet_batch_size)))):
216
- img_batch = torch.FloatTensor(np.transpose(img_batch, (0, 3, 1, 2))).to(device)
217
- mel_batch = torch.FloatTensor(np.transpose(mel_batch, (0, 3, 1, 2))).to(device)
218
- img_original = torch.FloatTensor(np.transpose(img_original, (0, 3, 1, 2))).to(device)/255. # BGR -> RGB
219
-
220
- with torch.no_grad():
221
- incomplete, reference = torch.split(img_batch, 3, dim=1)
222
- pred, low_res = model(mel_batch, img_batch, reference)
223
- pred = torch.clamp(pred, 0, 1)
224
-
225
- if up_face in ['sad', 'angry', 'surprise']:
226
- tar_aus = exp_aus_dict[up_face]
227
- else:
228
- pass
229
-
230
- if up_face == 'original':
231
- cur_gen_faces = img_original
232
- else:
233
- test_batch = {'src_img': torch.nn.functional.interpolate((img_original * 2 - 1), size=(128, 128), mode='bilinear'),
234
- 'tar_aus': tar_aus.repeat(len(incomplete), 1)}
235
- instance.feed_batch(test_batch)
236
- instance.forward()
237
- cur_gen_faces = torch.nn.functional.interpolate(instance.fake_img / 2. + 0.5, size=(384, 384), mode='bilinear')
238
-
239
- if without_rl1 is not False:
240
- incomplete, reference = torch.split(img_batch, 3, dim=1)
241
- mask = torch.where(incomplete==0, torch.ones_like(incomplete), torch.zeros_like(incomplete))
242
- pred = pred * mask + cur_gen_faces * (1 - mask)
243
-
244
- pred = pred.cpu().numpy().transpose(0, 2, 3, 1) * 255.
245
-
246
- torch.cuda.empty_cache()
247
- for p, f, xf, c in zip(pred, frames, f_frames, coords):
248
- y1, y2, x1, x2 = c
249
- p = cv2.resize(p.astype(np.uint8), (x2 - x1, y2 - y1))
250
-
251
- ff = xf.copy()
252
- ff[y1:y2, x1:x2] = p
253
-
254
- restored_img = ff
255
- mm = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 255, 255, 255, 0, 0, 0, 0, 0, 0]
256
- mouse_mask = np.zeros_like(restored_img)
257
- tmp_mask = enhancer.faceparser.process(restored_img[y1:y2, x1:x2], mm)[0]
258
- mouse_mask[y1:y2, x1:x2]= cv2.resize(tmp_mask, (x2 - x1, y2 - y1))[:, :, np.newaxis] / 255.
259
-
260
- height, width = ff.shape[:2]
261
- restored_img, ff, full_mask = [cv2.resize(x, (512, 512)) for x in (restored_img, ff, np.float32(mouse_mask))]
262
- img = Laplacian_Pyramid_Blending_with_mask(restored_img, ff, full_mask[:, :, 0], 10)
263
- pp = np.uint8(cv2.resize(np.clip(img, 0 ,255), (width, height)))
264
-
265
- pp, orig_faces, enhanced_faces = enhancer.process(pp, xf, bbox=c, face_enhance=False, possion_blending=True)
266
- out.write(pp)
267
- out.release()
268
-
269
- if not os.path.isdir(os.path.dirname(outfile)):
270
- os.makedirs(os.path.dirname(outfile), exist_ok=True)
271
- # command = 'ffmpeg -loglevel error -y -i {} -i {} -strict -2 -q:v 1 {}'.format(audio_path, 'temp/{}/result.mp4'.format(tmp_dir), outfile)
272
- # subprocess.call(command, shell=platform.system() != 'Windows')
273
- video_path = 'temp/{}/result.mp4'.format(tmp_dir)
274
- audio_clip = AudioFileClip(audio_path)
275
- video_clip = VideoFileClip(video_path)
276
- video_clip = video_clip.set_audio(audio_clip)
277
-
278
- # Write the result to the output file
279
- video_clip.write_videofile(outfile, codec='libx264', audio_codec='aac')
280
- print('outfile:', outfile)
281
-
282
- # frames:256x256, full_frames: original size
283
- def datagen(frames, mels, full_frames, frames_pil, cox, face, static, LNet_batch_size, img_size):
284
- img_batch, mel_batch, frame_batch, coords_batch, ref_batch, full_frame_batch = [], [], [], [], [], []
285
- base_name = face.split('/')[-1]
286
- refs = []
287
- image_size = 256
288
-
289
- # original frames
290
- kp_extractor = KeypointExtractor()
291
- fr_pil = [Image.fromarray(frame) for frame in frames]
292
- lms = kp_extractor.extract_keypoint(fr_pil, 'temp/'+base_name+'x12_landmarks.txt')
293
- frames_pil = [ (lm, frame) for frame,lm in zip(fr_pil, lms)] # frames is the croped version of modified face
294
- crops, orig_images, quads = crop_faces(image_size, frames_pil, scale=1.0, use_fa=True)
295
- inverse_transforms = [calc_alignment_coefficients(quad + 0.5, [[0, 0], [0, image_size], [image_size, image_size], [image_size, 0]]) for quad in quads]
296
- del kp_extractor.detector
297
-
298
- oy1,oy2,ox1,ox2 = cox
299
- face_det_results = face_detect(full_frames, face_det_batch_size=4, nosmooth=False, pads=[0, 20, 0, 0], jaw_correction=True, detector=None)
300
-
301
- for inverse_transform, crop, full_frame, face_det in zip(inverse_transforms, crops, full_frames, face_det_results):
302
- imc_pil = paste_image(inverse_transform, crop, Image.fromarray(
303
- cv2.resize(full_frame[int(oy1):int(oy2), int(ox1):int(ox2)], (256, 256))))
304
-
305
- ff = full_frame.copy()
306
- ff[int(oy1):int(oy2), int(ox1):int(ox2)] = cv2.resize(np.array(imc_pil.convert('RGB')), (ox2 - ox1, oy2 - oy1))
307
- oface, coords = face_det
308
- y1, y2, x1, x2 = coords
309
- refs.append(ff[y1: y2, x1:x2])
310
-
311
- for i, m in enumerate(mels):
312
- idx = 0 if static else i % len(frames)
313
- frame_to_save = frames[idx].copy()
314
- face = refs[idx]
315
- oface, coords = face_det_results[idx].copy()
316
-
317
- face = cv2.resize(face, (img_size, img_size))
318
- oface = cv2.resize(oface, (img_size, img_size))
319
-
320
- img_batch.append(oface)
321
- ref_batch.append(face)
322
- mel_batch.append(m)
323
- coords_batch.append(coords)
324
- frame_batch.append(frame_to_save)
325
- full_frame_batch.append(full_frames[idx].copy())
326
-
327
- if len(img_batch) >= LNet_batch_size:
328
- img_batch, mel_batch, ref_batch = np.asarray(img_batch), np.asarray(mel_batch), np.asarray(ref_batch)
329
- img_masked = img_batch.copy()
330
- img_original = img_batch.copy()
331
- img_masked[:, img_size//2:] = 0
332
- img_batch = np.concatenate((img_masked, ref_batch), axis=3) / 255.
333
- mel_batch = np.reshape(mel_batch, [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1])
334
-
335
- yield img_batch, mel_batch, frame_batch, coords_batch, img_original, full_frame_batch
336
- img_batch, mel_batch, frame_batch, coords_batch, img_original, full_frame_batch, ref_batch = [], [], [], [], [], [], []
337
-
338
- if len(img_batch) > 0:
339
- img_batch, mel_batch, ref_batch = np.asarray(img_batch), np.asarray(mel_batch), np.asarray(ref_batch)
340
- img_masked = img_batch.copy()
341
- img_original = img_batch.copy()
342
- img_masked[:, img_size//2:] = 0
343
- img_batch = np.concatenate((img_masked, ref_batch), axis=3) / 255.
344
- mel_batch = np.reshape(mel_batch, [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1])
345
- yield img_batch, mel_batch, frame_batch, coords_batch, img_original, full_frame_batch
346
-
347
-
348
-
349
- if __name__ == "__main__":
350
- face_path = "C:/Users/fd01076/Downloads/download_1.mp4" # Replace with the path to your face image or video
351
- audio_path = "C:/Users/fd01076/Downloads/audio_1.mp3" # Replace with the path to your audio file
352
- output_path = "C:/Users/fd01076/Downloads/result.mp4" # Replace with the path for the output video
353
-
354
- # Call the function
355
- video_lipsync_correctness(
356
- face=face_path,
357
- audio_path=audio_path,
358
- outfile=output_path,
359
- tmp_dir="temp",
360
- crop=[0, -1, 0, -1],
361
- re_preprocess=True, # Set to True if you want to reprocess; False otherwise
362
- exp_img="neutral", # Can be 'smile', 'neutral', or path to an expression image
363
- face3d_net_path="checkpoints/face3d_pretrain_epoch_20.pth",
364
- one_shot=False,
365
- up_face="original", # Options: 'original', 'sad', 'angry', 'surprise'
366
- LNet_batch_size=16,
367
- without_rl1=False
368
- )
 
1
+ import numpy as np
2
+ import cv2, os, sys, subprocess, platform, torch
3
+ from tqdm import tqdm
4
+ from PIL import Image
5
+ from scipy.io import loadmat
6
+ from moviepy.editor import AudioFileClip, VideoFileClip
7
+
8
+ sys.path.insert(0, 'third_part')
9
+ sys.path.insert(0, 'third_part/GPEN')
10
+
11
+ # 3dmm extraction
12
+ from third_part.face3d.util.preprocess import align_img
13
+ from third_part.face3d.util.load_mats import load_lm3d
14
+ from third_part.face3d.extract_kp_videos import KeypointExtractor
15
+ # face enhancement
16
+ from third_part.GPEN.gpen_face_enhancer import FaceEnhancement
17
+ # expression control
18
+ from third_part.ganimation_replicate.model.ganimation import GANimationModel
19
+
20
+ from utils import audio
21
+ from utils.ffhq_preprocess import Croper
22
+ from utils.alignment_stit import crop_faces, calc_alignment_coefficients, paste_image
23
+ from utils.inference_utils import Laplacian_Pyramid_Blending_with_mask, face_detect, load_model, options, split_coeff, \
24
+ trans_image, transform_semantic, find_crop_norm_ratio, load_face3d_net, exp_aus_dict
25
+ import warnings
26
+ warnings.filterwarnings("ignore")
27
+
28
+ def video_lipsync_correctness(face, audio_path, outfile=None, tmp_dir="temp", crop=[0, -1, 0, -1], re_preprocess=False, exp_img="neutral", face3d_net_path="checkpoints/face3d_pretrain_epoch_20.pth", one_shot=False, up_face="original", LNet_batch_size=16, without_rl1=False, static=False):
29
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
30
+ print('[Info] Using {} for inference.'.format(device))
31
+ os.makedirs(os.path.join('temp', tmp_dir), exist_ok=True)
32
+
33
+ enhancer = FaceEnhancement(base_dir='checkpoints', size=512, model='GPEN-BFR-512', use_sr=False, \
34
+ sr_model='rrdb_realesrnet_psnr', channel_multiplier=2, narrow=1, device=device)
35
+
36
+ base_name = face.split('/')[-1]
37
+ print('base_name',base_name)
38
+ if os.path.isfile(face) and face.split('.')[1] in ['jpg', 'png', 'jpeg']:
39
+ static = True
40
+ if not os.path.isfile(face):
41
+ raise ValueError('--face argument must be a valid path to video/image file')
42
+ elif face.split('.')[1] in ['jpg', 'png', 'jpeg']:
43
+ full_frames = [cv2.imread(face)]
44
+ fps = fps
45
+ else:
46
+ video_stream = cv2.VideoCapture(face)
47
+ fps = video_stream.get(cv2.CAP_PROP_FPS)
48
+
49
+ full_frames = []
50
+ while True:
51
+ still_reading, frame = video_stream.read()
52
+ if not still_reading:
53
+ video_stream.release()
54
+ break
55
+ y1, y2, x1, x2 = crop
56
+ if x2 == -1: x2 = frame.shape[1]
57
+ if y2 == -1: y2 = frame.shape[0]
58
+ frame = frame[y1:y2, x1:x2]
59
+ full_frames.append(frame)
60
+
61
+ print ("[Step 0] Number of frames available for inference: "+str(len(full_frames)))
62
+ # face detection & cropping, cropping the first frame as the style of FFHQ
63
+ croper = Croper('checkpoints/shape_predictor_68_face_landmarks.dat')
64
+ full_frames_RGB = [cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) for frame in full_frames]
65
+ full_frames_RGB, crop, quad = croper.crop(full_frames_RGB, xsize=512)
66
+
67
+ clx, cly, crx, cry = crop
68
+ lx, ly, rx, ry = quad
69
+ lx, ly, rx, ry = int(lx), int(ly), int(rx), int(ry)
70
+ oy1, oy2, ox1, ox2 = cly+ly, min(cly+ry, full_frames[0].shape[0]), clx+lx, min(clx+rx, full_frames[0].shape[1])
71
+ # original_size = (ox2 - ox1, oy2 - oy1)
72
+ frames_pil = [Image.fromarray(cv2.resize(frame,(256,256))) for frame in full_frames_RGB]
73
+
74
+ # get the landmark according to the detected face.
75
+ if not os.path.isfile('temp/'+base_name+'_landmarks.txt') or re_preprocess:
76
+ print('[Step 1] Landmarks Extraction in Video.')
77
+ kp_extractor = KeypointExtractor()
78
+ lm = kp_extractor.extract_keypoint(frames_pil, 'temp/'+base_name+'_landmarks.txt')
79
+ else:
80
+ print('[Step 1] Using saved landmarks.')
81
+ lm = np.loadtxt('temp/'+base_name+'_landmarks.txt').astype(np.float32)
82
+ lm = lm.reshape([len(full_frames), -1, 2])
83
+
84
+ if not os.path.isfile('temp/'+base_name+'_coeffs.npy') or exp_img is not None or re_preprocess:
85
+ net_recon = load_face3d_net(face3d_net_path, device)
86
+ lm3d_std = load_lm3d('checkpoints/BFM_Fitting')
87
+
88
+ video_coeffs = []
89
+ for idx in tqdm(range(len(frames_pil)), desc="[Step 2] 3DMM Extraction In Video:"):
90
+ frame = frames_pil[idx]
91
+ W, H = frame.size
92
+ lm_idx = lm[idx].reshape([-1, 2])
93
+ if np.mean(lm_idx) == -1:
94
+ lm_idx = (lm3d_std[:, :2]+1) / 2.
95
+ lm_idx = np.concatenate([lm_idx[:, :1] * W, lm_idx[:, 1:2] * H], 1)
96
+ else:
97
+ lm_idx[:, -1] = H - 1 - lm_idx[:, -1]
98
+
99
+ trans_params, im_idx, lm_idx, _ = align_img(frame, lm_idx, lm3d_std)
100
+ trans_params = np.array([float(item) for item in np.hsplit(trans_params, 5)]).astype(np.float32)
101
+ im_idx_tensor = torch.tensor(np.array(im_idx)/255., dtype=torch.float32).permute(2, 0, 1).to(device).unsqueeze(0)
102
+ with torch.no_grad():
103
+ coeffs = split_coeff(net_recon(im_idx_tensor))
104
+
105
+ pred_coeff = {key:coeffs[key].cpu().numpy() for key in coeffs}
106
+ pred_coeff = np.concatenate([pred_coeff['id'], pred_coeff['exp'], pred_coeff['tex'], pred_coeff['angle'],\
107
+ pred_coeff['gamma'], pred_coeff['trans'], trans_params[None]], 1)
108
+ video_coeffs.append(pred_coeff)
109
+ semantic_npy = np.array(video_coeffs)[:,0]
110
+ np.save('temp/'+base_name+'_coeffs.npy', semantic_npy)
111
+ else:
112
+ print('[Step 2] Using saved coeffs.')
113
+ semantic_npy = np.load('temp/'+base_name+'_coeffs.npy').astype(np.float32)
114
+
115
+ # generate the 3dmm coeff from a single image
116
+ if exp_img is not None and ('.png' in exp_img or '.jpg' in exp_img):
117
+ print('extract the exp from',exp_img)
118
+ exp_pil = Image.open(exp_img).convert('RGB')
119
+ lm3d_std = load_lm3d('third_part/face3d/BFM')
120
+
121
+ W, H = exp_pil.size
122
+ kp_extractor = KeypointExtractor()
123
+ lm_exp = kp_extractor.extract_keypoint([exp_pil], 'temp/'+base_name+'_temp.txt')[0]
124
+ if np.mean(lm_exp) == -1:
125
+ lm_exp = (lm3d_std[:, :2] + 1) / 2.
126
+ lm_exp = np.concatenate(
127
+ [lm_exp[:, :1] * W, lm_exp[:, 1:2] * H], 1)
128
+ else:
129
+ lm_exp[:, -1] = H - 1 - lm_exp[:, -1]
130
+
131
+ trans_params, im_exp, lm_exp, _ = align_img(exp_pil, lm_exp, lm3d_std)
132
+ trans_params = np.array([float(item) for item in np.hsplit(trans_params, 5)]).astype(np.float32)
133
+ im_exp_tensor = torch.tensor(np.array(im_exp)/255., dtype=torch.float32).permute(2, 0, 1).to(device).unsqueeze(0)
134
+ with torch.no_grad():
135
+ expression = split_coeff(net_recon(im_exp_tensor))['exp'][0]
136
+ del net_recon
137
+ elif exp_img == 'smile':
138
+ expression = torch.tensor(loadmat('checkpoints/expression.mat')['expression_mouth'])[0]
139
+ else:
140
+ print('using expression center')
141
+ expression = torch.tensor(loadmat('checkpoints/expression.mat')['expression_center'])[0]
142
+
143
+ # load DNet, model(LNet and ENet)
144
+ D_Net, model = load_model(device,DNet_path='checkpoints/DNet.pt',LNet_path='checkpoints/LNet.pth',ENet_path='checkpoints/ENet.pth')
145
+
146
+ if not os.path.isfile('temp/'+base_name+'_stablized.npy') or re_preprocess:
147
+ imgs = []
148
+ for idx in tqdm(range(len(frames_pil)), desc="[Step 3] Stabilize the expression In Video:"):
149
+ if one_shot:
150
+ source_img = trans_image(frames_pil[0]).unsqueeze(0).to(device)
151
+ semantic_source_numpy = semantic_npy[0:1]
152
+ else:
153
+ source_img = trans_image(frames_pil[idx]).unsqueeze(0).to(device)
154
+ semantic_source_numpy = semantic_npy[idx:idx+1]
155
+ ratio = find_crop_norm_ratio(semantic_source_numpy, semantic_npy)
156
+ coeff = transform_semantic(semantic_npy, idx, ratio).unsqueeze(0).to(device)
157
+
158
+ # hacking the new expression
159
+ coeff[:, :64, :] = expression[None, :64, None].to(device)
160
+ with torch.no_grad():
161
+ output = D_Net(source_img, coeff)
162
+ img_stablized = np.uint8((output['fake_image'].squeeze(0).permute(1,2,0).cpu().clamp_(-1, 1).numpy() + 1 )/2. * 255)
163
+ imgs.append(cv2.cvtColor(img_stablized,cv2.COLOR_RGB2BGR))
164
+ np.save('temp/'+base_name+'_stablized.npy',imgs)
165
+ del D_Net
166
+ else:
167
+ print('[Step 3] Using saved stabilized video.')
168
+ imgs = np.load('temp/'+base_name+'_stablized.npy')
169
+ torch.cuda.empty_cache()
170
+
171
+ if not audio_path.endswith('.wav'):
172
+ # command = 'ffmpeg -loglevel error -y -i {} -strict -2 {}'.format(audio_path, 'temp/{}/temp.wav'.format(tmp_dir))
173
+ # subprocess.call(command, shell=True)
174
+ converted_audio_path = os.path.join('temp', tmp_dir, 'temp.wav')
175
+ audio_clip = AudioFileClip(audio_path)
176
+ audio_clip.write_audiofile(converted_audio_path, codec='pcm_s16le')
177
+ audio_clip.close()
178
+ audio_path = converted_audio_path
179
+ # audio_path = 'temp/{}/temp.wav'.format(tmp_dir)
180
+ wav = audio.load_wav(audio_path, 16000)
181
+ mel = audio.melspectrogram(wav)
182
+ if np.isnan(mel.reshape(-1)).sum() > 0:
183
+ raise ValueError('Mel contains nan! Using a TTS voice? Add a small epsilon noise to the wav file and try again')
184
+
185
+ mel_step_size, mel_idx_multiplier, i, mel_chunks = 16, 80./fps, 0, []
186
+ while True:
187
+ start_idx = int(i * mel_idx_multiplier)
188
+ if start_idx + mel_step_size > len(mel[0]):
189
+ mel_chunks.append(mel[:, len(mel[0]) - mel_step_size:])
190
+ break
191
+ mel_chunks.append(mel[:, start_idx : start_idx + mel_step_size])
192
+ i += 1
193
+
194
+ print("[Step 4] Load audio; Length of mel chunks: {}".format(len(mel_chunks)))
195
+ imgs = imgs[:len(mel_chunks)]
196
+ full_frames = full_frames[:len(mel_chunks)]
197
+ lm = lm[:len(mel_chunks)]
198
+
199
+ imgs_enhanced = []
200
+ for idx in tqdm(range(len(imgs)), desc='[Step 5] Reference Enhancement'):
201
+ img = imgs[idx]
202
+ pred, _, _ = enhancer.process(img, img, face_enhance=True, possion_blending=False)
203
+ imgs_enhanced.append(pred)
204
+ gen = datagen(imgs_enhanced.copy(), mel_chunks, full_frames, None, (oy1,oy2,ox1,ox2), face, static, LNet_batch_size, img_size=384)
205
+
206
+ frame_h, frame_w = full_frames[0].shape[:-1]
207
+ out = cv2.VideoWriter('temp/{}/result.mp4'.format(tmp_dir), cv2.VideoWriter_fourcc(*'mp4v'), fps, (frame_w, frame_h))
208
+
209
+ if up_face != 'original':
210
+ instance = GANimationModel()
211
+ instance.initialize()
212
+ instance.setup()
213
+
214
+ kp_extractor = KeypointExtractor()
215
+ for i, (img_batch, mel_batch, frames, coords, img_original, f_frames) in enumerate(tqdm(gen, desc='[Step 6] Lip Synthesis:', total=int(np.ceil(float(len(mel_chunks)) / LNet_batch_size)))):
216
+ img_batch = torch.FloatTensor(np.transpose(img_batch, (0, 3, 1, 2))).to(device)
217
+ mel_batch = torch.FloatTensor(np.transpose(mel_batch, (0, 3, 1, 2))).to(device)
218
+ img_original = torch.FloatTensor(np.transpose(img_original, (0, 3, 1, 2))).to(device)/255. # BGR -> RGB
219
+
220
+ with torch.no_grad():
221
+ incomplete, reference = torch.split(img_batch, 3, dim=1)
222
+ pred, low_res = model(mel_batch, img_batch, reference)
223
+ pred = torch.clamp(pred, 0, 1)
224
+
225
+ if up_face in ['sad', 'angry', 'surprise']:
226
+ tar_aus = exp_aus_dict[up_face]
227
+ else:
228
+ pass
229
+
230
+ if up_face == 'original':
231
+ cur_gen_faces = img_original
232
+ else:
233
+ test_batch = {'src_img': torch.nn.functional.interpolate((img_original * 2 - 1), size=(128, 128), mode='bilinear'),
234
+ 'tar_aus': tar_aus.repeat(len(incomplete), 1)}
235
+ instance.feed_batch(test_batch)
236
+ instance.forward()
237
+ cur_gen_faces = torch.nn.functional.interpolate(instance.fake_img / 2. + 0.5, size=(384, 384), mode='bilinear')
238
+
239
+ if without_rl1 is not False:
240
+ incomplete, reference = torch.split(img_batch, 3, dim=1)
241
+ mask = torch.where(incomplete==0, torch.ones_like(incomplete), torch.zeros_like(incomplete))
242
+ pred = pred * mask + cur_gen_faces * (1 - mask)
243
+
244
+ pred = pred.cpu().numpy().transpose(0, 2, 3, 1) * 255.
245
+
246
+ torch.cuda.empty_cache()
247
+ for p, f, xf, c in zip(pred, frames, f_frames, coords):
248
+ y1, y2, x1, x2 = c
249
+ p = cv2.resize(p.astype(np.uint8), (x2 - x1, y2 - y1))
250
+
251
+ ff = xf.copy()
252
+ ff[y1:y2, x1:x2] = p
253
+
254
+ restored_img = ff
255
+ mm = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 255, 255, 255, 0, 0, 0, 0, 0, 0]
256
+ mouse_mask = np.zeros_like(restored_img)
257
+ tmp_mask = enhancer.faceparser.process(restored_img[y1:y2, x1:x2], mm)[0]
258
+ mouse_mask[y1:y2, x1:x2]= cv2.resize(tmp_mask, (x2 - x1, y2 - y1))[:, :, np.newaxis] / 255.
259
+
260
+ height, width = ff.shape[:2]
261
+ restored_img, ff, full_mask = [cv2.resize(x, (512, 512)) for x in (restored_img, ff, np.float32(mouse_mask))]
262
+ img = Laplacian_Pyramid_Blending_with_mask(restored_img, ff, full_mask[:, :, 0], 10)
263
+ pp = np.uint8(cv2.resize(np.clip(img, 0 ,255), (width, height)))
264
+
265
+ pp, orig_faces, enhanced_faces = enhancer.process(pp, xf, bbox=c, face_enhance=False, possion_blending=True)
266
+ out.write(pp)
267
+ out.release()
268
+
269
+ if not os.path.isdir(os.path.dirname(outfile)):
270
+ os.makedirs(os.path.dirname(outfile), exist_ok=True)
271
+ # command = 'ffmpeg -loglevel error -y -i {} -i {} -strict -2 -q:v 1 {}'.format(audio_path, 'temp/{}/result.mp4'.format(tmp_dir), outfile)
272
+ # subprocess.call(command, shell=platform.system() != 'Windows')
273
+ video_path = 'temp/{}/result.mp4'.format(tmp_dir)
274
+ audio_clip = AudioFileClip(audio_path)
275
+ video_clip = VideoFileClip(video_path)
276
+ video_clip = video_clip.set_audio(audio_clip)
277
+
278
+ # Write the result to the output file
279
+ video_clip.write_videofile(outfile, codec='libx264', audio_codec='aac')
280
+ print('outfile:', outfile)
281
+
282
+ # frames:256x256, full_frames: original size
283
+ def datagen(frames, mels, full_frames, frames_pil, cox, face, static, LNet_batch_size, img_size):
284
+ img_batch, mel_batch, frame_batch, coords_batch, ref_batch, full_frame_batch = [], [], [], [], [], []
285
+ base_name = face.split('/')[-1]
286
+ refs = []
287
+ image_size = 256
288
+
289
+ # original frames
290
+ kp_extractor = KeypointExtractor()
291
+ fr_pil = [Image.fromarray(frame) for frame in frames]
292
+ lms = kp_extractor.extract_keypoint(fr_pil, 'temp/'+base_name+'x12_landmarks.txt')
293
+ frames_pil = [ (lm, frame) for frame,lm in zip(fr_pil, lms)] # frames is the croped version of modified face
294
+ crops, orig_images, quads = crop_faces(image_size, frames_pil, scale=1.0, use_fa=True)
295
+ inverse_transforms = [calc_alignment_coefficients(quad + 0.5, [[0, 0], [0, image_size], [image_size, image_size], [image_size, 0]]) for quad in quads]
296
+ del kp_extractor.detector
297
+
298
+ oy1,oy2,ox1,ox2 = cox
299
+ face_det_results = face_detect(full_frames, face_det_batch_size=4, nosmooth=False, pads=[0, 20, 0, 0], jaw_correction=True, detector=None)
300
+
301
+ for inverse_transform, crop, full_frame, face_det in zip(inverse_transforms, crops, full_frames, face_det_results):
302
+ imc_pil = paste_image(inverse_transform, crop, Image.fromarray(
303
+ cv2.resize(full_frame[int(oy1):int(oy2), int(ox1):int(ox2)], (256, 256))))
304
+
305
+ ff = full_frame.copy()
306
+ ff[int(oy1):int(oy2), int(ox1):int(ox2)] = cv2.resize(np.array(imc_pil.convert('RGB')), (ox2 - ox1, oy2 - oy1))
307
+ oface, coords = face_det
308
+ y1, y2, x1, x2 = coords
309
+ refs.append(ff[y1: y2, x1:x2])
310
+
311
+ for i, m in enumerate(mels):
312
+ idx = 0 if static else i % len(frames)
313
+ frame_to_save = frames[idx].copy()
314
+ face = refs[idx]
315
+ oface, coords = face_det_results[idx].copy()
316
+
317
+ face = cv2.resize(face, (img_size, img_size))
318
+ oface = cv2.resize(oface, (img_size, img_size))
319
+
320
+ img_batch.append(oface)
321
+ ref_batch.append(face)
322
+ mel_batch.append(m)
323
+ coords_batch.append(coords)
324
+ frame_batch.append(frame_to_save)
325
+ full_frame_batch.append(full_frames[idx].copy())
326
+
327
+ if len(img_batch) >= LNet_batch_size:
328
+ img_batch, mel_batch, ref_batch = np.asarray(img_batch), np.asarray(mel_batch), np.asarray(ref_batch)
329
+ img_masked = img_batch.copy()
330
+ img_original = img_batch.copy()
331
+ img_masked[:, img_size//2:] = 0
332
+ img_batch = np.concatenate((img_masked, ref_batch), axis=3) / 255.
333
+ mel_batch = np.reshape(mel_batch, [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1])
334
+
335
+ yield img_batch, mel_batch, frame_batch, coords_batch, img_original, full_frame_batch
336
+ img_batch, mel_batch, frame_batch, coords_batch, img_original, full_frame_batch, ref_batch = [], [], [], [], [], [], []
337
+
338
+ if len(img_batch) > 0:
339
+ img_batch, mel_batch, ref_batch = np.asarray(img_batch), np.asarray(mel_batch), np.asarray(ref_batch)
340
+ img_masked = img_batch.copy()
341
+ img_original = img_batch.copy()
342
+ img_masked[:, img_size//2:] = 0
343
+ img_batch = np.concatenate((img_masked, ref_batch), axis=3) / 255.
344
+ mel_batch = np.reshape(mel_batch, [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1])
345
+ yield img_batch, mel_batch, frame_batch, coords_batch, img_original, full_frame_batch
346
+