salomonsky commited on
Commit
1345bfc
·
1 Parent(s): 38804e8

Update inference.py

Browse files
Files changed (1) hide show
  1. inference.py +124 -110
inference.py CHANGED
@@ -7,7 +7,7 @@ from glob import glob
7
  import torch, face_detection
8
  from models import Wav2Lip
9
  import platform
10
-
11
 
12
  parser = argparse.ArgumentParser(description='Inference code to lip-sync videos in the wild using Wav2Lip models')
13
 
@@ -67,116 +67,130 @@ def get_smoothened_boxes(boxes, T):
67
  return boxes
68
 
69
  def face_detect(images):
70
- # TODO 识别头像信息
71
-
72
- detector = face_detection.FaceAlignment(face_detection.LandmarksType._2D,
73
- flip_input=False, device=device)
74
-
75
- batch_size = args.face_det_batch_size
76
-
77
- while 1:
78
- predictions = []
79
- try:
80
- for i in tqdm(range(0, len(images), batch_size)):
81
- predictions.extend(detector.get_detections_for_batch(np.array(images[i:i + batch_size])))
82
- except RuntimeError:
83
- if batch_size == 1:
84
- raise RuntimeError('Image too big to run face detection on GPU. Please use the --resize_factor argument')
85
- batch_size //= 2
86
- print('Recovering from OOM error; New batch size: {}'.format(batch_size))
87
- continue
88
- break
89
- head_exist = []
90
- results = []
91
- pady1, pady2, padx1, padx2 = args.pads
92
-
93
- first_head_rect = None
94
- first_head_image =None
95
- for rect, image in zip(predictions, images):
96
- if rect is not None:
97
- first_head_rect = rect
98
- first_head_image = image
99
- break
100
- for rect, image in zip(predictions, images):
101
- if rect is None:
102
- head_exist.append(False)
103
- if len(results)==0:
104
- y1 = max(0, first_head_rect[1] - pady1)
105
- y2 = min(first_head_image.shape[0], first_head_rect[3] + pady2)
106
- x1 = max(0, first_head_rect[0] - padx1)
107
- x2 = min(first_head_image.shape[1], first_head_rect[2] + padx2)
108
- results.append([x1, y1, x2, y2])
109
- else:
110
- results.append(results[-1])
111
- # cv2.imwrite('temp/faulty_frame.jpg', image) # check this frame where the face was not detected.
112
- # raise ValueError('Face not detected! Ensure the video contains a face in all the frames.')
113
- else:
114
- head_exist.append(True)
115
- y1 = max(0, rect[1] - pady1)
116
- y2 = min(image.shape[0], rect[3] + pady2)
117
- x1 = max(0, rect[0] - padx1)
118
- x2 = min(image.shape[1], rect[2] + padx2)
119
- results.append([x1, y1, x2, y2])
120
-
121
- boxes = np.array(results)
122
- if not args.nosmooth: boxes = get_smoothened_boxes(boxes, T=5)
123
- results = [[image[y1: y2, x1:x2], (y1, y2, x1, x2)] for image, (x1, y1, x2, y2) in zip(images, boxes)]
124
-
125
- del detector
126
- return results,head_exist
 
 
 
127
 
128
  def datagen(frames, mels):
129
- img_batch,head_exist_batch, mel_batch, frame_batch, coords_batch = [], [], [], [],[]
130
-
131
- # ***************************1、识别人脸对应的位置坐标,未识别的人脸的帧对应为None ***************************
132
- if args.box[0] == -1:
133
- if not args.static:
134
- face_det_results,head_exist = face_detect(frames) # BGR2RGB for CNN face detection
135
- else:
136
- face_det_results,head_exist = face_detect([frames[0]])
137
- else:
138
- print('Using the specified bounding box instead of face detection...')
139
- y1, y2, x1, x2 = args.box
140
- face_det_results = [[f[y1: y2, x1:x2], (y1, y2, x1, x2)] for f in frames]
141
- head_exist = [True]*len(frames)
142
-
143
- for i, m in enumerate(mels):
144
- #获取对应的一组音频对应的帧下标idx
145
- idx = 0 if args.static else i%len(frames)
146
- #获取对应的一组音频对应的帧
147
- frame_to_save = frames[idx].copy()
148
- #获取对应的一组音频对应的帧对应的人脸坐标
149
- face, coords = face_det_results[idx].copy()
150
-
151
- face = cv2.resize(face, (args.img_size, args.img_size))
152
- head_exist_batch.append(head_exist[idx])
153
- img_batch.append(face)
154
- mel_batch.append(m)
155
- frame_batch.append(frame_to_save)
156
- coords_batch.append(coords)
157
-
158
- if len(img_batch) >= args.wav2lip_batch_size:
159
- img_batch, mel_batch = np.asarray(img_batch), np.asarray(mel_batch)
160
-
161
- img_masked = img_batch.copy()
162
- img_masked[:, args.img_size//2:] = 0
163
-
164
- img_batch = np.concatenate((img_masked, img_batch), axis=3) / 255.
165
- mel_batch = np.reshape(mel_batch, [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1])
166
-
167
- yield img_batch,head_exist_batch, mel_batch, frame_batch, coords_batch
168
- img_batch,head_exist_batch, mel_batch, frame_batch, coords_batch = [],[], [], [], []
169
-
170
- if len(img_batch) > 0:
171
- img_batch, mel_batch = np.asarray(img_batch), np.asarray(mel_batch)
172
-
173
- img_masked = img_batch.copy()
174
- img_masked[:, args.img_size//2:] = 0
175
-
176
- img_batch = np.concatenate((img_masked, img_batch), axis=3) / 255.
177
- mel_batch = np.reshape(mel_batch, [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1])
178
-
179
- yield img_batch,head_exist_batch, mel_batch, frame_batch, coords_batch
 
 
 
 
 
 
 
 
 
 
 
180
 
181
  mel_step_size = 16
182
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
@@ -313,4 +327,4 @@ def main():
313
  subprocess.call(command, shell=platform.system() != 'Windows')
314
 
315
  if __name__ == '__main__':
316
- main()
 
7
  import torch, face_detection
8
  from models import Wav2Lip
9
  import platform
10
+ import cv2
11
 
12
  parser = argparse.ArgumentParser(description='Inference code to lip-sync videos in the wild using Wav2Lip models')
13
 
 
67
  return boxes
68
 
69
  def face_detect(images):
70
+ detector = face_detection.FaceAlignment(face_detection.LandmarksType._2D,
71
+ flip_input=False, device=device)
72
+
73
+ batch_size = args.face_det_batch_size
74
+
75
+ last_face = None # Agregar la variable para guardar la última imagen detectada
76
+
77
+ while 1:
78
+ predictions = []
79
+ try:
80
+ for i in tqdm(range(0, len(images), batch_size)):
81
+ predictions.extend(detector.get_detections_for_batch(np.array(images[i:i + batch_size])))
82
+ except RuntimeError:
83
+ if batch_size == 1:
84
+ raise RuntimeError('Image too big to run face detection on GPU. Please use the --resize_factor argument')
85
+ batch_size //= 2
86
+ print('Recovering from OOM error; New batch size: {}'.format(batch_size))
87
+ continue
88
+ break
89
+ head_exist = []
90
+ results = []
91
+ pady1, pady2, padx1, padx2 = args.pads
92
+
93
+ first_head_rect = None
94
+ first_head_image =None
95
+ for rect, image in zip(predictions, images):
96
+ if rect is not None:
97
+ first_head_rect = rect
98
+ first_head_image = image
99
+ break
100
+ for rect, image in zip(predictions, images):
101
+ if rect is None:
102
+ head_exist.append(False)
103
+ if len(results)==0:
104
+ y1 = max(0, first_head_rect[1] - pady1)
105
+ y2 = min(first_head_image.shape[0], first_head_rect[3] + pady2)
106
+ x1 = max(0, first_head_rect[0] - padx1)
107
+ x2 = min(first_head_image.shape[1], first_head_rect[2] + padx2)
108
+ results.append([x1, y1, x2, y2])
109
+ else:
110
+ results.append(results[-1])
111
+ else:
112
+ head_exist.append(True)
113
+ y1 = max(0, rect[1] - pady1)
114
+ y2 = min(image.shape[0], rect[3] + pady2)
115
+ x1 = max(0, rect[0] - padx1)
116
+ x2 = min(image.shape[1], rect[2] + padx2)
117
+ results.append([x1, y1, x2, y2])
118
+ # Agregar la línea de código para guardar la imagen
119
+ last_face = image[y1: y2, x1:x2]
120
+ cv2.imwrite("last_face.jpg", last_face)
121
+
122
+ boxes = np.array(results)
123
+ if not args.nosmooth: boxes = get_smoothened_boxes(boxes, T=5)
124
+ results = [[image[y1: y2, x1:x2], (y1, y2, x1, x2)] for image, (x1, y1, x2, y2) in zip(images, boxes)]
125
+
126
+ del detector
127
+ return results,head_exist
128
+
129
+ import cv2
130
 
131
  def datagen(frames, mels):
132
+ img_batch,head_exist_batch, mel_batch, frame_batch, coords_batch = [], [], [], [],[]
133
+
134
+ # ***************************1、识别人脸对应的位置坐标,未识别的人脸的帧对应为None ***************************
135
+ if args.box[0] == -1:
136
+ if not args.static:
137
+ face_det_results,head_exist = face_detect(frames) # BGR2RGB for CNN face detection
138
+ else:
139
+ face_det_results,head_exist = face_detect([frames[0]])
140
+ else:
141
+ print('Using the specified bounding box instead of face detection...')
142
+ y1, y2, x1, x2 = args.box
143
+ face_det_results = [[f[y1: y2, x1:x2], (y1, y2, x1, x2)] for f in frames]
144
+ head_exist = [True]*len(frames)
145
+
146
+ for i, m in enumerate(mels):
147
+ #获取对应的一组音频对应的帧下标idx
148
+ idx = 0 if args.static else i%len(frames)
149
+ #获取对应的一组音频对应的帧
150
+ frame_to_save = frames[idx].copy()
151
+ #获取对应的一组音频对应的帧对应的人脸坐标
152
+ face, coords = face_det_results[idx].copy()
153
+
154
+ face = cv2.resize(face, (args.img_size, args.img_size))
155
+ head_exist_batch.append(head_exist[idx])
156
+ img_batch.append(face)
157
+ melspec = m
158
+ mel_batch.append(melspec)
159
+ frame_batch.append(frame_to_save)
160
+ coords_batch.append(coords)
161
+
162
+ if len(img_batch) >= args.wav2lip_batch_size:
163
+ img_batch, mel_batch = np.asarray(img_batch), np.asarray(mel_batch)
164
+
165
+ img_masked = img_batch.copy()
166
+ img_masked[:, args.img_size//2:] = 0
167
+
168
+ img_batch = np.concatenate((img_masked, img_batch), axis=3) / 255.
169
+ mel_batch = np.reshape(mel_batch, [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1])
170
+
171
+ yield img_batch,head_exist_batch, mel_batch, frame_batch, coords_batch
172
+ img_batch,head_exist_batch, mel_batch, frame_batch, coords_batch = [],[], [], [], []
173
+
174
+ # Agregar la línea de código para leer la imagen guardada automáticamente
175
+ last_face = cv2.imread("last_face.jpg")
176
+ last_face = cv2.resize(last_face, (args.img_size, args.img_size))
177
+ img_batch.append(last_face)
178
+ melspec = mels[-1]
179
+ mel_batch.append(melspec)
180
+ frame_batch.append(frames[-1])
181
+ coords_batch.append(face_det_results[-1][1])
182
+ head_exist_batch.append(head_exist[-1])
183
+
184
+ if len(img_batch) > 0:
185
+ img_batch, mel_batch = np.asarray(img_batch), np.asarray(mel_batch)
186
+
187
+ img_masked = img_batch.copy()
188
+ img_masked[:, args.img_size//2:] = 0
189
+
190
+ img_batch = np.concatenate((img_masked, img_batch), axis=3) / 255.
191
+ mel_batch = np.reshape(mel_batch, [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1])
192
+
193
+ yield img_batch,head_exist_batch, mel_batch, frame_batch, coords_batch
194
 
195
  mel_step_size = 16
196
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
 
327
  subprocess.call(command, shell=platform.system() != 'Windows')
328
 
329
  if __name__ == '__main__':
330
+ main()