Spanicin commited on
Commit
42e8c1c
·
verified ·
1 Parent(s): 98eca54

Update videoretalking/utils/inference_utils.py

Browse files
Files changed (1) hide show
  1. videoretalking/utils/inference_utils.py +253 -253
videoretalking/utils/inference_utils.py CHANGED
@@ -1,254 +1,254 @@
1
- import numpy as np
2
- import cv2, argparse, torch
3
- import torchvision.transforms.functional as TF
4
-
5
- from models import load_network, load_DNet
6
- from tqdm import tqdm
7
- from PIL import Image
8
- from scipy.spatial import ConvexHull
9
- from third_part import face_detection
10
- from third_part.face3d.models import networks
11
-
12
- import warnings
13
- warnings.filterwarnings("ignore")
14
-
15
- def options():
16
- parser = argparse.ArgumentParser(description='Inference code to lip-sync videos in the wild using Wav2Lip models')
17
-
18
- parser.add_argument('--DNet_path', type=str, default='checkpoints/DNet.pt')
19
- parser.add_argument('--LNet_path', type=str, default='checkpoints/LNet.pth')
20
- parser.add_argument('--ENet_path', type=str, default='checkpoints/ENet.pth')
21
- parser.add_argument('--face3d_net_path', type=str, default='checkpoints/face3d_pretrain_epoch_20.pth')
22
- parser.add_argument('--face', type=str, help='Filepath of video/image that contains faces to use', required=True)
23
- parser.add_argument('--audio', type=str, help='Filepath of video/audio file to use as raw audio source', required=True)
24
- parser.add_argument('--exp_img', type=str, help='Expression template. neutral, smile or image path', default='neutral')
25
- parser.add_argument('--outfile', type=str, help='Video path to save result')
26
-
27
- parser.add_argument('--fps', type=float, help='Can be specified only if input is a static image (default: 25)', default=25., required=False)
28
- parser.add_argument('--pads', nargs='+', type=int, default=[0, 20, 0, 0], help='Padding (top, bottom, left, right). Please adjust to include chin at least')
29
- parser.add_argument('--face_det_batch_size', type=int, help='Batch size for face detection', default=4)
30
- parser.add_argument('--LNet_batch_size', type=int, help='Batch size for LNet', default=16)
31
- parser.add_argument('--img_size', type=int, default=384)
32
- parser.add_argument('--crop', nargs='+', type=int, default=[0, -1, 0, -1],
33
- help='Crop video to a smaller region (top, bottom, left, right). Applied after resize_factor and rotate arg. '
34
- 'Useful if multiple face present. -1 implies the value will be auto-inferred based on height, width')
35
- parser.add_argument('--box', nargs='+', type=int, default=[-1, -1, -1, -1],
36
- help='Specify a constant bounding box for the face. Use only as a last resort if the face is not detected.'
37
- 'Also, might work only if the face is not moving around much. Syntax: (top, bottom, left, right).')
38
- parser.add_argument('--nosmooth', default=False, action='store_true', help='Prevent smoothing face detections over a short temporal window')
39
- parser.add_argument('--static', default=False, action='store_true')
40
-
41
-
42
- parser.add_argument('--up_face', default='original')
43
- parser.add_argument('--one_shot', action='store_true')
44
- parser.add_argument('--without_rl1', default=False, action='store_true', help='Do not use the relative l1')
45
- parser.add_argument('--tmp_dir', type=str, default='temp', help='Folder to save tmp results')
46
- parser.add_argument('--re_preprocess', action='store_true')
47
-
48
- args = parser.parse_args()
49
- return args
50
-
51
- exp_aus_dict = { # AU01_r, AU02_r, AU04_r, AU05_r, AU06_r, AU07_r, AU09_r, AU10_r, AU12_r, AU14_r, AU15_r, AU17_r, AU20_r, AU23_r, AU25_r, AU26_r, AU45_r.
52
- 'sad': torch.Tensor([[ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]),
53
- 'angry':torch.Tensor([[0, 0, 0.3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]),
54
- 'surprise': torch.Tensor([[0, 0, 0, 0.2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]])
55
- }
56
-
57
- def mask_postprocess(mask, thres=20):
58
- mask[:thres, :] = 0; mask[-thres:, :] = 0
59
- mask[:, :thres] = 0; mask[:, -thres:] = 0
60
- mask = cv2.GaussianBlur(mask, (101, 101), 11)
61
- mask = cv2.GaussianBlur(mask, (101, 101), 11)
62
- return mask.astype(np.float32)
63
-
64
- def trans_image(image):
65
- image = TF.resize(
66
- image, size=256, interpolation=Image.BICUBIC)
67
- image = TF.to_tensor(image)
68
- image = TF.normalize(image, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
69
- return image
70
-
71
- def obtain_seq_index(index, num_frames):
72
- seq = list(range(index-13, index+13))
73
- seq = [ min(max(item, 0), num_frames-1) for item in seq ]
74
- return seq
75
-
76
- def transform_semantic(semantic, frame_index, crop_norm_ratio=None):
77
- index = obtain_seq_index(frame_index, semantic.shape[0])
78
-
79
- coeff_3dmm = semantic[index,...]
80
- ex_coeff = coeff_3dmm[:,80:144] #expression # 64
81
- angles = coeff_3dmm[:,224:227] #euler angles for pose
82
- translation = coeff_3dmm[:,254:257] #translation
83
- crop = coeff_3dmm[:,259:262] #crop param
84
-
85
- if crop_norm_ratio:
86
- crop[:, -3] = crop[:, -3] * crop_norm_ratio
87
-
88
- coeff_3dmm = np.concatenate([ex_coeff, angles, translation, crop], 1)
89
- return torch.Tensor(coeff_3dmm).permute(1,0)
90
-
91
- def find_crop_norm_ratio(source_coeff, target_coeffs):
92
- alpha = 0.3
93
- exp_diff = np.mean(np.abs(target_coeffs[:,80:144] - source_coeff[:,80:144]), 1) # mean different exp
94
- angle_diff = np.mean(np.abs(target_coeffs[:,224:227] - source_coeff[:,224:227]), 1) # mean different angle
95
- index = np.argmin(alpha*exp_diff + (1-alpha)*angle_diff) # find the smallerest index
96
- crop_norm_ratio = source_coeff[:,-3] / target_coeffs[index:index+1, -3]
97
- return crop_norm_ratio
98
-
99
- def get_smoothened_boxes(boxes, T):
100
- for i in range(len(boxes)):
101
- if i + T > len(boxes):
102
- window = boxes[len(boxes) - T:]
103
- else:
104
- window = boxes[i : i + T]
105
- boxes[i] = np.mean(window, axis=0)
106
- return boxes
107
-
108
- def face_detect(images, face_det_batch_size, nosmooth, pads, jaw_correction, detector=None):
109
- # def face_detect(images, args, jaw_correction=False, detector=None):
110
- if detector == None:
111
- device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
112
- detector = face_detection.FaceAlignment(face_detection.LandmarksType._2D,
113
- flip_input=False, device=device)
114
-
115
- batch_size = face_det_batch_size
116
- while 1:
117
- predictions = []
118
- try:
119
- for i in tqdm(range(0, len(images), batch_size),desc='FaceDet:'):
120
- predictions.extend(detector.get_detections_for_batch(np.array(images[i:i + batch_size])))
121
- except RuntimeError:
122
- if batch_size == 1:
123
- raise RuntimeError('Image too big to run face detection on GPU. Please use the --resize_factor argument')
124
- batch_size //= 2
125
- print('Recovering from OOM error; New batch size: {}'.format(batch_size))
126
- continue
127
- break
128
-
129
- results = []
130
- pady1, pady2, padx1, padx2 = pads if jaw_correction else (0,20,0,0)
131
- for rect, image in zip(predictions, images):
132
- if rect is None:
133
- cv2.imwrite('temp/faulty_frame.jpg', image) # check this frame where the face was not detected.
134
- raise ValueError('Face not detected! Ensure the video contains a face in all the frames.')
135
-
136
- y1 = max(0, rect[1] - pady1)
137
- y2 = min(image.shape[0], rect[3] + pady2)
138
- x1 = max(0, rect[0] - padx1)
139
- x2 = min(image.shape[1], rect[2] + padx2)
140
- results.append([x1, y1, x2, y2])
141
-
142
- boxes = np.array(results)
143
- if not nosmooth: boxes = get_smoothened_boxes(boxes, T=5)
144
- results = [[image[y1: y2, x1:x2], (y1, y2, x1, x2)] for image, (x1, y1, x2, y2) in zip(images, boxes)]
145
-
146
- del detector
147
- torch.cuda.empty_cache()
148
- return results
149
-
150
- def _load(checkpoint_path, device):
151
- if device == 'cuda':
152
- checkpoint = torch.load(checkpoint_path)
153
- else:
154
- checkpoint = torch.load(checkpoint_path,
155
- map_location=lambda storage, loc: storage)
156
- return checkpoint
157
-
158
- def split_coeff(coeffs):
159
- """
160
- Return:
161
- coeffs_dict -- a dict of torch.tensors
162
-
163
- Parameters:
164
- coeffs -- torch.tensor, size (B, 256)
165
- """
166
- id_coeffs = coeffs[:, :80]
167
- exp_coeffs = coeffs[:, 80: 144]
168
- tex_coeffs = coeffs[:, 144: 224]
169
- angles = coeffs[:, 224: 227]
170
- gammas = coeffs[:, 227: 254]
171
- translations = coeffs[:, 254:]
172
- return {
173
- 'id': id_coeffs,
174
- 'exp': exp_coeffs,
175
- 'tex': tex_coeffs,
176
- 'angle': angles,
177
- 'gamma': gammas,
178
- 'trans': translations
179
- }
180
-
181
- def Laplacian_Pyramid_Blending_with_mask(A, B, m, num_levels = 6):
182
- # generate Gaussian pyramid for A,B and mask
183
- GA = A.copy()
184
- GB = B.copy()
185
- GM = m.copy()
186
- gpA = [GA]
187
- gpB = [GB]
188
- gpM = [GM]
189
- for i in range(num_levels):
190
- GA = cv2.pyrDown(GA)
191
- GB = cv2.pyrDown(GB)
192
- GM = cv2.pyrDown(GM)
193
- gpA.append(np.float32(GA))
194
- gpB.append(np.float32(GB))
195
- gpM.append(np.float32(GM))
196
-
197
- # generate Laplacian Pyramids for A,B and masks
198
- lpA = [gpA[num_levels-1]] # the bottom of the Lap-pyr holds the last (smallest) Gauss level
199
- lpB = [gpB[num_levels-1]]
200
- gpMr = [gpM[num_levels-1]]
201
- for i in range(num_levels-1,0,-1):
202
- # Laplacian: subtract upscaled version of lower level from current level
203
- # to get the high frequencies
204
- LA = np.subtract(gpA[i-1], cv2.pyrUp(gpA[i]))
205
- LB = np.subtract(gpB[i-1], cv2.pyrUp(gpB[i]))
206
- lpA.append(LA)
207
- lpB.append(LB)
208
- gpMr.append(gpM[i-1]) # also reverse the masks
209
-
210
- # Now blend images according to mask in each level
211
- LS = []
212
- for la,lb,gm in zip(lpA,lpB,gpMr):
213
- gm = gm[:,:,np.newaxis]
214
- ls = la * gm + lb * (1.0 - gm)
215
- LS.append(ls)
216
-
217
- # now reconstruct
218
- ls_ = LS[0]
219
- for i in range(1,num_levels):
220
- ls_ = cv2.pyrUp(ls_)
221
- ls_ = cv2.add(ls_, LS[i])
222
- return ls_
223
-
224
- def load_model(device,DNet_path,LNet_path,ENet_path):
225
- D_Net = load_DNet(DNet_path).to(device)
226
- model = load_network(LNet_path,ENet_path).to(device)
227
- return D_Net, model
228
-
229
- def normalize_kp(kp_source, kp_driving, kp_driving_initial, adapt_movement_scale=False,
230
- use_relative_movement=False, use_relative_jacobian=False):
231
- if adapt_movement_scale:
232
- source_area = ConvexHull(kp_source['value'][0].data.cpu().numpy()).volume
233
- driving_area = ConvexHull(kp_driving_initial['value'][0].data.cpu().numpy()).volume
234
- adapt_movement_scale = np.sqrt(source_area) / np.sqrt(driving_area)
235
- else:
236
- adapt_movement_scale = 1
237
-
238
- kp_new = {k: v for k, v in kp_driving.items()}
239
- if use_relative_movement:
240
- kp_value_diff = (kp_driving['value'] - kp_driving_initial['value'])
241
- kp_value_diff *= adapt_movement_scale
242
- kp_new['value'] = kp_value_diff + kp_source['value']
243
-
244
- if use_relative_jacobian:
245
- jacobian_diff = torch.matmul(kp_driving['jacobian'], torch.inverse(kp_driving_initial['jacobian']))
246
- kp_new['jacobian'] = torch.matmul(jacobian_diff, kp_source['jacobian'])
247
- return kp_new
248
-
249
- def load_face3d_net(ckpt_path, device):
250
- net_recon = networks.define_net_recon(net_recon='resnet50', use_last_fc=False, init_path='').to(device)
251
- checkpoint = torch.load(ckpt_path, map_location=device)
252
- net_recon.load_state_dict(checkpoint['net_recon'])
253
- net_recon.eval()
254
  return net_recon
 
1
+ import numpy as np
2
+ import cv2, argparse, torch
3
+ import torchvision.transforms.functional as TF
4
+
5
+ from videoretalking.models import load_network, load_DNet
6
+ from tqdm import tqdm
7
+ from PIL import Image
8
+ from scipy.spatial import ConvexHull
9
+ from videoretalking.third_part import face_detection
10
+ from videoretalking.third_part.face3d.models import networks
11
+
12
+ import warnings
13
+ warnings.filterwarnings("ignore")
14
+
15
+ def options():
16
+ parser = argparse.ArgumentParser(description='Inference code to lip-sync videos in the wild using Wav2Lip models')
17
+
18
+ parser.add_argument('--DNet_path', type=str, default='checkpoints/DNet.pt')
19
+ parser.add_argument('--LNet_path', type=str, default='checkpoints/LNet.pth')
20
+ parser.add_argument('--ENet_path', type=str, default='checkpoints/ENet.pth')
21
+ parser.add_argument('--face3d_net_path', type=str, default='checkpoints/face3d_pretrain_epoch_20.pth')
22
+ parser.add_argument('--face', type=str, help='Filepath of video/image that contains faces to use', required=True)
23
+ parser.add_argument('--audio', type=str, help='Filepath of video/audio file to use as raw audio source', required=True)
24
+ parser.add_argument('--exp_img', type=str, help='Expression template. neutral, smile or image path', default='neutral')
25
+ parser.add_argument('--outfile', type=str, help='Video path to save result')
26
+
27
+ parser.add_argument('--fps', type=float, help='Can be specified only if input is a static image (default: 25)', default=25., required=False)
28
+ parser.add_argument('--pads', nargs='+', type=int, default=[0, 20, 0, 0], help='Padding (top, bottom, left, right). Please adjust to include chin at least')
29
+ parser.add_argument('--face_det_batch_size', type=int, help='Batch size for face detection', default=4)
30
+ parser.add_argument('--LNet_batch_size', type=int, help='Batch size for LNet', default=16)
31
+ parser.add_argument('--img_size', type=int, default=384)
32
+ parser.add_argument('--crop', nargs='+', type=int, default=[0, -1, 0, -1],
33
+ help='Crop video to a smaller region (top, bottom, left, right). Applied after resize_factor and rotate arg. '
34
+ 'Useful if multiple face present. -1 implies the value will be auto-inferred based on height, width')
35
+ parser.add_argument('--box', nargs='+', type=int, default=[-1, -1, -1, -1],
36
+ help='Specify a constant bounding box for the face. Use only as a last resort if the face is not detected.'
37
+ 'Also, might work only if the face is not moving around much. Syntax: (top, bottom, left, right).')
38
+ parser.add_argument('--nosmooth', default=False, action='store_true', help='Prevent smoothing face detections over a short temporal window')
39
+ parser.add_argument('--static', default=False, action='store_true')
40
+
41
+
42
+ parser.add_argument('--up_face', default='original')
43
+ parser.add_argument('--one_shot', action='store_true')
44
+ parser.add_argument('--without_rl1', default=False, action='store_true', help='Do not use the relative l1')
45
+ parser.add_argument('--tmp_dir', type=str, default='temp', help='Folder to save tmp results')
46
+ parser.add_argument('--re_preprocess', action='store_true')
47
+
48
+ args = parser.parse_args()
49
+ return args
50
+
51
+ exp_aus_dict = { # AU01_r, AU02_r, AU04_r, AU05_r, AU06_r, AU07_r, AU09_r, AU10_r, AU12_r, AU14_r, AU15_r, AU17_r, AU20_r, AU23_r, AU25_r, AU26_r, AU45_r.
52
+ 'sad': torch.Tensor([[ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]),
53
+ 'angry':torch.Tensor([[0, 0, 0.3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]),
54
+ 'surprise': torch.Tensor([[0, 0, 0, 0.2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]])
55
+ }
56
+
57
+ def mask_postprocess(mask, thres=20):
58
+ mask[:thres, :] = 0; mask[-thres:, :] = 0
59
+ mask[:, :thres] = 0; mask[:, -thres:] = 0
60
+ mask = cv2.GaussianBlur(mask, (101, 101), 11)
61
+ mask = cv2.GaussianBlur(mask, (101, 101), 11)
62
+ return mask.astype(np.float32)
63
+
64
+ def trans_image(image):
65
+ image = TF.resize(
66
+ image, size=256, interpolation=Image.BICUBIC)
67
+ image = TF.to_tensor(image)
68
+ image = TF.normalize(image, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
69
+ return image
70
+
71
+ def obtain_seq_index(index, num_frames):
72
+ seq = list(range(index-13, index+13))
73
+ seq = [ min(max(item, 0), num_frames-1) for item in seq ]
74
+ return seq
75
+
76
+ def transform_semantic(semantic, frame_index, crop_norm_ratio=None):
77
+ index = obtain_seq_index(frame_index, semantic.shape[0])
78
+
79
+ coeff_3dmm = semantic[index,...]
80
+ ex_coeff = coeff_3dmm[:,80:144] #expression # 64
81
+ angles = coeff_3dmm[:,224:227] #euler angles for pose
82
+ translation = coeff_3dmm[:,254:257] #translation
83
+ crop = coeff_3dmm[:,259:262] #crop param
84
+
85
+ if crop_norm_ratio:
86
+ crop[:, -3] = crop[:, -3] * crop_norm_ratio
87
+
88
+ coeff_3dmm = np.concatenate([ex_coeff, angles, translation, crop], 1)
89
+ return torch.Tensor(coeff_3dmm).permute(1,0)
90
+
91
+ def find_crop_norm_ratio(source_coeff, target_coeffs):
92
+ alpha = 0.3
93
+ exp_diff = np.mean(np.abs(target_coeffs[:,80:144] - source_coeff[:,80:144]), 1) # mean different exp
94
+ angle_diff = np.mean(np.abs(target_coeffs[:,224:227] - source_coeff[:,224:227]), 1) # mean different angle
95
+ index = np.argmin(alpha*exp_diff + (1-alpha)*angle_diff) # find the smallerest index
96
+ crop_norm_ratio = source_coeff[:,-3] / target_coeffs[index:index+1, -3]
97
+ return crop_norm_ratio
98
+
99
+ def get_smoothened_boxes(boxes, T):
100
+ for i in range(len(boxes)):
101
+ if i + T > len(boxes):
102
+ window = boxes[len(boxes) - T:]
103
+ else:
104
+ window = boxes[i : i + T]
105
+ boxes[i] = np.mean(window, axis=0)
106
+ return boxes
107
+
108
+ def face_detect(images, face_det_batch_size, nosmooth, pads, jaw_correction, detector=None):
109
+ # def face_detect(images, args, jaw_correction=False, detector=None):
110
+ if detector == None:
111
+ device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
112
+ detector = face_detection.FaceAlignment(face_detection.LandmarksType._2D,
113
+ flip_input=False, device=device)
114
+
115
+ batch_size = face_det_batch_size
116
+ while 1:
117
+ predictions = []
118
+ try:
119
+ for i in tqdm(range(0, len(images), batch_size),desc='FaceDet:'):
120
+ predictions.extend(detector.get_detections_for_batch(np.array(images[i:i + batch_size])))
121
+ except RuntimeError:
122
+ if batch_size == 1:
123
+ raise RuntimeError('Image too big to run face detection on GPU. Please use the --resize_factor argument')
124
+ batch_size //= 2
125
+ print('Recovering from OOM error; New batch size: {}'.format(batch_size))
126
+ continue
127
+ break
128
+
129
+ results = []
130
+ pady1, pady2, padx1, padx2 = pads if jaw_correction else (0,20,0,0)
131
+ for rect, image in zip(predictions, images):
132
+ if rect is None:
133
+ cv2.imwrite('temp/faulty_frame.jpg', image) # check this frame where the face was not detected.
134
+ raise ValueError('Face not detected! Ensure the video contains a face in all the frames.')
135
+
136
+ y1 = max(0, rect[1] - pady1)
137
+ y2 = min(image.shape[0], rect[3] + pady2)
138
+ x1 = max(0, rect[0] - padx1)
139
+ x2 = min(image.shape[1], rect[2] + padx2)
140
+ results.append([x1, y1, x2, y2])
141
+
142
+ boxes = np.array(results)
143
+ if not nosmooth: boxes = get_smoothened_boxes(boxes, T=5)
144
+ results = [[image[y1: y2, x1:x2], (y1, y2, x1, x2)] for image, (x1, y1, x2, y2) in zip(images, boxes)]
145
+
146
+ del detector
147
+ torch.cuda.empty_cache()
148
+ return results
149
+
150
+ def _load(checkpoint_path, device):
151
+ if device == 'cuda':
152
+ checkpoint = torch.load(checkpoint_path)
153
+ else:
154
+ checkpoint = torch.load(checkpoint_path,
155
+ map_location=lambda storage, loc: storage)
156
+ return checkpoint
157
+
158
+ def split_coeff(coeffs):
159
+ """
160
+ Return:
161
+ coeffs_dict -- a dict of torch.tensors
162
+
163
+ Parameters:
164
+ coeffs -- torch.tensor, size (B, 256)
165
+ """
166
+ id_coeffs = coeffs[:, :80]
167
+ exp_coeffs = coeffs[:, 80: 144]
168
+ tex_coeffs = coeffs[:, 144: 224]
169
+ angles = coeffs[:, 224: 227]
170
+ gammas = coeffs[:, 227: 254]
171
+ translations = coeffs[:, 254:]
172
+ return {
173
+ 'id': id_coeffs,
174
+ 'exp': exp_coeffs,
175
+ 'tex': tex_coeffs,
176
+ 'angle': angles,
177
+ 'gamma': gammas,
178
+ 'trans': translations
179
+ }
180
+
181
+ def Laplacian_Pyramid_Blending_with_mask(A, B, m, num_levels = 6):
182
+ # generate Gaussian pyramid for A,B and mask
183
+ GA = A.copy()
184
+ GB = B.copy()
185
+ GM = m.copy()
186
+ gpA = [GA]
187
+ gpB = [GB]
188
+ gpM = [GM]
189
+ for i in range(num_levels):
190
+ GA = cv2.pyrDown(GA)
191
+ GB = cv2.pyrDown(GB)
192
+ GM = cv2.pyrDown(GM)
193
+ gpA.append(np.float32(GA))
194
+ gpB.append(np.float32(GB))
195
+ gpM.append(np.float32(GM))
196
+
197
+ # generate Laplacian Pyramids for A,B and masks
198
+ lpA = [gpA[num_levels-1]] # the bottom of the Lap-pyr holds the last (smallest) Gauss level
199
+ lpB = [gpB[num_levels-1]]
200
+ gpMr = [gpM[num_levels-1]]
201
+ for i in range(num_levels-1,0,-1):
202
+ # Laplacian: subtract upscaled version of lower level from current level
203
+ # to get the high frequencies
204
+ LA = np.subtract(gpA[i-1], cv2.pyrUp(gpA[i]))
205
+ LB = np.subtract(gpB[i-1], cv2.pyrUp(gpB[i]))
206
+ lpA.append(LA)
207
+ lpB.append(LB)
208
+ gpMr.append(gpM[i-1]) # also reverse the masks
209
+
210
+ # Now blend images according to mask in each level
211
+ LS = []
212
+ for la,lb,gm in zip(lpA,lpB,gpMr):
213
+ gm = gm[:,:,np.newaxis]
214
+ ls = la * gm + lb * (1.0 - gm)
215
+ LS.append(ls)
216
+
217
+ # now reconstruct
218
+ ls_ = LS[0]
219
+ for i in range(1,num_levels):
220
+ ls_ = cv2.pyrUp(ls_)
221
+ ls_ = cv2.add(ls_, LS[i])
222
+ return ls_
223
+
224
+ def load_model(device,DNet_path,LNet_path,ENet_path):
225
+ D_Net = load_DNet(DNet_path).to(device)
226
+ model = load_network(LNet_path,ENet_path).to(device)
227
+ return D_Net, model
228
+
229
+ def normalize_kp(kp_source, kp_driving, kp_driving_initial, adapt_movement_scale=False,
230
+ use_relative_movement=False, use_relative_jacobian=False):
231
+ if adapt_movement_scale:
232
+ source_area = ConvexHull(kp_source['value'][0].data.cpu().numpy()).volume
233
+ driving_area = ConvexHull(kp_driving_initial['value'][0].data.cpu().numpy()).volume
234
+ adapt_movement_scale = np.sqrt(source_area) / np.sqrt(driving_area)
235
+ else:
236
+ adapt_movement_scale = 1
237
+
238
+ kp_new = {k: v for k, v in kp_driving.items()}
239
+ if use_relative_movement:
240
+ kp_value_diff = (kp_driving['value'] - kp_driving_initial['value'])
241
+ kp_value_diff *= adapt_movement_scale
242
+ kp_new['value'] = kp_value_diff + kp_source['value']
243
+
244
+ if use_relative_jacobian:
245
+ jacobian_diff = torch.matmul(kp_driving['jacobian'], torch.inverse(kp_driving_initial['jacobian']))
246
+ kp_new['jacobian'] = torch.matmul(jacobian_diff, kp_source['jacobian'])
247
+ return kp_new
248
+
249
+ def load_face3d_net(ckpt_path, device):
250
+ net_recon = networks.define_net_recon(net_recon='resnet50', use_last_fc=False, init_path='').to(device)
251
+ checkpoint = torch.load(ckpt_path, map_location=device)
252
+ net_recon.load_state_dict(checkpoint['net_recon'])
253
+ net_recon.eval()
254
  return net_recon