Spaces:
Running
on
A10G
Running
on
A10G
Update app.py
Browse files
app.py
CHANGED
@@ -1,387 +1,118 @@
|
|
1 |
-
import
|
2 |
-
import
|
3 |
-
import
|
4 |
-
|
5 |
-
from
|
6 |
-
import
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
nn.init.constant_(self.mapping1.bias, 0.)
|
31 |
-
self.use_ref = use_ref
|
32 |
-
|
33 |
-
def forward(self, x, ref, use_tanh=False):
|
34 |
-
x = self.audio_encoder.forward_feature(x).view(x.size(0), -1)
|
35 |
-
ref_reshape = ref.reshape(x.size(0), -1) #20, -1
|
36 |
-
|
37 |
-
y = self.mapping1(torch.cat([x, ref_reshape], dim=1))
|
38 |
-
|
39 |
-
if self.use_ref:
|
40 |
-
out = y.reshape(ref.shape[0], ref.shape[1], -1) + ref # resudial
|
41 |
-
else:
|
42 |
-
out = y.reshape(ref.shape[0], ref.shape[1], -1)
|
43 |
-
|
44 |
-
if use_tanh:
|
45 |
-
out[:, :50] = torch.tanh(out[:, :50]) * 3
|
46 |
-
|
47 |
-
return out
|
48 |
-
|
49 |
-
class Audio2Mesh(object):
|
50 |
-
def __init__(self, args) -> None:
|
51 |
-
self.args = args
|
52 |
-
|
53 |
-
spectre_cfg.model.use_tex = True
|
54 |
-
spectre_cfg.model.mask_type = args.mask_type
|
55 |
-
spectre_cfg.debug = self.args.debug
|
56 |
-
spectre_cfg.model.netA_sync = 'ressesync'
|
57 |
-
spectre_cfg.model.gpu_ids = [0]
|
58 |
-
|
59 |
-
self.spectre = SPECTRE(spectre_cfg)
|
60 |
-
self.spectre.eval()
|
61 |
-
self.face_tracker = None #FaceTrackerV2() # face landmark detection
|
62 |
-
self.mel_step_size = 16
|
63 |
-
self.fps = args.fps
|
64 |
-
self.Nw = args.tframes
|
65 |
-
self.device = self.args.device
|
66 |
-
self.image_size = self.args.image_size
|
67 |
-
|
68 |
-
### only audio
|
69 |
-
args.netA_sync = 'ressesync'
|
70 |
-
args.gpu_ids = [0]
|
71 |
-
args.exp_dim = 53
|
72 |
-
args.use_tanh = False
|
73 |
-
args.K = 20
|
74 |
-
|
75 |
-
self.audio2exp = 'pcavs'
|
76 |
-
|
77 |
-
#
|
78 |
-
self.avmodel = SimpleWrapperV2(args, exp_dim=args.exp_dim).cuda()
|
79 |
-
self.avmodel.load_state_dict(torch.load('../packages/pretrained/audio2expression_v2_model.tar')['opt'])
|
80 |
-
|
81 |
-
# 5, 160 = 25fps
|
82 |
-
self.audio = AudioConfig(frame_rate=args.fps, num_frames_per_clip=5, hop_size=160)
|
83 |
-
|
84 |
-
with open(os.path.join(args.source_dir, 'deca_infos.pkl'), 'rb') as f: # ?
|
85 |
-
self.fitting_coeffs = pickle.load(f, encoding='bytes')
|
86 |
-
|
87 |
-
self.coeffs_dict = { key: torch.Tensor(self.fitting_coeffs[key]).cuda().squeeze(1) for key in ['cam', 'pose', 'light', 'tex', 'shape', 'exp']}
|
88 |
-
|
89 |
-
#### find the close month
|
90 |
-
exp_tensors = torch.sum(self.coeffs_dict['exp'], dim=1)
|
91 |
-
ssss, sorted_indices = torch.sort(exp_tensors)
|
92 |
-
self.exp_id = sorted_indices[0].item()
|
93 |
-
|
94 |
-
if '.ts' in args.render_path:
|
95 |
-
self.render = torch.jit.load(args.render_path).cuda()
|
96 |
-
self.trt = True
|
97 |
-
else:
|
98 |
-
self.render = define_G(self.Nw*6, 3, args.ngf, args.netR).eval().cuda()
|
99 |
-
self.render.load_state_dict(torch.load(args.render_path))
|
100 |
-
self.trt = False
|
101 |
-
|
102 |
-
print('loaded cached images...')
|
103 |
-
|
104 |
-
@torch.no_grad()
|
105 |
-
def cg2real(self, rendedimages, start_frame=0):
|
106 |
-
|
107 |
-
## load original image and the mask
|
108 |
-
self.source_images = np.concatenate(load_image_from_dir(os.path.join(self.args.source_dir, 'original_frame'),\
|
109 |
-
resize=self.image_size, limit=len(rendedimages)+start_frame))[start_frame:]
|
110 |
-
self.source_masks = np.concatenate(load_image_from_dir(os.path.join(self.args.source_dir, 'original_mask'),\
|
111 |
-
resize=self.image_size, limit=len(rendedimages)+start_frame))[start_frame:]
|
112 |
-
|
113 |
-
self.source_masks = torch.FloatTensor(np.transpose(self.source_masks,(0,3,1,2))/255.)
|
114 |
-
self.padded_real_tensor = torch.FloatTensor(np.transpose(self.source_images,(0,3,1,2))/255.)
|
115 |
-
|
116 |
-
## padding the rended_imgs
|
117 |
-
paded_tensor = torch.cat([rendedimages[0:1]]* (self.Nw // 2) + [rendedimages] + [rendedimages[-1:]]* (self.Nw // 2)).contiguous()
|
118 |
-
paded_mask_tensor = torch.cat([self.source_masks[0:1]]* (self.Nw // 2) + [self.source_masks] + [self.source_masks[-1:]]* (self.Nw // 2)).contiguous()
|
119 |
-
paded_real_tensor = torch.cat([self.padded_real_tensor[0:1]]* (self.Nw // 2) + [self.padded_real_tensor] + [self.padded_real_tensor[-1:]]* (self.Nw // 2)).contiguous()
|
120 |
-
|
121 |
-
# paded_mask_tensor = maskErosion(paded_mask_tensor, offY=self.args.mask)
|
122 |
-
padded_input = ((paded_real_tensor-0.5)*2 ) # *(1-paded_mask_tensor)
|
123 |
-
padded_input = torch.nn.functional.interpolate(padded_input, (self.image_size, self.image_size), mode='bilinear', align_corners=False)
|
124 |
-
paded_tensor = torch.nn.functional.interpolate(paded_tensor, (self.image_size, self.image_size), mode='bilinear', align_corners=False)
|
125 |
-
paded_tensor = (paded_tensor-0.5)*2
|
126 |
-
|
127 |
-
result = []
|
128 |
-
for index in tqdm(range(0, len(rendedimages), self.args.renderbs), desc='CG2REAL:'):
|
129 |
-
list_A = []
|
130 |
-
list_R = []
|
131 |
-
list_M = []
|
132 |
-
for i in range(self.args.renderbs):
|
133 |
-
idx = index + i
|
134 |
-
if idx+self.Nw > len(padded_input):
|
135 |
-
list_A.append(torch.zeros(self.Nw*3,self.image_size,self.image_size).unsqueeze(0))
|
136 |
-
list_R.append(torch.zeros(self.Nw*3,self.image_size,self.image_size).unsqueeze(0))
|
137 |
-
list_M.append(torch.zeros(self.Nw*3,self.image_size,self.image_size).unsqueeze(0))
|
138 |
-
else:
|
139 |
-
list_A.append(padded_input[idx:idx+self.Nw].view(-1, self.image_size, self.image_size).unsqueeze(0))
|
140 |
-
list_R.append(paded_tensor[idx:idx+self.Nw].view(-1, self.image_size, self.image_size).unsqueeze(0))
|
141 |
-
list_M.append(paded_mask_tensor[idx:idx+self.Nw].view(-1, self.image_size, self.image_size).unsqueeze(0))
|
142 |
-
|
143 |
-
list_A = torch.cat(list_A)
|
144 |
-
list_R = torch.cat(list_R)
|
145 |
-
list_M = torch.cat(list_M)
|
146 |
-
|
147 |
-
idx = (self.Nw//2) * 3
|
148 |
-
mask = list_M[:, idx:idx+3]
|
149 |
-
|
150 |
-
# list_A = padded_input
|
151 |
-
mask = maskErosion(mask, offY=self.args.mask)
|
152 |
-
list_A = list_A * (1 - mask[:,0:1])
|
153 |
-
A = torch.cat([list_A, list_R], 1)
|
154 |
-
|
155 |
-
if self.trt:
|
156 |
-
B = self.render(A.half().cuda())
|
157 |
-
elif self.args.netR == 'unet_256':
|
158 |
-
# import pdb; pdb.set_trace()
|
159 |
-
idx = (self.Nw//2) * 3
|
160 |
-
mask = list_M[:, idx:idx+3].cuda()
|
161 |
-
mask = maskErosion(mask, offY=self.args.mask)
|
162 |
-
B0 = list_A[:, idx:idx+3].cuda()
|
163 |
-
B = self.render(A.cuda()) * mask[:,0:1] + (1 - mask[:,0:1]) * B0
|
164 |
-
elif self.args.netR == 's2am':
|
165 |
-
# import pdb; pdb.set_trace()
|
166 |
-
idx = (self.Nw//2) * 3
|
167 |
-
mask = list_M[:, idx:idx+3].cuda()
|
168 |
-
mask = maskErosion(mask, offY=self.args.mask)
|
169 |
-
B0 = list_A[:, idx:idx+3].cuda()
|
170 |
-
B = self.render(A.cuda(), mask[:,0:1] ) * mask[:,0:1] + (1 - mask[:,0:1]) * B0
|
171 |
-
else:
|
172 |
-
B = self.render(A.cuda())
|
173 |
-
|
174 |
-
result.append((B.cpu() + 1) * 0.5) # -1,1 -> 0,1
|
175 |
-
|
176 |
-
return torch.cat(result)[:len(rendedimages)]
|
177 |
-
|
178 |
-
@torch.no_grad()
|
179 |
-
def coeffs_to_img(self, vertices, coeffs, zero_pose=False, XK = 20):
|
180 |
-
|
181 |
-
xlen = vertices.shape[0]
|
182 |
-
all_shape_images = []
|
183 |
-
landmark2d = []
|
184 |
-
|
185 |
-
#### find the most larger pose 51 in the coeffs.
|
186 |
-
max_pose_51 = torch.max(self.coeffs_dict['pose'][..., 3:4].squeeze(-1))
|
187 |
-
|
188 |
-
for i in tqdm(range(0, xlen, XK)):
|
189 |
-
|
190 |
-
if i + XK > xlen:
|
191 |
-
XK = xlen - i
|
192 |
-
|
193 |
-
codedictdecoder = {}
|
194 |
-
codedictdecoder['shape'] = torch.zeros_like(self.coeffs_dict['shape'][i:i+XK].cuda())
|
195 |
-
codedictdecoder['tex'] = self.coeffs_dict['tex'][i:i+XK].cuda()
|
196 |
-
codedictdecoder['exp'] = torch.zeros_like(self.coeffs_dict['exp'][i:i+XK].cuda()) # all_exps[i:i+XK, :50].cuda() # # # vid_exps[i:i+1].cuda() i:i+XK
|
197 |
-
codedictdecoder['pose'] = self.coeffs_dict['pose'][i:i+XK] # vid_poses[i:i+1].cuda()
|
198 |
-
codedictdecoder['cam'] = self.coeffs_dict['cam'][i:i+XK].cuda() # vid_poses[i:i+1].cuda()
|
199 |
-
codedictdecoder['light'] = self.coeffs_dict['light'][i:i+XK].cuda() # vid_poses[i:i+1].cuda()
|
200 |
-
codedictdecoder['images'] = torch.zeros((XK,3,256,256)).cuda()
|
201 |
-
|
202 |
-
codedictdecoder['pose'][..., 3:4] = torch.clip(coeffs[i:i+XK, 50:51], 0, max_pose_51*0.9) # torch.zeros_like(self.coeffs_dict['pose'][i:i+XK, 3:])
|
203 |
-
codedictdecoder['pose'][..., 4:6] = 0 # coeffs[i:i+XK, 50:]*( - 0.25) # torch.zeros_like(self.coeffs_dict['pose'][i:i+XK, 3:])
|
204 |
-
|
205 |
-
sub_vertices = vertices[i:i+XK].cuda()
|
206 |
-
|
207 |
-
opdict = self.spectre.decode_verts(codedictdecoder, sub_vertices, rendering=True, vis_lmk=False, return_vis=False)
|
208 |
-
|
209 |
-
landmark2d.append(opdict['landmarks2d'].cpu())
|
210 |
-
|
211 |
-
all_shape_images.append(opdict['rendered_images'].cpu())
|
212 |
-
|
213 |
-
rendedimages = torch.cat(all_shape_images)
|
214 |
-
|
215 |
-
lmk2d = torch.cat(landmark2d)
|
216 |
-
|
217 |
-
return rendedimages, lmk2d
|
218 |
-
|
219 |
-
|
220 |
-
@torch.no_grad()
|
221 |
-
def run_spectre_v3(self, wav=None, ds_features=None, L=20):
|
222 |
-
|
223 |
-
wav = audio_normalize(wav)
|
224 |
-
all_mel = self.audio.melspectrogram(wav).astype(np.float32).T
|
225 |
-
frames_from_audio = np.arange(2, len(all_mel) // self.audio.num_bins_per_frame - 2) # 2,[]mmmmmmmmmmmmmmmmmmmmmmmmmmmm
|
226 |
-
audio_inds = frame2audio_indexs(frames_from_audio, self.audio.num_frames_per_clip, self.audio.num_bins_per_frame)
|
227 |
-
|
228 |
-
vid_exps = self.coeffs_dict['exp'][self.exp_id:self.exp_id+1]
|
229 |
-
vid_poses = self.coeffs_dict['pose'][self.exp_id:self.exp_id+1]
|
230 |
|
231 |
-
|
232 |
-
|
233 |
-
|
234 |
-
|
235 |
-
|
236 |
-
|
237 |
-
|
238 |
-
|
239 |
-
|
240 |
-
|
241 |
-
|
242 |
-
|
243 |
-
|
244 |
-
|
245 |
-
|
246 |
-
|
247 |
-
|
248 |
-
|
249 |
-
|
250 |
-
|
251 |
-
|
252 |
-
|
253 |
-
|
254 |
-
|
255 |
-
|
256 |
-
|
257 |
-
|
258 |
-
|
259 |
-
|
260 |
-
|
261 |
-
|
262 |
-
|
263 |
-
|
264 |
-
|
265 |
-
|
266 |
-
|
267 |
-
|
268 |
-
|
269 |
-
|
270 |
-
|
271 |
-
|
272 |
-
|
273 |
-
|
274 |
-
|
275 |
-
|
276 |
-
|
277 |
-
|
278 |
-
|
279 |
-
|
280 |
-
|
281 |
-
|
282 |
-
|
283 |
-
|
284 |
-
|
285 |
-
|
286 |
-
|
287 |
-
|
288 |
-
|
289 |
-
|
290 |
-
|
291 |
-
|
292 |
-
|
293 |
-
|
294 |
-
|
295 |
-
|
296 |
-
|
297 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
298 |
|
299 |
-
prediction = model.predict(audio_feature, template, one_hot, 1.0) # (1, seq_len, V*3)
|
300 |
-
|
301 |
-
return prediction.squeeze()
|
302 |
-
|
303 |
-
@torch.no_grad()
|
304 |
-
def run(self, face, audio, start_frame=0):
|
305 |
-
|
306 |
-
wav, sr = librosa.load(audio, sr=16000) # 16*80 ? 20*80
|
307 |
-
wav_tensor = torch.FloatTensor(wav).unsqueeze(0) if len(wav.shape) == 1 else torch.FloatTensor(wav)
|
308 |
-
_, frames = parse_audio_length(wav_tensor.shape[1], 16000, self.args.fps)
|
309 |
-
|
310 |
-
##### audio-guided, only use the jaw movement
|
311 |
-
all_exps = self.run_spectre_v3(wav)
|
312 |
-
|
313 |
-
# #### temp. interpolation
|
314 |
-
all_exps = torch.nn.functional.interpolate(all_exps.unsqueeze(0).permute([0,2,1]), size=frames, mode='linear')
|
315 |
-
all_exps = all_exps.permute([0,2,1]).squeeze(0)
|
316 |
-
|
317 |
-
# run faceformer for face mesh generation
|
318 |
-
predicted_vertices = self.test_model(audio)
|
319 |
-
predicted_vertices = predicted_vertices.view(-1, 5023*3)
|
320 |
-
|
321 |
-
#### temp. interpolation
|
322 |
-
predicted_vertices = torch.nn.functional.interpolate(predicted_vertices.unsqueeze(0).permute([0,2,1]), size=frames, mode='linear')
|
323 |
-
predicted_vertices = predicted_vertices.permute([0,2,1]).squeeze(0).view(-1, 5023, 3)
|
324 |
-
|
325 |
-
all_exps = torch.Tensor(savgol_filter(all_exps.cpu().numpy(), 5, 3, axis=0)).cpu() # smooth GT
|
326 |
-
|
327 |
-
rendedimages, lm2d = self.coeffs_to_img(predicted_vertices, all_exps, zero_pose=True)
|
328 |
-
debug_video_gen(rendedimages, self.args.result_dir+"/debug_before_ff.mp4", wav_tensor, self.args.fps, sr)
|
329 |
-
|
330 |
-
# cg2real
|
331 |
-
debug_video_gen(self.cg2real(rendedimages, start_frame=start_frame), self.args.result_dir+"/debug_cg2real_raw.mp4", wav_tensor, self.args.fps, sr)
|
332 |
-
|
333 |
-
exit()
|
334 |
-
|
335 |
-
|
336 |
-
|
337 |
-
if __name__ == '__main__':
|
338 |
-
parser = argparse.ArgumentParser(description='Stylization and Seamless Video Dubbing')
|
339 |
-
parser.add_argument('--face', default='examples', type=str, help='')
|
340 |
-
parser.add_argument('--audio', default='examples', type=str, help='')
|
341 |
-
parser.add_argument('--source_dir', default='examples', type=str,help='TODO')
|
342 |
-
parser.add_argument('--result_dir', default='examples', type=str,help='TODO')
|
343 |
-
parser.add_argument('--backend', default='wav2lip', type=str,help='wav2lip or pcavs')
|
344 |
-
parser.add_argument('--result_tag', default='result', type=str,help='TODO')
|
345 |
-
parser.add_argument('--netR', default='unet_256', type=str,help='TODO')
|
346 |
-
parser.add_argument('--render_path', default='', type=str,help='TODO')
|
347 |
-
parser.add_argument('--ngf', default=16, type=int,help='TODO')
|
348 |
-
parser.add_argument('--fps', default=20, type=int,help='TODO')
|
349 |
-
parser.add_argument('--mask', default=100, type=int,help='TODO')
|
350 |
-
parser.add_argument('--mask_type', default='v3', type=str,help='TODO')
|
351 |
-
parser.add_argument('--image_size', default=256, type=int,help='TODO')
|
352 |
-
parser.add_argument('--input_nc', default=21, type=int,help='TODO')
|
353 |
-
parser.add_argument('--output_nc', default=3, type=int,help='TODO')
|
354 |
-
parser.add_argument('--renderbs', default=16, type=int,help='TODO')
|
355 |
-
parser.add_argument('--tframes', default=1, type=int,help='TODO')
|
356 |
-
parser.add_argument('--debug', action='store_true')
|
357 |
-
parser.add_argument('--enhance', action='store_true')
|
358 |
-
parser.add_argument('--phone', action='store_true')
|
359 |
-
|
360 |
-
#### faceformer
|
361 |
-
parser.add_argument("--model_name", type=str, default="VOCA")
|
362 |
-
parser.add_argument("--dataset", type=str, default="vocaset", help='vocaset or BIWI')
|
363 |
-
parser.add_argument("--feature_dim", type=int, default=64, help='64 for vocaset; 128 for BIWI')
|
364 |
-
parser.add_argument("--period", type=int, default=30, help='period in PPE - 30 for vocaset; 25 for BIWI')
|
365 |
-
parser.add_argument("--vertice_dim", type=int, default=5023*3, help='number of vertices - 5023*3 for vocaset; 23370*3 for BIWI')
|
366 |
-
parser.add_argument("--device", type=str, default="cuda")
|
367 |
-
parser.add_argument("--train_subjects", type=str, default="FaceTalk_170728_03272_TA ")
|
368 |
-
parser.add_argument("--test_subjects", type=str, default="FaceTalk_170809_00138_TA FaceTalk_170731_00024_TA")
|
369 |
-
parser.add_argument("--condition", type=str, default="FaceTalk_170904_00128_TA", help='select a conditioning subject from train_subjects')
|
370 |
-
parser.add_argument("--subject", type=str, default="FaceTalk_170731_00024_TA", help='select a subject from test_subjects or train_subjects')
|
371 |
-
parser.add_argument("--background_black", type=bool, default=True, help='whether to use black background')
|
372 |
-
parser.add_argument("--template_path", type=str, default="templates.pkl", help='path of the personalized templates')
|
373 |
-
parser.add_argument("--render_template_path", type=str, default="templates", help='path of the mesh in BIWI/FLAME topology')
|
374 |
|
375 |
-
|
376 |
|
377 |
-
|
378 |
-
|
379 |
-
|
|
|
|
|
380 |
|
381 |
-
a2m = Audio2Mesh(opt)
|
382 |
|
383 |
-
print('link start!')
|
384 |
-
t = time.time()
|
385 |
-
# 02780
|
386 |
-
a2m.run(opt.face, opt.audio, 0)
|
387 |
-
print(time.time() - t)
|
|
|
1 |
+
import os, sys
|
2 |
+
import tempfile
|
3 |
+
import gradio as gr
|
4 |
+
from modules.text2speech import text2speech
|
5 |
+
from modules.gfpgan_inference import gfpgan
|
6 |
+
from modules.sadtalker_test import SadTalker
|
7 |
+
|
8 |
+
def get_driven_audio(audio):
|
9 |
+
if os.path.isfile(audio):
|
10 |
+
return audio
|
11 |
+
else:
|
12 |
+
save_path = tempfile.NamedTemporaryFile(
|
13 |
+
delete=False,
|
14 |
+
suffix=("." + "wav"),
|
15 |
+
)
|
16 |
+
gen_audio = text2speech(audio, save_path.name)
|
17 |
+
return gen_audio, gen_audio
|
18 |
+
|
19 |
+
def get_source_image(image):
|
20 |
+
return image
|
21 |
+
|
22 |
+
def sadtalker_demo(result_dir):
|
23 |
+
|
24 |
+
sad_talker = SadTalker()
|
25 |
+
with gr.Blocks(analytics_enabled=False) as sadtalker_interface:
|
26 |
+
gr.Markdown("<div align='center'> <h2> 😭 SadTalker: Learning Realistic 3D Motion Coefficients for Stylized Audio-Driven Single Image Talking Face Animation (CVPR 2023) </span> </h2> \
|
27 |
+
<a style='font-size:18px;color: #efefef' href='https://arxiv.org/abs/2211.12194'>Arxiv</a> \
|
28 |
+
<a style='font-size:18px;color: #efefef' href='https://sadtalker.github.io'>Homepage</a> \
|
29 |
+
<a style='font-size:18px;color: #efefef' href='https://github.com/Winfredy/SadTalker'> Github </div>")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
30 |
|
31 |
+
with gr.Row().style(equal_height=False):
|
32 |
+
with gr.Column(variant='panel'):
|
33 |
+
with gr.Tabs(elem_id="sadtalker_source_image"):
|
34 |
+
with gr.TabItem('Upload image'):
|
35 |
+
with gr.Row():
|
36 |
+
source_image = gr.Image(label="Source image", source="upload", type="filepath").style(height=256,width=256)
|
37 |
+
|
38 |
+
with gr.Tabs(elem_id="sadtalker_driven_audio"):
|
39 |
+
with gr.TabItem('Upload audio'):
|
40 |
+
with gr.Column(variant='panel'):
|
41 |
+
driven_audio = gr.Audio(label="Input audio", source="upload", type="filepath")
|
42 |
+
# submit_audio_1 = gr.Button('Submit', variant='primary')
|
43 |
+
# submit_audio_1.click(fn=get_driven_audio, inputs=input_audio1, outputs=driven_audio)
|
44 |
+
|
45 |
+
|
46 |
+
with gr.Column(variant='panel'):
|
47 |
+
with gr.Tabs(elem_id="sadtalker_checkbox"):
|
48 |
+
with gr.TabItem('Settings'):
|
49 |
+
with gr.Column(variant='panel'):
|
50 |
+
is_still_mode = gr.Checkbox(label="w/ Still Mode (fewer hand motion)")
|
51 |
+
enhancer = gr.Checkbox(label="w/ GFPGAN as Face enhancer")
|
52 |
+
submit = gr.Button('Generate', elem_id="sadtalker_generate", variant='primary')
|
53 |
+
|
54 |
+
with gr.Tabs(elem_id="sadtalker_genearted"):
|
55 |
+
gen_video = gr.Video(label="Generated video", format="mp4").style(height=256,width=256)
|
56 |
+
gen_text = gr.Textbox(visible=False)
|
57 |
+
|
58 |
+
|
59 |
+
with gr.Row():
|
60 |
+
examples = [
|
61 |
+
[
|
62 |
+
'examples/source_image/art_10.png',
|
63 |
+
'examples/driven_audio/deyu.wav',
|
64 |
+
True,
|
65 |
+
False
|
66 |
+
],
|
67 |
+
[
|
68 |
+
'examples/source_image/art_1.png',
|
69 |
+
'examples/driven_audio/chinese_poem1.wav',
|
70 |
+
True,
|
71 |
+
False
|
72 |
+
],
|
73 |
+
[
|
74 |
+
'examples/source_image/art_13.png',
|
75 |
+
'examples/driven_audio/fayu.wav',
|
76 |
+
True,
|
77 |
+
False
|
78 |
+
],
|
79 |
+
[
|
80 |
+
'examples/source_image/art_5.png',
|
81 |
+
'examples/driven_audio/chinese_news.wav',
|
82 |
+
True,
|
83 |
+
False
|
84 |
+
],
|
85 |
+
]
|
86 |
+
gr.Examples(examples=examples,
|
87 |
+
inputs=[
|
88 |
+
source_image,
|
89 |
+
driven_audio,
|
90 |
+
is_still_mode,
|
91 |
+
enhancer,
|
92 |
+
gr.Textbox(value=result_dir, visible=False)],
|
93 |
+
outputs=[gen_video, gen_text],
|
94 |
+
fn=sad_talker.test,
|
95 |
+
cache_examples=os.getenv('SYSTEM') == 'spaces')
|
96 |
+
|
97 |
+
submit.click(
|
98 |
+
fn=sad_talker.test,
|
99 |
+
inputs=[source_image,
|
100 |
+
driven_audio,
|
101 |
+
is_still_mode,
|
102 |
+
enhancer,
|
103 |
+
gr.Textbox(value=result_dir, visible=False)],
|
104 |
+
outputs=[gen_video, gen_text]
|
105 |
+
)
|
106 |
+
|
107 |
+
return sadtalker_interface
|
108 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
109 |
|
110 |
+
if __name__ == "__main__":
|
111 |
|
112 |
+
current_code_path = sys.argv[0]
|
113 |
+
current_root_dir = os.path.split(current_code_path)[0]
|
114 |
+
sadtalker_result_dir = os.path.join(current_root_dir, 'results', 'sadtalker')
|
115 |
+
demo = sadtalker_demo(sadtalker_result_dir)
|
116 |
+
demo.launch()
|
117 |
|
|
|
118 |
|
|
|
|
|
|
|
|
|
|