Update modules/sadtalker_test.py
Browse files- modules/sadtalker_test.py +10 -5
modules/sadtalker_test.py
CHANGED
@@ -18,7 +18,7 @@ class SadTalker():
|
|
18 |
device = "cuda"
|
19 |
else:
|
20 |
device = "cpu"
|
21 |
-
|
22 |
current_code_path = sys.argv[0]
|
23 |
modules_path = os.path.split(current_code_path)[0]
|
24 |
|
@@ -53,7 +53,7 @@ class SadTalker():
|
|
53 |
facerender_yaml_path, device)
|
54 |
self.device = device
|
55 |
|
56 |
-
def test(self, source_image, driven_audio, result_dir):
|
57 |
|
58 |
time_tag = strftime("%Y_%m_%d_%H.%M.%S")
|
59 |
save_dir = os.path.join(result_dir, time_tag)
|
@@ -87,9 +87,14 @@ class SadTalker():
|
|
87 |
coeff_path = self.audio_to_coeff.generate(batch, save_dir, pose_style)
|
88 |
#coeff2video
|
89 |
batch_size = 4
|
90 |
-
data = get_facerender_data(coeff_path, crop_pic_path, first_coeff_path, audio_path, batch_size)
|
91 |
-
self.animate_from_coeff.generate(data, save_dir)
|
92 |
video_name = data['video_name']
|
93 |
print(f'The generated video is named {video_name} in {save_dir}')
|
94 |
-
|
|
|
|
|
|
|
|
|
|
|
95 |
|
|
|
18 |
device = "cuda"
|
19 |
else:
|
20 |
device = "cpu"
|
21 |
+
|
22 |
current_code_path = sys.argv[0]
|
23 |
modules_path = os.path.split(current_code_path)[0]
|
24 |
|
|
|
53 |
facerender_yaml_path, device)
|
54 |
self.device = device
|
55 |
|
56 |
+
def test(self, source_image, driven_audio, still_mode, use_enhancer, result_dir):
|
57 |
|
58 |
time_tag = strftime("%Y_%m_%d_%H.%M.%S")
|
59 |
save_dir = os.path.join(result_dir, time_tag)
|
|
|
87 |
coeff_path = self.audio_to_coeff.generate(batch, save_dir, pose_style)
|
88 |
#coeff2video
|
89 |
batch_size = 4
|
90 |
+
data = get_facerender_data(coeff_path, crop_pic_path, first_coeff_path, audio_path, batch_size, still_mode=still_mode)
|
91 |
+
self.animate_from_coeff.generate(data, save_dir, enhancer='gfpgan' if use_enhancer else None)
|
92 |
video_name = data['video_name']
|
93 |
print(f'The generated video is named {video_name} in {save_dir}')
|
94 |
+
|
95 |
+
if use_enhancer:
|
96 |
+
return os.path.join(save_dir, video_name+'_enhanced.mp4'), os.path.join(save_dir, video_name+'_enhanced.mp4')
|
97 |
+
|
98 |
+
else:
|
99 |
+
return os.path.join(save_dir, video_name+'.mp4'), os.path.join(save_dir, video_name+'.mp4')
|
100 |
|